diff --git a/requirements.txt b/requirements.txt index 63e9e8888..976055383 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ agent==0.1.2 aiohttp==0.18.2 git+https://github.com/felliott/boto.git@feature/gen-url-query-params-6#egg=boto +boto3==1.7.36 celery==3.1.17 furl==0.4.2 google-auth==1.4.1 @@ -15,7 +16,7 @@ raven==5.27.0 setuptools==37.0.0 stevedore==1.2.0 tornado==4.3 -xmltodict==0.9.0 +xmltodict==0.11.0 # Issue: certifi-2015.9.6.1 and 2015.9.6.2 fail verification (https://github.com/certifi/python-certifi/issues/26) certifi==2015.4.28 diff --git a/waterbutler/auth/osf/handler.py b/waterbutler/auth/osf/handler.py index 6f976976f..174e230cb 100644 --- a/waterbutler/auth/osf/handler.py +++ b/waterbutler/auth/osf/handler.py @@ -40,6 +40,7 @@ def build_payload(self, bundle, view_only=None, cookie=None): return query_params async def make_request(self, params, headers, cookies): + data = None try: response = await aiohttp.request( 'get', diff --git a/waterbutler/core/streams/http.py b/waterbutler/core/streams/http.py index fb2955344..957e9525c 100644 --- a/waterbutler/core/streams/http.py +++ b/waterbutler/core/streams/http.py @@ -179,6 +179,7 @@ def __init__(self, request, inner): super().__init__() self.inner = inner self.request = request + self.offset = 0 @property def size(self): @@ -187,12 +188,16 @@ def size(self): def at_eof(self): return self.inner.at_eof() + def tell(self): + return self.offset + async def _read(self, size): if self.inner.at_eof(): return b'' if size < 0: return (await self.inner.read(size)) try: + self.offset += size return (await self.inner.readexactly(size)) except asyncio.IncompleteReadError as e: return e.partial diff --git a/waterbutler/core/utils.py b/waterbutler/core/utils.py index 2e953418e..8be49cbab 100644 --- a/waterbutler/core/utils.py +++ b/waterbutler/core/utils.py @@ -105,15 +105,17 @@ async def send_signed_request(method, url, payload): )) -def normalize_datetime(date_string): - if date_string is None: +def normalize_datetime(date): + if date is None: return None - parsed_datetime = dateutil.parser.parse(date_string) - if not parsed_datetime.tzinfo: - parsed_datetime = parsed_datetime.replace(tzinfo=pytz.UTC) - parsed_datetime = parsed_datetime.astimezone(tz=pytz.UTC) - parsed_datetime = parsed_datetime.replace(microsecond=0) - return parsed_datetime.isoformat() + if isinstance(date, str): + date = dateutil.parser.parse(date) + if not date.tzinfo: + date = date.replace(tzinfo=pytz.UTC) + date = date.astimezone(tz=pytz.UTC) + date = date.replace(microsecond=0) + return date.isoformat() + class ZipStreamGenerator: diff --git a/waterbutler/providers/s3/metadata.py b/waterbutler/providers/s3/metadata.py index bcdbdea4a..9e9acde45 100644 --- a/waterbutler/providers/s3/metadata.py +++ b/waterbutler/providers/s3/metadata.py @@ -1,6 +1,7 @@ import os from waterbutler.core import metadata +from waterbutler.core import utils class S3Metadata(metadata.BaseMetadata): @@ -16,11 +17,13 @@ def name(self): class S3FileMetadataHeaders(S3Metadata, metadata.BaseFileMetadata): - def __init__(self, path, headers): + def __init__(self, path, s3_object=None, headers=None): self._path = path + self.obj = headers + self._etag = None # Cast to dict to clone as the headers will # be destroyed when the request leaves scope - super().__init__(dict(headers)) + super().__init__(headers) @property def path(self): @@ -44,16 +47,17 @@ def created_utc(self): @property def etag(self): - return self.raw['ETAG'].replace('"', '') + if self._etag is None: + self._etag = self.raw['ETAG'].replace('"', '') + return self._etag @property def extra(self): - md5 = self.raw['ETAG'].replace('"', '') return { - 'md5': md5, + 'md5': self.etag, 'encryption': self.raw.get('X-AMZ-SERVER-SIDE-ENCRYPTION', ''), 'hashes': { - 'md5': md5, + 'md5': self.etag, }, } @@ -132,7 +136,7 @@ def version(self): @property def modified(self): - return self.raw['LastModified'] + return utils.normalize_datetime(self.raw['LastModified']) @property def extra(self): diff --git a/waterbutler/providers/s3/provider.py b/waterbutler/providers/s3/provider.py index b09c7d07a..42157a086 100644 --- a/waterbutler/providers/s3/provider.py +++ b/waterbutler/providers/s3/provider.py @@ -1,12 +1,16 @@ +import asyncio import os +import itertools import hashlib import functools from urllib import parse +import logging import xmltodict import xml.sax.saxutils +#import aioboto3 from boto import config as boto_config from boto.compat import BytesIO # type: ignore from boto.utils import compute_md5 @@ -14,18 +18,100 @@ from boto.s3.connection import S3Connection from boto.s3.connection import OrdinaryCallingFormat + +import boto3 +from botocore.awsrequest import prepare_request_dict +from botocore.client import Config +from botocore.exceptions import ( + ClientError, + UnknownClientMethodError +) +from botocore.signers import _should_use_global_endpoint + + from waterbutler.core import streams from waterbutler.core import provider from waterbutler.core import exceptions from waterbutler.core.path import WaterButlerPath from waterbutler.providers.s3 import settings +from waterbutler.providers.s3.streams import S3ResponseBodyStream from waterbutler.providers.s3.metadata import S3Revision from waterbutler.providers.s3.metadata import S3FileMetadata from waterbutler.providers.s3.metadata import S3FolderMetadata from waterbutler.providers.s3.metadata import S3FolderKeyMetadata from waterbutler.providers.s3.metadata import S3FileMetadataHeaders +logger = logging.getLogger(__name__) + + +def generate_presigned_url(self, ClientMethod, Params=None, Headers=None, ExpiresIn=3600, + HttpMethod=None): + """Generate a presigned url given a client, its method, and arguments + :type ClientMethod: string + :param ClientMethod: The client method to presign for + :type Params: dict + :param Params: The parameters normally passed to + ``ClientMethod``. + :type ExpiresIn: int + :param ExpiresIn: The number of seconds the presigned url is valid + for. By default it expires in an hour (3600 seconds) + :type HttpMethod: string + :param HttpMethod: The http method to use on the generated url. By + default, the http method is whatever is used in the method's model. + :returns: The presigned url + """ + client_method = ClientMethod + params = Params + if params is None: + params = {} + # + headers = Headers + # + expires_in = ExpiresIn + http_method = HttpMethod + context = { + 'is_presign_request': True, + 'use_global_endpoint': _should_use_global_endpoint(self), + } + + request_signer = self._request_signer + serializer = self._serializer + + try: + operation_name = self._PY_TO_OP_NAME[client_method] + except KeyError: + raise UnknownClientMethodError(method_name=client_method) + + operation_model = self.meta.service_model.operation_model( + operation_name) + + params = self._emit_api_params(params, operation_model, context) + + # Create a request dict based on the params to serialize. + request_dict = serializer.serialize_to_request( + params, operation_model) + + logger.info(headers) + logger.info(request_dict) + + # Switch out the http method if user specified it. + if http_method is not None: + request_dict['method'] = http_method + + # + if headers is not None: + request_dict['headers'].update(headers) + # + # Prepare the request dict by including the client's endpoint url. + prepare_request_dict( + request_dict, endpoint_url=self.meta.endpoint_url, context=context) + + # Generate the presigned url. + return request_signer.generate_presigned_url( + request_dict=request_dict, expires_in=expires_in, + operation_name=operation_name) + class S3Provider(provider.BaseProvider): """Provider for Amazon's S3 cloud storage service. @@ -43,9 +129,12 @@ class S3Provider(provider.BaseProvider): * A GET prefix query against a non-existent path returns 200 """ NAME = 's3' + _s3 = None + _client = None + _region = None def __init__(self, auth, credentials, settings): - """ + """Initialize S3Provider .. note:: Neither `S3Connection#__init__` nor `S3Connection#get_bucket` @@ -57,48 +146,77 @@ def __init__(self, auth, credentials, settings): """ super().__init__(auth, credentials, settings) - self.connection = S3Connection(credentials['access_key'], - credentials['secret_key'], calling_format=OrdinaryCallingFormat()) - self.bucket = self.connection.get_bucket(settings['bucket'], validate=False) + self.credentials = credentials + self.bucket_name = settings['bucket'] self.encrypt_uploads = self.settings.get('encrypt_uploads', False) - self.region = None + + # TODO Move client creaation to `__aenter__` + @property + async def client(self): + if self._client is None: + # In order to make a client that we can use on any region, we need + # to supply the client with a string of the region name. First we + # make a temporary client in order to get that string. We put the + # client creation inside a lmabda so its easier to call twice. + # This must be a lambda an *not* a partial, because we want the + # expression reevaluated each time. + _make_client = lambda: boto3.client( + 's3', + region_name=self._region, + aws_access_key_id=self.credentials['access_key'], + aws_secret_access_key=self.credentials['secret_key'], + endpoint_url='http{}://{}:{}'.format( + 's' if self.credentials['encrypted'] else '', + self.credentials['host'], + self.credentials['port'] + ) if self.credentials['host'] != 's3.amazonaws.com' else None + ) + self._region = _make_client().get_bucket_location( + Bucket=self.bucket_name + ).get('LocationConstraint', None) + # Remake client after getting bucket location + self._client = _make_client() + # Put the patched version of the url signer on the client. + self._client.__class__.generate_presigned_url = generate_presigned_url + return self._client + + @property + def s3(self): + if self._s3 is None: + self._s3 = boto3.resource('s3') + return self._s3 + + @property + async def region(self): + # Awaiting self.client ensures the region is set properly; if we have a + # client set on our provider, we know the region is correct because we + # need the region in order to make the client. + await self.client + return self._region async def validate_v1_path(self, path, **kwargs): - await self._check_region() + """Validates a waterbutler path + """ + wb_path = WaterButlerPath(path) if path == '/': - return WaterButlerPath(path) + return wb_path implicit_folder = path.endswith('/') if implicit_folder: - params = {'prefix': path, 'delimiter': '/'} - resp = await self.make_request( - 'GET', - functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=params), - params=params, - expects=(200, 404), - throws=exceptions.MetadataError, - ) + await self._metadata_folder(wb_path.path) else: - resp = await self.make_request( - 'HEAD', - functools.partial(self.bucket.new_key(path).generate_url, settings.TEMP_URL_SECS, 'HEAD'), - expects=(200, 404), - throws=exceptions.MetadataError, - ) - - await resp.release() - - if resp.status == 404: - raise exceptions.NotFoundError(str(path)) + await self._metadata_file(wb_path.path) - return WaterButlerPath(path) + return wb_path + # Do we call this anywhere, and why can't we just use the constructor? async def validate_path(self, path, **kwargs): return WaterButlerPath(path) def can_duplicate_names(self): + # TODO This should be a class attribute return True def can_intra_copy(self, dest_provider, path=None): @@ -111,28 +229,34 @@ async def intra_copy(self, dest_provider, source_path, dest_path): """Copy key from one S3 bucket to another. The credentials specified in `dest_provider` must have read access to `source.bucket`. """ - await self._check_region() exists = await dest_provider.exists(dest_path) - dest_key = dest_provider.bucket.new_key(dest_path.path) + # TODO move this to `__aenter__` + client = await self.client # ensure no left slash when joining paths source_path = '/' + os.path.join(self.settings['bucket'], source_path.path) headers = {'x-amz-copy-source': parse.quote(source_path)} - url = functools.partial( - dest_key.generate_url, - settings.TEMP_URL_SECS, - 'PUT', - headers=headers, + + sign_url = lambda: client.generate_presigned_url( + 'copy_object', + Params={ + 'Bucket': self.bucket_name, + 'CopySource': source_path, + 'Key': dest_path.path + }, + ExpiresIn=settings.TEMP_URL_SECS, ) - resp = await self.make_request( - 'PUT', url, + response = await self.make_request( + 'PUT', + sign_url, + expects={200}, skip_auto_headers={'CONTENT-TYPE'}, headers=headers, - expects=(200, ), - throws=exceptions.IntraCopyError, + throws=exceptions.IntraCopyError ) - await resp.release() + await response.release() + return (await dest_provider.metadata(dest_path)), not exists async def download(self, path, accept_url=False, revision=None, range=None, **kwargs): @@ -144,40 +268,43 @@ async def download(self, path, accept_url=False, revision=None, range=None, **kw :rtype: :class:`waterbutler.core.streams.ResponseStreamReader` :raises: :class:`waterbutler.core.exceptions.DownloadError` """ - await self._check_region() + get_kwargs = {} if not path.is_file: raise exceptions.DownloadError('No file specified for download', code=400) - if not revision or revision.lower() == 'latest': - query_parameters = None - else: - query_parameters = {'versionId': revision} + if range: + get_kwargs['Range'] = 'bytes={}-{}'.format('', '') - if kwargs.get('displayName'): - response_headers = {'response-content-disposition': 'attachment; filename*=UTF-8\'\'{}'.format(parse.quote(kwargs['displayName']))} - else: - response_headers = {'response-content-disposition': 'attachment'} + # if kwargs.get('displayName'): + # get_kwargs['ResponseContentDisposition'] = 'attachment; filename*=UTF-8\'\'{}'.format(parse.quote(kwargs['displayName'])) + # else: + # get_kwargs['ResponseContentDisposition'] = 'attachment' - url = functools.partial( - self.bucket.new_key(path.path).generate_url, - settings.TEMP_URL_SECS, - query_parameters=query_parameters, - response_headers=response_headers - ) + if revision: + get_kwargs['VersionId'] = revision - if accept_url: - return url() + # TODO move this to `__aenter__` + client = await self.client - resp = await self.make_request( + sign_url = lambda: client.generate_presigned_url( + 'get_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path.path + }, + ExpiresIn=settings.TEMP_URL_SECS, + HttpMethod='GET' + ) + + response = await self.make_request( 'GET', - url, + sign_url, range=range, - expects=(200, 206), - throws=exceptions.DownloadError, + expects={200, 206}, + throws=exceptions.DownloadError ) - - return streams.ResponseStreamReader(resp) + return streams.ResponseStreamReader(response) async def upload(self, stream, path, conflict='replace', **kwargs): """Uploads the given stream to S3 @@ -187,11 +314,11 @@ async def upload(self, stream, path, conflict='replace', **kwargs): :rtype: dict, bool """ - await self._check_region() - path, exists = await self.handle_name_conflict(path, conflict=conflict) stream.add_writer('md5', streams.HashStreamWriter(hashlib.md5)) + # TODO move this to `__aenter__` + client = await self.client headers = {'Content-Length': str(stream.size)} # this is usually set in boto.s3.key.generate_url, but do it here @@ -199,26 +326,32 @@ async def upload(self, stream, path, conflict='replace', **kwargs): if self.encrypt_uploads: headers['x-amz-server-side-encryption'] = 'AES256' - upload_url = functools.partial( - self.bucket.new_key(path.path).generate_url, - settings.TEMP_URL_SECS, - 'PUT', - headers=headers, + sign_url = lambda: client.generate_presigned_url( + 'put_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path.path, + 'ContentLength': stream.size, + **({'ServerSideEncryption': 'AES256'} if self.encrypt_uploads else {}) + }, + ExpiresIn=settings.TEMP_URL_SECS, ) - resp = await self.make_request( + + response = await self.make_request( 'PUT', - upload_url, + sign_url, data=stream, skip_auto_headers={'CONTENT-TYPE'}, headers=headers, - expects=(200, 201, ), - throws=exceptions.UploadError, + expects={200, 206}, + throws=exceptions.DownloadError ) + await response.release() + # md5 is returned as ETag header as long as server side encryption is not used. - if stream.writers['md5'].hexdigest != resp.headers['ETag'].replace('"', ''): + if stream.writers['md5'].hexdigest != response.headers['ETag'].replace('"', ''): raise exceptions.UploadChecksumMismatchError() - await resp.release() return (await self.metadata(path, **kwargs)), not exists async def delete(self, path, confirm_delete=0, **kwargs): @@ -227,8 +360,6 @@ async def delete(self, path, confirm_delete=0, **kwargs): :param str path: The path of the key to delete :param int confirm_delete: Must be 1 to confirm root folder delete """ - await self._check_region() - if path.is_root: if not confirm_delete == 1: raise exceptions.DeleteError( @@ -237,23 +368,40 @@ async def delete(self, path, confirm_delete=0, **kwargs): ) if path.is_file: - resp = await self.make_request( - 'DELETE', - self.bucket.new_key(path.path).generate_url(settings.TEMP_URL_SECS, 'DELETE'), - expects=(200, 204, ), - throws=exceptions.DeleteError, - ) - await resp.release() + await self._delete_file(path, **kwargs) else: await self._delete_folder(path, **kwargs) + async def _delete_file(self, path, **kwargs): + """Deletes a single object located at a certain key. + + Called from: func: delete if path.is_file + """ + # TODO move this to `__aenter__` + client = await self.client + sign_url = lambda: client.generate_presigned_url( + 'delete_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path.path + }, + ExpiresIn=settings.TEMP_URL_SECS, + HttpMethod='DELETE', + ) + resp = await self.make_request( + 'DELETE', + sign_url, + expects={200, 204}, + throws=exceptions.DeleteError, + ) + await resp.release() + async def _delete_folder(self, path, **kwargs): """Query for recursive contents of folder and delete in batches of 1000 Called from: func: delete if not path.is_file - Calls: func: self._check_region - func: self.make_request + Calls: func: self.make_request func: self.bucket.generate_url :param *ProviderPath path: Path to be deleted @@ -264,81 +412,97 @@ async def _delete_folder(self, path, **kwargs): To fully delete an occupied folder, we must delete all of the comprising objects. Amazon provides a bulk delete operation to simplify this. """ - await self._check_region() + # TODO move this to `__aenter__` + client = await self.client + + # Needs to be a lambda; *not* partial, so offset is reevaluated + # each time it's called. This is done so that we don't need to + # create a new callable object each request we make. - more_to_come = True - content_keys = [] - query_params = {'prefix': path.path} + # The wierdness with using ** on the Params is because boto3 is + # rather draconian about arguments; passing None for this param results + # in an error that the None object is not of type string. + + # `marker` needs to be defined before the url signer so that it exists + # when the url signer is defined. It is used for pagination, to + # determine which page is returned by the request. marker = None + sign_list_url = lambda: client.generate_presigned_url( + 'list_objects_v2', + Params={ + 'Bucket': self.bucket_name, + 'Prefix': path.path, + 'Delimiter': '/', + **({'Marker': marker} if marker is not None else {}) + }, + ExpiresIn=settings.TEMP_URL_SECS + ) - while more_to_come: - if marker is not None: - query_params['marker'] = marker + objects_to_delete = [] + + delete_payload = '' + sign_delete_url = lambda: client.generate_presigned_url( + 'delete_objects', + Params={ + 'Bucket': self.bucket_name, + 'Delete': { + 'Objects': [{'Key': object['Key']} for object in objects_to_delete] + } + }, + Headers={ + 'Content-Length': str(len(delete_payload)), + 'Content-MD5': compute_md5(BytesIO(delete_payload))[1], + 'Content-Type': 'text/xml' + } + ) - resp = await self.make_request( + # S3 'truncates' responses that would list over 1000 objects. The + # response will contain a key, 'IsTruncated', if there were more than + # 1000 objects. Before the first request, we assume the list is + # truncated, so that at least one request will be made. + truncated = True + while truncated: + list_response = await self.make_request( 'GET', - self.bucket.generate_url(settings.TEMP_URL_SECS, 'GET', query_parameters=query_params), - params=query_params, - expects=(200, ), + sign_list_url, + expects={200, 204}, throws=exceptions.MetadataError, ) + page = xmltodict.parse( + await list_response.read(), + strip_whitespace=False, + force_list={'Contents'} + ) + marker = page['ListBucketResult'].get('NextMarker', None) + truncated = page['ListBucketResult'].get('IsTruncated', 'false') != 'false' - contents = await resp.read() - parsed = xmltodict.parse(contents, strip_whitespace=False)['ListBucketResult'] - more_to_come = parsed.get('IsTruncated') == 'true' - contents = parsed.get('Contents', []) - - if isinstance(contents, dict): - contents = [contents] - - content_keys.extend([content['Key'] for content in contents]) - if len(content_keys) > 0: - marker = content_keys[-1] - - # Query against non-existant folder does not return 404 - if len(content_keys) == 0: - raise exceptions.NotFoundError(str(path)) - - while len(content_keys) > 0: - key_batch = content_keys[:1000] - del content_keys[:1000] - - payload = '' - payload += '' - payload += ''.join(map( - lambda x: '{}'.format(xml.sax.saxutils.escape(x)), - key_batch - )) - payload += '' - payload = payload.encode('utf-8') - md5 = compute_md5(BytesIO(payload)) - - query_params = {'delete': ''} - headers = { - 'Content-Length': str(len(payload)), - 'Content-MD5': md5[1], - 'Content-Type': 'text/xml', - } + objects_to_delete = page['ListBucketResult'].get('Contents', []) - # We depend on a customized version of boto that can make query parameters part of - # the signature. - url = functools.partial( - self.bucket.generate_url, - settings.TEMP_URL_SECS, - 'POST', - query_parameters=query_params, - headers=headers, - ) - resp = await self.make_request( + delete_payload = '{}'.format( + ''.join([ + '{}'.format(object['Key']) + for object in objects_to_delete + ]) + ).encode('utf-8') + + # TODO Don't wait for the delete to finish before requesting the + # next batch, or sending that delete request. + delete_response = await self.make_request( 'POST', - url, - params=query_params, - data=payload, - headers=headers, - expects=(200, 204, ), - throws=exceptions.DeleteError, + sign_delete_url, + data=delete_payload, + headers={ + 'Content-Length': str(len(delete_payload)), + 'Content-MD5': compute_md5(BytesIO(delete_payload))[1], + 'Content-Type': 'text/xml' + } ) - await resp.release() + await delete_response.release() + del delete_response + del list_response + + # TODO Put the delete requests in a list of tasks and wait for all of them to + # finish here, before returning async def revisions(self, path, **kwargs): """Get past versions of the requested key @@ -346,22 +510,28 @@ async def revisions(self, path, **kwargs): :param str path: The path to a key :rtype list: """ - await self._check_region() + # TODO move this to `__aenter__` + client = await self.client + sign_url = lambda: client.generate_signed_url( + 'list_object_versions', + Params={ + 'Bucket': self.bucket_name, + 'Delimiter': '/', + 'Prefix': path.path + } + ) - query_params = {'prefix': path.path, 'delimiter': '/', 'versions': ''} - url = functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=query_params) - resp = await self.make_request( - 'GET', - url, - params=query_params, - expects=(200, ), - throws=exceptions.MetadataError, + response = await self.make_request( + 'POST', + sign_url, + expects={200}, + throws=exceptions.MetadataError ) - content = await resp.read() - versions = xmltodict.parse(content)['ListVersionsResult'].get('Version') or [] - if isinstance(versions, dict): - versions = [versions] + versions = xmltodict.parse( + await response.release(), + force_list={'Version'} + )['ListVersionsResult'].get('Version', []) return [ S3Revision(item) @@ -369,148 +539,188 @@ async def revisions(self, path, **kwargs): if item['Key'] == path.path ] - async def metadata(self, path, revision=None, **kwargs): - """Get Metadata about the requested file or folder - - :param WaterButlerPath path: The path to a key or folder - :rtype: dict or list - """ - await self._check_region() - - if path.is_dir: - return (await self._metadata_folder(path)) - - return (await self._metadata_file(path, revision=revision)) - async def create_folder(self, path, folder_precheck=True, **kwargs): - """ + """Create an empty object on the bucket that contains a trailing slash :param str path: The path to create a folder at """ - await self._check_region() - + # TODO move this to `__aenter__` + client = await self.client WaterButlerPath.validate_folder(path) if folder_precheck: if (await self.exists(path)): raise exceptions.FolderNamingConflict(path.name) + sign_url = lambda: client.generate_presigned_url( + 'put_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path, + }, + ExpiresIn=settings.TEMP_URL_SECS, + ) async with self.request( 'PUT', - functools.partial(self.bucket.new_key(path.path).generate_url, settings.TEMP_URL_SECS, 'PUT'), + sign_url, skip_auto_headers={'CONTENT-TYPE'}, - expects=(200, 201), - throws=exceptions.CreateFolderError + expects={200, 201}, + throws=exceptions.CreateFolderError, ): return S3FolderMetadata({'Prefix': path.path}) - async def _metadata_file(self, path, revision=None): - await self._check_region() + async def metadata(self, path, revision=None, **kwargs): + """Get Metadata about the requested file or folder + + :param WaterButlerPath path: The path to a key or folder + :rtype: dict or list + """ + if path.is_dir: + return (await self._metadata_folder(path.path)) + # store a hash of these args and the result in redis? - if revision == 'Latest': + return (await self._metadata_file(path.path, revision=revision)) + + async def _metadata_file(self, path, revision=None): + """Load metadata for a single object in the bucket. + """ + # TODO move this to `__aenter__` + client = await self.client + + # Homogenise any weird version ids + if any({ + revision == 'Latest', + revision == '', + not revision + }): revision = None - resp = await self.make_request( + + sign_url = lambda: client.generate_presigned_url( + 'head_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path, + **({'VersionId': revision} if revision is not None else {}) + }, + ExpiresIn=settings.TEMP_URL_SECS, + ) + response = await self.make_request( 'HEAD', - functools.partial( - self.bucket.new_key(path.path).generate_url, - settings.TEMP_URL_SECS, - 'HEAD', - query_parameters={'versionId': revision} if revision else None - ), - expects=(200, ), + sign_url, + expects={200, 204}, throws=exceptions.MetadataError, ) - await resp.release() - return S3FileMetadataHeaders(path.path, resp.headers) - - async def _metadata_folder(self, path): - await self._check_region() + await response.release() - params = {'prefix': path.path, 'delimiter': '/'} - resp = await self.make_request( - 'GET', - functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=params), - params=params, - expects=(200, ), - throws=exceptions.MetadataError, + return S3FileMetadataHeaders( + path, + headers=response.headers # TODO Fix S3MetadataFileHeaders ) - contents = await resp.read() + async def _metadata_folder(self, path): + """Get metadata about the contents of a bucket. This is either the + contents at the root of the bucket, or a folder has + been selected as a prefix by the user + """ + # TODO move this to `__aenter__` + client = await self.client + + # Needs to be a lambda; *not* partial, so offset is reevaluated + # each time it's called. This is done so that we don't need to + # create a new callable object each request we make. - parsed = xmltodict.parse(contents, strip_whitespace=False)['ListBucketResult'] + # The wierdness with using ** on the Params is because boto3 is + # rather draconian about arguments; passing None for this param results + # in an error that the None object is not of type string. - contents = parsed.get('Contents', []) - prefixes = parsed.get('CommonPrefixes', []) + # `marker` needs to be defined before the url signer so that it exists + # when the url signer is defined. It is used for pagination, to + # determine which page is returned by the request. + marker = None + sign_url = lambda: client.generate_presigned_url( + 'list_objects_v2', + Params={ + 'Bucket': self.bucket_name, + 'Prefix': path, + 'Delimiter': '/', + **({'Marker': marker} if marker is not None else {}) + }, + ExpiresIn=settings.TEMP_URL_SECS + ) - if not contents and not prefixes and not path.is_root: - # If contents and prefixes are empty then this "folder" - # must exist as a key with a / at the end of the name - # if the path is root there is no need to test if it exists - resp = await self.make_request( + # S3 'truncates' responses that would list over 1000 objects. The + # response will contain a key, 'IsTruncated', if there were more than + # 1000 objects. Before the first request, we assume the list is + # truncated, so that at least one request will be made. + truncated = True + + # Each request will return 0 or more 'contents' and 'common prefixes'. + # Contents contains keys that begin with 'prefix' and contain no + # delimiter characters after the characters that match the prefix. + # Common prefixes match any keys that do contain a delimiter after the + # characters that match the prefix. Each request extends the `contents` + # and `prefixes` arrays with the respective contents and prefixes that + # were returned in the request. + contents = [] + prefixes = [] + + while truncated: + response = await self.make_request( + 'GET', + sign_url, + expects={200, 204}, + throws=exceptions.MetadataError, + ) + page = xmltodict.parse( + await response.read(), + strip_whitespace=False, + force_list={'CommonPrefixes', 'Contents'} + ) + prefixes.extend(page['ListBucketResult'].get('CommonPrefixes', [])) + contents.extend(page['ListBucketResult'].get('Contents', [])) + + marker = page['ListBucketResult'].get('NextMarker', None) + truncated = page['ListBucketResult'].get('IsTruncated', 'false') != 'false' + del response + + del sign_url + + items = [] + # If there are keys that have the provided prefix... + if contents or prefixes: + # Prefixes represent 'folders' + items.extend([S3FolderMetadata(prefix) for prefix in prefixes]) + + for content in contents: + # Only care about items that are not the same as where the + # addon is mounted. + if content['Key'] != path: + items.append( + S3FolderKeyMetadata(content) + if content['Key'].endswith('/') + else S3FileMetadata(content) + ) + + # If contents and prefixes are empty, but this is not the root + # path, then this "folder" must exist as a key with a / at the + # end of the name. + elif not path == "": + sign_url = lambda: client.generate_presigned_url( + 'head_object', + Params={ + 'Bucket': self.bucket_name, + 'Key': path, + }, + ExpiresIn=settings.TEMP_URL_SECS + ) + response = await self.make_request( 'HEAD', - functools.partial(self.bucket.new_key(path.path).generate_url, settings.TEMP_URL_SECS, 'HEAD'), - expects=(200, ), + sign_url, + expects={200, 204}, throws=exceptions.MetadataError, ) - await resp.release() - - if isinstance(contents, dict): - contents = [contents] - - if isinstance(prefixes, dict): - prefixes = [prefixes] - - items = [ - S3FolderMetadata(item) - for item in prefixes - ] - - for content in contents: - if content['Key'] == path.path: - continue - - if content['Key'].endswith('/'): - items.append(S3FolderKeyMetadata(content)) - else: - items.append(S3FileMetadata(content)) + del sign_url + del response return items - async def _check_region(self): - """Lookup the region via bucket name, then update the host to match. - - Manually constructing the connection hostname allows us to use OrdinaryCallingFormat - instead of SubdomainCallingFormat, which can break on buckets with periods in their name. - The default region, US East (N. Virginia), is represented by the empty string and does not - require changing the host. Ireland is represented by the string 'EU', with the host - parameter 'eu-west-1'. All other regions return the host parameter as the region name. - - Region Naming: http://docs.aws.amazon.com/general/latest/gr/rande.html#s3_region - """ - if self.region is None: - self.region = await self._get_bucket_region() - if self.region == 'EU': - self.region = 'eu-west-1' - - if self.region != '': - self.connection.host = self.connection.host.replace('s3.', 's3-' + self.region + '.', 1) - self.connection._auth_handler = get_auth_handler( - self.connection.host, boto_config, self.connection.provider, self.connection._required_auth_capability()) - - self.metrics.add('region', self.region) - - async def _get_bucket_region(self): - """Bucket names are unique across all regions. - - Endpoint doc: - http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html - """ - resp = await self.make_request( - 'GET', - functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters={'location': ''}), - expects=(200, ), - throws=exceptions.MetadataError, - ) - contents = await resp.read() - parsed = xmltodict.parse(contents, strip_whitespace=False) - return parsed['LocationConstraint'].get('#text', '') diff --git a/waterbutler/providers/s3/streams.py b/waterbutler/providers/s3/streams.py new file mode 100644 index 000000000..5d403d4bd --- /dev/null +++ b/waterbutler/providers/s3/streams.py @@ -0,0 +1,29 @@ +import asyncio + +from botocore.response import StreamingBody + +from waterbutler.core.streams.base import BaseStream + + +class S3ResponseBodyStream(BaseStream): + def __init__(self, data): + super().__init__() + + if not isinstance(data['Body'], StreamingBody): + raise TypeError('Data must be a StreamingBody, found {!r}'.format(type(data['body']))) + + self.content_type = data['ContentType'] + self._size = data['ContentLength'] + self.streaming_body = data['Body'] + + @property + def size(self): + return self._size + + async def _read(self, n=None): + n = self._size if n is None else n + + chunk = self.streaming_body.read(amt=n) + if not chunk: + self.feed_eof() + return chunk diff --git a/waterbutler/server/api/v1/provider/__init__.py b/waterbutler/server/api/v1/provider/__init__.py index b31c9e6d0..c85cbac6b 100644 --- a/waterbutler/server/api/v1/provider/__init__.py +++ b/waterbutler/server/api/v1/provider/__init__.py @@ -32,11 +32,17 @@ def list_or_value(value): @tornado.web.stream_request_body class ProviderHandler(core.BaseHandler, CreateMixin, MetadataMixin, MoveCopyMixin): + """ProviderHandler + Handler for provider operations. Inherits from provider handler mixins + Create, Metadata, and MoveCopy + """ PRE_VALIDATORS = {'put': 'prevalidate_put', 'post': 'prevalidate_post'} POST_VALIDATORS = {'put': 'postvalidate_put'} PATTERN = r'/resources/(?P(?:\w|\d)+)/providers/(?P(?:\w|\d)+)(?P/.*/?)' async def prepare(self, *args, **kwargs): + """Prepare to handle request + """ method = self.request.method.lower() # TODO Find a nicer way to handle this @@ -174,6 +180,7 @@ def on_finish(self): self._send_hook(action) + def _send_hook(self, action): source = None destination = None