110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
import random
|
|
import time
|
|
import threading
|
|
from tests import unittest
|
|
|
|
from botocore.retries import bucket
|
|
|
|
|
|
class InstrumentedTokenBucket(bucket.TokenBucket):
|
|
def _acquire(self, amount, block):
|
|
rval = super(InstrumentedTokenBucket, self)._acquire(amount, block)
|
|
assert self._current_capacity >= 0
|
|
return rval
|
|
|
|
|
|
class TestTokenBucketThreading(unittest.TestCase):
|
|
def setUp(self):
|
|
self.shutdown_threads = False
|
|
self.caught_exceptions = []
|
|
self.acquisitions_by_thread = {}
|
|
|
|
def run_in_thread(self):
|
|
while not self.shutdown_threads:
|
|
capacity = random.randint(1, self.max_capacity)
|
|
self.retry_quota.acquire(capacity)
|
|
self.seen_capacities.append(self.retry_quota.available_capacity)
|
|
self.retry_quota.release(capacity)
|
|
self.seen_capacities.append(self.retry_quota.available_capacity)
|
|
|
|
def create_clock(self):
|
|
return bucket.Clock()
|
|
|
|
def test_can_change_max_rate_while_blocking(self):
|
|
# This isn't a stress test, we just want to verify we can change
|
|
# the rate at which we acquire a token.
|
|
min_rate = 0.1
|
|
max_rate = 1
|
|
token_bucket = bucket.TokenBucket(
|
|
min_rate=min_rate, max_rate=max_rate,
|
|
clock=self.create_clock(),
|
|
)
|
|
# First we'll set the max_rate to 0.1 (min_rate). This means that
|
|
# it will take 10 seconds to accumulate a single token. We'll start
|
|
# a thread and have it acquire() a token.
|
|
# Then in the main thread we'll change the max_rate to something
|
|
# really quick (e.g 100). We should immediately get a token back.
|
|
# This is going to be timing sensitive, but we can verify that
|
|
# as long as it doesn't take 10 seconds to get a token, we were
|
|
# able to update the rate as needed.
|
|
thread = threading.Thread(target=token_bucket.acquire)
|
|
token_bucket.max_rate = min_rate
|
|
start_time = time.time()
|
|
thread.start()
|
|
# This shouldn't block the main thread.
|
|
token_bucket.max_rate = 100
|
|
thread.join()
|
|
end_time = time.time()
|
|
self.assertLessEqual(end_time - start_time, 1.0 / min_rate)
|
|
|
|
def acquire_in_loop(self, token_bucket):
|
|
while not self.shutdown_threads:
|
|
try:
|
|
self.assertTrue(token_bucket.acquire())
|
|
thread_name = threading.current_thread().name
|
|
self.acquisitions_by_thread[thread_name] += 1
|
|
except Exception as e:
|
|
self.caught_exceptions.append(e)
|
|
|
|
def randomly_set_max_rate(self, token_bucket, min_val, max_val):
|
|
while not self.shutdown_threads:
|
|
new_rate = random.randint(min_val, max_val)
|
|
token_bucket.max_rate = new_rate
|
|
time.sleep(0.01)
|
|
|
|
def test_stress_test_token_bucket(self):
|
|
token_bucket = InstrumentedTokenBucket(
|
|
max_rate=10,
|
|
clock=self.create_clock(),
|
|
)
|
|
all_threads = []
|
|
for _ in range(2):
|
|
all_threads.append(
|
|
threading.Thread(target=self.randomly_set_max_rate,
|
|
args=(token_bucket, 30, 200))
|
|
)
|
|
for _ in range(10):
|
|
t = threading.Thread(target=self.acquire_in_loop,
|
|
args=(token_bucket,))
|
|
self.acquisitions_by_thread[t.name] = 0
|
|
all_threads.append(t)
|
|
for thread in all_threads:
|
|
thread.start()
|
|
try:
|
|
# If you're working on this code you can bump this number way
|
|
# up to stress test it more locally.
|
|
time.sleep(3)
|
|
finally:
|
|
self.shutdown_threads = True
|
|
for thread in all_threads:
|
|
thread.join()
|
|
self.assertEqual(self.caught_exceptions, [])
|
|
distribution = self.acquisitions_by_thread.values()
|
|
mean = sum(distribution) / float(len(distribution))
|
|
# We can't really rely on any guarantees about evenly distributing
|
|
# thread acquisition(), e.g. must be with a 2 stddev range, but we
|
|
# can sanity check that our implementation isn't drastically
|
|
# starving a thread. So we'll arbitrarily say that a thread
|
|
# can't have less than 30% of the mean allocations per thread.
|
|
self.assertTrue(not any(x < (0.3 * mean) for x in distribution))
|