import time from botocore.awsrequest import AWSRequest from botocore.client import ClientMeta from botocore.discovery import ( EndpointDiscoveryHandler, EndpointDiscoveryManager, EndpointDiscoveryRefreshFailed, EndpointDiscoveryRequired, block_endpoint_discovery_required_operations, ) from botocore.exceptions import ConnectionError from botocore.handlers import inject_api_version_header_if_needed from botocore.hooks import HierarchicalEmitter from botocore.model import ServiceModel from tests import mock, unittest class BaseEndpointDiscoveryTest(unittest.TestCase): def setUp(self): self.service_description = { 'version': '2.0', 'metadata': { 'apiVersion': '2018-08-31', 'endpointPrefix': 'fooendpoint', 'jsonVersion': '1.1', 'protocol': 'json', 'serviceAbbreviation': 'FooService', 'serviceId': 'FooService', 'serviceFullName': 'AwsFooService', 'signatureVersion': 'v4', 'signingName': 'awsfooservice', 'targetPrefix': 'awsfooservice', }, 'operations': { 'DescribeEndpoints': { 'name': 'DescribeEndpoints', 'http': {'method': 'POST', 'requestUri': '/'}, 'input': {'shape': 'DescribeEndpointsRequest'}, 'output': {'shape': 'DescribeEndpointsResponse'}, 'endpointoperation': True, }, 'TestDiscoveryRequired': { 'name': 'TestDiscoveryRequired', 'http': {'method': 'POST', 'requestUri': '/'}, 'input': {'shape': 'TestDiscoveryIdsRequest'}, 'output': {'shape': 'EmptyStruct'}, 'endpointdiscovery': {'required': True}, }, 'TestDiscoveryOptional': { 'name': 'TestDiscoveryOptional', 'http': {'method': 'POST', 'requestUri': '/'}, 'input': {'shape': 'TestDiscoveryIdsRequest'}, 'output': {'shape': 'EmptyStruct'}, 'endpointdiscovery': {}, }, 'TestDiscovery': { 'name': 'TestDiscovery', 'http': {'method': 'POST', 'requestUri': '/'}, 'input': {'shape': 'EmptyStruct'}, 'output': {'shape': 'EmptyStruct'}, 'endpointdiscovery': {}, }, }, 'shapes': { 'Boolean': {'type': 'boolean'}, 'DescribeEndpointsRequest': { 'type': 'structure', 'members': { 'Operation': {'shape': 'String'}, 'Identifiers': {'shape': 'Identifiers'}, }, }, 'DescribeEndpointsResponse': { 'type': 'structure', 'required': ['Endpoints'], 'members': {'Endpoints': {'shape': 'Endpoints'}}, }, 'Endpoint': { 'type': 'structure', 'required': ['Address', 'CachePeriodInMinutes'], 'members': { 'Address': {'shape': 'String'}, 'CachePeriodInMinutes': {'shape': 'Long'}, }, }, 'Endpoints': {'type': 'list', 'member': {'shape': 'Endpoint'}}, 'Identifiers': { 'type': 'map', 'key': {'shape': 'String'}, 'value': {'shape': 'String'}, }, 'Long': {'type': 'long'}, 'String': {'type': 'string'}, 'TestDiscoveryIdsRequest': { 'type': 'structure', 'required': ['Foo', 'Nested'], 'members': { 'Foo': { 'shape': 'String', 'endpointdiscoveryid': True, }, 'Baz': {'shape': 'String'}, 'Nested': {'shape': 'Nested'}, }, }, 'EmptyStruct': {'type': 'structure', 'members': {}}, 'Nested': { 'type': 'structure', 'required': 'Bar', 'members': { 'Bar': { 'shape': 'String', 'endpointdiscoveryid': True, } }, }, }, } class TestEndpointDiscoveryManager(BaseEndpointDiscoveryTest): def setUp(self): super().setUp() self.construct_manager() def construct_manager(self, cache=None, time=None, side_effect=None): self.service_model = ServiceModel(self.service_description) self.meta = mock.Mock(spec=ClientMeta) self.meta.service_model = self.service_model self.client = mock.Mock() if side_effect is None: side_effect = [ { 'Endpoints': [ { 'Address': 'new.com', 'CachePeriodInMinutes': 2, } ] } ] self.client.describe_endpoints.side_effect = side_effect self.client.meta = self.meta self.manager = EndpointDiscoveryManager( self.client, cache=cache, current_time=time ) def test_injects_api_version_if_endpoint_operation(self): model = self.service_model.operation_model('DescribeEndpoints') params = {'headers': {}} inject_api_version_header_if_needed(model, params) self.assertEqual( params['headers'].get('x-amz-api-version'), '2018-08-31' ) def test_no_inject_api_version_if_not_endpoint_operation(self): model = self.service_model.operation_model('TestDiscoveryRequired') params = {'headers': {}} inject_api_version_header_if_needed(model, params) self.assertNotIn('x-amz-api-version', params['headers']) def test_gather_identifiers(self): params = {'Foo': 'value1', 'Nested': {'Bar': 'value2'}} operation = self.service_model.operation_model('TestDiscoveryRequired') ids = self.manager.gather_identifiers(operation, params) self.assertEqual(ids, {'Foo': 'value1', 'Bar': 'value2'}) def test_gather_identifiers_none(self): operation = self.service_model.operation_model('TestDiscovery') ids = self.manager.gather_identifiers(operation, {}) self.assertEqual(ids, {}) def test_describe_endpoint(self): kwargs = { 'Operation': 'FooBar', 'Identifiers': {'Foo': 'value1', 'Bar': 'value2'}, } self.manager.describe_endpoint(**kwargs) self.client.describe_endpoints.assert_called_with(**kwargs) def test_describe_endpoint_no_input(self): describe = self.service_description['operations']['DescribeEndpoints'] del describe['input'] self.construct_manager() self.manager.describe_endpoint(Operation='FooBar', Identifiers={}) self.client.describe_endpoints.assert_called_with() def test_describe_endpoint_empty_input(self): describe = self.service_description['operations']['DescribeEndpoints'] describe['input'] = {'shape': 'EmptyStruct'} self.construct_manager() self.manager.describe_endpoint(Operation='FooBar', Identifiers={}) self.client.describe_endpoints.assert_called_with() def test_describe_endpoint_ids_and_operation(self): cache = {} self.construct_manager(cache=cache) ids = {'Foo': 'value1', 'Bar': 'value2'} kwargs = { 'Operation': 'TestDiscoveryRequired', 'Identifiers': ids, } self.manager.describe_endpoint(**kwargs) self.client.describe_endpoints.assert_called_with(**kwargs) key = ((('Bar', 'value2'), ('Foo', 'value1')), 'TestDiscoveryRequired') self.assertIn(key, cache) self.assertEqual(cache[key][0]['Address'], 'new.com') self.manager.describe_endpoint(**kwargs) call_count = self.client.describe_endpoints.call_count self.assertEqual(call_count, 1) def test_describe_endpoint_no_ids_or_operation(self): cache = {} describe = self.service_description['operations']['DescribeEndpoints'] describe['input'] = {'shape': 'EmptyStruct'} self.construct_manager(cache=cache) self.manager.describe_endpoint( Operation='TestDiscoveryRequired', Identifiers={} ) self.client.describe_endpoints.assert_called_with() key = () self.assertIn(key, cache) self.assertEqual(cache[key][0]['Address'], 'new.com') self.manager.describe_endpoint( Operation='TestDiscoveryRequired', Identifiers={} ) call_count = self.client.describe_endpoints.call_count self.assertEqual(call_count, 1) def test_describe_endpoint_expired_entry(self): current_time = time.time() key = () cache = { key: [{'Address': 'old.com', 'Expiration': current_time - 10}] } self.construct_manager(cache=cache) kwargs = { 'Identifiers': {}, 'Operation': 'TestDiscoveryRequired', } self.manager.describe_endpoint(**kwargs) self.client.describe_endpoints.assert_called_with() self.assertIn(key, cache) self.assertEqual(cache[key][0]['Address'], 'new.com') self.manager.describe_endpoint(**kwargs) call_count = self.client.describe_endpoints.call_count self.assertEqual(call_count, 1) def test_describe_endpoint_cache_expiration(self): def _time(): return float(0) cache = {} self.construct_manager(cache=cache, time=_time) self.manager.describe_endpoint( Operation='TestDiscoveryRequired', Identifiers={} ) key = () self.assertIn(key, cache) self.assertEqual(cache[key][0]['Expiration'], float(120)) def test_delete_endpoints_present(self): key = () cache = {key: [{'Address': 'old.com', 'Expiration': 0}]} self.construct_manager(cache=cache) kwargs = { 'Identifiers': {}, 'Operation': 'TestDiscoveryRequired', } self.manager.delete_endpoints(**kwargs) self.assertEqual(cache, {}) def test_delete_endpoints_absent(self): cache = {} self.construct_manager(cache=cache) kwargs = { 'Identifiers': {}, 'Operation': 'TestDiscoveryRequired', } self.manager.delete_endpoints(**kwargs) self.assertEqual(cache, {}) def test_describe_endpoint_optional_fails_no_cache(self): side_effect = [ConnectionError(error=None)] self.construct_manager(side_effect=side_effect) kwargs = {'Operation': 'TestDiscoveryOptional'} endpoint = self.manager.describe_endpoint(**kwargs) self.assertIsNone(endpoint) # This second call should be blocked as we just failed endpoint = self.manager.describe_endpoint(**kwargs) self.assertIsNone(endpoint) self.client.describe_endpoints.call_args_list == [mock.call()] def test_describe_endpoint_optional_fails_stale_cache(self): key = () cache = {key: [{'Address': 'old.com', 'Expiration': 0}]} side_effect = [ConnectionError(error=None)] * 2 self.construct_manager(cache=cache, side_effect=side_effect) kwargs = {'Operation': 'TestDiscoveryOptional'} endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'old.com') # This second call shouldn't go through as we just failed endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'old.com') self.client.describe_endpoints.call_args_list == [mock.call()] def test_describe_endpoint_required_fails_no_cache(self): side_effect = [ConnectionError(error=None)] * 2 self.construct_manager(side_effect=side_effect) kwargs = {'Operation': 'TestDiscoveryRequired'} with self.assertRaises(EndpointDiscoveryRefreshFailed): self.manager.describe_endpoint(**kwargs) # This second call should go through, as we have no cache with self.assertRaises(EndpointDiscoveryRefreshFailed): self.manager.describe_endpoint(**kwargs) describe_count = self.client.describe_endpoints.call_count self.assertEqual(describe_count, 2) def test_describe_endpoint_required_fails_stale_cache(self): key = () cache = {key: [{'Address': 'old.com', 'Expiration': 0}]} side_effect = [ConnectionError(error=None)] * 2 self.construct_manager(cache=cache, side_effect=side_effect) kwargs = {'Operation': 'TestDiscoveryRequired'} endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'old.com') # We have a stale endpoint, so this shouldn't fail or force a refresh endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'old.com') self.client.describe_endpoints.call_args_list == [mock.call()] def test_describe_endpoint_required_force_refresh_success(self): side_effect = [ ConnectionError(error=None), { 'Endpoints': [ { 'Address': 'new.com', 'CachePeriodInMinutes': 2, } ] }, ] self.construct_manager(side_effect=side_effect) kwargs = {'Operation': 'TestDiscoveryRequired'} # First call will fail with self.assertRaises(EndpointDiscoveryRefreshFailed): self.manager.describe_endpoint(**kwargs) self.client.describe_endpoints.call_args_list == [mock.call()] # Force a refresh if the cache is empty but discovery is required endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'new.com') def test_describe_endpoint_retries_after_failing(self): fake_time = mock.Mock() fake_time.side_effect = [0, 100, 200] side_effect = [ ConnectionError(error=None), { 'Endpoints': [ { 'Address': 'new.com', 'CachePeriodInMinutes': 2, } ] }, ] self.construct_manager(side_effect=side_effect, time=fake_time) kwargs = {'Operation': 'TestDiscoveryOptional'} endpoint = self.manager.describe_endpoint(**kwargs) self.assertIsNone(endpoint) self.client.describe_endpoints.call_args_list == [mock.call()] # Second time should try again as enough time has elapsed endpoint = self.manager.describe_endpoint(**kwargs) self.assertEqual(endpoint, 'new.com') class TestEndpointDiscoveryHandler(BaseEndpointDiscoveryTest): def setUp(self): super().setUp() self.manager = mock.Mock(spec=EndpointDiscoveryManager) self.handler = EndpointDiscoveryHandler(self.manager) self.service_model = ServiceModel(self.service_description) def test_register_handler(self): events = mock.Mock(spec=HierarchicalEmitter) self.handler.register(events, 'foo-bar') events.register.assert_any_call( 'before-parameter-build.foo-bar', self.handler.gather_identifiers ) events.register.assert_any_call( 'needs-retry.foo-bar', self.handler.handle_retries ) events.register_first.assert_called_with( 'request-created.foo-bar', self.handler.discover_endpoint ) def test_discover_endpoint(self): request = AWSRequest() request.context = {'discovery': {'identifiers': {}}} self.manager.describe_endpoint.return_value = 'https://new.foo' self.handler.discover_endpoint(request, 'TestOperation') self.assertEqual(request.url, 'https://new.foo') self.manager.describe_endpoint.assert_called_with( Operation='TestOperation', Identifiers={} ) def test_discover_endpoint_fails(self): request = AWSRequest() request.context = {'discovery': {'identifiers': {}}} request.url = 'old.com' self.manager.describe_endpoint.return_value = None self.handler.discover_endpoint(request, 'TestOperation') self.assertEqual(request.url, 'old.com') self.manager.describe_endpoint.assert_called_with( Operation='TestOperation', Identifiers={} ) def test_discover_endpoint_no_protocol(self): request = AWSRequest() request.context = {'discovery': {'identifiers': {}}} self.manager.describe_endpoint.return_value = 'new.foo' self.handler.discover_endpoint(request, 'TestOperation') self.assertEqual(request.url, 'https://new.foo') self.manager.describe_endpoint.assert_called_with( Operation='TestOperation', Identifiers={} ) def test_inject_no_context(self): request = AWSRequest(url='https://original.foo') self.handler.discover_endpoint(request, 'TestOperation') self.assertEqual(request.url, 'https://original.foo') self.manager.describe_endpoint.assert_not_called() def test_gather_identifiers(self): context = {} params = {'Foo': 'value1', 'Nested': {'Bar': 'value2'}} ids = {'Foo': 'value1', 'Bar': 'value2'} model = self.service_model.operation_model('TestDiscoveryRequired') self.manager.gather_identifiers.return_value = ids self.handler.gather_identifiers(params, model, context) self.assertEqual(context['discovery']['identifiers'], ids) def test_gather_identifiers_not_discoverable(self): context = {} model = self.service_model.operation_model('DescribeEndpoints') self.handler.gather_identifiers({}, model, context) self.assertEqual(context, {}) def test_discovery_disabled_but_required(self): model = self.service_model.operation_model('TestDiscoveryRequired') with self.assertRaises(EndpointDiscoveryRequired): block_endpoint_discovery_required_operations(model) def test_discovery_disabled_but_optional(self): context = {} model = self.service_model.operation_model('TestDiscoveryOptional') block_endpoint_discovery_required_operations(model, context=context) self.assertEqual(context, {}) def test_does_not_retry_no_response(self): retry = self.handler.handle_retries(None, None, None) self.assertIsNone(retry) def test_does_not_retry_other_errors(self): parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 200}} response = (None, parsed_response) retry = self.handler.handle_retries(None, response, None) self.assertIsNone(retry) def test_does_not_retry_if_no_context(self): request_dict = {'context': {}} parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 421}} response = (None, parsed_response) retry = self.handler.handle_retries(request_dict, response, None) self.assertIsNone(retry) def _assert_retries(self, parsed_response): request_dict = {'context': {'discovery': {'identifiers': {}}}} response = (None, parsed_response) model = self.service_model.operation_model('TestDiscoveryOptional') retry = self.handler.handle_retries(request_dict, response, model) self.assertEqual(retry, 0) self.manager.delete_endpoints.assert_called_with( Operation='TestDiscoveryOptional', Identifiers={} ) def test_retries_421_status_code(self): parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 421}} self._assert_retries(parsed_response) def test_retries_invalid_endpoint_exception(self): parsed_response = {'Error': {'Code': 'InvalidEndpointException'}} self._assert_retries(parsed_response)