diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestHeaders.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestHeaders.java index 2aa6e498..330d19ee 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestHeaders.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/RequestHeaders.java @@ -26,4 +26,9 @@ public record RequestHeaders( passthroughHeaders = ImmutableMultiMap.copyOfCaseInsensitive(passthroughHeaders); unmodifiedHeaders = ImmutableMultiMap.copyOfCaseInsensitive(unmodifiedHeaders); } + + public RequestHeaders withPassthroughHeaders(MultiMap newPassthroughHeaders) + { + return new RequestHeaders(newPassthroughHeaders, unmodifiedHeaders); + } } diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java index 5ce179c4..4d497631 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java @@ -15,6 +15,8 @@ import io.trino.aws.proxy.spi.credentials.Identity; import io.trino.aws.proxy.spi.signing.SigningMetadata; +import io.trino.aws.proxy.spi.util.ImmutableMultiMap; +import io.trino.aws.proxy.spi.util.MultiMap; import java.util.Optional; @@ -24,12 +26,24 @@ public interface S3RequestRewriter { S3RequestRewriter NOOP = (_, _, _) -> Optional.empty(); - record S3RewriteResult(String finalRequestBucket, String finalRequestKey) + record S3RewriteResult( + String finalRequestBucket, + String finalRequestKey, + Optional finalRequestHeaders, + Optional finalQueryParameters) { + public S3RewriteResult(String finalRequestBucket, String finalRequestKey) + { + this(finalRequestBucket, finalRequestKey, Optional.empty(), Optional.empty()); + } + public S3RewriteResult { requireNonNull(finalRequestBucket, "finalRequestBucket is null"); requireNonNull(finalRequestKey, "finalRequestKey is null"); + requireNonNull(finalRequestHeaders, "finalRequestHeaders is null"); + requireNonNull(finalQueryParameters, "finalQueryParameters is null"); + finalQueryParameters = finalQueryParameters.map(ImmutableMultiMap::copyOf); } } diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/util/ImmutableMultiMap.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/util/ImmutableMultiMap.java index 46b6a6cc..b64eb4b1 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/util/ImmutableMultiMap.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/util/ImmutableMultiMap.java @@ -142,6 +142,13 @@ private static ImmutableMultiMap copyOf(Set data; diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index d74dde6a..9d6f3eda 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -25,6 +25,7 @@ import io.trino.aws.proxy.spi.credentials.Identity; import io.trino.aws.proxy.spi.rest.ParsedS3Request; import io.trino.aws.proxy.spi.rest.RequestContent; +import io.trino.aws.proxy.spi.rest.RequestHeaders; import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; import io.trino.aws.proxy.spi.security.SecurityResponse; @@ -131,9 +132,15 @@ public void proxyRequest(Optional identity, SigningMetadata signingMet .map(S3RewriteResult::finalRequestKey) .map(SdkHttpUtils::urlEncodeIgnoreSlashes) .orElse(request.rawPath()); + RequestHeaders rewrittenRequestHeaders = rewriteResult + .flatMap(S3RewriteResult::finalRequestHeaders) + .orElse(request.requestHeaders()); + MultiMap rewrittenQueryParameters = rewriteResult + .flatMap(S3RewriteResult::finalQueryParameters) + .orElse(request.queryParameters()); RemoteRequestWithPresignedURIs remoteRequest = remoteS3ConnectionController.withRemoteConnection(signingMetadata, identity, request, (remoteCredential, remoteS3Facade) -> { - URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region()); + URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(rewrittenQueryParameters), targetKey, targetBucket, request.requestAuthorization().region()); Request.Builder remoteRequestBuilder = new Request.Builder() .setMethod(request.httpVerb()) @@ -147,7 +154,7 @@ public void proxyRequest(Optional identity, SigningMetadata signingMet ImmutableMultiMap.Builder remoteRequestHeadersBuilder = ImmutableMultiMap.builder(false); Instant targetRequestTimestamp = Instant.now(); - request.requestHeaders().passthroughHeaders().forEach(remoteRequestHeadersBuilder::addAll); + rewrittenRequestHeaders.passthroughHeaders().forEach(remoteRequestHeadersBuilder::addAll); remoteRequestHeadersBuilder.putOrReplaceSingle("Host", buildRemoteHost(remoteUri)); // Use now for the remote request @@ -182,7 +189,7 @@ public void proxyRequest(Optional identity, SigningMetadata signingMet Optional.empty(), remoteUri, remoteRequestHeaders, - request.queryParameters(), + rewrittenQueryParameters, request.httpVerb()).signingAuthorization().authorization(); // remoteRequestHeaders now has correct values, copy to the remote request diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java index 90637940..2509d3da 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java @@ -44,6 +44,7 @@ import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.Tag; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.presigner.S3Presigner; import software.amazon.awssdk.services.s3.presigner.model.CompleteMultipartUploadPresignRequest; @@ -65,6 +66,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Duration; +import java.util.List; import java.util.Optional; import static io.airlift.http.client.Request.Builder.prepareDelete; @@ -302,12 +304,56 @@ void uploadFileToStorage(String bucketName, String key, Path filePath) String getFileFromStorage(String bucketName, String key) throws IOException { - String dataFromProxy = TestingUtil.getFileFromStorage(internalClient, bucketName, key); - String dataFromStorage = TestingUtil.getFileFromStorage(storageClient, requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key)); + return getFileFromStorage(internalClient, bucketName, key, Optional.empty()); + } + + String getFileFromStorage(S3Client s3Client, String bucketName, String key, Optional credential) + throws IOException + { + String dataFromProxy = TestingUtil.getFileFromStorage(s3Client, bucketName, key); + String dataFromStorage = credential + .map(actualCredential -> { + try { + return TestingUtil.getFileFromStorage( + storageClient, + requestRewriteController.getTargetBucket(actualCredential.accessKey(), bucketName, key), + requestRewriteController.getTargetKey(actualCredential.accessKey(), bucketName, key)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }) + .orElseGet(() -> { + try { + return TestingUtil.getFileFromStorage( + storageClient, + requestRewriteController.getTargetBucket(bucketName, key), + requestRewriteController.getTargetKey(bucketName, key)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }); assertThat(dataFromProxy).isEqualTo(dataFromStorage); return dataFromStorage; } + List getObjectTagging(S3Client s3Client, String bucketName, String key, Optional credential) + { + List tagsFromProxy = TestingUtil.getObjectTagging(s3Client, bucketName, key); + List tagsFromStorage = credential + .map(actualCredential -> TestingUtil.getObjectTagging( + storageClient, + requestRewriteController.getTargetBucket(actualCredential.accessKey(), bucketName, key), + requestRewriteController.getTargetKey(actualCredential.accessKey(), bucketName, key))) + .orElseGet(() -> TestingUtil.getObjectTagging( + storageClient, + requestRewriteController.getTargetBucket(bucketName, key), + requestRewriteController.getTargetKey(bucketName, key))); + assertThat(tagsFromProxy).isEqualTo(tagsFromStorage); + return tagsFromStorage; + } + T executeHttpRequest(SdkHttpRequest sdkRequest, ResponseHandler responseHandler) { return executeHttpRequest(sdkRequest, Optional.empty(), responseHandler); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java index 0985737d..deb04d69 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.dataformat.xml.XmlMapper; import com.google.inject.Inject; import io.airlift.http.client.HttpClient; +import io.airlift.http.client.StatusResponseHandler; import io.airlift.http.client.StringResponseHandler; import io.airlift.http.server.testing.TestingHttpServer; import io.trino.aws.proxy.server.testing.RequestRewriteUtil; @@ -27,27 +28,40 @@ import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient; import io.trino.aws.proxy.spi.credentials.IdentityCredential; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.presigner.S3Presigner; import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest; import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedPutObjectRequest; +import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest; import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.time.Duration; +import java.util.Optional; +import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler; import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT; import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET; import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_TAGS; import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder; import static org.assertj.core.api.Assertions.assertThat; @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, RequestRewriteUtil.Filter.class}) public class TestPresignedRequestsWithRewrite extends AbstractTestPresignedRequests { + private final URI baseUri; + private final String relativePath; + @Inject public TestPresignedRequestsWithRewrite( @ForTesting HttpClient httpClient, @@ -60,6 +74,8 @@ public TestPresignedRequestsWithRewrite( TestingS3RequestRewriteController requestRewriteController) { super(httpClient, internalClient, storageClient, testingCredentials, httpServer, s3ProxyConfig, xmlMapper, requestRewriteController); + this.baseUri = httpServer.getBaseUrl(); + this.relativePath = s3ProxyConfig.getS3Path(); } @Test @@ -85,4 +101,37 @@ public void testPresignedRedirectBasedOnIdentity() assertThat(response.getBody()).isEqualTo(Files.readString(TEST_FILE)); } } + + @Test + public void testPresignedUploadRedirectAndModifyHeadersBasedOnIdentity() + throws IOException + { + String testBucket = "two"; + String testKey = "does-not-matter"; + String fileContents = Files.readString(TEST_FILE, StandardCharsets.UTF_8); + + try ( + S3Presigner presigner = buildPresigner(CREDENTIAL_TO_REDIRECT); + S3Client testS3Client = clientBuilder(baseUri, Optional.of(relativePath)) + .credentialsProvider(() -> AwsBasicCredentials.create(CREDENTIAL_TO_REDIRECT.accessKey(), CREDENTIAL_TO_REDIRECT.secretKey())) + .build()) { + PutObjectRequest request = PutObjectRequest + .builder() + .bucket(testBucket) + .key(testKey) + .contentEncoding("gzip") + .contentType("text/plain;charset=UTF-8") + .build(); + PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder() + .signatureDuration(Duration.ofDays(1)) + .putObjectRequest(request) + .build(); + PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest); + + StatusResponseHandler.StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), fileContents, createStatusResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(getFileFromStorage(testS3Client, testBucket, testKey, Optional.of(CREDENTIAL_TO_REDIRECT))).isEqualTo(fileContents); + assertThat(getObjectTagging(testS3Client, testBucket, testKey, Optional.of(CREDENTIAL_TO_REDIRECT))).isEqualTo(TEST_REWRITTEN_TAGS); + } + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java index 96b25a33..a179b159 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java @@ -23,6 +23,8 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; @@ -34,10 +36,14 @@ import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT; import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET; import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_PREFIX_QUERY_PARAM; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_TAGS; import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder; import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; +import static io.trino.aws.proxy.server.testing.TestingUtil.getObjectTagging; import static org.assertj.core.api.Assertions.assertThat; @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class}) @@ -74,8 +80,26 @@ public void testRewriteBasedOnIdentity() assertThat(uploadResponse.sdkHttpResponse().statusCode()).isEqualTo(200); assertThat(getFileFromStorage(testS3Client, testBucket, testKey)).isEqualTo(Files.readString(TEST_FILE)); + assertThat(getObjectTagging(testS3Client, testBucket, testKey)).isEqualTo(TEST_REWRITTEN_TAGS); } assertThat(getFileFromStorage(remoteClient, TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY)).isEqualTo(Files.readString(TEST_FILE)); + assertThat(getObjectTagging(remoteClient, TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY)).isEqualTo(TEST_REWRITTEN_TAGS); assertFileNotInS3(remoteClient, testBucket, testKey); } + + @Test + public void testRewriteRequestParamsBasedOnIdentity() + { + try (S3Client testS3Client = clientBuilder(baseUri, Optional.of(relativePath)) + .credentialsProvider(() -> AwsBasicCredentials.create(CREDENTIAL_TO_REDIRECT.accessKey(), CREDENTIAL_TO_REDIRECT.secretKey())) + .build()) { + ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder() + .bucket(TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET) + .prefix("original-prefix/") + .build(); + ListObjectsV2Response listObjectsV2Response = testS3Client.listObjectsV2(listObjectsV2Request); + assertThat(listObjectsV2Response.sdkHttpResponse().statusCode()).isEqualTo(200); + assertThat(listObjectsV2Response.prefix()).isEqualTo(TEST_REWRITTEN_PREFIX_QUERY_PARAM); + } + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java index 0376c0a5..3e23f9ee 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java @@ -13,18 +13,27 @@ */ package io.trino.aws.proxy.server.testing; +import com.google.common.collect.ImmutableList; import com.google.inject.Inject; import com.google.inject.Scopes; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.BuilderFilter; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.IdentityCredential; +import io.trino.aws.proxy.spi.rest.RequestHeaders; +import io.trino.aws.proxy.spi.util.ImmutableMultiMap; +import io.trino.aws.proxy.spi.util.MultiMap; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Tag; +import software.amazon.awssdk.utils.http.SdkHttpUtils; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.s3RequestRewriterModule; @@ -34,6 +43,12 @@ public final class RequestRewriteUtil public static final String TEST_CREDENTIAL_REDIRECT_BUCKET = "redirected-bucket-for-credential"; public static final String TEST_CREDENTIAL_REDIRECT_KEY = "redirected-key-for-credential"; public static final Credential CREDENTIAL_TO_REDIRECT = new Credential("credential-to-redirect", UUID.randomUUID().toString()); + public static final String TAGGING_HEADER = "x-amz-tagging"; + public static final List TEST_REWRITTEN_TAGS = ImmutableList.of( + Tag.builder().key("rewritten-tag-1").value("rewritten-value-1").build(), + Tag.builder().key("rewritten-tag-2").value("rewritten-value-2").build()); + public static final String TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET = "rewrite-prefix-query-param"; + public static final String TEST_REWRITTEN_PREFIX_QUERY_PARAM = "rewritten-prefix/"; private RequestRewriteUtil() {} @@ -76,15 +91,47 @@ public int getCallCount() return callCount.get(); } + /** + * If the credential is CREDENTIAL_TO_REDIRECT: + *
    + *
  • Redirect to bucket TEST_CREDENTIAL_REDIRECT_BUCKET
  • + *
  • Redirect key (if any) to TEST_CREDENTIAL_REDIRECT_KEY
  • + *
  • Set "x-amz-tagging" header such that uploaded files (if applicable) have tags TEST_REWRITTEN_TAGS
  • + *
  • Only if the bucket in the incoming request is TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET, set "prefix" query parameter to value TEST_REWRITTEN_PREFIX_QUERY_PARAM
  • + *
  • Note that this means we should only run tests checking List operations (e.g., ListObjectsV2) against this bucket, as the query param doesn't make sense for other operations
  • + *
+ * Otherwise: + *
    + *
  • Redirect all buckets to prepend "redirected-" in front
  • + *
  • Redirect all non-empty keys to prepend "redirected-" in front
  • + *
  • Empty keys are not changed
  • + *
+ */ @Override - public Optional testRewrite(String accessKey, String bucketName, String keyName) + public Optional testRewrite(String accessKey, String bucketName, String keyName, Optional requestHeaders) { callCount.incrementAndGet(); boolean redirectForTestCredential = accessKey.equalsIgnoreCase(CREDENTIAL_TO_REDIRECT.accessKey()); if (redirectForTestCredential) { - return Optional.of(new S3RewriteResult(TEST_CREDENTIAL_REDIRECT_BUCKET, keyName.isEmpty() ? "" : TEST_CREDENTIAL_REDIRECT_KEY)); + Optional rewrittenRequestHeaders = requestHeaders + .flatMap(headers -> Optional.of(headers.withPassthroughHeaders( + ImmutableMultiMap + .copyOfCaseInsensitive(headers.passthroughHeaders()) + .toBuilder() + .putOrReplaceSingle(TAGGING_HEADER, getTaggingStringFromTags(TEST_REWRITTEN_TAGS)) + .build()))); + + Optional rewrittenQueryParams = Optional.empty(); + if (bucketName.equals(TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET)) { + rewrittenQueryParams = Optional.of(ImmutableMultiMap + .builder(false) + .putOrReplaceSingle("prefix", TEST_REWRITTEN_PREFIX_QUERY_PARAM) + .build()); + } + + return Optional.of(new S3RewriteResult(TEST_CREDENTIAL_REDIRECT_BUCKET, keyName.isEmpty() ? "" : TEST_CREDENTIAL_REDIRECT_KEY, rewrittenRequestHeaders, rewrittenQueryParams)); } - return Optional.of(new S3RewriteResult(getTargetName(bucketName), getTargetName(keyName))); + return Optional.of(new S3RewriteResult(getTargetName(bucketName), getTargetName(keyName), requestHeaders, Optional.empty())); } } @@ -92,4 +139,16 @@ private static String getTargetName(String name) { return name.isEmpty() ? "" : "redirected-%s".formatted(name); } + + public static String getTaggingStringFromTags(List tags) + { + Map> queryParams = tags + .stream() + .collect(Collectors.toMap( + Tag::key, + tag -> ImmutableList.of(tag.value()), + (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()))); + + return SdkHttpUtils.encodeAndFlattenQueryParameters(queryParams).orElse(""); + } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java index 5b22a7c1..1a8fe256 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java @@ -18,6 +18,8 @@ import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; +import java.util.Optional; + import static java.util.Objects.requireNonNull; public class TestingS3RequestRewriteController @@ -34,7 +36,7 @@ public TestingS3RequestRewriteController(TestingS3RequestRewriter rewriter, @For private S3RewriteResult rewriteOrNoop(String accessKey, String bucket, String key) { - return s3RequestRewriter.testRewrite(accessKey, bucket, key).orElseGet(() -> new S3RewriteResult(bucket, key)); + return s3RequestRewriter.testRewrite(accessKey, bucket, key, Optional.empty()).orElseGet(() -> new S3RewriteResult(bucket, key)); } public String getTargetBucket(String accessKey, String bucket, String key) diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java index 74e365b9..c3bec733 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java @@ -15,6 +15,7 @@ import io.trino.aws.proxy.spi.credentials.Identity; import io.trino.aws.proxy.spi.rest.ParsedS3Request; +import io.trino.aws.proxy.spi.rest.RequestHeaders; import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.signing.SigningMetadata; @@ -24,13 +25,13 @@ public interface TestingS3RequestRewriter extends S3RequestRewriter { - TestingS3RequestRewriter NOOP = (_, _, _) -> Optional.empty(); + TestingS3RequestRewriter NOOP = (_, _, _, _) -> Optional.empty(); - Optional testRewrite(String accessKey, String bucketName, String keyName); + Optional testRewrite(String accessKey, String bucketName, String keyName, Optional requestHeaders); @Override default Optional rewrite(Optional identity, SigningMetadata signingMetadata, ParsedS3Request request) { - return testRewrite(signingMetadata.credential().accessKey(), request.bucketName(), request.keyInBucket()); + return testRewrite(signingMetadata.credential().accessKey(), request.bucketName(), request.keyInBucket(), Optional.of(request.requestHeaders())); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java index aa99f69d..c770d17e 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java @@ -30,11 +30,13 @@ import software.amazon.awssdk.services.s3.model.Delete; import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectTaggingRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectIdentifier; import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Object; +import software.amazon.awssdk.services.s3.model.Tag; import java.io.ByteArrayOutputStream; import java.io.File; @@ -112,6 +114,13 @@ public static String getFileFromStorage(S3Client storageClient, String bucketNam return readContents.toString(); } + public static List getObjectTagging(S3Client storageClient, String bucketName, String key) + { + GetObjectTaggingRequest getObjectTaggingRequest = GetObjectTaggingRequest + .builder().bucket(bucketName).key(key).build(); + return storageClient.getObjectTagging(getObjectTaggingRequest).tagSet(); + } + public static HeadObjectResponse headObjectInStorage(S3Client storageClient, String bucketName, String key) { return storageClient.headObject(HeadObjectRequest.builder().bucket(bucketName).key(key).build());