Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ public record RequestHeaders(
passthroughHeaders = ImmutableMultiMap.copyOfCaseInsensitive(passthroughHeaders);
unmodifiedHeaders = ImmutableMultiMap.copyOfCaseInsensitive(unmodifiedHeaders);
}

public RequestHeaders withPassthroughHeaders(MultiMap newPassthroughHeaders)
{
return new RequestHeaders(newPassthroughHeaders, unmodifiedHeaders);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import io.trino.aws.proxy.spi.credentials.Identity;
import io.trino.aws.proxy.spi.signing.SigningMetadata;
import io.trino.aws.proxy.spi.util.ImmutableMultiMap;
import io.trino.aws.proxy.spi.util.MultiMap;

import java.util.Optional;

Expand All @@ -24,12 +26,24 @@ public interface S3RequestRewriter
{
S3RequestRewriter NOOP = (_, _, _) -> Optional.empty();

record S3RewriteResult(String finalRequestBucket, String finalRequestKey)
record S3RewriteResult(
String finalRequestBucket,
String finalRequestKey,
Optional<RequestHeaders> finalRequestHeaders,
Optional<MultiMap> finalQueryParameters)
{
public S3RewriteResult(String finalRequestBucket, String finalRequestKey)
{
this(finalRequestBucket, finalRequestKey, Optional.empty(), Optional.empty());
}

public S3RewriteResult
{
requireNonNull(finalRequestBucket, "finalRequestBucket is null");
requireNonNull(finalRequestKey, "finalRequestKey is null");
requireNonNull(finalRequestHeaders, "finalRequestHeaders is null");
requireNonNull(finalQueryParameters, "finalQueryParameters is null");
finalQueryParameters = finalQueryParameters.map(ImmutableMultiMap::copyOf);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - we don't need the requireNonNull if we are invoking a method on the object already

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ private static ImmutableMultiMap copyOf(Set<? extends Map.Entry<String, ? extend
return builder.build();
}

public Builder toBuilder()
{
Builder builder = builder(isCaseSensitiveKeys());
forEach(builder::addAll);
return builder;
}

public static class Builder
{
private final Multimap<String, String> data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.aws.proxy.spi.credentials.Identity;
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
import io.trino.aws.proxy.spi.rest.RequestContent;
import io.trino.aws.proxy.spi.rest.RequestHeaders;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult;
import io.trino.aws.proxy.spi.security.SecurityResponse;
Expand Down Expand Up @@ -131,9 +132,15 @@ public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMet
.map(S3RewriteResult::finalRequestKey)
.map(SdkHttpUtils::urlEncodeIgnoreSlashes)
.orElse(request.rawPath());
RequestHeaders rewrittenRequestHeaders = rewriteResult
.flatMap(S3RewriteResult::finalRequestHeaders)
.orElse(request.requestHeaders());
MultiMap rewrittenQueryParameters = rewriteResult
.flatMap(S3RewriteResult::finalQueryParameters)
.orElse(request.queryParameters());

RemoteRequestWithPresignedURIs remoteRequest = remoteS3ConnectionController.withRemoteConnection(signingMetadata, identity, request, (remoteCredential, remoteS3Facade) -> {
URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region());
URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(rewrittenQueryParameters), targetKey, targetBucket, request.requestAuthorization().region());

Request.Builder remoteRequestBuilder = new Request.Builder()
.setMethod(request.httpVerb())
Expand All @@ -147,7 +154,7 @@ public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMet

ImmutableMultiMap.Builder remoteRequestHeadersBuilder = ImmutableMultiMap.builder(false);
Instant targetRequestTimestamp = Instant.now();
request.requestHeaders().passthroughHeaders().forEach(remoteRequestHeadersBuilder::addAll);
rewrittenRequestHeaders.passthroughHeaders().forEach(remoteRequestHeadersBuilder::addAll);
remoteRequestHeadersBuilder.putOrReplaceSingle("Host", buildRemoteHost(remoteUri));

// Use now for the remote request
Expand Down Expand Up @@ -182,7 +189,7 @@ public void proxyRequest(Optional<Identity> identity, SigningMetadata signingMet
Optional.empty(),
remoteUri,
remoteRequestHeaders,
request.queryParameters(),
rewrittenQueryParameters,
request.httpVerb()).signingAuthorization().authorization();

// remoteRequestHeaders now has correct values, copy to the remote request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
import software.amazon.awssdk.services.s3.model.Tag;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import software.amazon.awssdk.services.s3.presigner.model.CompleteMultipartUploadPresignRequest;
Expand All @@ -65,6 +66,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.Optional;

import static io.airlift.http.client.Request.Builder.prepareDelete;
Expand Down Expand Up @@ -302,12 +304,56 @@ void uploadFileToStorage(String bucketName, String key, Path filePath)
String getFileFromStorage(String bucketName, String key)
throws IOException
{
String dataFromProxy = TestingUtil.getFileFromStorage(internalClient, bucketName, key);
String dataFromStorage = TestingUtil.getFileFromStorage(storageClient, requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key));
return getFileFromStorage(internalClient, bucketName, key, Optional.empty());
}

String getFileFromStorage(S3Client s3Client, String bucketName, String key, Optional<Credential> credential)
throws IOException
{
String dataFromProxy = TestingUtil.getFileFromStorage(s3Client, bucketName, key);
String dataFromStorage = credential
.map(actualCredential -> {
try {
return TestingUtil.getFileFromStorage(
storageClient,
requestRewriteController.getTargetBucket(actualCredential.accessKey(), bucketName, key),
requestRewriteController.getTargetKey(actualCredential.accessKey(), bucketName, key));
}
catch (IOException e) {
throw new RuntimeException(e);
}
})
.orElseGet(() -> {
try {
return TestingUtil.getFileFromStorage(
storageClient,
requestRewriteController.getTargetBucket(bucketName, key),
requestRewriteController.getTargetKey(bucketName, key));
}
catch (IOException e) {
throw new RuntimeException(e);
}
});
assertThat(dataFromProxy).isEqualTo(dataFromStorage);
return dataFromStorage;
}

List<Tag> getObjectTagging(S3Client s3Client, String bucketName, String key, Optional<Credential> credential)
{
List<Tag> tagsFromProxy = TestingUtil.getObjectTagging(s3Client, bucketName, key);
List<Tag> tagsFromStorage = credential
.map(actualCredential -> TestingUtil.getObjectTagging(
storageClient,
requestRewriteController.getTargetBucket(actualCredential.accessKey(), bucketName, key),
requestRewriteController.getTargetKey(actualCredential.accessKey(), bucketName, key)))
.orElseGet(() -> TestingUtil.getObjectTagging(
storageClient,
requestRewriteController.getTargetBucket(bucketName, key),
requestRewriteController.getTargetKey(bucketName, key)));
assertThat(tagsFromProxy).isEqualTo(tagsFromStorage);
return tagsFromStorage;
}

<T> T executeHttpRequest(SdkHttpRequest sdkRequest, ResponseHandler<T, RuntimeException> responseHandler)
{
return executeHttpRequest(sdkRequest, Optional.empty(), responseHandler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import com.google.inject.Inject;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.StatusResponseHandler;
import io.airlift.http.client.StringResponseHandler;
import io.airlift.http.server.testing.TestingHttpServer;
import io.trino.aws.proxy.server.testing.RequestRewriteUtil;
Expand All @@ -27,27 +28,40 @@
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient;
import io.trino.aws.proxy.spi.credentials.IdentityCredential;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest;
import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest;
import software.amazon.awssdk.services.s3.presigner.model.PresignedPutObjectRequest;
import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.Duration;
import java.util.Optional;

import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler;
import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_TAGS;
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder;
import static org.assertj.core.api.Assertions.assertThat;

@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, RequestRewriteUtil.Filter.class})
public class TestPresignedRequestsWithRewrite
extends AbstractTestPresignedRequests
{
private final URI baseUri;
private final String relativePath;

@Inject
public TestPresignedRequestsWithRewrite(
@ForTesting HttpClient httpClient,
Expand All @@ -60,6 +74,8 @@ public TestPresignedRequestsWithRewrite(
TestingS3RequestRewriteController requestRewriteController)
{
super(httpClient, internalClient, storageClient, testingCredentials, httpServer, s3ProxyConfig, xmlMapper, requestRewriteController);
this.baseUri = httpServer.getBaseUrl();
this.relativePath = s3ProxyConfig.getS3Path();
}

@Test
Expand All @@ -85,4 +101,37 @@ public void testPresignedRedirectBasedOnIdentity()
assertThat(response.getBody()).isEqualTo(Files.readString(TEST_FILE));
}
}

@Test
public void testPresignedUploadRedirectAndModifyHeadersBasedOnIdentity()
throws IOException
{
String testBucket = "two";
String testKey = "does-not-matter";
String fileContents = Files.readString(TEST_FILE, StandardCharsets.UTF_8);

try (
S3Presigner presigner = buildPresigner(CREDENTIAL_TO_REDIRECT);
S3Client testS3Client = clientBuilder(baseUri, Optional.of(relativePath))
.credentialsProvider(() -> AwsBasicCredentials.create(CREDENTIAL_TO_REDIRECT.accessKey(), CREDENTIAL_TO_REDIRECT.secretKey()))
.build()) {
PutObjectRequest request = PutObjectRequest
.builder()
.bucket(testBucket)
.key(testKey)
.contentEncoding("gzip")
.contentType("text/plain;charset=UTF-8")
.build();
PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder()
.signatureDuration(Duration.ofDays(1))
.putObjectRequest(request)
.build();
PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest);

StatusResponseHandler.StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), fileContents, createStatusResponseHandler());
assertThat(response.getStatusCode()).isEqualTo(200);
assertThat(getFileFromStorage(testS3Client, testBucket, testKey, Optional.of(CREDENTIAL_TO_REDIRECT))).isEqualTo(fileContents);
assertThat(getObjectTagging(testS3Client, testBucket, testKey, Optional.of(CREDENTIAL_TO_REDIRECT))).isEqualTo(TEST_REWRITTEN_TAGS);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

Expand All @@ -34,10 +36,14 @@
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_PREFIX_QUERY_PARAM;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_REWRITTEN_TAGS;
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3;
import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder;
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
import static io.trino.aws.proxy.server.testing.TestingUtil.getObjectTagging;
import static org.assertj.core.api.Assertions.assertThat;

@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class})
Expand Down Expand Up @@ -74,8 +80,26 @@ public void testRewriteBasedOnIdentity()
assertThat(uploadResponse.sdkHttpResponse().statusCode()).isEqualTo(200);

assertThat(getFileFromStorage(testS3Client, testBucket, testKey)).isEqualTo(Files.readString(TEST_FILE));
assertThat(getObjectTagging(testS3Client, testBucket, testKey)).isEqualTo(TEST_REWRITTEN_TAGS);
}
assertThat(getFileFromStorage(remoteClient, TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY)).isEqualTo(Files.readString(TEST_FILE));
assertThat(getObjectTagging(remoteClient, TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY)).isEqualTo(TEST_REWRITTEN_TAGS);
assertFileNotInS3(remoteClient, testBucket, testKey);
}

@Test
public void testRewriteRequestParamsBasedOnIdentity()
{
try (S3Client testS3Client = clientBuilder(baseUri, Optional.of(relativePath))
.credentialsProvider(() -> AwsBasicCredentials.create(CREDENTIAL_TO_REDIRECT.accessKey(), CREDENTIAL_TO_REDIRECT.secretKey()))
.build()) {
ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder()
.bucket(TEST_REWRITE_PREFIX_QUERY_PARAM_BUCKET)
.prefix("original-prefix/")
.build();
ListObjectsV2Response listObjectsV2Response = testS3Client.listObjectsV2(listObjectsV2Request);
assertThat(listObjectsV2Response.sdkHttpResponse().statusCode()).isEqualTo(200);
assertThat(listObjectsV2Response.prefix()).isEqualTo(TEST_REWRITTEN_PREFIX_QUERY_PARAM);
}
}
}
Loading