diff --git a/requirements.txt b/requirements.txt
index 63e9e8888..976055383 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
agent==0.1.2
aiohttp==0.18.2
git+https://github.com/felliott/boto.git@feature/gen-url-query-params-6#egg=boto
+boto3==1.7.36
celery==3.1.17
furl==0.4.2
google-auth==1.4.1
@@ -15,7 +16,7 @@ raven==5.27.0
setuptools==37.0.0
stevedore==1.2.0
tornado==4.3
-xmltodict==0.9.0
+xmltodict==0.11.0
# Issue: certifi-2015.9.6.1 and 2015.9.6.2 fail verification (https://github.com/certifi/python-certifi/issues/26)
certifi==2015.4.28
diff --git a/waterbutler/auth/osf/handler.py b/waterbutler/auth/osf/handler.py
index 6f976976f..174e230cb 100644
--- a/waterbutler/auth/osf/handler.py
+++ b/waterbutler/auth/osf/handler.py
@@ -40,6 +40,7 @@ def build_payload(self, bundle, view_only=None, cookie=None):
return query_params
async def make_request(self, params, headers, cookies):
+ data = None
try:
response = await aiohttp.request(
'get',
diff --git a/waterbutler/core/streams/http.py b/waterbutler/core/streams/http.py
index fb2955344..957e9525c 100644
--- a/waterbutler/core/streams/http.py
+++ b/waterbutler/core/streams/http.py
@@ -179,6 +179,7 @@ def __init__(self, request, inner):
super().__init__()
self.inner = inner
self.request = request
+ self.offset = 0
@property
def size(self):
@@ -187,12 +188,16 @@ def size(self):
def at_eof(self):
return self.inner.at_eof()
+ def tell(self):
+ return self.offset
+
async def _read(self, size):
if self.inner.at_eof():
return b''
if size < 0:
return (await self.inner.read(size))
try:
+ self.offset += size
return (await self.inner.readexactly(size))
except asyncio.IncompleteReadError as e:
return e.partial
diff --git a/waterbutler/core/utils.py b/waterbutler/core/utils.py
index 2e953418e..8be49cbab 100644
--- a/waterbutler/core/utils.py
+++ b/waterbutler/core/utils.py
@@ -105,15 +105,17 @@ async def send_signed_request(method, url, payload):
))
-def normalize_datetime(date_string):
- if date_string is None:
+def normalize_datetime(date):
+ if date is None:
return None
- parsed_datetime = dateutil.parser.parse(date_string)
- if not parsed_datetime.tzinfo:
- parsed_datetime = parsed_datetime.replace(tzinfo=pytz.UTC)
- parsed_datetime = parsed_datetime.astimezone(tz=pytz.UTC)
- parsed_datetime = parsed_datetime.replace(microsecond=0)
- return parsed_datetime.isoformat()
+ if isinstance(date, str):
+ date = dateutil.parser.parse(date)
+ if not date.tzinfo:
+ date = date.replace(tzinfo=pytz.UTC)
+ date = date.astimezone(tz=pytz.UTC)
+ date = date.replace(microsecond=0)
+ return date.isoformat()
+
class ZipStreamGenerator:
diff --git a/waterbutler/providers/s3/metadata.py b/waterbutler/providers/s3/metadata.py
index bcdbdea4a..9e9acde45 100644
--- a/waterbutler/providers/s3/metadata.py
+++ b/waterbutler/providers/s3/metadata.py
@@ -1,6 +1,7 @@
import os
from waterbutler.core import metadata
+from waterbutler.core import utils
class S3Metadata(metadata.BaseMetadata):
@@ -16,11 +17,13 @@ def name(self):
class S3FileMetadataHeaders(S3Metadata, metadata.BaseFileMetadata):
- def __init__(self, path, headers):
+ def __init__(self, path, s3_object=None, headers=None):
self._path = path
+ self.obj = headers
+ self._etag = None
# Cast to dict to clone as the headers will
# be destroyed when the request leaves scope
- super().__init__(dict(headers))
+ super().__init__(headers)
@property
def path(self):
@@ -44,16 +47,17 @@ def created_utc(self):
@property
def etag(self):
- return self.raw['ETAG'].replace('"', '')
+ if self._etag is None:
+ self._etag = self.raw['ETAG'].replace('"', '')
+ return self._etag
@property
def extra(self):
- md5 = self.raw['ETAG'].replace('"', '')
return {
- 'md5': md5,
+ 'md5': self.etag,
'encryption': self.raw.get('X-AMZ-SERVER-SIDE-ENCRYPTION', ''),
'hashes': {
- 'md5': md5,
+ 'md5': self.etag,
},
}
@@ -132,7 +136,7 @@ def version(self):
@property
def modified(self):
- return self.raw['LastModified']
+ return utils.normalize_datetime(self.raw['LastModified'])
@property
def extra(self):
diff --git a/waterbutler/providers/s3/provider.py b/waterbutler/providers/s3/provider.py
index b09c7d07a..42157a086 100644
--- a/waterbutler/providers/s3/provider.py
+++ b/waterbutler/providers/s3/provider.py
@@ -1,12 +1,16 @@
+import asyncio
import os
+import itertools
import hashlib
import functools
from urllib import parse
+import logging
import xmltodict
import xml.sax.saxutils
+#import aioboto3
from boto import config as boto_config
from boto.compat import BytesIO # type: ignore
from boto.utils import compute_md5
@@ -14,18 +18,100 @@
from boto.s3.connection import S3Connection
from boto.s3.connection import OrdinaryCallingFormat
+
+import boto3
+from botocore.awsrequest import prepare_request_dict
+from botocore.client import Config
+from botocore.exceptions import (
+ ClientError,
+ UnknownClientMethodError
+)
+from botocore.signers import _should_use_global_endpoint
+
+
from waterbutler.core import streams
from waterbutler.core import provider
from waterbutler.core import exceptions
from waterbutler.core.path import WaterButlerPath
from waterbutler.providers.s3 import settings
+from waterbutler.providers.s3.streams import S3ResponseBodyStream
from waterbutler.providers.s3.metadata import S3Revision
from waterbutler.providers.s3.metadata import S3FileMetadata
from waterbutler.providers.s3.metadata import S3FolderMetadata
from waterbutler.providers.s3.metadata import S3FolderKeyMetadata
from waterbutler.providers.s3.metadata import S3FileMetadataHeaders
+logger = logging.getLogger(__name__)
+
+
+def generate_presigned_url(self, ClientMethod, Params=None, Headers=None, ExpiresIn=3600,
+ HttpMethod=None):
+ """Generate a presigned url given a client, its method, and arguments
+ :type ClientMethod: string
+ :param ClientMethod: The client method to presign for
+ :type Params: dict
+ :param Params: The parameters normally passed to
+ ``ClientMethod``.
+ :type ExpiresIn: int
+ :param ExpiresIn: The number of seconds the presigned url is valid
+ for. By default it expires in an hour (3600 seconds)
+ :type HttpMethod: string
+ :param HttpMethod: The http method to use on the generated url. By
+ default, the http method is whatever is used in the method's model.
+ :returns: The presigned url
+ """
+ client_method = ClientMethod
+ params = Params
+ if params is None:
+ params = {}
+ #
+ headers = Headers
+ #
+ expires_in = ExpiresIn
+ http_method = HttpMethod
+ context = {
+ 'is_presign_request': True,
+ 'use_global_endpoint': _should_use_global_endpoint(self),
+ }
+
+ request_signer = self._request_signer
+ serializer = self._serializer
+
+ try:
+ operation_name = self._PY_TO_OP_NAME[client_method]
+ except KeyError:
+ raise UnknownClientMethodError(method_name=client_method)
+
+ operation_model = self.meta.service_model.operation_model(
+ operation_name)
+
+ params = self._emit_api_params(params, operation_model, context)
+
+ # Create a request dict based on the params to serialize.
+ request_dict = serializer.serialize_to_request(
+ params, operation_model)
+
+ logger.info(headers)
+ logger.info(request_dict)
+
+ # Switch out the http method if user specified it.
+ if http_method is not None:
+ request_dict['method'] = http_method
+
+ #
+ if headers is not None:
+ request_dict['headers'].update(headers)
+ #
+ # Prepare the request dict by including the client's endpoint url.
+ prepare_request_dict(
+ request_dict, endpoint_url=self.meta.endpoint_url, context=context)
+
+ # Generate the presigned url.
+ return request_signer.generate_presigned_url(
+ request_dict=request_dict, expires_in=expires_in,
+ operation_name=operation_name)
+
class S3Provider(provider.BaseProvider):
"""Provider for Amazon's S3 cloud storage service.
@@ -43,9 +129,12 @@ class S3Provider(provider.BaseProvider):
* A GET prefix query against a non-existent path returns 200
"""
NAME = 's3'
+ _s3 = None
+ _client = None
+ _region = None
def __init__(self, auth, credentials, settings):
- """
+ """Initialize S3Provider
.. note::
Neither `S3Connection#__init__` nor `S3Connection#get_bucket`
@@ -57,48 +146,77 @@ def __init__(self, auth, credentials, settings):
"""
super().__init__(auth, credentials, settings)
- self.connection = S3Connection(credentials['access_key'],
- credentials['secret_key'], calling_format=OrdinaryCallingFormat())
- self.bucket = self.connection.get_bucket(settings['bucket'], validate=False)
+ self.credentials = credentials
+ self.bucket_name = settings['bucket']
self.encrypt_uploads = self.settings.get('encrypt_uploads', False)
- self.region = None
+
+ # TODO Move client creaation to `__aenter__`
+ @property
+ async def client(self):
+ if self._client is None:
+ # In order to make a client that we can use on any region, we need
+ # to supply the client with a string of the region name. First we
+ # make a temporary client in order to get that string. We put the
+ # client creation inside a lmabda so its easier to call twice.
+ # This must be a lambda an *not* a partial, because we want the
+ # expression reevaluated each time.
+ _make_client = lambda: boto3.client(
+ 's3',
+ region_name=self._region,
+ aws_access_key_id=self.credentials['access_key'],
+ aws_secret_access_key=self.credentials['secret_key'],
+ endpoint_url='http{}://{}:{}'.format(
+ 's' if self.credentials['encrypted'] else '',
+ self.credentials['host'],
+ self.credentials['port']
+ ) if self.credentials['host'] != 's3.amazonaws.com' else None
+ )
+ self._region = _make_client().get_bucket_location(
+ Bucket=self.bucket_name
+ ).get('LocationConstraint', None)
+ # Remake client after getting bucket location
+ self._client = _make_client()
+ # Put the patched version of the url signer on the client.
+ self._client.__class__.generate_presigned_url = generate_presigned_url
+ return self._client
+
+ @property
+ def s3(self):
+ if self._s3 is None:
+ self._s3 = boto3.resource('s3')
+ return self._s3
+
+ @property
+ async def region(self):
+ # Awaiting self.client ensures the region is set properly; if we have a
+ # client set on our provider, we know the region is correct because we
+ # need the region in order to make the client.
+ await self.client
+ return self._region
async def validate_v1_path(self, path, **kwargs):
- await self._check_region()
+ """Validates a waterbutler path
+ """
+ wb_path = WaterButlerPath(path)
if path == '/':
- return WaterButlerPath(path)
+ return wb_path
implicit_folder = path.endswith('/')
if implicit_folder:
- params = {'prefix': path, 'delimiter': '/'}
- resp = await self.make_request(
- 'GET',
- functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=params),
- params=params,
- expects=(200, 404),
- throws=exceptions.MetadataError,
- )
+ await self._metadata_folder(wb_path.path)
else:
- resp = await self.make_request(
- 'HEAD',
- functools.partial(self.bucket.new_key(path).generate_url, settings.TEMP_URL_SECS, 'HEAD'),
- expects=(200, 404),
- throws=exceptions.MetadataError,
- )
-
- await resp.release()
-
- if resp.status == 404:
- raise exceptions.NotFoundError(str(path))
+ await self._metadata_file(wb_path.path)
- return WaterButlerPath(path)
+ return wb_path
+ # Do we call this anywhere, and why can't we just use the constructor?
async def validate_path(self, path, **kwargs):
return WaterButlerPath(path)
def can_duplicate_names(self):
+ # TODO This should be a class attribute
return True
def can_intra_copy(self, dest_provider, path=None):
@@ -111,28 +229,34 @@ async def intra_copy(self, dest_provider, source_path, dest_path):
"""Copy key from one S3 bucket to another. The credentials specified in
`dest_provider` must have read access to `source.bucket`.
"""
- await self._check_region()
exists = await dest_provider.exists(dest_path)
- dest_key = dest_provider.bucket.new_key(dest_path.path)
+ # TODO move this to `__aenter__`
+ client = await self.client
# ensure no left slash when joining paths
source_path = '/' + os.path.join(self.settings['bucket'], source_path.path)
headers = {'x-amz-copy-source': parse.quote(source_path)}
- url = functools.partial(
- dest_key.generate_url,
- settings.TEMP_URL_SECS,
- 'PUT',
- headers=headers,
+
+ sign_url = lambda: client.generate_presigned_url(
+ 'copy_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'CopySource': source_path,
+ 'Key': dest_path.path
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
)
- resp = await self.make_request(
- 'PUT', url,
+ response = await self.make_request(
+ 'PUT',
+ sign_url,
+ expects={200},
skip_auto_headers={'CONTENT-TYPE'},
headers=headers,
- expects=(200, ),
- throws=exceptions.IntraCopyError,
+ throws=exceptions.IntraCopyError
)
- await resp.release()
+ await response.release()
+
return (await dest_provider.metadata(dest_path)), not exists
async def download(self, path, accept_url=False, revision=None, range=None, **kwargs):
@@ -144,40 +268,43 @@ async def download(self, path, accept_url=False, revision=None, range=None, **kw
:rtype: :class:`waterbutler.core.streams.ResponseStreamReader`
:raises: :class:`waterbutler.core.exceptions.DownloadError`
"""
- await self._check_region()
+ get_kwargs = {}
if not path.is_file:
raise exceptions.DownloadError('No file specified for download', code=400)
- if not revision or revision.lower() == 'latest':
- query_parameters = None
- else:
- query_parameters = {'versionId': revision}
+ if range:
+ get_kwargs['Range'] = 'bytes={}-{}'.format('', '')
- if kwargs.get('displayName'):
- response_headers = {'response-content-disposition': 'attachment; filename*=UTF-8\'\'{}'.format(parse.quote(kwargs['displayName']))}
- else:
- response_headers = {'response-content-disposition': 'attachment'}
+ # if kwargs.get('displayName'):
+ # get_kwargs['ResponseContentDisposition'] = 'attachment; filename*=UTF-8\'\'{}'.format(parse.quote(kwargs['displayName']))
+ # else:
+ # get_kwargs['ResponseContentDisposition'] = 'attachment'
- url = functools.partial(
- self.bucket.new_key(path.path).generate_url,
- settings.TEMP_URL_SECS,
- query_parameters=query_parameters,
- response_headers=response_headers
- )
+ if revision:
+ get_kwargs['VersionId'] = revision
- if accept_url:
- return url()
+ # TODO move this to `__aenter__`
+ client = await self.client
- resp = await self.make_request(
+ sign_url = lambda: client.generate_presigned_url(
+ 'get_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path.path
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
+ HttpMethod='GET'
+ )
+
+ response = await self.make_request(
'GET',
- url,
+ sign_url,
range=range,
- expects=(200, 206),
- throws=exceptions.DownloadError,
+ expects={200, 206},
+ throws=exceptions.DownloadError
)
-
- return streams.ResponseStreamReader(resp)
+ return streams.ResponseStreamReader(response)
async def upload(self, stream, path, conflict='replace', **kwargs):
"""Uploads the given stream to S3
@@ -187,11 +314,11 @@ async def upload(self, stream, path, conflict='replace', **kwargs):
:rtype: dict, bool
"""
- await self._check_region()
-
path, exists = await self.handle_name_conflict(path, conflict=conflict)
stream.add_writer('md5', streams.HashStreamWriter(hashlib.md5))
+ # TODO move this to `__aenter__`
+ client = await self.client
headers = {'Content-Length': str(stream.size)}
# this is usually set in boto.s3.key.generate_url, but do it here
@@ -199,26 +326,32 @@ async def upload(self, stream, path, conflict='replace', **kwargs):
if self.encrypt_uploads:
headers['x-amz-server-side-encryption'] = 'AES256'
- upload_url = functools.partial(
- self.bucket.new_key(path.path).generate_url,
- settings.TEMP_URL_SECS,
- 'PUT',
- headers=headers,
+ sign_url = lambda: client.generate_presigned_url(
+ 'put_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path.path,
+ 'ContentLength': stream.size,
+ **({'ServerSideEncryption': 'AES256'} if self.encrypt_uploads else {})
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
)
- resp = await self.make_request(
+
+ response = await self.make_request(
'PUT',
- upload_url,
+ sign_url,
data=stream,
skip_auto_headers={'CONTENT-TYPE'},
headers=headers,
- expects=(200, 201, ),
- throws=exceptions.UploadError,
+ expects={200, 206},
+ throws=exceptions.DownloadError
)
+ await response.release()
+
# md5 is returned as ETag header as long as server side encryption is not used.
- if stream.writers['md5'].hexdigest != resp.headers['ETag'].replace('"', ''):
+ if stream.writers['md5'].hexdigest != response.headers['ETag'].replace('"', ''):
raise exceptions.UploadChecksumMismatchError()
- await resp.release()
return (await self.metadata(path, **kwargs)), not exists
async def delete(self, path, confirm_delete=0, **kwargs):
@@ -227,8 +360,6 @@ async def delete(self, path, confirm_delete=0, **kwargs):
:param str path: The path of the key to delete
:param int confirm_delete: Must be 1 to confirm root folder delete
"""
- await self._check_region()
-
if path.is_root:
if not confirm_delete == 1:
raise exceptions.DeleteError(
@@ -237,23 +368,40 @@ async def delete(self, path, confirm_delete=0, **kwargs):
)
if path.is_file:
- resp = await self.make_request(
- 'DELETE',
- self.bucket.new_key(path.path).generate_url(settings.TEMP_URL_SECS, 'DELETE'),
- expects=(200, 204, ),
- throws=exceptions.DeleteError,
- )
- await resp.release()
+ await self._delete_file(path, **kwargs)
else:
await self._delete_folder(path, **kwargs)
+ async def _delete_file(self, path, **kwargs):
+ """Deletes a single object located at a certain key.
+
+ Called from: func: delete if path.is_file
+ """
+ # TODO move this to `__aenter__`
+ client = await self.client
+ sign_url = lambda: client.generate_presigned_url(
+ 'delete_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path.path
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
+ HttpMethod='DELETE',
+ )
+ resp = await self.make_request(
+ 'DELETE',
+ sign_url,
+ expects={200, 204},
+ throws=exceptions.DeleteError,
+ )
+ await resp.release()
+
async def _delete_folder(self, path, **kwargs):
"""Query for recursive contents of folder and delete in batches of 1000
Called from: func: delete if not path.is_file
- Calls: func: self._check_region
- func: self.make_request
+ Calls: func: self.make_request
func: self.bucket.generate_url
:param *ProviderPath path: Path to be deleted
@@ -264,81 +412,97 @@ async def _delete_folder(self, path, **kwargs):
To fully delete an occupied folder, we must delete all of the comprising
objects. Amazon provides a bulk delete operation to simplify this.
"""
- await self._check_region()
+ # TODO move this to `__aenter__`
+ client = await self.client
+
+ # Needs to be a lambda; *not* partial, so offset is reevaluated
+ # each time it's called. This is done so that we don't need to
+ # create a new callable object each request we make.
- more_to_come = True
- content_keys = []
- query_params = {'prefix': path.path}
+ # The wierdness with using ** on the Params is because boto3 is
+ # rather draconian about arguments; passing None for this param results
+ # in an error that the None object is not of type string.
+
+ # `marker` needs to be defined before the url signer so that it exists
+ # when the url signer is defined. It is used for pagination, to
+ # determine which page is returned by the request.
marker = None
+ sign_list_url = lambda: client.generate_presigned_url(
+ 'list_objects_v2',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Prefix': path.path,
+ 'Delimiter': '/',
+ **({'Marker': marker} if marker is not None else {})
+ },
+ ExpiresIn=settings.TEMP_URL_SECS
+ )
- while more_to_come:
- if marker is not None:
- query_params['marker'] = marker
+ objects_to_delete = []
+
+ delete_payload = ''
+ sign_delete_url = lambda: client.generate_presigned_url(
+ 'delete_objects',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Delete': {
+ 'Objects': [{'Key': object['Key']} for object in objects_to_delete]
+ }
+ },
+ Headers={
+ 'Content-Length': str(len(delete_payload)),
+ 'Content-MD5': compute_md5(BytesIO(delete_payload))[1],
+ 'Content-Type': 'text/xml'
+ }
+ )
- resp = await self.make_request(
+ # S3 'truncates' responses that would list over 1000 objects. The
+ # response will contain a key, 'IsTruncated', if there were more than
+ # 1000 objects. Before the first request, we assume the list is
+ # truncated, so that at least one request will be made.
+ truncated = True
+ while truncated:
+ list_response = await self.make_request(
'GET',
- self.bucket.generate_url(settings.TEMP_URL_SECS, 'GET', query_parameters=query_params),
- params=query_params,
- expects=(200, ),
+ sign_list_url,
+ expects={200, 204},
throws=exceptions.MetadataError,
)
+ page = xmltodict.parse(
+ await list_response.read(),
+ strip_whitespace=False,
+ force_list={'Contents'}
+ )
+ marker = page['ListBucketResult'].get('NextMarker', None)
+ truncated = page['ListBucketResult'].get('IsTruncated', 'false') != 'false'
- contents = await resp.read()
- parsed = xmltodict.parse(contents, strip_whitespace=False)['ListBucketResult']
- more_to_come = parsed.get('IsTruncated') == 'true'
- contents = parsed.get('Contents', [])
-
- if isinstance(contents, dict):
- contents = [contents]
-
- content_keys.extend([content['Key'] for content in contents])
- if len(content_keys) > 0:
- marker = content_keys[-1]
-
- # Query against non-existant folder does not return 404
- if len(content_keys) == 0:
- raise exceptions.NotFoundError(str(path))
-
- while len(content_keys) > 0:
- key_batch = content_keys[:1000]
- del content_keys[:1000]
-
- payload = ''
- payload += ''
- payload += ''.join(map(
- lambda x: ''.format(xml.sax.saxutils.escape(x)),
- key_batch
- ))
- payload += ''
- payload = payload.encode('utf-8')
- md5 = compute_md5(BytesIO(payload))
-
- query_params = {'delete': ''}
- headers = {
- 'Content-Length': str(len(payload)),
- 'Content-MD5': md5[1],
- 'Content-Type': 'text/xml',
- }
+ objects_to_delete = page['ListBucketResult'].get('Contents', [])
- # We depend on a customized version of boto that can make query parameters part of
- # the signature.
- url = functools.partial(
- self.bucket.generate_url,
- settings.TEMP_URL_SECS,
- 'POST',
- query_parameters=query_params,
- headers=headers,
- )
- resp = await self.make_request(
+ delete_payload = '{}'.format(
+ ''.join([
+ ''.format(object['Key'])
+ for object in objects_to_delete
+ ])
+ ).encode('utf-8')
+
+ # TODO Don't wait for the delete to finish before requesting the
+ # next batch, or sending that delete request.
+ delete_response = await self.make_request(
'POST',
- url,
- params=query_params,
- data=payload,
- headers=headers,
- expects=(200, 204, ),
- throws=exceptions.DeleteError,
+ sign_delete_url,
+ data=delete_payload,
+ headers={
+ 'Content-Length': str(len(delete_payload)),
+ 'Content-MD5': compute_md5(BytesIO(delete_payload))[1],
+ 'Content-Type': 'text/xml'
+ }
)
- await resp.release()
+ await delete_response.release()
+ del delete_response
+ del list_response
+
+ # TODO Put the delete requests in a list of tasks and wait for all of them to
+ # finish here, before returning
async def revisions(self, path, **kwargs):
"""Get past versions of the requested key
@@ -346,22 +510,28 @@ async def revisions(self, path, **kwargs):
:param str path: The path to a key
:rtype list:
"""
- await self._check_region()
+ # TODO move this to `__aenter__`
+ client = await self.client
+ sign_url = lambda: client.generate_signed_url(
+ 'list_object_versions',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Delimiter': '/',
+ 'Prefix': path.path
+ }
+ )
- query_params = {'prefix': path.path, 'delimiter': '/', 'versions': ''}
- url = functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=query_params)
- resp = await self.make_request(
- 'GET',
- url,
- params=query_params,
- expects=(200, ),
- throws=exceptions.MetadataError,
+ response = await self.make_request(
+ 'POST',
+ sign_url,
+ expects={200},
+ throws=exceptions.MetadataError
)
- content = await resp.read()
- versions = xmltodict.parse(content)['ListVersionsResult'].get('Version') or []
- if isinstance(versions, dict):
- versions = [versions]
+ versions = xmltodict.parse(
+ await response.release(),
+ force_list={'Version'}
+ )['ListVersionsResult'].get('Version', [])
return [
S3Revision(item)
@@ -369,148 +539,188 @@ async def revisions(self, path, **kwargs):
if item['Key'] == path.path
]
- async def metadata(self, path, revision=None, **kwargs):
- """Get Metadata about the requested file or folder
-
- :param WaterButlerPath path: The path to a key or folder
- :rtype: dict or list
- """
- await self._check_region()
-
- if path.is_dir:
- return (await self._metadata_folder(path))
-
- return (await self._metadata_file(path, revision=revision))
-
async def create_folder(self, path, folder_precheck=True, **kwargs):
- """
+ """Create an empty object on the bucket that contains a trailing slash
:param str path: The path to create a folder at
"""
- await self._check_region()
-
+ # TODO move this to `__aenter__`
+ client = await self.client
WaterButlerPath.validate_folder(path)
if folder_precheck:
if (await self.exists(path)):
raise exceptions.FolderNamingConflict(path.name)
+ sign_url = lambda: client.generate_presigned_url(
+ 'put_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path,
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
+ )
async with self.request(
'PUT',
- functools.partial(self.bucket.new_key(path.path).generate_url, settings.TEMP_URL_SECS, 'PUT'),
+ sign_url,
skip_auto_headers={'CONTENT-TYPE'},
- expects=(200, 201),
- throws=exceptions.CreateFolderError
+ expects={200, 201},
+ throws=exceptions.CreateFolderError,
):
return S3FolderMetadata({'Prefix': path.path})
- async def _metadata_file(self, path, revision=None):
- await self._check_region()
+ async def metadata(self, path, revision=None, **kwargs):
+ """Get Metadata about the requested file or folder
+
+ :param WaterButlerPath path: The path to a key or folder
+ :rtype: dict or list
+ """
+ if path.is_dir:
+ return (await self._metadata_folder(path.path))
+ # store a hash of these args and the result in redis?
- if revision == 'Latest':
+ return (await self._metadata_file(path.path, revision=revision))
+
+ async def _metadata_file(self, path, revision=None):
+ """Load metadata for a single object in the bucket.
+ """
+ # TODO move this to `__aenter__`
+ client = await self.client
+
+ # Homogenise any weird version ids
+ if any({
+ revision == 'Latest',
+ revision == '',
+ not revision
+ }):
revision = None
- resp = await self.make_request(
+
+ sign_url = lambda: client.generate_presigned_url(
+ 'head_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path,
+ **({'VersionId': revision} if revision is not None else {})
+ },
+ ExpiresIn=settings.TEMP_URL_SECS,
+ )
+ response = await self.make_request(
'HEAD',
- functools.partial(
- self.bucket.new_key(path.path).generate_url,
- settings.TEMP_URL_SECS,
- 'HEAD',
- query_parameters={'versionId': revision} if revision else None
- ),
- expects=(200, ),
+ sign_url,
+ expects={200, 204},
throws=exceptions.MetadataError,
)
- await resp.release()
- return S3FileMetadataHeaders(path.path, resp.headers)
-
- async def _metadata_folder(self, path):
- await self._check_region()
+ await response.release()
- params = {'prefix': path.path, 'delimiter': '/'}
- resp = await self.make_request(
- 'GET',
- functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters=params),
- params=params,
- expects=(200, ),
- throws=exceptions.MetadataError,
+ return S3FileMetadataHeaders(
+ path,
+ headers=response.headers # TODO Fix S3MetadataFileHeaders
)
- contents = await resp.read()
+ async def _metadata_folder(self, path):
+ """Get metadata about the contents of a bucket. This is either the
+ contents at the root of the bucket, or a folder has
+ been selected as a prefix by the user
+ """
+ # TODO move this to `__aenter__`
+ client = await self.client
+
+ # Needs to be a lambda; *not* partial, so offset is reevaluated
+ # each time it's called. This is done so that we don't need to
+ # create a new callable object each request we make.
- parsed = xmltodict.parse(contents, strip_whitespace=False)['ListBucketResult']
+ # The wierdness with using ** on the Params is because boto3 is
+ # rather draconian about arguments; passing None for this param results
+ # in an error that the None object is not of type string.
- contents = parsed.get('Contents', [])
- prefixes = parsed.get('CommonPrefixes', [])
+ # `marker` needs to be defined before the url signer so that it exists
+ # when the url signer is defined. It is used for pagination, to
+ # determine which page is returned by the request.
+ marker = None
+ sign_url = lambda: client.generate_presigned_url(
+ 'list_objects_v2',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Prefix': path,
+ 'Delimiter': '/',
+ **({'Marker': marker} if marker is not None else {})
+ },
+ ExpiresIn=settings.TEMP_URL_SECS
+ )
- if not contents and not prefixes and not path.is_root:
- # If contents and prefixes are empty then this "folder"
- # must exist as a key with a / at the end of the name
- # if the path is root there is no need to test if it exists
- resp = await self.make_request(
+ # S3 'truncates' responses that would list over 1000 objects. The
+ # response will contain a key, 'IsTruncated', if there were more than
+ # 1000 objects. Before the first request, we assume the list is
+ # truncated, so that at least one request will be made.
+ truncated = True
+
+ # Each request will return 0 or more 'contents' and 'common prefixes'.
+ # Contents contains keys that begin with 'prefix' and contain no
+ # delimiter characters after the characters that match the prefix.
+ # Common prefixes match any keys that do contain a delimiter after the
+ # characters that match the prefix. Each request extends the `contents`
+ # and `prefixes` arrays with the respective contents and prefixes that
+ # were returned in the request.
+ contents = []
+ prefixes = []
+
+ while truncated:
+ response = await self.make_request(
+ 'GET',
+ sign_url,
+ expects={200, 204},
+ throws=exceptions.MetadataError,
+ )
+ page = xmltodict.parse(
+ await response.read(),
+ strip_whitespace=False,
+ force_list={'CommonPrefixes', 'Contents'}
+ )
+ prefixes.extend(page['ListBucketResult'].get('CommonPrefixes', []))
+ contents.extend(page['ListBucketResult'].get('Contents', []))
+
+ marker = page['ListBucketResult'].get('NextMarker', None)
+ truncated = page['ListBucketResult'].get('IsTruncated', 'false') != 'false'
+ del response
+
+ del sign_url
+
+ items = []
+ # If there are keys that have the provided prefix...
+ if contents or prefixes:
+ # Prefixes represent 'folders'
+ items.extend([S3FolderMetadata(prefix) for prefix in prefixes])
+
+ for content in contents:
+ # Only care about items that are not the same as where the
+ # addon is mounted.
+ if content['Key'] != path:
+ items.append(
+ S3FolderKeyMetadata(content)
+ if content['Key'].endswith('/')
+ else S3FileMetadata(content)
+ )
+
+ # If contents and prefixes are empty, but this is not the root
+ # path, then this "folder" must exist as a key with a / at the
+ # end of the name.
+ elif not path == "":
+ sign_url = lambda: client.generate_presigned_url(
+ 'head_object',
+ Params={
+ 'Bucket': self.bucket_name,
+ 'Key': path,
+ },
+ ExpiresIn=settings.TEMP_URL_SECS
+ )
+ response = await self.make_request(
'HEAD',
- functools.partial(self.bucket.new_key(path.path).generate_url, settings.TEMP_URL_SECS, 'HEAD'),
- expects=(200, ),
+ sign_url,
+ expects={200, 204},
throws=exceptions.MetadataError,
)
- await resp.release()
-
- if isinstance(contents, dict):
- contents = [contents]
-
- if isinstance(prefixes, dict):
- prefixes = [prefixes]
-
- items = [
- S3FolderMetadata(item)
- for item in prefixes
- ]
-
- for content in contents:
- if content['Key'] == path.path:
- continue
-
- if content['Key'].endswith('/'):
- items.append(S3FolderKeyMetadata(content))
- else:
- items.append(S3FileMetadata(content))
+ del sign_url
+ del response
return items
- async def _check_region(self):
- """Lookup the region via bucket name, then update the host to match.
-
- Manually constructing the connection hostname allows us to use OrdinaryCallingFormat
- instead of SubdomainCallingFormat, which can break on buckets with periods in their name.
- The default region, US East (N. Virginia), is represented by the empty string and does not
- require changing the host. Ireland is represented by the string 'EU', with the host
- parameter 'eu-west-1'. All other regions return the host parameter as the region name.
-
- Region Naming: http://docs.aws.amazon.com/general/latest/gr/rande.html#s3_region
- """
- if self.region is None:
- self.region = await self._get_bucket_region()
- if self.region == 'EU':
- self.region = 'eu-west-1'
-
- if self.region != '':
- self.connection.host = self.connection.host.replace('s3.', 's3-' + self.region + '.', 1)
- self.connection._auth_handler = get_auth_handler(
- self.connection.host, boto_config, self.connection.provider, self.connection._required_auth_capability())
-
- self.metrics.add('region', self.region)
-
- async def _get_bucket_region(self):
- """Bucket names are unique across all regions.
-
- Endpoint doc:
- http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html
- """
- resp = await self.make_request(
- 'GET',
- functools.partial(self.bucket.generate_url, settings.TEMP_URL_SECS, 'GET', query_parameters={'location': ''}),
- expects=(200, ),
- throws=exceptions.MetadataError,
- )
- contents = await resp.read()
- parsed = xmltodict.parse(contents, strip_whitespace=False)
- return parsed['LocationConstraint'].get('#text', '')
diff --git a/waterbutler/providers/s3/streams.py b/waterbutler/providers/s3/streams.py
new file mode 100644
index 000000000..5d403d4bd
--- /dev/null
+++ b/waterbutler/providers/s3/streams.py
@@ -0,0 +1,29 @@
+import asyncio
+
+from botocore.response import StreamingBody
+
+from waterbutler.core.streams.base import BaseStream
+
+
+class S3ResponseBodyStream(BaseStream):
+ def __init__(self, data):
+ super().__init__()
+
+ if not isinstance(data['Body'], StreamingBody):
+ raise TypeError('Data must be a StreamingBody, found {!r}'.format(type(data['body'])))
+
+ self.content_type = data['ContentType']
+ self._size = data['ContentLength']
+ self.streaming_body = data['Body']
+
+ @property
+ def size(self):
+ return self._size
+
+ async def _read(self, n=None):
+ n = self._size if n is None else n
+
+ chunk = self.streaming_body.read(amt=n)
+ if not chunk:
+ self.feed_eof()
+ return chunk
diff --git a/waterbutler/server/api/v1/provider/__init__.py b/waterbutler/server/api/v1/provider/__init__.py
index b31c9e6d0..c85cbac6b 100644
--- a/waterbutler/server/api/v1/provider/__init__.py
+++ b/waterbutler/server/api/v1/provider/__init__.py
@@ -32,11 +32,17 @@ def list_or_value(value):
@tornado.web.stream_request_body
class ProviderHandler(core.BaseHandler, CreateMixin, MetadataMixin, MoveCopyMixin):
+ """ProviderHandler
+ Handler for provider operations. Inherits from provider handler mixins
+ Create, Metadata, and MoveCopy
+ """
PRE_VALIDATORS = {'put': 'prevalidate_put', 'post': 'prevalidate_post'}
POST_VALIDATORS = {'put': 'postvalidate_put'}
PATTERN = r'/resources/(?P(?:\w|\d)+)/providers/(?P(?:\w|\d)+)(?P/.*/?)'
async def prepare(self, *args, **kwargs):
+ """Prepare to handle request
+ """
method = self.request.method.lower()
# TODO Find a nicer way to handle this
@@ -174,6 +180,7 @@ def on_finish(self):
self._send_hook(action)
+
def _send_hook(self, action):
source = None
destination = None