python-botocore/tests/functional/csm/test_monitoring.py

209 lines
6.6 KiB
Python
Raw Normal View History

2018-12-28 08:05:06 +01:00
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under thimport mock
import contextlib
import copy
import json
import logging
import os
import socket
import threading
2021-10-04 18:33:37 +02:00
import pytest
2018-12-28 08:05:06 +01:00
2021-09-22 22:53:42 +02:00
from tests import mock
2018-12-28 08:05:06 +01:00
from tests import temporary_file
from tests import ClientHTTPStubber
from botocore import xform_name
import botocore.session
import botocore.config
import botocore.exceptions
logger = logging.getLogger(__name__)
CASES_FILE = os.path.join(os.path.dirname(__file__), 'cases.json')
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data/')
class RetryableException(botocore.exceptions.EndpointConnectionError):
fmt = '{message}'
class NonRetryableException(Exception):
pass
EXPECTED_EXCEPTIONS_THROWN = (
botocore.exceptions.ClientError, NonRetryableException, RetryableException)
def _load_test_cases():
with open(CASES_FILE) as f:
loaded_tests = json.loads(f.read())
test_cases = _get_cases_with_defaults(loaded_tests)
_replace_expected_anys(test_cases)
return test_cases
def _get_cases_with_defaults(loaded_tests):
cases = []
defaults = loaded_tests['defaults']
for case in loaded_tests['cases']:
base = copy.deepcopy(defaults)
base.update(case)
cases.append(base)
return cases
def _replace_expected_anys(test_cases):
for case in test_cases:
for expected_event in case['expectedMonitoringEvents']:
for entry, value in expected_event.items():
if value in ['ANY_STR', 'ANY_INT']:
expected_event[entry] = mock.ANY
2021-10-04 18:33:37 +02:00
@pytest.mark.parametrize("test_case", _load_test_cases())
def test_client_monitoring(test_case):
_run_test_case(test_case)
2018-12-28 08:05:06 +01:00
@contextlib.contextmanager
def _configured_session(case_configuration, listener_port):
environ = {
'AWS_ACCESS_KEY_ID': case_configuration['accessKey'],
'AWS_SECRET_ACCESS_KEY': 'secret-key',
'AWS_DEFAULT_REGION': case_configuration['region'],
'AWS_DATA_PATH': DATA_DIR,
'AWS_CSM_PORT': listener_port
}
if 'sessionToken' in case_configuration:
environ['AWS_SESSION_TOKEN'] = case_configuration['sessionToken']
environ.update(case_configuration['environmentVariables'])
with temporary_file('w') as f:
_setup_shared_config(
f, case_configuration['sharedConfigFile'], environ)
with mock.patch('os.environ', environ):
session = botocore.session.Session()
if 'maxRetries' in case_configuration:
_setup_max_retry_attempts(session, case_configuration)
yield session
def _setup_shared_config(fileobj, shared_config_options, environ):
fileobj.write('[default]\n')
for key, value in shared_config_options.items():
fileobj.write('%s = %s\n' % (key, value))
fileobj.flush()
environ['AWS_CONFIG_FILE'] = fileobj.name
def _setup_max_retry_attempts(session, case_configuration):
config = botocore.config.Config(
retries={'max_attempts': case_configuration['maxRetries']})
session.set_default_client_config(config)
def _run_test_case(case):
with MonitoringListener() as listener:
with _configured_session(
case['configuration'], listener.port) as session:
for api_call in case['apiCalls']:
_make_api_call(session, api_call)
2021-10-04 18:33:37 +02:00
assert listener.received_events == case['expectedMonitoringEvents']
2018-12-28 08:05:06 +01:00
def _make_api_call(session, api_call):
client = session.create_client(
api_call['serviceId'].lower().replace(' ', ''))
operation_name = api_call['operationName']
client_method = getattr(client, xform_name(operation_name))
with _stubbed_http_layer(client, api_call['attemptResponses']):
try:
client_method(**api_call['params'])
except EXPECTED_EXCEPTIONS_THROWN:
pass
@contextlib.contextmanager
def _stubbed_http_layer(client, attempt_responses):
with ClientHTTPStubber(client) as stubber:
_add_stubbed_responses(stubber, attempt_responses)
yield
def _add_stubbed_responses(stubber, attempt_responses):
for attempt_response in attempt_responses:
if 'sdkException' in attempt_response:
sdk_exception = attempt_response['sdkException']
_add_sdk_exception(
stubber, sdk_exception['message'],
sdk_exception['isRetryable']
)
else:
_add_stubbed_response(stubber, attempt_response)
def _add_sdk_exception(stubber, message, is_retryable):
if is_retryable:
stubber.responses.append(RetryableException(message=message))
else:
stubber.responses.append(NonRetryableException(message))
def _add_stubbed_response(stubber, attempt_response):
headers = attempt_response['responseHeaders']
status_code = attempt_response['httpStatus']
if 'errorCode' in attempt_response:
error = {
'__type': attempt_response['errorCode'],
'message': attempt_response['errorMessage']
}
content = json.dumps(error).encode('utf-8')
else:
content = b'{}'
stubber.add_response(status=status_code, headers=headers, body=content)
class MonitoringListener(threading.Thread):
_PACKET_SIZE = 1024 * 8
def __init__(self, port=0):
threading.Thread.__init__(self)
self._socket = None
self.port = port
self.received_events = []
def __enter__(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._socket.bind(('127.0.0.1', self.port))
# The socket may have been assigned to an unused port so we
# reset the port member after binding.
self.port = self._socket.getsockname()[1]
self.start()
return self
def __exit__(self, *args):
self._socket.sendto(b'', ('127.0.0.1', self.port))
self.join()
self._socket.close()
def run(self):
logger.debug('Started listener')
while True:
data = self._socket.recv(self._PACKET_SIZE)
logger.debug('Received: %s', data.decode('utf-8'))
if not data:
return
self.received_events.append(json.loads(data.decode('utf-8')))