# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under thimport mock import contextlib import copy import json import logging import os import socket import threading import mock from nose.tools import assert_equal from tests import temporary_file from tests import ClientHTTPStubber from botocore import xform_name import botocore.session import botocore.config import botocore.exceptions logger = logging.getLogger(__name__) CASES_FILE = os.path.join(os.path.dirname(__file__), 'cases.json') DATA_DIR = os.path.join(os.path.dirname(__file__), 'data/') class RetryableException(botocore.exceptions.EndpointConnectionError): fmt = '{message}' class NonRetryableException(Exception): pass EXPECTED_EXCEPTIONS_THROWN = ( botocore.exceptions.ClientError, NonRetryableException, RetryableException) def test_client_monitoring(): test_cases = _load_test_cases() for case in test_cases: yield _run_test_case, case def _load_test_cases(): with open(CASES_FILE) as f: loaded_tests = json.loads(f.read()) test_cases = _get_cases_with_defaults(loaded_tests) _replace_expected_anys(test_cases) return test_cases def _get_cases_with_defaults(loaded_tests): cases = [] defaults = loaded_tests['defaults'] for case in loaded_tests['cases']: base = copy.deepcopy(defaults) base.update(case) cases.append(base) return cases def _replace_expected_anys(test_cases): for case in test_cases: for expected_event in case['expectedMonitoringEvents']: for entry, value in expected_event.items(): if value in ['ANY_STR', 'ANY_INT']: expected_event[entry] = mock.ANY @contextlib.contextmanager def _configured_session(case_configuration, listener_port): environ = { 'AWS_ACCESS_KEY_ID': case_configuration['accessKey'], 'AWS_SECRET_ACCESS_KEY': 'secret-key', 'AWS_DEFAULT_REGION': case_configuration['region'], 'AWS_DATA_PATH': DATA_DIR, 'AWS_CSM_PORT': listener_port } if 'sessionToken' in case_configuration: environ['AWS_SESSION_TOKEN'] = case_configuration['sessionToken'] environ.update(case_configuration['environmentVariables']) with temporary_file('w') as f: _setup_shared_config( f, case_configuration['sharedConfigFile'], environ) with mock.patch('os.environ', environ): session = botocore.session.Session() if 'maxRetries' in case_configuration: _setup_max_retry_attempts(session, case_configuration) yield session def _setup_shared_config(fileobj, shared_config_options, environ): fileobj.write('[default]\n') for key, value in shared_config_options.items(): fileobj.write('%s = %s\n' % (key, value)) fileobj.flush() environ['AWS_CONFIG_FILE'] = fileobj.name def _setup_max_retry_attempts(session, case_configuration): config = botocore.config.Config( retries={'max_attempts': case_configuration['maxRetries']}) session.set_default_client_config(config) def _run_test_case(case): with MonitoringListener() as listener: with _configured_session( case['configuration'], listener.port) as session: for api_call in case['apiCalls']: _make_api_call(session, api_call) assert_equal( listener.received_events, case['expectedMonitoringEvents']) def _make_api_call(session, api_call): client = session.create_client( api_call['serviceId'].lower().replace(' ', '')) operation_name = api_call['operationName'] client_method = getattr(client, xform_name(operation_name)) with _stubbed_http_layer(client, api_call['attemptResponses']): try: client_method(**api_call['params']) except EXPECTED_EXCEPTIONS_THROWN: pass @contextlib.contextmanager def _stubbed_http_layer(client, attempt_responses): with ClientHTTPStubber(client) as stubber: _add_stubbed_responses(stubber, attempt_responses) yield def _add_stubbed_responses(stubber, attempt_responses): for attempt_response in attempt_responses: if 'sdkException' in attempt_response: sdk_exception = attempt_response['sdkException'] _add_sdk_exception( stubber, sdk_exception['message'], sdk_exception['isRetryable'] ) else: _add_stubbed_response(stubber, attempt_response) def _add_sdk_exception(stubber, message, is_retryable): if is_retryable: stubber.responses.append(RetryableException(message=message)) else: stubber.responses.append(NonRetryableException(message)) def _add_stubbed_response(stubber, attempt_response): headers = attempt_response['responseHeaders'] status_code = attempt_response['httpStatus'] if 'errorCode' in attempt_response: error = { '__type': attempt_response['errorCode'], 'message': attempt_response['errorMessage'] } content = json.dumps(error).encode('utf-8') else: content = b'{}' stubber.add_response(status=status_code, headers=headers, body=content) class MonitoringListener(threading.Thread): _PACKET_SIZE = 1024 * 8 def __init__(self, port=0): threading.Thread.__init__(self) self._socket = None self.port = port self.received_events = [] def __enter__(self): self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._socket.bind(('127.0.0.1', self.port)) # The socket may have been assigned to an unused port so we # reset the port member after binding. self.port = self._socket.getsockname()[1] self.start() return self def __exit__(self, *args): self._socket.sendto(b'', ('127.0.0.1', self.port)) self.join() self._socket.close() def run(self): logger.debug('Started listener') while True: data = self._socket.recv(self._PACKET_SIZE) logger.debug('Received: %s', data.decode('utf-8')) if not data: return self.received_events.append(json.loads(data.decode('utf-8')))