# Copyright 2012-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import json from functools import lru_cache from pathlib import Path import pytest from botocore import xform_name from botocore.compat import HAS_CRT from botocore.config import Config from botocore.endpoint_provider import EndpointProvider from botocore.exceptions import ( BotoCoreError, ClientError, EndpointResolutionError, ) from botocore.loaders import Loader from botocore.parsers import ResponseParserError from tests import ClientHTTPStubber ENDPOINT_TESTDATA_DIR = Path(__file__).parent / 'endpoint-rules' LOADER = Loader() # For the purpose of the tests in this file, only services for which an # endpoint ruleset file exists matter. The existence of required endpoint # ruleset files is asserted for in tests/functional/test_model_completeness.py ALL_SERVICES = [ service_name for service_name in LOADER.list_available_services( type_name='endpoint-rule-set-1' ) ] @pytest.fixture(scope='module') def partitions(): return LOADER.load_data('partitions') @lru_cache() def get_endpoint_tests_for_service(service_name): file_path = ENDPOINT_TESTDATA_DIR / service_name / 'endpoint-tests-1.json' if not file_path.is_file(): raise FileNotFoundError( f'Cannot find endpoint tests file for "{service_name}" at ' 'path {file_path}' ) with file_path.open('r') as f: return json.load(f) @pytest.mark.parametrize("service_name", ALL_SERVICES) def test_all_endpoint_tests_exist(service_name): """Tests the existence of endpoint-tests-1.json for each service that has a ruleset and verifies that content is present.""" data = get_endpoint_tests_for_service(service_name) assert len(data['testCases']) > 0 def assert_all_signing_region_sets_have_length_one(rule): """Helper function for test_all_signing_region_sets_have_length_one()""" if 'endpoint' in rule: authSchemes = ( rule['endpoint'].get('properties', {}).get('authSchemes', []) ) for authScheme in authSchemes: if 'signingRegionSet' in authScheme: assert len(authScheme['signingRegionSet']) == 1 for sub_rule in rule.get('rules', []): assert_all_signing_region_sets_have_length_one(sub_rule) @pytest.mark.parametrize("service_name", ALL_SERVICES) def test_all_signing_region_sets_have_length_one(service_name): """Checks all endpoint rulesets for endpoints that contain an authSchemes property with a `signingRegionSet` and asserts that it is a list of length 1. In theory, `signingRegionSet` could have >1 entries. As of writing this test, no service uses >1 entry, the meaning of >1 entry is poorly defined, and botocore cannot handle >1 entry. This test exists specifically to fail if a ruleset ever uses >1 entry. The test also fails for empty lists. While botocore would handle these gracefully, the expected behavior for empty `signingRegionSet` lists is not defined. """ ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1') assert_all_signing_region_sets_have_length_one(ruleset) def test_assert_all_signing_region_sets_have_length_one(): """Negative test for to confirm that assert_all_signing_region_sets_have_length_one() actually fails when two sigingRegionSet entries are present.""" with pytest.raises(AssertionError): assert_all_signing_region_sets_have_length_one( { "version": "1.0", "parameters": {}, "rules": [ { "conditions": [], "endpoint": { "url": "https://foo", "properties": { "authSchemes": [ { "name": "sigv4a", "disableDoubleEncoding": True, "signingRegionSet": ["*", "abc"], "signingName": "myservice", } ] }, "headers": {}, }, "type": "endpoint", } ], } ) def iter_all_test_cases(): for service_name in ALL_SERVICES: test_data = get_endpoint_tests_for_service(service_name) for test_case in test_data['testCases']: yield service_name, test_case def iter_provider_test_cases_that_produce(endpoints=False, errors=False): for service_name, test in iter_all_test_cases(): input_params = test.get('params', {}) expected_object = test['expect'] if endpoints and 'endpoint' in expected_object: yield service_name, input_params, expected_object['endpoint'] if errors and 'error' in expected_object: yield service_name, input_params, expected_object['error'] def iter_e2e_test_cases_that_produce(endpoints=False, errors=False): for service_name, test in iter_all_test_cases(): # Not all test cases contain operation inputs for end-to-end tests. if 'operationInputs' not in test: continue # Each test case can contain a list of input sets for the same # expected result. for op_inputs in test['operationInputs']: op_params = op_inputs.get('operationParams', {}) # Test cases that use invalid bucket names as inputs fail in # botocore because botocore validated bucket names before running # endpoint resolution. if op_params.get('Bucket') in ['bucket name', 'example.com#']: continue op_name = op_inputs['operationName'] builtins = op_inputs.get('builtInParams', {}) expected_object = test['expect'] if endpoints and 'endpoint' in expected_object: expected_endpoint = expected_object['endpoint'] expected_props = expected_endpoint.get('properties', {}) expected_authschemes = [ auth_scheme['name'] for auth_scheme in expected_props.get('authSchemes', []) ] yield pytest.param( service_name, op_name, op_params, builtins, expected_endpoint, marks=pytest.mark.skipif( 'sigv4a' in expected_authschemes and not HAS_CRT, reason="Test case expects sigv4a which requires CRT", ), ) if errors and 'error' in expected_object: yield pytest.param( service_name, op_name, op_params, builtins, expected_object['error'], ) @pytest.mark.parametrize( 'service_name, input_params, expected_endpoint', iter_provider_test_cases_that_produce(endpoints=True), ) def test_endpoint_provider_test_cases_yielding_endpoints( partitions, service_name, input_params, expected_endpoint ): ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1') endpoint_provider = EndpointProvider(ruleset, partitions) endpoint = endpoint_provider.resolve_endpoint(**input_params) assert endpoint.url == expected_endpoint['url'] assert endpoint.properties == expected_endpoint.get('properties', {}) assert endpoint.headers == expected_endpoint.get('headers', {}) @pytest.mark.parametrize( 'service_name, input_params, expected_error', iter_provider_test_cases_that_produce(errors=True), ) def test_endpoint_provider_test_cases_yielding_errors( partitions, service_name, input_params, expected_error ): ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1') endpoint_provider = EndpointProvider(ruleset, partitions) with pytest.raises(EndpointResolutionError) as exc_info: endpoint_provider.resolve_endpoint(**input_params) assert str(exc_info.value) == expected_error @pytest.mark.parametrize( 'service_name, op_name, op_params, builtin_params, expected_endpoint', iter_e2e_test_cases_that_produce(endpoints=True), ) def test_end_to_end_test_cases_yielding_endpoints( patched_session, service_name, op_name, op_params, builtin_params, expected_endpoint, ): def builtin_overwriter_handler(builtins, **kwargs): # must edit builtins dict in place but need to erase all existing # entries for key in list(builtins.keys()): del builtins[key] for key, val in builtin_params.items(): builtins[key] = val region = builtin_params.get('AWS::Region', 'us-east-1') client = patched_session.create_client( service_name, region_name=region, # endpoint ruleset test cases do not account for host prefixes from the # operation model config=Config(inject_host_prefix=False), ) client.meta.events.register_last( 'before-endpoint-resolution', builtin_overwriter_handler ) with ClientHTTPStubber(client, strict=True) as http_stubber: http_stubber.add_response(status=418) op_fn = getattr(client, xform_name(op_name)) try: op_fn(**op_params) except (ClientError, ResponseParserError): pass assert len(http_stubber.requests) > 0 actual_url = http_stubber.requests[0].url assert actual_url.startswith( expected_endpoint['url'] ), f"{actual_url} does not start with {expected_endpoint['url']}" @pytest.mark.parametrize( 'service_name, op_name, op_params, builtin_params, expected_error', iter_e2e_test_cases_that_produce(errors=True), ) def test_end_to_end_test_cases_yielding_errors( patched_session, service_name, op_name, op_params, builtin_params, expected_error, ): def builtin_overwriter_handler(builtins, **kwargs): # must edit builtins dict in place but need to erase all existing # entries for key in list(builtins.keys()): del builtins[key] for key, val in builtin_params.items(): builtins[key] = val region = builtin_params.get('AWS::Region', 'us-east-1') client = patched_session.create_client(service_name, region_name=region) client.meta.events.register_last( 'before-endpoint-resolution', builtin_overwriter_handler ) with ClientHTTPStubber(client, strict=True) as http_stubber: http_stubber.add_response(status=418) op_fn = getattr(client, xform_name(op_name)) with pytest.raises(BotoCoreError): try: op_fn(**op_params) except (ClientError, ResponseParserError): pass assert len(http_stubber.requests) == 0