# Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import threading import os import math import time import mock import tempfile import shutil from datetime import datetime, timedelta import sys from botocore.vendored import requests from dateutil.tz import tzlocal from botocore.exceptions import CredentialRetrievalError from tests import unittest, IntegerRefresher, BaseEnvVar, random_chars from tests import temporary_file from botocore.credentials import EnvProvider, ContainerProvider from botocore.credentials import InstanceMetadataProvider from botocore.credentials import Credentials, ReadOnlyCredentials from botocore.credentials import AssumeRoleProvider from botocore.credentials import CanonicalNameCredentialSourcer from botocore.session import Session from botocore.exceptions import InvalidConfigError, InfiniteLoopConfigError from botocore.stub import Stubber class TestCredentialRefreshRaces(unittest.TestCase): def assert_consistent_credentials_seen(self, creds, func): collected = [] threads = [] for _ in range(20): threads.append(threading.Thread(target=func, args=(collected,))) start = time.time() for thread in threads: thread.start() for thread in threads: thread.join() for creds in collected: # During testing, the refresher uses it's current # refresh count as the values for the access, secret, and # token value. This means that at any given point in time, # the credentials should be something like: # # ReadOnlyCredentials('1', '1', '1') # ReadOnlyCredentials('2', '2', '2') # ... # ReadOnlyCredentials('30', '30', '30') # # This makes it really easy to verify we see a consistent # set of credentials from the same time period. We just # check if all the credential values are the same. If # we ever see something like: # # ReadOnlyCredentials('1', '2', '1') # # We fail. This is because we're using the access_key # from the first refresh ('1'), the secret key from # the second refresh ('2'), and the token from the # first refresh ('1'). self.assertTrue(creds[0] == creds[1] == creds[2], creds) def test_has_no_race_conditions(self): creds = IntegerRefresher( creds_last_for=2, advisory_refresh=1, mandatory_refresh=0 ) def _run_in_thread(collected): for _ in range(4000): frozen = creds.get_frozen_credentials() collected.append((frozen.access_key, frozen.secret_key, frozen.token)) start = time.time() self.assert_consistent_credentials_seen(creds, _run_in_thread) end = time.time() # creds_last_for = 2 seconds (from above) # So, for example, if execution time took 6.1 seconds, then # we should see a maximum number of refreshes being (6 / 2.0) + 1 = 4 max_calls_allowed = math.ceil((end - start) / 2.0) + 1 self.assertTrue(creds.refresh_counter <= max_calls_allowed, "Too many cred refreshes, max: %s, actual: %s, " "time_delta: %.4f" % (max_calls_allowed, creds.refresh_counter, (end - start))) def test_no_race_for_immediate_advisory_expiration(self): creds = IntegerRefresher( creds_last_for=1, advisory_refresh=1, mandatory_refresh=0 ) def _run_in_thread(collected): for _ in range(100): frozen = creds.get_frozen_credentials() collected.append((frozen.access_key, frozen.secret_key, frozen.token)) self.assert_consistent_credentials_seen(creds, _run_in_thread) class TestAssumeRole(BaseEnvVar): def setUp(self): super(TestAssumeRole, self).setUp() self.tempdir = tempfile.mkdtemp() self.config_file = os.path.join(self.tempdir, 'config') self.environ['AWS_CONFIG_FILE'] = self.config_file self.environ['AWS_ACCESS_KEY_ID'] = 'access_key' self.environ['AWS_SECRET_ACCESS_KEY'] = 'secret_key' self.metadata_provider = self.mock_provider(InstanceMetadataProvider) self.env_provider = self.mock_provider(EnvProvider) self.container_provider = self.mock_provider(ContainerProvider) def mock_provider(self, provider_cls): mock_instance = mock.Mock(spec=provider_cls) mock_instance.load.return_value = None mock_instance.METHOD = provider_cls.METHOD mock_instance.CANONICAL_NAME = provider_cls.CANONICAL_NAME return mock_instance def tearDown(self): shutil.rmtree(self.tempdir) def create_session(self, profile=None): session = Session(profile=profile) # We have to set bogus credentials here or otherwise we'll trigger # an early credential chain resolution. sts = session.create_client( 'sts', aws_access_key_id='spam', aws_secret_access_key='eggs', ) stubber = Stubber(sts) stubber.activate() assume_role_provider = AssumeRoleProvider( load_config=lambda: session.full_config, client_creator=lambda *args, **kwargs: sts, cache={}, profile_name=profile, credential_sourcer=CanonicalNameCredentialSourcer([ self.env_provider, self.container_provider, self.metadata_provider ]) ) component_name = 'credential_provider' resolver = session.get_component(component_name) available_methods = [p.METHOD for p in resolver.providers] replacements = { 'env': self.env_provider, 'iam-role': self.metadata_provider, 'container-role': self.container_provider, 'assume-role': assume_role_provider } for name, provider in replacements.items(): try: index = available_methods.index(name) except ValueError: # The provider isn't in the session continue resolver.providers[index] = provider session.register_component( 'credential_provider', resolver ) return session, stubber def create_assume_role_response(self, credentials, expiration=None): if expiration is None: expiration = self.some_future_time() response = { 'Credentials': { 'AccessKeyId': credentials.access_key, 'SecretAccessKey': credentials.secret_key, 'SessionToken': credentials.token, 'Expiration': expiration }, 'AssumedRoleUser': { 'AssumedRoleId': 'myroleid', 'Arn': 'arn:aws:iam::1234567890:user/myuser' } } return response def create_random_credentials(self): return Credentials( 'fake-%s' % random_chars(15), 'fake-%s' % random_chars(35), 'fake-%s' % random_chars(45) ) def some_future_time(self): timeobj = datetime.now(tzlocal()) return timeobj + timedelta(hours=24) def write_config(self, config): with open(self.config_file, 'w') as f: f.write(config) def assert_creds_equal(self, c1, c2): c1_frozen = c1 if not isinstance(c1_frozen, ReadOnlyCredentials): c1_frozen = c1.get_frozen_credentials() c2_frozen = c2 if not isinstance(c2_frozen, ReadOnlyCredentials): c2_frozen = c2.get_frozen_credentials() self.assertEqual(c1_frozen, c2_frozen) def test_assume_role(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = B\n\n' '[profile B]\n' 'aws_access_key_id = abc123\n' 'aws_secret_access_key = def456\n' ) self.write_config(config) expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, expected_creds) stubber.assert_no_pending_responses() def test_environment_credential_source(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'credential_source = Environment\n' ) self.write_config(config) environment_creds = self.create_random_credentials() self.env_provider.load.return_value = environment_creds expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, expected_creds) stubber.assert_no_pending_responses() self.assertEqual(self.env_provider.load.call_count, 1) def test_instance_metadata_credential_source(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'credential_source = Ec2InstanceMetadata\n' ) self.write_config(config) metadata_creds = self.create_random_credentials() self.metadata_provider.load.return_value = metadata_creds expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, expected_creds) stubber.assert_no_pending_responses() self.assertEqual(self.metadata_provider.load.call_count, 1) def test_container_credential_source(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'credential_source = EcsContainer\n' ) self.write_config(config) container_creds = self.create_random_credentials() self.container_provider.load.return_value = container_creds expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, expected_creds) stubber.assert_no_pending_responses() self.assertEqual(self.container_provider.load.call_count, 1) def test_invalid_credential_source(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'credential_source = CustomInvalidProvider\n' ) self.write_config(config) with self.assertRaises(InvalidConfigError): session, _ = self.create_session(profile='A') session.get_credentials() def test_misconfigured_source_profile(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = B\n' '[profile B]\n' 'credential_process = command\n' ) self.write_config(config) with self.assertRaises(InvalidConfigError): session, _ = self.create_session(profile='A') session.get_credentials() def test_recursive_assume_role(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = B\n\n' '[profile B]\n' 'role_arn = arn:aws:iam::123456789:role/RoleB\n' 'source_profile = C\n\n' '[profile C]\n' 'aws_access_key_id = abc123\n' 'aws_secret_access_key = def456\n' ) self.write_config(config) profile_b_creds = self.create_random_credentials() profile_b_response = self.create_assume_role_response(profile_b_creds) profile_a_creds = self.create_random_credentials() profile_a_response = self.create_assume_role_response(profile_a_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', profile_b_response) stubber.add_response('assume_role', profile_a_response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, profile_a_creds) stubber.assert_no_pending_responses() def test_recursive_assume_role_stops_at_static_creds(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = B\n\n' '[profile B]\n' 'aws_access_key_id = abc123\n' 'aws_secret_access_key = def456\n' 'role_arn = arn:aws:iam::123456789:role/RoleB\n' 'source_profile = C\n\n' '[profile C]\n' 'aws_access_key_id = abc123\n' 'aws_secret_access_key = def456\n' ) self.write_config(config) profile_a_creds = self.create_random_credentials() profile_a_response = self.create_assume_role_response(profile_a_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', profile_a_response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, profile_a_creds) stubber.assert_no_pending_responses() def test_infinitely_recursive_assume_role(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = A\n' ) self.write_config(config) with self.assertRaises(InfiniteLoopConfigError): session, _ = self.create_session(profile='A') session.get_credentials() def test_self_referential_profile(self): config = ( '[profile A]\n' 'role_arn = arn:aws:iam::123456789:role/RoleA\n' 'source_profile = A\n' 'aws_access_key_id = abc123\n' 'aws_secret_access_key = def456\n' ) self.write_config(config) expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session, stubber = self.create_session(profile='A') stubber.add_response('assume_role', response) actual_creds = session.get_credentials() self.assert_creds_equal(actual_creds, expected_creds) stubber.assert_no_pending_responses() class TestProcessProvider(unittest.TestCase): def setUp(self): current_dir = os.path.dirname(os.path.abspath(__file__)) credential_process = os.path.join( current_dir, 'utils', 'credentialprocess.py' ) self.credential_process = '%s %s' % ( sys.executable, credential_process ) self.environ = os.environ.copy() self.environ_patch = mock.patch('os.environ', self.environ) self.environ_patch.start() def tearDown(self): self.environ_patch.stop() def test_credential_process(self): config = ( '[profile processcreds]\n' 'credential_process = %s\n' ) config = config % self.credential_process with temporary_file('w') as f: f.write(config) f.flush() self.environ['AWS_CONFIG_FILE'] = f.name credentials = Session(profile='processcreds').get_credentials() self.assertEqual(credentials.access_key, 'spam') self.assertEqual(credentials.secret_key, 'eggs') def test_credential_process_returns_error(self): config = ( '[profile processcreds]\n' 'credential_process = %s --raise-error\n' ) config = config % self.credential_process with temporary_file('w') as f: f.write(config) f.flush() self.environ['AWS_CONFIG_FILE'] = f.name session = Session(profile='processcreds') # This regex validates that there is no substring: b' # The reason why we want to validate that is that we want to # make sure that stderr is actually decoded so that in # exceptional cases the error is properly formatted. # As for how the regex works: # `(?!b').` is a negative lookahead, meaning that it will only # match if it is not followed by the pattern `b'`. Since it is # followed by a `.` it will match any character not followed by # that pattern. `((?!hede).)*` does that zero or more times. The # final pattern adds `^` and `$` to anchor the beginning and end # of the string so we can know the whole string is consumed. # Finally `(?s)` at the beginning makes dots match newlines so # we can handle a multi-line string. reg = r"(?s)^((?!b').)*$" with self.assertRaisesRegexp(CredentialRetrievalError, reg): session.get_credentials()