Innovenergy_trunk/NodeRed/NodeRedFiles/pika-0.13.1/tests/unit/connection_parameters_tests.py

744 lines
26 KiB
Python

"""
Test `pika.connection.Parameters`, `pika.connection.ConnectionParameters`, and
`pika.connection.URLParameters`
"""
import copy
import unittest
import warnings
import pika
from pika.compat import urlencode, url_quote, dict_iteritems
from pika import channel, connection, credentials, spec
# disable missing-docstring
# pylint: disable=C0111
# disable invalid-name
# pylint: disable=C0103
# Unordered sequence of connection.Parameters's property getters
_ALL_PUBLIC_PARAMETERS_PROPERTIES = tuple(
attr for attr in vars(connection.Parameters) if not attr.startswith('_')
and issubclass(type(getattr(connection.Parameters, attr)), property))
class _ParametersTestsBase(unittest.TestCase):
def setUp(self):
warnings.resetwarnings()
self.addCleanup(warnings.resetwarnings)
def get_default_properties(self):
"""
:returns: a dict of expected public property names and default values
for `pika.connection.Parameters`
"""
kls = connection.Parameters
defaults = {
'backpressure_detection':
kls.DEFAULT_BACKPRESSURE_DETECTION,
'blocked_connection_timeout':
kls.DEFAULT_BLOCKED_CONNECTION_TIMEOUT,
'channel_max':
kls.DEFAULT_CHANNEL_MAX,
'client_properties':
kls.DEFAULT_CLIENT_PROPERTIES,
'connection_attempts':
kls.DEFAULT_CONNECTION_ATTEMPTS,
'credentials':
credentials.PlainCredentials(kls.DEFAULT_USERNAME,
kls.DEFAULT_PASSWORD),
'frame_max':
kls.DEFAULT_FRAME_MAX,
'heartbeat':
kls.DEFAULT_HEARTBEAT_TIMEOUT,
'host':
kls.DEFAULT_HOST,
'locale':
kls.DEFAULT_LOCALE,
'port':
kls.DEFAULT_PORT,
'retry_delay':
kls.DEFAULT_RETRY_DELAY,
'socket_timeout':
kls.DEFAULT_SOCKET_TIMEOUT,
'ssl':
kls.DEFAULT_SSL,
'ssl_options':
kls.DEFAULT_SSL_OPTIONS,
'virtual_host':
kls.DEFAULT_VIRTUAL_HOST,
'tcp_options':
kls.DEFAULT_TCP_OPTIONS
}
# Make sure we didn't miss anything
self.assertSequenceEqual(
sorted(defaults), sorted(_ALL_PUBLIC_PARAMETERS_PROPERTIES))
return defaults
def assert_default_parameter_values(self, params):
"""Assert that the given parameters object has the default parameter
values.
:param params: verify that the given params instance has all default
property values
:type params: one of the classes based on `pika.connection.Parameters`
"""
for name, expected_value in dict_iteritems(
self.get_default_properties()):
value = getattr(params, name)
self.assertEqual(
value,
expected_value,
msg='Expected %s=%r, but got %r' % (name, expected_value,
value))
class ParametersTests(_ParametersTestsBase):
"""Test `pika.connection.Parameters`"""
def test_default_property_values(self):
self.assert_default_parameter_values(connection.Parameters())
def test_backpressure_detection(self):
params = connection.Parameters()
params.backpressure_detection = False
self.assertEqual(params.backpressure_detection, False)
params.backpressure_detection = True
self.assertEqual(params.backpressure_detection, True)
with self.assertRaises(TypeError):
params.backpressure_detection = 1
with self.assertRaises(TypeError):
params.backpressure_detection = 'true'
with self.assertRaises(TypeError):
params.backpressure_detection = 'f'
def test_blocked_connection_timeout(self):
params = connection.Parameters()
params.blocked_connection_timeout = 60
self.assertEqual(params.blocked_connection_timeout, 60)
params.blocked_connection_timeout = 93.5
self.assertEqual(params.blocked_connection_timeout, 93.5)
params.blocked_connection_timeout = None
self.assertIsNone(params.blocked_connection_timeout)
with self.assertRaises(TypeError):
params.blocked_connection_timeout = '1.5'
with self.assertRaises(ValueError):
params.blocked_connection_timeout = -40
def test_channel_max(self):
params = connection.Parameters()
params.channel_max = 1
self.assertEqual(params.channel_max, 1)
params.channel_max = channel.MAX_CHANNELS
self.assertEqual(params.channel_max, channel.MAX_CHANNELS)
with self.assertRaises(TypeError):
params.channel_max = 1.5
with self.assertRaises(TypeError):
params.channel_max = '99'
with self.assertRaises(ValueError):
params.channel_max = 0
with self.assertRaises(ValueError):
params.channel_max = -40
with self.assertRaises(ValueError):
params.channel_max = channel.MAX_CHANNELS + 1
def test_connection_attempts(self):
params = connection.Parameters()
params.connection_attempts = 1
self.assertEqual(params.connection_attempts, 1)
params.connection_attempts = 10
self.assertEqual(params.connection_attempts, 10)
with self.assertRaises(TypeError):
params.connection_attempts = 1.5
with self.assertRaises(TypeError):
params.connection_attempts = '99'
with self.assertRaises(ValueError):
params.connection_attempts = 0
with self.assertRaises(ValueError):
params.connection_attempts = -40
def test_credentials(self):
params = connection.Parameters()
plain_cred = credentials.PlainCredentials('very', 'reliable')
params.credentials = plain_cred
self.assertEqual(params.credentials, plain_cred)
ext_cred = credentials.ExternalCredentials()
params.credentials = ext_cred
self.assertEqual(params.credentials, ext_cred)
with self.assertRaises(TypeError):
params.credentials = connection.Parameters()
with self.assertRaises(TypeError):
params.credentials = repr(plain_cred)
def test_frame_max(self):
params = connection.Parameters()
params.frame_max = spec.FRAME_MIN_SIZE
self.assertEqual(params.frame_max, spec.FRAME_MIN_SIZE)
params.frame_max = spec.FRAME_MIN_SIZE + 1
self.assertEqual(params.frame_max, spec.FRAME_MIN_SIZE + 1)
params.frame_max = spec.FRAME_MAX_SIZE
self.assertEqual(params.frame_max, spec.FRAME_MAX_SIZE)
params.frame_max = spec.FRAME_MAX_SIZE - 1
self.assertEqual(params.frame_max, spec.FRAME_MAX_SIZE - 1)
with self.assertRaises(TypeError):
params.frame_max = 10000.9
with self.assertRaises(TypeError):
params.frame_max = '10000'
with self.assertRaises(ValueError):
params.frame_max = spec.FRAME_MIN_SIZE - 1
with self.assertRaises(ValueError):
params.frame_max = spec.FRAME_MAX_SIZE + 1
def test_heartbeat(self):
params = connection.Parameters()
params.heartbeat = None
self.assertIsNone(params.heartbeat)
params.heartbeat = 0
self.assertEqual(params.heartbeat, 0)
params.heartbeat = 600
self.assertEqual(params.heartbeat, 600)
with self.assertRaises(TypeError):
params.heartbeat = 1.5
with self.assertRaises(TypeError):
params.heartbeat = '99'
with self.assertRaises(ValueError):
params.heartbeat = -1
def heartbeat_callback(_conn, _val):
return 1
params.heartbeat = heartbeat_callback
self.assertTrue(callable(params.heartbeat))
self.assertIs(params.heartbeat, heartbeat_callback)
def test_host(self):
params = connection.Parameters()
params.host = '127.0.0.1'
self.assertEqual(params.host, '127.0.0.1')
params.host = 'my.server.com'
self.assertEqual(params.host, 'my.server.com')
params.host = u'my.server.com'
self.assertEqual(params.host, u'my.server.com')
with self.assertRaises(TypeError):
params.host = 127
with self.assertRaises(TypeError):
params.host = ('127.0.0.1', 5672)
def test_locale(self):
params = connection.Parameters()
params.locale = 'en_UK'
self.assertEqual(params.locale, 'en_UK')
params.locale = u'en_UK'
self.assertEqual(params.locale, u'en_UK')
with self.assertRaises(TypeError):
params.locale = 127
def test_port(self):
params = connection.Parameters()
params.port = 0
self.assertEqual(params.port, 0)
params.port = 5672
self.assertEqual(params.port, 5672)
params.port = '5672'
self.assertEqual(params.port, 5672)
params.port = 1.5
self.assertEqual(params.port, 1)
with self.assertRaises(TypeError):
params.port = '50a'
with self.assertRaises(TypeError):
params.port = []
def test_retry_delay(self):
params = connection.Parameters()
params.retry_delay = 0
self.assertEqual(params.retry_delay, 0)
params.retry_delay = 0.1
self.assertEqual(params.retry_delay, 0.1)
with self.assertRaises(TypeError):
params.retry_delay = '0.1'
def test_socket_timeout(self):
params = connection.Parameters()
params.socket_timeout = 1
self.assertEqual(params.socket_timeout, 1)
params.socket_timeout = 0.5
self.assertEqual(params.socket_timeout, 0.5)
params.socket_timeout = 60.5
self.assertEqual(params.socket_timeout, 60.5)
params.socket_timeout = None
self.assertEqual(params.socket_timeout, None)
with self.assertRaises(TypeError):
params.socket_timeout = '60.5'
with self.assertRaises(ValueError):
params.socket_timeout = -1
with self.assertRaises(ValueError):
params.socket_timeout = 0
def test_ssl(self):
params = connection.Parameters()
params.ssl = False
self.assertEqual(params.ssl, False)
params.ssl = True
self.assertEqual(params.ssl, True)
with self.assertRaises(TypeError):
params.backpressure_detection = 1
with self.assertRaises(TypeError):
params.ssl = 'True'
with self.assertRaises(TypeError):
params.ssl = 'f'
def test_ssl_options(self):
params = connection.Parameters()
opt = dict(key='value', key2=2, key3=dict(a=1))
params.ssl_options = copy.deepcopy(opt)
self.assertEqual(params.ssl_options, opt)
params.ssl_options = None
self.assertIsNone(params.ssl_options)
with self.assertRaises(TypeError):
params.ssl_options = str(opt)
def test_virtual_host(self):
params = connection.Parameters()
params.virtual_host = '/'
self.assertEqual(params.virtual_host, '/')
params.virtual_host = u'/'
self.assertEqual(params.virtual_host, u'/')
params.virtual_host = 'test-vhost'
self.assertEqual(params.virtual_host, 'test-vhost')
with self.assertRaises(TypeError):
params.virtual_host = 99
def test_tcp_options(self):
params = connection.Parameters()
opt = dict(
TCP_KEEPIDLE=60,
TCP_KEEPINTVL=2,
TCP_KEEPCNT=1,
TCP_USER_TIMEOUT=1000)
params.tcp_options = copy.deepcopy(opt)
self.assertEqual(params.tcp_options, opt)
params.tcp_options = None
self.assertIsNone(params.tcp_options)
with self.assertRaises(TypeError):
params.tcp_options = str(opt)
class ConnectionParametersTests(_ParametersTestsBase):
def test_default_property_values(self):
self.assert_default_parameter_values(connection.ConnectionParameters())
def test_explicit_ssl_with_default_port(self):
params = connection.ConnectionParameters(ssl=True)
self.assertEqual(params.ssl, True)
self.assertEqual(params.port, params.DEFAULT_SSL_PORT)
def test_explicit_ssl_with_explict_port(self):
params = connection.ConnectionParameters(ssl=True, port=99)
self.assertEqual(params.ssl, True)
self.assertEqual(params.port, 99)
def test_explicit_non_ssl_with_default_port(self):
params = connection.ConnectionParameters(ssl=False)
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, params.DEFAULT_PORT)
def test_explicit_non_ssl_with_explict_port(self):
params = connection.ConnectionParameters(ssl=False, port=100)
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, 100)
def test_good_connection_parameters(self):
"""make sure connection kwargs get set correctly"""
kwargs = {
'backpressure_detection': False,
'blocked_connection_timeout': 10.5,
'channel_max': 3,
'client_properties': {
'good': 'day'
},
'connection_attempts': 2,
'credentials': credentials.PlainCredentials('very', 'secure'),
'frame_max': 40000,
'heartbeat': 7,
'host': 'https://www.test.com',
'locale': 'en',
'port': 5678,
'retry_delay': 3,
'socket_timeout': 100.5,
'ssl': True,
'ssl_options': {
'ssl': 'options'
},
'virtual_host': u'vvhost',
'tcp_options': {
'TCP_USER_TIMEOUT': 1000
}
}
params = connection.ConnectionParameters(**kwargs)
# Verify
expected_values = copy.copy(kwargs)
# Make sure we're testing all public properties
self.assertSequenceEqual(
sorted(expected_values), sorted(_ALL_PUBLIC_PARAMETERS_PROPERTIES))
# Check property values
for t_param in expected_values:
value = getattr(params, t_param)
self.assertEqual(
expected_values[t_param],
value,
msg='Expected %s=%r, but got %r' %
(t_param, expected_values[t_param], value))
def test_deprecated_heartbeat_interval(self):
with warnings.catch_warnings(record=True) as warnings_list:
warnings.simplefilter('always')
params = connection.ConnectionParameters(heartbeat_interval=999)
self.assertEqual(params.heartbeat, 999)
# Check that a warning was generated
self.assertEqual(len(warnings_list), 1)
self.assertIs(warnings_list[0].category, DeprecationWarning)
def test_callable_heartbeat(self):
def heartbeat_callback(_connection, _broker_val):
return 1
parameters = pika.ConnectionParameters(heartbeat=heartbeat_callback)
self.assertIs(parameters.heartbeat, heartbeat_callback)
def test_bad_type_connection_parameters(self):
"""test connection kwargs type checks throw errors for bad input"""
kwargs = {
'host': 'https://www.test.com',
'port': 5678,
'virtual_host': 'vvhost',
'channel_max': 3,
'frame_max': 40000,
'heartbeat': 7,
'backpressure_detection': False,
'ssl': True,
'blocked_connection_timeout': 10.5
}
# Test Type Errors
for bad_field, bad_value in (('host', 15672), ('port', '5672a'),
('virtual_host',
True), ('channel_max',
'4'), ('frame_max',
'5'), ('credentials',
'bad'), ('locale',
1),
('heartbeat',
'6'), ('socket_timeout',
'42'), ('retry_delay', 'two'),
('backpressure_detection',
'true'), ('ssl', {
'ssl': 'dict'
}), ('ssl_options', True),
('connection_attempts',
'hello'), ('blocked_connection_timeout',
set())):
bkwargs = copy.deepcopy(kwargs)
bkwargs[bad_field] = bad_value
self.assertRaises(TypeError, connection.ConnectionParameters,
**bkwargs)
def test_parameters_accepts_plain_string_virtualhost(self):
parameters = pika.ConnectionParameters(virtual_host='prtfqpeo')
self.assertEqual(parameters.virtual_host, 'prtfqpeo')
def test_parameters_accepts_unicode_string_virtualhost(self):
parameters = pika.ConnectionParameters(virtual_host=u'prtfqpeo')
self.assertEqual(parameters.virtual_host, 'prtfqpeo')
def test_parameters_accept_plain_string_locale(self):
parameters = pika.ConnectionParameters(locale='en_US')
self.assertEqual(parameters.locale, 'en_US')
def test_parameters_accept_unicode_locale(self):
parameters = pika.ConnectionParameters(locale=u'en_US')
self.assertEqual(parameters.locale, 'en_US')
class URLParametersTests(_ParametersTestsBase):
def test_default_property_values(self):
params = connection.URLParameters('')
self.assert_default_parameter_values(params)
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, params.DEFAULT_PORT)
def test_no_ssl(self):
params = connection.URLParameters('http://')
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, params.DEFAULT_PORT)
self.assert_default_parameter_values(params)
params = connection.URLParameters('amqp://')
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, params.DEFAULT_PORT)
self.assert_default_parameter_values(params)
def test_ssl(self):
params = connection.URLParameters('https://')
self.assertEqual(params.ssl, True)
self.assertEqual(params.port, params.DEFAULT_SSL_PORT)
params = connection.URLParameters('amqps://')
self.assertEqual(params.ssl, True)
self.assertEqual(params.port, params.DEFAULT_SSL_PORT)
# Make sure the other parameters unrelated to SSL have default values
params = connection.URLParameters('amqps://')
params.ssl = False
params.port = params.DEFAULT_PORT
self.assert_default_parameter_values(params)
def test_no_url_scheme_defaults_to_plaintext(self):
params = connection.URLParameters('//')
self.assertEqual(params.ssl, False)
self.assertEqual(params.port, params.DEFAULT_PORT)
def test_good_parameters(self):
"""test for the different query stings checked by process url"""
query_args = {
'blocked_connection_timeout': 10.5,
'channel_max': 3,
'connection_attempts': 2,
'frame_max': 40000,
'heartbeat': 7,
'locale': 'en_UK',
'retry_delay': 3,
'socket_timeout': 100.5,
'ssl_options': None,
'tcp_options': {
'TCP_USER_TIMEOUT': 1000,
'TCP_KEEPIDLE': 60
}
}
for backpressure in ('t', 'f'):
test_params = copy.deepcopy(query_args)
test_params['backpressure_detection'] = backpressure
virtual_host = '/'
query_string = urlencode(test_params)
test_url = ('amqp://myuser:mypass@www.test.com:5678/%s?%s' % (
url_quote(virtual_host, safe=''),
query_string,
))
params = connection.URLParameters(test_url)
# check all value from query string
for t_param in query_args:
expected_value = query_args[t_param]
actual_value = getattr(params, t_param)
self.assertEqual(
actual_value,
expected_value,
msg='Expected %s=%r, but got %r' %
(t_param, expected_value, actual_value))
self.assertEqual(params.backpressure_detection,
backpressure == 't')
# check all values from base URL
self.assertEqual(params.credentials.username, 'myuser')
self.assertEqual(params.credentials.password, 'mypass')
self.assertEqual(params.host, 'www.test.com')
self.assertEqual(params.port, 5678)
self.assertEqual(params.virtual_host, virtual_host)
def test_deprecated_heartbeat_interval(self):
with warnings.catch_warnings(record=True) as warnings_list:
warnings.simplefilter('always')
params = pika.URLParameters(
'amqp://prtfqpeo:oihdglkhcp0@myserver.'
'mycompany.com:5672/prtfqpeo?heartbeat_interval=999')
self.assertEqual(params.heartbeat, 999)
# Check that a warning was generated
self.assertEqual(len(warnings_list), 1)
self.assertIs(warnings_list[0].category, DeprecationWarning)
def test_accepts_plain_string(self):
parameters = pika.URLParameters('amqp://prtfqpeo:oihdglkhcp0@myserver.'
'mycompany.com:5672/prtfqpeo?locale='
'en_US')
self.assertEqual(parameters.port, 5672)
self.assertEqual(parameters.virtual_host, 'prtfqpeo')
self.assertEqual(parameters.credentials.password, 'oihdglkhcp0')
self.assertEqual(parameters.credentials.username, 'prtfqpeo')
self.assertEqual(parameters.locale, 'en_US')
def test_accepts_unicode_string(self):
parameters = pika.URLParameters(u'amqp://prtfqpeo:oihdglkhcp0@myserver'
u'.mycompany.com:5672/prtfqpeo?locale='
u'en_US')
self.assertEqual(parameters.port, 5672)
self.assertEqual(parameters.virtual_host, 'prtfqpeo')
self.assertEqual(parameters.credentials.password, 'oihdglkhcp0')
self.assertEqual(parameters.credentials.username, 'prtfqpeo')
self.assertEqual(parameters.locale, 'en_US')
def test_uses_default_port_if_not_specified(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com')
self.assertEqual(parameters.port, pika.URLParameters.DEFAULT_PORT)
def test_uses_default_virtual_host_if_not_specified(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_if_only_slash_is_specified(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com/')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_via_encoded_slash_upcase(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com/%2F')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_via_encoded_slash_downcase(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com/%2f')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_via_encoded_slash_upcase_ending_with_slash(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com/%2F/')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_via_encoded_slash_downcase_ending_with_slash(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com/%2f/')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
def test_uses_default_virtual_host_if_only_parameters_provided(self):
parameters = pika.URLParameters(
'amqp://myserver.mycompany.com?frame_max=8192&locale=utf8')
self.assertEqual(parameters.virtual_host,
pika.URLParameters.DEFAULT_VIRTUAL_HOST)
self.assertEqual(parameters.frame_max, 8192)
self.assertEqual(parameters.locale, 'utf8')
def test_uses_default_username_and_password_if_not_specified(self):
parameters = pika.URLParameters('amqp://myserver.mycompany.com')
self.assertEqual(parameters.credentials.username,
pika.URLParameters.DEFAULT_USERNAME)
self.assertEqual(parameters.credentials.password,
pika.URLParameters.DEFAULT_PASSWORD)
def test_accepts_blank_username_and_password(self):
parameters = pika.URLParameters('amqp://:@myserver.mycompany.com')
self.assertEqual(parameters.credentials.username, '')
self.assertEqual(parameters.credentials.password, '')
def test_url_decodes_username_and_password(self):
username = '@@@@'
password = '////'
parameters = pika.URLParameters(
'amqp://%40%40%40%40:%2F%2F%2F%2F@myserver.mycompany.com')
self.assertEqual(parameters.credentials.username, username)
self.assertEqual(parameters.credentials.password, password)