From dcc85d8df4da9022e7f468a46522d15db09b2c2e Mon Sep 17 00:00:00 2001 From: Andrey Litvitski Date: Wed, 22 Oct 2025 18:03:01 +0300 Subject: [PATCH] Add generic request validator for refresh token Signed-off-by: Andrey Litvitski --- ...uth2RefreshTokenAuthenticationContext.java | 111 +++++++++++++++++ ...th2RefreshTokenAuthenticationProvider.java | 71 ++++------- ...h2RefreshTokenAuthenticationValidator.java | 114 ++++++++++++++++++ 3 files changed, 246 insertions(+), 50 deletions(-) create mode 100644 oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java create mode 100644 oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java new file mode 100644 index 00000000000..93299fc418c --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationContext.java @@ -0,0 +1,111 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.jspecify.annotations.Nullable; + +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.util.Assert; + +/** + * An {@link OAuth2AuthenticationContext} that holds an + * {@link OAuth2RefreshTokenAuthenticationToken} and additional information and is used + * when validating the OAuth 2.0 Refresh Token Grant Request. + *

+ * This context provides access to the current {@link OAuth2Authorization}, + * {@link OAuth2ClientAuthenticationToken}, and optionally a DPoP {@link Jwt} proof. + *

+ * + * @author Andrey Litvitski + * @since 7.0.0 + * @see OAuth2AuthenticationContext + * @see OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OAuth2RefreshTokenAuthenticationContext implements OAuth2AuthenticationContext { + + private final Map context; + + private OAuth2RefreshTokenAuthenticationContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + public OAuth2Authorization getAuthorization() { + return get(OAuth2Authorization.class); + } + + public OAuth2ClientAuthenticationToken getClientPrincipal() { + return get(OAuth2ClientAuthenticationToken.class); + } + + @Nullable public Jwt getDPoPProof() { + return get(Jwt.class); + } + + public static Builder with(OAuth2RefreshTokenAuthenticationToken authentication) { + return new Builder(authentication); + } + + public static final class Builder extends AbstractBuilder { + + private Builder(OAuth2RefreshTokenAuthenticationToken authentication) { + super(authentication); + } + + public Builder authorization(OAuth2Authorization authorization) { + return put(OAuth2Authorization.class, authorization); + } + + public Builder clientPrincipal(OAuth2ClientAuthenticationToken clientPrincipal) { + return put(OAuth2ClientAuthenticationToken.class, clientPrincipal); + } + + public Builder dPoPProof(@Nullable Jwt dPoPProof) { + if (dPoPProof != null) { + put(Jwt.class, dPoPProof); + } + return this; + } + + @Override + public OAuth2RefreshTokenAuthenticationContext build() { + Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null"); + Assert.notNull(get(OAuth2ClientAuthenticationToken.class), "clientPrincipal cannot be null"); + return new OAuth2RefreshTokenAuthenticationContext(getContext()); + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index c6c50a32155..1f66c60c342 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -21,8 +21,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; -import com.nimbusds.jose.jwk.JWK; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -31,8 +31,6 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClaimAccessor; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -52,7 +50,6 @@ import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; /** * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant. @@ -60,6 +57,7 @@ * @author Alexey Nesterov * @author Joe Grandja * @author Anoop Garlapati + * @author Andrey Litvitski * @since 7.0 * @see OAuth2RefreshTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken @@ -84,6 +82,8 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic private final OAuth2TokenGenerator tokenGenerator; + private Consumer authenticationValidator = new OAuth2RefreshTokenAuthenticationValidator(); + /** * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided * parameters. @@ -164,13 +164,14 @@ public Authentication authenticate(Authentication authentication) throws Authent // Verify the DPoP Proof (if available) Jwt dPoPProof = DPoPProofVerifier.verifyIfAvailable(refreshTokenAuthentication); - if (dPoPProof != null - && clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) { - // For public clients, verify the DPoP Proof public key is same as (current) - // access token public key binding - Map accessTokenClaims = authorization.getAccessToken().getClaims(); - verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims); - } + OAuth2RefreshTokenAuthenticationContext context = OAuth2RefreshTokenAuthenticationContext + .with(refreshTokenAuthentication) + .authorization(authorization) + .clientPrincipal(clientPrincipal) + .dPoPProof(dPoPProof) + .build(); + + this.authenticationValidator.accept(context); if (this.logger.isTraceEnabled()) { this.logger.trace("Validated token request parameters"); @@ -292,45 +293,15 @@ public boolean supports(Class authentication) { return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication); } - private static void verifyDPoPProofPublicKey(Jwt dPoPProof, ClaimAccessor accessTokenClaims) { - JWK jwk = null; - @SuppressWarnings("unchecked") - Map jwkJson = (Map) dPoPProof.getHeaders().get("jwk"); - try { - jwk = JWK.parse(jwkJson); - } - catch (Exception ignored) { - } - if (jwk == null) { - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, - "jwk header is missing or invalid.", null); - throw new OAuth2AuthenticationException(error); - } - - String jwkThumbprint; - try { - jwkThumbprint = jwk.computeThumbprint().toString(); - } - catch (Exception ex) { - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, - "Failed to compute SHA-256 Thumbprint for jwk.", null); - throw new OAuth2AuthenticationException(error); - } - - String jwkThumbprintClaim = null; - Map confirmationMethodClaim = accessTokenClaims.getClaimAsMap("cnf"); - if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) { - jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt"); - } - if (jwkThumbprintClaim == null) { - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jkt claim is missing.", null); - throw new OAuth2AuthenticationException(error); - } - - if (!jwkThumbprint.equals(jwkThumbprintClaim)) { - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jwk header is invalid.", null); - throw new OAuth2AuthenticationException(error); - } + /** + * Sets the {@code Consumer} responsible for validating the OAuth 2.0 Refresh Token + * Grant Request using the provided {@link OAuth2RefreshTokenAuthenticationContext}. + *

+ * The default validator performs DPoP proof verification if present. + */ + public void setAuthenticationValidator(Consumer authenticationValidator) { + Assert.notNull(authenticationValidator, "authenticationValidator cannot be null"); + this.authenticationValidator = authenticationValidator; } } diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java new file mode 100644 index 00000000000..2be598da501 --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationValidator.java @@ -0,0 +1,114 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Map; +import java.util.function.Consumer; + +import com.nimbusds.jose.jwk.JWK; + +import org.springframework.security.oauth2.core.ClaimAccessor; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.CollectionUtils; + +/** + * A {@code Consumer} that validates an {@link OAuth2RefreshTokenAuthenticationContext} + * and acts as the default + * {@link OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer) + * authentication validator} for the Refresh Token grant. + *

+ * The default implementation validates a DPoP proof if present and throws + * {@link OAuth2AuthenticationException} on failure. + *

+ * + * @author Andrey Litvitski + * @since 7.0.0 + * @see OAuth2RefreshTokenAuthenticationContext + * @see OAuth2RefreshTokenAuthenticationProvider#setAuthenticationValidator(Consumer) + */ +public final class OAuth2RefreshTokenAuthenticationValidator + implements Consumer { + + public static final Consumer DEFAULT_VALIDATOR = OAuth2RefreshTokenAuthenticationValidator::validateDefault; + + private final Consumer authenticationValidator = DEFAULT_VALIDATOR; + + @Override + public void accept(OAuth2RefreshTokenAuthenticationContext context) { + this.authenticationValidator.accept(context); + } + + private static void validateDefault(OAuth2RefreshTokenAuthenticationContext context) { + Jwt dPoPProof; + if (context.getDPoPProof() == null) { + dPoPProof = DPoPProofVerifier.verifyIfAvailable(context.getAuthentication()); + } + else { + dPoPProof = context.getDPoPProof(); + } + if (dPoPProof == null || !context.getClientPrincipal() + .getClientAuthenticationMethod() + .equals(ClientAuthenticationMethod.NONE)) { + return; + } + JWK jwk = null; + @SuppressWarnings("unchecked") + Map jwkJson = (Map) dPoPProof.getHeaders().get("jwk"); + try { + jwk = JWK.parse(jwkJson); + } + catch (Exception ignored) { + } + if (jwk == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, + "jwk header is missing or invalid.", null); + throw new OAuth2AuthenticationException(error); + } + + String jwkThumbprint; + try { + jwkThumbprint = jwk.computeThumbprint().toString(); + } + catch (Exception ex) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, + "Failed to compute SHA-256 Thumbprint for jwk.", null); + throw new OAuth2AuthenticationException(error); + } + + String jwkThumbprintClaim = null; + Map accessTokenClaimsMap = context.getAuthorization().getAccessToken().getClaims(); + ClaimAccessor accessTokenClaims = () -> accessTokenClaimsMap; + Map confirmationMethodClaim = accessTokenClaims.getClaimAsMap("cnf"); + if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) { + jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt"); + } + if (jwkThumbprintClaim == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jkt claim is missing.", null); + throw new OAuth2AuthenticationException(error); + } + + if (!jwkThumbprint.equals(jwkThumbprintClaim)) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jwk header is invalid.", null); + throw new OAuth2AuthenticationException(error); + } + } + +}