diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index d74dde6a..1dc32141 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -112,28 +112,44 @@ public void shutDown() } } - public void proxyRequest(Optional identity, SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse, + public void proxyRequest(Optional identity, SigningMetadata signingMetadata, ParsedS3Request originalRequest, AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession) { - SecurityResponse securityResponse = s3SecurityController.apply(request, identity); + SecurityResponse securityResponse = s3SecurityController.apply(originalRequest, identity); if (securityResponse instanceof Failure(var error)) { - log.debug("SecurityController check failed. AccessKey: %s, Request: %s, SecurityResponse: %s", signingMetadata.credential().accessKey(), request, securityResponse); + log.debug("SecurityController check failed. AccessKey: %s, Request: %s, SecurityResponse: %s", signingMetadata.credential().accessKey(), originalRequest, securityResponse); requestLoggingSession.logError("request.security.fail.credentials", signingMetadata.credential()); - requestLoggingSession.logError("request.security.fail.request", request); + requestLoggingSession.logError("request.security.fail.request", originalRequest); requestLoggingSession.logError("request.security.fail.error", error); throw new WebApplicationException(Response.Status.UNAUTHORIZED); } - Optional rewriteResult = s3RequestRewriter.rewrite(identity, signingMetadata, request); - String targetBucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(request.bucketName()); - String targetKey = rewriteResult + Optional rewriteResult = s3RequestRewriter.rewrite(identity, signingMetadata, originalRequest); + String bucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(originalRequest.bucketName()); + String key = rewriteResult + .map(S3RewriteResult::finalRequestKey) + .orElse(originalRequest.keyInBucket()); + String path = rewriteResult .map(S3RewriteResult::finalRequestKey) .map(SdkHttpUtils::urlEncodeIgnoreSlashes) - .orElse(request.rawPath()); + .orElse(originalRequest.rawPath()); + + ParsedS3Request request = new ParsedS3Request( + originalRequest.requestId(), + originalRequest.requestAuthorization(), + originalRequest.requestDate(), + bucket, + key, + originalRequest.requestHeaders(), + originalRequest.queryParameters(), + originalRequest.httpVerb(), + path, + originalRequest.rawQuery(), + originalRequest.requestContent()); RemoteRequestWithPresignedURIs remoteRequest = remoteS3ConnectionController.withRemoteConnection(signingMetadata, identity, request, (remoteCredential, remoteS3Facade) -> { - URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region()); + URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), request.rawPath(), request.bucketName(), request.requestAuthorization().region()); Request.Builder remoteRequestBuilder = new Request.Builder() .setMethod(request.httpVerb()) @@ -161,7 +177,8 @@ public void proxyRequest(Optional identity, SigningMetadata signingMet Map presignedUrls; if (generatePresignedUrlsOnHead && request.httpVerb().equalsIgnoreCase("HEAD")) { - presignedUrls = s3PresignController.buildPresignedRemoteUrls(identity, remoteSigningMetadata, request, targetRequestTimestamp, remoteUri); + // Presigned URLs are generated for the ORIGINAL key and bucket, not the rewritten ones + presignedUrls = s3PresignController.buildPresignedRemoteUrls(identity, remoteSigningMetadata, originalRequest, targetRequestTimestamp, remoteUri); } else { presignedUrls = ImmutableMap.of(); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/remote/provider/TestRemoteS3ConnectionProviderWithRewriter.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/remote/provider/TestRemoteS3ConnectionProviderWithRewriter.java new file mode 100644 index 00000000..c588ca68 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/remote/provider/TestRemoteS3ConnectionProviderWithRewriter.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.remote.provider; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.RequestRewriteUtil; +import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.BuilderFilter; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; +import io.trino.aws.proxy.spi.credentials.Identity; +import io.trino.aws.proxy.spi.remote.RemoteS3Connection; +import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection; +import io.trino.aws.proxy.spi.remote.RemoteS3ConnectionProvider; +import io.trino.aws.proxy.spi.rest.ParsedS3Request; +import io.trino.aws.proxy.spi.signing.SigningMetadata; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.google.inject.Scopes.SINGLETON; +import static io.trino.aws.proxy.server.testing.TestingUtil.LOREM_IPSUM; +import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_REMOTE_CREDENTIAL; +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.remoteS3ConnectionProviderModule; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class, TestRemoteS3ConnectionProviderWithRewriter.Filter.class}) +public class TestRemoteS3ConnectionProviderWithRewriter +{ + public static class Filter + implements BuilderFilter + { + @Override + public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) + { + return builder + .withoutTestingRemoteS3ConnectionProvider() + .addModule(remoteS3ConnectionProviderModule("rewrite-test", DelegateRemoteS3ConnectionProvider.class, + binder -> binder.bind(DelegateRemoteS3ConnectionProvider.class).in(SINGLETON))) + .withProperty("remote-s3-connection-provider.type", "rewrite-test"); + } + } + + public static class DelegateRemoteS3ConnectionProvider + implements RemoteS3ConnectionProvider + { + private RemoteS3ConnectionProvider delegate; + + private final List callArgs = new ArrayList<>(); + + @Override + public Optional remoteConnection(SigningMetadata signingMetadata, Optional identity, ParsedS3Request request) + { + callArgs.add(new RemoteS3ConnectionProviderArgs(signingMetadata, identity, request)); + return delegate.remoteConnection(signingMetadata, identity, request); + } + + public void setDelegate(RemoteS3ConnectionProvider delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + public List getCallArgs() + { + return callArgs; + } + + public void reset() + { + callArgs.clear(); + delegate = null; + } + } + + public record RemoteS3ConnectionProviderArgs(SigningMetadata signingMetadata, Optional identity, ParsedS3Request request) {} + + private final S3Client s3Client; + private final S3Client storageClient; + private final DelegateRemoteS3ConnectionProvider delegateRemoteS3ConnectionProvider; + private final List buckets; + + @Inject + public TestRemoteS3ConnectionProviderWithRewriter( + S3Client s3Client, + @ForS3Container S3Client storageClient, + DelegateRemoteS3ConnectionProvider delegateRemoteS3ConnectionProvider, + @ForS3Container List buckets) + { + this.s3Client = requireNonNull(s3Client, "s3Client is null"); + this.storageClient = requireNonNull(storageClient, "storageClient is null"); + this.delegateRemoteS3ConnectionProvider = requireNonNull(delegateRemoteS3ConnectionProvider, "delegateRemoteS3ConnectionProvider is null"); + this.buckets = ImmutableList.copyOf(buckets); + } + + @AfterEach + public void cleanup() + { + delegateRemoteS3ConnectionProvider.reset(); + } + + @Test + public void testRemoteS3ConnectionRetrievedWithRewrittenRequest() + throws IOException + { + String bucket = buckets.getFirst(); + + storageClient.putObject(PutObjectRequest.builder().bucket("redirected-" + bucket).key("redirected-test_key_1337").build(), RequestBody.fromString(LOREM_IPSUM)); + + delegateRemoteS3ConnectionProvider.setDelegate((_, _, _) -> Optional.of(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL))); + + ResponseInputStream resp = s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key("test_key_1337").build()); + assertThat(resp.readAllBytes()).asString().isEqualTo(LOREM_IPSUM); + + assertThat(delegateRemoteS3ConnectionProvider.getCallArgs()).hasSize(1).first().satisfies(args -> { + assertThat(args.request().bucketName()).isEqualTo("redirected-" + bucket); + assertThat(args.request().keyInBucket()).isEqualTo("redirected-test_key_1337"); + }); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingCredentialsRolesProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingCredentialsRolesProvider.java index 960ec7ad..57582380 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingCredentialsRolesProvider.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingCredentialsRolesProvider.java @@ -13,6 +13,7 @@ */ package io.trino.aws.proxy.server.testing; +import com.google.inject.Inject; import io.trino.aws.proxy.spi.credentials.AssumedRoleProvider; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.CredentialsProvider; @@ -20,6 +21,7 @@ import io.trino.aws.proxy.spi.credentials.Identity; import io.trino.aws.proxy.spi.credentials.IdentityCredential; import io.trino.aws.proxy.spi.remote.RemoteS3Connection; +import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection; import io.trino.aws.proxy.spi.remote.RemoteS3ConnectionProvider; import io.trino.aws.proxy.spi.rest.ParsedS3Request; import io.trino.aws.proxy.spi.signing.SigningMetadata; @@ -33,6 +35,8 @@ import java.util.concurrent.atomic.AtomicInteger; import static com.google.common.base.Preconditions.checkState; +import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_IDENTITY_CREDENTIAL; +import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_REMOTE_CREDENTIAL; import static java.util.Objects.requireNonNull; /** @@ -60,6 +64,13 @@ private record Session(Credential sessionCredential, String originalEmulatedAcce } } + @Inject + public TestingCredentialsRolesProvider() + { + addCredentials(TESTING_IDENTITY_CREDENTIAL); + setDefaultRemoteConnection(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL)); + } + @Override public Optional remoteConnection(SigningMetadata signingMetadata, Optional identity, ParsedS3Request request) { diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java index 7904f5b9..658a4bb1 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.inject.Inject; import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; @@ -41,7 +40,6 @@ import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.IdentityCredential; -import io.trino.aws.proxy.spi.remote.RemoteS3Connection.StaticRemoteS3Connection; import java.io.Closeable; import java.util.Collection; @@ -98,6 +96,7 @@ public static class Builder private boolean v4PySparkContainerAdded; private boolean opaContainerAdded; private boolean addTestingCredentialsRoleProviders = true; + private boolean addTestingRemoteS3CredentialsProvider = true; public Builder addModule(Module module) { @@ -200,6 +199,12 @@ public Builder withoutTestingCredentialsRoleProviders() return this; } + public Builder withoutTestingRemoteS3ConnectionProvider() + { + addTestingRemoteS3CredentialsProvider = false; + return this; + } + public Builder withOpaContainer() { if (opaContainerAdded) { @@ -214,16 +219,15 @@ public Builder withOpaContainer() public TestingTrinoAwsProxyServer buildAndStart() { if (addTestingCredentialsRoleProviders) { - if (mockS3ContainerAdded) { - modules.add(binder -> binder.bind(TestingCredentialsInitializer.class).asEagerSingleton()); - } - addModule(credentialsProviderModule("testing", TestingCredentialsRolesProvider.class, (binder) -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON))); withProperty("credentials-provider.type", "testing"); addModule(assumedRoleProviderModule("testing", TestingCredentialsRolesProvider.class, (binder) -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON))); withProperty("assumed-role-provider.type", "testing"); + } + + if (addTestingRemoteS3CredentialsProvider) { addModule(remoteS3ConnectionProviderModule("testing", TestingCredentialsRolesProvider.class, - binder -> binder.bind(TestingCredentialsInitializer.class).in(Scopes.SINGLETON))); + binder -> binder.bind(TestingCredentialsRolesProvider.class).in(Scopes.SINGLETON))); withProperty("remote-s3-connection-provider.type", "testing"); } @@ -231,16 +235,6 @@ public TestingTrinoAwsProxyServer buildAndStart() } } - static class TestingCredentialsInitializer - { - @Inject - TestingCredentialsInitializer(TestingCredentialsRolesProvider credentialsController) - { - credentialsController.addCredentials(TESTING_IDENTITY_CREDENTIAL); - credentialsController.setDefaultRemoteConnection(new StaticRemoteS3Connection(TESTING_REMOTE_CREDENTIAL)); - } - } - private static TestingTrinoAwsProxyServer start(Collection extraModules, Map properties) { ImmutableList.Builder modules = ImmutableList.builder()