diff --git a/tuf/ngclient/_internal/download.py b/tuf/ngclient/_internal/download.py deleted file mode 100644 index 31b59f6630..0000000000 --- a/tuf/ngclient/_internal/download.py +++ /dev/null @@ -1,243 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2012 - 2017, New York University and the TUF contributors -# SPDX-License-Identifier: MIT OR Apache-2.0 - -""" - - download.py - - - February 21, 2012. Based on previous version by Geremy Condra. - - - Konstantin Andrianov - Vladimir Diaz - - - See LICENSE-MIT OR LICENSE for licensing information. - - - Download metadata and target files and check their validity. The hash and - length of a downloaded file has to match the hash and length supplied by the - metadata of that file. -""" - -import logging -import tempfile -import timeit -from urllib import parse - -from securesystemslib import formats as sslib_formats - -import tuf -from tuf import exceptions, formats - -# See 'log.py' to learn how logging is handled in TUF. -logger = logging.getLogger(__name__) - - -def download_file(url, required_length, fetcher, strict_required_length=True): - """ - - Given the url and length of the desired file, this function opens a - connection to 'url' and downloads the file while ensuring its length - matches 'required_length' if 'STRICT_REQUIRED_LENGH' is True (If False, - the file's length is not checked and a slow retrieval exception is raised - if the downloaded rate falls below the acceptable rate). - - - url: - A URL string that represents the location of the file. - - required_length: - An integer value representing the length of the file. - - strict_required_length: - A Boolean indicator used to signal whether we should perform strict - checking of required_length. True by default. We explicitly set this to - False when we know that we want to turn this off for downloading the - timestamp metadata, which has no signed required_length. - - - A file object is created on disk to store the contents of 'url'. - - - exceptions.DownloadLengthMismatchError, if there was a - mismatch of observed vs expected lengths while downloading the file. - - securesystemslib.exceptions.FormatError, if any of the arguments are - improperly formatted. - - Any other unforeseen runtime exception. - - - A file object that points to the contents of 'url'. - """ - # Do all of the arguments have the appropriate format? - # Raise 'securesystemslib.exceptions.FormatError' if there is a mismatch. - sslib_formats.URL_SCHEMA.check_match(url) - formats.LENGTH_SCHEMA.check_match(required_length) - - # 'url.replace('\\', '/')' is needed for compatibility with Windows-based - # systems, because they might use back-slashes in place of forward-slashes. - # This converts it to the common format. unquote() replaces %xx escapes in - # a url with their single-character equivalent. A back-slash may be - # encoded as %5c in the url, which should also be replaced with a forward - # slash. - url = parse.unquote(url).replace("\\", "/") - logger.info("Downloading: %s", url) - - # This is the temporary file that we will return to contain the contents of - # the downloaded file. - temp_file = tempfile.TemporaryFile() # pylint: disable=consider-using-with - - average_download_speed = 0 - number_of_bytes_received = 0 - - try: - chunks = fetcher.fetch(url, required_length) - start_time = timeit.default_timer() - for chunk in chunks: - - stop_time = timeit.default_timer() - temp_file.write(chunk) - - # Measure the average download speed. - number_of_bytes_received += len(chunk) - seconds_spent_receiving = stop_time - start_time - average_download_speed = ( - number_of_bytes_received / seconds_spent_receiving - ) - - if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED: - logger.debug( - "The average download speed dropped below the minimum" - " average download speed set in tuf.settings.py." - " Stopping the download!" - ) - break - - logger.debug( - "The average download speed has not dipped below the" - " minimum average download speed set in tuf.settings.py." - ) - - # Does the total number of downloaded bytes match the required length? - _check_downloaded_length( - number_of_bytes_received, - required_length, - strict_required_length=strict_required_length, - average_download_speed=average_download_speed, - ) - - except Exception: - # Close 'temp_file'. Any written data is lost. - temp_file.close() - logger.debug("Could not download URL: %s", url) - raise - - else: - temp_file.seek(0) - return temp_file - - -def download_bytes(url, required_length, fetcher, strict_required_length=True): - """Download bytes from given url - - Returns the downloaded bytes, otherwise like download_file() - """ - with download_file( - url, required_length, fetcher, strict_required_length - ) as dl_file: - return dl_file.read() - - -def _check_downloaded_length( - total_downloaded, - required_length, - strict_required_length=True, - average_download_speed=None, -): - """ - - A helper function which checks whether the total number of downloaded - bytes matches our expectation. - - - total_downloaded: - The total number of bytes supposedly downloaded for the file in - question. - - required_length: - The total number of bytes expected of the file as seen from its metadata - The Timestamp role is always downloaded without a known file length, and - the Root role when the client cannot download any of the required - top-level roles. In both cases, 'required_length' is actually an upper - limit on the length of the downloaded file. - - strict_required_length: - A Boolean indicator used to signal whether we should perform strict - checking of required_length. True by default. We explicitly set this to - False when we know that we want to turn this off for downloading the - timestamp metadata, which has no signed required_length. - - average_download_speed: - The average download speed for the downloaded file. - - - None. - - - securesystemslib.exceptions.DownloadLengthMismatchError, if - strict_required_length is True and total_downloaded is not equal - required_length. - - exceptions.SlowRetrievalError, if the total downloaded was - done in less than the acceptable download speed (as set in - tuf.settings.py). - - - None. - """ - - if total_downloaded == required_length: - logger.info("Downloaded %d bytes as expected.", total_downloaded) - - else: - # What we downloaded is not equal to the required length, but did we ask - # for strict checking of required length? - if strict_required_length: - logger.info( - "Downloaded %d bytes, but expected %d bytes", - total_downloaded, - required_length, - ) - - # If the average download speed is below a certain threshold, we - # flag this as a possible slow-retrieval attack. - if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED: - raise exceptions.SlowRetrievalError(average_download_speed) - - raise exceptions.DownloadLengthMismatchError( - required_length, total_downloaded - ) - - # We specifically disabled strict checking of required length, but - # we will log a warning anyway. This is useful when we wish to - # download the Timestamp or Root metadata, for which we have no - # signed metadata; so, we must guess a reasonable required_length - # for it. - if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED: - raise exceptions.SlowRetrievalError(average_download_speed) - - logger.debug( - "Good average download speed: %f bytes per second", - average_download_speed, - ) - - logger.info( - "Downloaded %d bytes out of upper limit of %d bytes.", - total_downloaded, - required_length, - ) diff --git a/tuf/ngclient/_internal/requests_fetcher.py b/tuf/ngclient/_internal/requests_fetcher.py index 216153b1f9..a26231c5bb 100644 --- a/tuf/ngclient/_internal/requests_fetcher.py +++ b/tuf/ngclient/_internal/requests_fetcher.py @@ -7,7 +7,7 @@ import logging import time -from typing import Optional +from typing import Iterator, Optional from urllib import parse # Imports @@ -21,7 +21,7 @@ # Globals logger = logging.getLogger(__name__) -# Classess +# Classes class RequestsFetcher(FetcherInterface): """A concrete implementation of FetcherInterface based on the Requests library. @@ -53,15 +53,15 @@ def __init__(self): self.chunk_size: int = 400000 # bytes self.sleep_before_round: Optional[int] = None - def fetch(self, url, required_length): + def fetch(self, url: str, max_length: int) -> Iterator[bytes]: """Fetches the contents of HTTP/HTTPS url from a remote server. - Ensures the length of the downloaded data is up to 'required_length'. + Ensures the length of the downloaded data is up to 'max_length'. Arguments: url: A URL string that represents a file location. - required_length: An integer value representing the file length in - bytes. + max_length: An integer value representing the maximum + number of bytes to be downloaded. Raises: exceptions.SlowRetrievalError: A timeout occurs while receiving @@ -90,58 +90,61 @@ def fetch(self, url, required_length): status = e.response.status_code raise exceptions.FetcherHTTPError(str(e), status) - # Define a generator function to be returned by fetch. This way the - # caller of fetch can differentiate between connection and actual data - # download and measure download times accordingly. - def chunks(): - try: - bytes_received = 0 - while True: - # We download a fixed chunk of data in every round. This is - # so that we can defend against slow retrieval attacks. - # Furthermore, we do not wish to download an extremely - # large file in one shot. Before beginning the round, sleep - # (if set) for a short amount of time so that the CPU is not - # hogged in the while loop. - if self.sleep_before_round: - time.sleep(self.sleep_before_round) - - read_amount = min( - self.chunk_size, - required_length - bytes_received, - ) - - # NOTE: This may not handle some servers adding a - # Content-Encoding header, which may cause urllib3 to - # misbehave: - # https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582 - data = response.raw.read(read_amount) - bytes_received += len(data) - - # We might have no more data to read. Check number of bytes - # downloaded. - if not data: - logger.debug( - "Downloaded %d out of %d bytes", - bytes_received, - required_length, - ) - - # Finally, we signal that the download is complete. - break - - yield data - - if bytes_received >= required_length: - break - - except urllib3.exceptions.ReadTimeoutError as e: - raise exceptions.SlowRetrievalError(str(e)) - - finally: - response.close() - - return chunks() + return self._chunks(response, max_length) + + def _chunks( + self, response: "requests.Response", max_length: int + ) -> Iterator[bytes]: + """A generator function to be returned by fetch. This way the + caller of fetch can differentiate between connection and actual data + download.""" + + try: + bytes_received = 0 + while True: + # We download a fixed chunk of data in every round. This is + # so that we can defend against slow retrieval attacks. + # Furthermore, we do not wish to download an extremely + # large file in one shot. Before beginning the round, sleep + # (if set) for a short amount of time so that the CPU is not + # hogged in the while loop. + if self.sleep_before_round: + time.sleep(self.sleep_before_round) + + read_amount = min( + self.chunk_size, + max_length - bytes_received, + ) + + # NOTE: This may not handle some servers adding a + # Content-Encoding header, which may cause urllib3 to + # misbehave: + # https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582 + data = response.raw.read(read_amount) + bytes_received += len(data) + + # We might have no more data to read. Check number of bytes + # downloaded. + if not data: + # Finally, we signal that the download is complete. + break + + yield data + + if bytes_received >= max_length: + break + + logger.debug( + "Downloaded %d out of %d bytes", + bytes_received, + max_length, + ) + + except urllib3.exceptions.ReadTimeoutError as e: + raise exceptions.SlowRetrievalError(str(e)) + + finally: + response.close() def _get_session(self, url): """Returns a different customized requests.Session per schema+hostname @@ -157,10 +160,6 @@ def _get_session(self, url): ) session_index = parsed_url.scheme + "+" + parsed_url.hostname - - logger.debug("url: %s", url) - logger.debug("session index: %s", session_index) - session = self._sessions.get(session_index) if not session: diff --git a/tuf/ngclient/fetcher.py b/tuf/ngclient/fetcher.py index 8a6cae34d7..89d5a98473 100644 --- a/tuf/ngclient/fetcher.py +++ b/tuf/ngclient/fetcher.py @@ -6,6 +6,15 @@ # Imports import abc +import logging +import tempfile +from contextlib import contextmanager +from typing import IO, Iterator +from urllib import parse + +from tuf import exceptions + +logger = logging.getLogger(__name__) # Classes @@ -20,15 +29,15 @@ class FetcherInterface: __metaclass__ = abc.ABCMeta @abc.abstractmethod - def fetch(self, url, required_length): + def fetch(self, url: str, max_length: int) -> Iterator[bytes]: """Fetches the contents of HTTP/HTTPS url from a remote server. - Ensures the length of the downloaded data is up to 'required_length'. + Ensures the length of the downloaded data is up to 'max_length'. Arguments: url: A URL string that represents a file location. - required_length: An integer value representing the file length in - bytes. + max_length: An integer value representing the maximum + number of bytes to be downloaded. Raises: tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving @@ -39,3 +48,50 @@ def fetch(self, url, required_length): A bytes iterator """ raise NotImplementedError # pragma: no cover + + @contextmanager + def download_file(self, url: str, max_length: int) -> Iterator[IO]: + """Opens a connection to 'url' and downloads the content + up to 'max_length'. + + Args: + url: a URL string that represents the location of the file. + max_length: an integer value representing the length of + the file or an upper bound. + + Raises: + DownloadLengthMismatchError: downloaded bytes exceed 'max_length'. + + Yields: + A TemporaryFile object that points to the contents of 'url'. + """ + # 'url.replace('\\', '/')' is needed for compatibility with + # Windows-based systems, because they might use back-slashes in place + # of forward-slashes. This converts it to the common format. + # unquote() replaces %xx escapes in a url with their single-character + # equivalent. A back-slash may beencoded as %5c in the url, which + # should also be replaced with a forward slash. + url = parse.unquote(url).replace("\\", "/") + logger.debug("Downloading: %s", url) + + number_of_bytes_received = 0 + + with tempfile.TemporaryFile() as temp_file: + chunks = self.fetch(url, max_length) + for chunk in chunks: + temp_file.write(chunk) + number_of_bytes_received += len(chunk) + if number_of_bytes_received > max_length: + raise exceptions.DownloadLengthMismatchError( + max_length, number_of_bytes_received + ) + temp_file.seek(0) + yield temp_file + + def download_bytes(self, url: str, max_length: int) -> bytes: + """Download bytes from given url + + Returns the downloaded bytes, otherwise like download_file() + """ + with self.download_file(url, max_length) as dl_file: + return dl_file.read() diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index b054abdaf3..de3af53cff 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -14,11 +14,7 @@ from securesystemslib import util as sslib_util from tuf import exceptions -from tuf.ngclient._internal import ( - download, - requests_fetcher, - trusted_metadata_set, -) +from tuf.ngclient._internal import requests_fetcher, trusted_metadata_set from tuf.ngclient.config import UpdaterConfig from tuf.ngclient.fetcher import FetcherInterface @@ -27,7 +23,7 @@ class Updater: """ - An implemetation of the TUF client workflow. + An implementation of the TUF client workflow. Provides a public API for integration in client applications. """ @@ -193,8 +189,8 @@ def download_target( target_fileinfo: "TargetFile" = targetinfo["fileinfo"] full_url = parse.urljoin(target_base_url, target_filepath) - with download.download_file( - full_url, target_fileinfo.length, self._fetcher + with self._fetcher.download_file( + full_url, target_fileinfo.length ) as target_file: try: target_fileinfo.verify_length_and_hashes(target_file) @@ -215,12 +211,7 @@ def _download_metadata( else: filename = f"{version}.{rolename}.json" url = parse.urljoin(self._metadata_base_url, filename) - return download.download_bytes( - url, - length, - self._fetcher, - strict_required_length=False, - ) + return self._fetcher.download_bytes(url, length) def _load_local_metadata(self, rolename: str) -> bytes: with open(os.path.join(self._dir, f"{rolename}.json"), "rb") as f: