python-botocore/botocore/httpchecksum.py
2022-05-25 15:10:07 -07:00

462 lines
15 KiB
Python

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
""" The interfaces in this module are not intended for public use.
This module defines interfaces for applying checksums to HTTP requests within
the context of botocore. This involves both resolving the checksum to be used
based on client configuration and environment, as well as application of the
checksum to the request.
"""
import base64
import io
import logging
from binascii import crc32
from hashlib import sha1, sha256
from botocore.compat import HAS_CRT
from botocore.exceptions import AwsChunkedWrapperError, FlexibleChecksumError
from botocore.response import StreamingBody
from botocore.utils import (
conditionally_calculate_md5,
determine_content_length,
)
if HAS_CRT:
from awscrt import checksums as crt_checksums
else:
crt_checksums = None
logger = logging.getLogger(__name__)
class BaseChecksum:
_CHUNK_SIZE = 1024 * 1024
def update(self, chunk):
pass
def digest(self):
pass
def b64digest(self):
bs = self.digest()
return base64.b64encode(bs).decode("ascii")
def _handle_fileobj(self, fileobj):
start_position = fileobj.tell()
for chunk in iter(lambda: fileobj.read(self._CHUNK_SIZE), b""):
self.update(chunk)
fileobj.seek(start_position)
def handle(self, body):
if isinstance(body, (bytes, bytearray)):
self.update(body)
else:
self._handle_fileobj(body)
return self.b64digest()
class Crc32Checksum(BaseChecksum):
def __init__(self):
self._int_crc32 = 0
def update(self, chunk):
self._int_crc32 = crc32(chunk, self._int_crc32) & 0xFFFFFFFF
def digest(self):
return self._int_crc32.to_bytes(4, byteorder="big")
class CrtCrc32Checksum(BaseChecksum):
# Note: This class is only used if the CRT is available
def __init__(self):
self._int_crc32 = 0
def update(self, chunk):
new_checksum = crt_checksums.crc32(chunk, self._int_crc32)
self._int_crc32 = new_checksum & 0xFFFFFFFF
def digest(self):
return self._int_crc32.to_bytes(4, byteorder="big")
class CrtCrc32cChecksum(BaseChecksum):
# Note: This class is only used if the CRT is available
def __init__(self):
self._int_crc32c = 0
def update(self, chunk):
new_checksum = crt_checksums.crc32c(chunk, self._int_crc32c)
self._int_crc32c = new_checksum & 0xFFFFFFFF
def digest(self):
return self._int_crc32c.to_bytes(4, byteorder="big")
class Sha1Checksum(BaseChecksum):
def __init__(self):
self._checksum = sha1()
def update(self, chunk):
self._checksum.update(chunk)
def digest(self):
return self._checksum.digest()
class Sha256Checksum(BaseChecksum):
def __init__(self):
self._checksum = sha256()
def update(self, chunk):
self._checksum.update(chunk)
def digest(self):
return self._checksum.digest()
class AwsChunkedWrapper:
_DEFAULT_CHUNK_SIZE = 1024 * 1024
def __init__(
self,
raw,
checksum_cls=None,
checksum_name="x-amz-checksum",
chunk_size=None,
):
self._raw = raw
self._checksum_name = checksum_name
self._checksum_cls = checksum_cls
self._reset()
if chunk_size is None:
chunk_size = self._DEFAULT_CHUNK_SIZE
self._chunk_size = chunk_size
def _reset(self):
self._remaining = b""
self._complete = False
self._checksum = None
if self._checksum_cls:
self._checksum = self._checksum_cls()
def seek(self, offset, whence=0):
if offset != 0 or whence != 0:
raise AwsChunkedWrapperError(
error_msg="Can only seek to start of stream"
)
self._reset()
self._raw.seek(0)
def read(self, size=None):
# Normalize "read all" size values to None
if size is not None and size <= 0:
size = None
# If the underlying body is done and we have nothing left then
# end the stream
if self._complete and not self._remaining:
return b""
# While we're not done and want more bytes
want_more_bytes = size is None or size > len(self._remaining)
while not self._complete and want_more_bytes:
self._remaining += self._make_chunk()
want_more_bytes = size is None or size > len(self._remaining)
# If size was None, we want to return everything
if size is None:
size = len(self._remaining)
# Return a chunk up to the size asked for
to_return = self._remaining[:size]
self._remaining = self._remaining[size:]
return to_return
def _make_chunk(self):
# NOTE: Chunk size is not deterministic as read could return less. This
# means we cannot know the content length of the encoded aws-chunked
# stream ahead of time without ensuring a consistent chunk size
raw_chunk = self._raw.read(self._chunk_size)
hex_len = hex(len(raw_chunk))[2:].encode("ascii")
self._complete = not raw_chunk
if self._checksum:
self._checksum.update(raw_chunk)
if self._checksum and self._complete:
name = self._checksum_name.encode("ascii")
checksum = self._checksum.b64digest().encode("ascii")
return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)
return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)
def __iter__(self):
while not self._complete:
yield self._make_chunk()
class StreamingChecksumBody(StreamingBody):
def __init__(self, raw_stream, content_length, checksum, expected):
super().__init__(raw_stream, content_length)
self._checksum = checksum
self._expected = expected
def read(self, amt=None):
chunk = super().read(amt=amt)
self._checksum.update(chunk)
if amt is None or (not chunk and amt > 0):
self._validate_checksum()
return chunk
def _validate_checksum(self):
if self._checksum.digest() != base64.b64decode(self._expected):
error_msg = (
f"Expected checksum {self._expected} did not match calculated "
f"checksum: {self._checksum.b64digest()}"
)
raise FlexibleChecksumError(error_msg=error_msg)
def resolve_checksum_context(request, operation_model, params):
resolve_request_checksum_algorithm(request, operation_model, params)
resolve_response_checksum_algorithms(request, operation_model, params)
def resolve_request_checksum_algorithm(
request,
operation_model,
params,
supported_algorithms=None,
):
http_checksum = operation_model.http_checksum
algorithm_member = http_checksum.get("requestAlgorithmMember")
if algorithm_member and algorithm_member in params:
# If the client has opted into using flexible checksums and the
# request supports it, use that instead of checksum required
if supported_algorithms is None:
supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
algorithm_name = params[algorithm_member].lower()
if algorithm_name not in supported_algorithms:
raise FlexibleChecksumError(
error_msg="Unsupported checksum algorithm: %s" % algorithm_name
)
location_type = "header"
if operation_model.has_streaming_input:
# Operations with streaming input must support trailers.
if request["url"].startswith("https:"):
# We only support unsigned trailer checksums currently. As this
# disables payload signing we'll only use trailers over TLS.
location_type = "trailer"
algorithm = {
"algorithm": algorithm_name,
"in": location_type,
"name": "x-amz-checksum-%s" % algorithm_name,
}
if algorithm["name"] in request["headers"]:
# If the header is already set by the customer, skip calculation
return
checksum_context = request["context"].get("checksum", {})
checksum_context["request_algorithm"] = algorithm
request["context"]["checksum"] = checksum_context
elif operation_model.http_checksum_required or http_checksum.get(
"requestChecksumRequired"
):
# Otherwise apply the old http checksum behavior via Content-MD5
checksum_context = request["context"].get("checksum", {})
checksum_context["request_algorithm"] = "conditional-md5"
request["context"]["checksum"] = checksum_context
def apply_request_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
if not algorithm:
return
if algorithm == "conditional-md5":
# Special case to handle the http checksum required trait
conditionally_calculate_md5(request)
elif algorithm["in"] == "header":
_apply_request_header_checksum(request)
elif algorithm["in"] == "trailer":
_apply_request_trailer_checksum(request)
else:
raise FlexibleChecksumError(
error_msg="Unknown checksum variant: %s" % algorithm["in"]
)
def _apply_request_header_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
location_name = algorithm["name"]
if location_name in request["headers"]:
# If the header is already set by the customer, skip calculation
return
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
digest = checksum_cls().handle(request["body"])
request["headers"][location_name] = digest
def _apply_request_trailer_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
location_name = algorithm["name"]
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
headers = request["headers"]
body = request["body"]
if location_name in headers:
# If the header is already set by the customer, skip calculation
return
headers["Transfer-Encoding"] = "chunked"
headers["Content-Encoding"] = "aws-chunked"
headers["X-Amz-Trailer"] = location_name
content_length = determine_content_length(body)
if content_length is not None:
# Send the decoded content length if we can determine it. Some
# services such as S3 may require the decoded content length
headers["X-Amz-Decoded-Content-Length"] = str(content_length)
if isinstance(body, (bytes, bytearray)):
body = io.BytesIO(body)
request["body"] = AwsChunkedWrapper(
body,
checksum_cls=checksum_cls,
checksum_name=location_name,
)
def resolve_response_checksum_algorithms(
request, operation_model, params, supported_algorithms=None
):
http_checksum = operation_model.http_checksum
mode_member = http_checksum.get("requestValidationModeMember")
if mode_member and mode_member in params:
if supported_algorithms is None:
supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
response_algorithms = {
a.lower() for a in http_checksum.get("responseAlgorithms", [])
}
usable_algorithms = []
for algorithm in _ALGORITHMS_PRIORITY_LIST:
if algorithm not in response_algorithms:
continue
if algorithm in supported_algorithms:
usable_algorithms.append(algorithm)
checksum_context = request["context"].get("checksum", {})
checksum_context["response_algorithms"] = usable_algorithms
request["context"]["checksum"] = checksum_context
def handle_checksum_body(http_response, response, context, operation_model):
headers = response["headers"]
checksum_context = context.get("checksum", {})
algorithms = checksum_context.get("response_algorithms")
if not algorithms:
return
for algorithm in algorithms:
header_name = "x-amz-checksum-%s" % algorithm
# If the header is not found, check the next algorithm
if header_name not in headers:
continue
# If a - is in the checksum this is not valid Base64. S3 returns
# checksums that include a -# suffix to indicate a checksum derived
# from the hash of all part checksums. We cannot wrap this response
if "-" in headers[header_name]:
continue
if operation_model.has_streaming_output:
response["body"] = _handle_streaming_response(
http_response, response, algorithm
)
else:
response["body"] = _handle_bytes_response(
http_response, response, algorithm
)
# Expose metadata that the checksum check actually occured
checksum_context = response["context"].get("checksum", {})
checksum_context["response_algorithm"] = algorithm
response["context"]["checksum"] = checksum_context
return
logger.info(
f'Skipping checksum validation. Response did not contain one of the '
f'following algorithms: {algorithms}.'
)
def _handle_streaming_response(http_response, response, algorithm):
checksum_cls = _CHECKSUM_CLS.get(algorithm)
header_name = "x-amz-checksum-%s" % algorithm
return StreamingChecksumBody(
http_response.raw,
response["headers"].get("content-length"),
checksum_cls(),
response["headers"][header_name],
)
def _handle_bytes_response(http_response, response, algorithm):
body = http_response.content
header_name = "x-amz-checksum-%s" % algorithm
checksum_cls = _CHECKSUM_CLS.get(algorithm)
checksum = checksum_cls()
checksum.update(body)
expected = response["headers"][header_name]
if checksum.digest() != base64.b64decode(expected):
error_msg = (
"Expected checksum %s did not match calculated checksum: %s"
% (
expected,
checksum.b64digest(),
)
)
raise FlexibleChecksumError(error_msg=error_msg)
return body
_CHECKSUM_CLS = {
"crc32": Crc32Checksum,
"sha1": Sha1Checksum,
"sha256": Sha256Checksum,
}
if HAS_CRT:
# Use CRT checksum implementations if available
_CHECKSUM_CLS.update(
{"crc32": CrtCrc32Checksum, "crc32c": CrtCrc32cChecksum}
)
_SUPPORTED_CHECKSUM_ALGORITHMS = list(_CHECKSUM_CLS.keys())
_ALGORITHMS_PRIORITY_LIST = ['crc32c', 'crc32', 'sha1', 'sha256']