python-botocore/tests/functional/test_endpoint_rulesets.py
2022-12-12 08:14:19 -08:00

308 lines
11 KiB
Python

# Copyright 2012-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json
from functools import lru_cache
from pathlib import Path
import pytest
from botocore import xform_name
from botocore.compat import HAS_CRT
from botocore.config import Config
from botocore.endpoint_provider import EndpointProvider
from botocore.exceptions import (
BotoCoreError,
ClientError,
EndpointResolutionError,
)
from botocore.loaders import Loader
from botocore.parsers import ResponseParserError
from tests import ClientHTTPStubber
ENDPOINT_TESTDATA_DIR = Path(__file__).parent / 'endpoint-rules'
LOADER = Loader()
# For the purpose of the tests in this file, only services for which an
# endpoint ruleset file exists matter. The existence of required endpoint
# ruleset files is asserted for in tests/functional/test_model_completeness.py
ALL_SERVICES = [
service_name
for service_name in LOADER.list_available_services(
type_name='endpoint-rule-set-1'
)
]
@pytest.fixture(scope='module')
def partitions():
return LOADER.load_data('partitions')
@lru_cache()
def get_endpoint_tests_for_service(service_name):
file_path = ENDPOINT_TESTDATA_DIR / service_name / 'endpoint-tests-1.json'
if not file_path.is_file():
raise FileNotFoundError(
f'Cannot find endpoint tests file for "{service_name}" at '
'path {file_path}'
)
with file_path.open('r') as f:
return json.load(f)
@pytest.mark.parametrize("service_name", ALL_SERVICES)
def test_all_endpoint_tests_exist(service_name):
"""Tests the existence of endpoint-tests-1.json for each service that has
a ruleset and verifies that content is present."""
data = get_endpoint_tests_for_service(service_name)
assert len(data['testCases']) > 0
def assert_all_signing_region_sets_have_length_one(rule):
"""Helper function for test_all_signing_region_sets_have_length_one()"""
if 'endpoint' in rule:
authSchemes = (
rule['endpoint'].get('properties', {}).get('authSchemes', [])
)
for authScheme in authSchemes:
if 'signingRegionSet' in authScheme:
assert len(authScheme['signingRegionSet']) == 1
for sub_rule in rule.get('rules', []):
assert_all_signing_region_sets_have_length_one(sub_rule)
@pytest.mark.parametrize("service_name", ALL_SERVICES)
def test_all_signing_region_sets_have_length_one(service_name):
"""Checks all endpoint rulesets for endpoints that contain an authSchemes
property with a `signingRegionSet` and asserts that it is a list of
length 1.
In theory, `signingRegionSet` could have >1 entries. As of writing this
test, no service uses >1 entry, the meaning of >1 entry is poorly defined,
and botocore cannot handle >1 entry. This test exists specifically to
fail if a ruleset ever uses >1 entry.
The test also fails for empty lists. While botocore would handle these
gracefully, the expected behavior for empty `signingRegionSet` lists is
not defined.
"""
ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1')
assert_all_signing_region_sets_have_length_one(ruleset)
def test_assert_all_signing_region_sets_have_length_one():
"""Negative test for to confirm that
assert_all_signing_region_sets_have_length_one() actually fails when two
sigingRegionSet entries are present."""
with pytest.raises(AssertionError):
assert_all_signing_region_sets_have_length_one(
{
"version": "1.0",
"parameters": {},
"rules": [
{
"conditions": [],
"endpoint": {
"url": "https://foo",
"properties": {
"authSchemes": [
{
"name": "sigv4a",
"disableDoubleEncoding": True,
"signingRegionSet": ["*", "abc"],
"signingName": "myservice",
}
]
},
"headers": {},
},
"type": "endpoint",
}
],
}
)
def iter_all_test_cases():
for service_name in ALL_SERVICES:
test_data = get_endpoint_tests_for_service(service_name)
for test_case in test_data['testCases']:
yield service_name, test_case
def iter_provider_test_cases_that_produce(endpoints=False, errors=False):
for service_name, test in iter_all_test_cases():
input_params = test.get('params', {})
expected_object = test['expect']
if endpoints and 'endpoint' in expected_object:
yield service_name, input_params, expected_object['endpoint']
if errors and 'error' in expected_object:
yield service_name, input_params, expected_object['error']
def iter_e2e_test_cases_that_produce(endpoints=False, errors=False):
for service_name, test in iter_all_test_cases():
# Not all test cases contain operation inputs for end-to-end tests.
if 'operationInputs' not in test:
continue
# Each test case can contain a list of input sets for the same
# expected result.
for op_inputs in test['operationInputs']:
op_params = op_inputs.get('operationParams', {})
# Test cases that use invalid bucket names as inputs fail in
# botocore because botocore validated bucket names before running
# endpoint resolution.
if op_params.get('Bucket') in ['bucket name', 'example.com#']:
continue
op_name = op_inputs['operationName']
builtins = op_inputs.get('builtInParams', {})
expected_object = test['expect']
if endpoints and 'endpoint' in expected_object:
expected_endpoint = expected_object['endpoint']
expected_props = expected_endpoint.get('properties', {})
expected_authschemes = [
auth_scheme['name']
for auth_scheme in expected_props.get('authSchemes', [])
]
yield pytest.param(
service_name,
op_name,
op_params,
builtins,
expected_endpoint,
marks=pytest.mark.skipif(
'sigv4a' in expected_authschemes and not HAS_CRT,
reason="Test case expects sigv4a which requires CRT",
),
)
if errors and 'error' in expected_object:
yield pytest.param(
service_name,
op_name,
op_params,
builtins,
expected_object['error'],
)
@pytest.mark.parametrize(
'service_name, input_params, expected_endpoint',
iter_provider_test_cases_that_produce(endpoints=True),
)
def test_endpoint_provider_test_cases_yielding_endpoints(
partitions, service_name, input_params, expected_endpoint
):
ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1')
endpoint_provider = EndpointProvider(ruleset, partitions)
endpoint = endpoint_provider.resolve_endpoint(**input_params)
assert endpoint.url == expected_endpoint['url']
assert endpoint.properties == expected_endpoint.get('properties', {})
assert endpoint.headers == expected_endpoint.get('headers', {})
@pytest.mark.parametrize(
'service_name, input_params, expected_error',
iter_provider_test_cases_that_produce(errors=True),
)
def test_endpoint_provider_test_cases_yielding_errors(
partitions, service_name, input_params, expected_error
):
ruleset = LOADER.load_service_model(service_name, 'endpoint-rule-set-1')
endpoint_provider = EndpointProvider(ruleset, partitions)
with pytest.raises(EndpointResolutionError) as exc_info:
endpoint_provider.resolve_endpoint(**input_params)
assert str(exc_info.value) == expected_error
@pytest.mark.parametrize(
'service_name, op_name, op_params, builtin_params, expected_endpoint',
iter_e2e_test_cases_that_produce(endpoints=True),
)
def test_end_to_end_test_cases_yielding_endpoints(
patched_session,
service_name,
op_name,
op_params,
builtin_params,
expected_endpoint,
):
def builtin_overwriter_handler(builtins, **kwargs):
# must edit builtins dict in place but need to erase all existing
# entries
for key in list(builtins.keys()):
del builtins[key]
for key, val in builtin_params.items():
builtins[key] = val
region = builtin_params.get('AWS::Region', 'us-east-1')
client = patched_session.create_client(
service_name,
region_name=region,
# endpoint ruleset test cases do not account for host prefixes from the
# operation model
config=Config(inject_host_prefix=False),
)
client.meta.events.register_last(
'before-endpoint-resolution', builtin_overwriter_handler
)
with ClientHTTPStubber(client, strict=True) as http_stubber:
http_stubber.add_response(status=418)
op_fn = getattr(client, xform_name(op_name))
try:
op_fn(**op_params)
except (ClientError, ResponseParserError):
pass
assert len(http_stubber.requests) > 0
actual_url = http_stubber.requests[0].url
assert actual_url.startswith(
expected_endpoint['url']
), f"{actual_url} does not start with {expected_endpoint['url']}"
@pytest.mark.parametrize(
'service_name, op_name, op_params, builtin_params, expected_error',
iter_e2e_test_cases_that_produce(errors=True),
)
def test_end_to_end_test_cases_yielding_errors(
patched_session,
service_name,
op_name,
op_params,
builtin_params,
expected_error,
):
def builtin_overwriter_handler(builtins, **kwargs):
# must edit builtins dict in place but need to erase all existing
# entries
for key in list(builtins.keys()):
del builtins[key]
for key, val in builtin_params.items():
builtins[key] = val
region = builtin_params.get('AWS::Region', 'us-east-1')
client = patched_session.create_client(service_name, region_name=region)
client.meta.events.register_last(
'before-endpoint-resolution', builtin_overwriter_handler
)
with ClientHTTPStubber(client, strict=True) as http_stubber:
http_stubber.add_response(status=418)
op_fn = getattr(client, xform_name(op_name))
with pytest.raises(BotoCoreError):
try:
op_fn(**op_params)
except (ClientError, ResponseParserError):
pass
assert len(http_stubber.requests) == 0