diff --git a/packages/auth/__tests__/providers/cognito/fetchAuthSession.test.ts b/packages/auth/__tests__/providers/cognito/fetchAuthSession.test.ts index a04c6ccdce3..b3af67039d1 100644 --- a/packages/auth/__tests__/providers/cognito/fetchAuthSession.test.ts +++ b/packages/auth/__tests__/providers/cognito/fetchAuthSession.test.ts @@ -75,8 +75,10 @@ describe('fetchAuthSession behavior for IdentityPools only', () => { }); describe('fetchAuthSession behavior for UserPools only', () => { + let getTokensSpy: jest.SpyInstance; + beforeAll(() => { - jest + getTokensSpy = jest .spyOn(cognitoUserPoolsTokenProvider, 'getTokens') .mockImplementation(async () => { return { @@ -136,4 +138,28 @@ describe('fetchAuthSession behavior for UserPools only', () => { userSub: '1234567890', }); }); + + test('should pass clientMetadata option to token provider', async () => { + Amplify.configure( + { + Auth: { + Cognito: { + userPoolClientId: 'userPoolCliendIdValue', + userPoolId: 'userpoolIdvalue', + }, + }, + }, + { + Auth: { + credentialsProvider: cognitoCredentialsProvider, + tokenProvider: cognitoUserPoolsTokenProvider, + }, + }, + ); + + const clientMetadata = { 'app-version': '1.0.0' }; + await fetchAuthSession({ clientMetadata }); + + expect(getTokensSpy).toHaveBeenCalledWith({ clientMetadata }); + }); }); diff --git a/packages/auth/__tests__/providers/cognito/refreshToken.test.ts b/packages/auth/__tests__/providers/cognito/refreshToken.test.ts index a298e1aa377..c84dfefe0d7 100644 --- a/packages/auth/__tests__/providers/cognito/refreshToken.test.ts +++ b/packages/auth/__tests__/providers/cognito/refreshToken.test.ts @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + import { decodeJWT } from '@aws-amplify/core/internals/utils'; import { refreshAuthTokens } from '../../../src/providers/cognito/utils/refreshAuthTokens'; @@ -60,6 +63,7 @@ describe('refreshToken', () => { }); it('should refresh token', async () => { + const clientMetadata = { 'app-version': '1.0.0' }; const expectedOutput = { accessToken: decodeJWT(mockAccessToken), idToken: decodeJWT(mockAccessToken), @@ -82,6 +86,7 @@ describe('refreshToken', () => { }, }, username: mockedUsername, + clientMetadata, }); // stringify and re-parse for JWT equality @@ -93,6 +98,7 @@ describe('refreshToken', () => { expect.objectContaining({ ClientId: 'aaaaaaaaaaaa', RefreshToken: mockedRefreshToken, + ClientMetadata: clientMetadata, }), ); }); diff --git a/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts b/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts index 8906d8d7eed..704f08175d4 100644 --- a/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts +++ b/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts @@ -151,4 +151,81 @@ describe('TokenOrchestrator', () => { expect(tokens?.accessToken).toEqual(validAuthTokens.accessToken); }); }); + + describe('setClientMetadataProvider', () => { + it('should use clientMetadataProvider for token refresh', async () => { + const clientMetadata = { 'app-version': '1.0.0' }; + const clientMetadataProvider = () => Promise.resolve(clientMetadata); + + mockTokenRefresher.mockResolvedValue({ + accessToken: { payload: {} }, + idToken: { payload: {} }, + clockDrift: 0, + refreshToken: 'newRefreshToken', + username: 'testuser', + }); + + tokenOrchestrator.setTokenRefresher(mockTokenRefresher); + tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore); + tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider); + + mockAuthTokenStore.loadTokens.mockResolvedValue({ + accessToken: { payload: { exp: 1 } }, + idToken: { payload: { exp: 1 } }, + clockDrift: 0, + refreshToken: 'refreshToken', + username: 'testuser', + }); + mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser'); + + await tokenOrchestrator.getTokens({ forceRefresh: true }); + + expect(mockTokenRefresher).toHaveBeenCalledWith( + expect.objectContaining({ + clientMetadata, + }), + ); + }); + + it('should prioritize clientMetadata from options over clientMetadataProvider', async () => { + const providerMetadata = { 'app-version': '1.0.0' }; + const optionsMetadata = { + 'app-version': '2.0.0', + 'device-id': 'test-device', + }; + const clientMetadataProvider = () => Promise.resolve(providerMetadata); + + mockTokenRefresher.mockResolvedValue({ + accessToken: { payload: {} }, + idToken: { payload: {} }, + clockDrift: 0, + refreshToken: 'newRefreshToken', + username: 'testuser', + }); + + tokenOrchestrator.setTokenRefresher(mockTokenRefresher); + tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore); + tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider); + + mockAuthTokenStore.loadTokens.mockResolvedValue({ + accessToken: { payload: { exp: 1 } }, + idToken: { payload: { exp: 1 } }, + clockDrift: 0, + refreshToken: 'refreshToken', + username: 'testuser', + }); + mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser'); + + await tokenOrchestrator.getTokens({ + forceRefresh: true, + clientMetadata: optionsMetadata, + }); + + expect(mockTokenRefresher).toHaveBeenCalledWith( + expect.objectContaining({ + clientMetadata: optionsMetadata, + }), + ); + }); + }); }); diff --git a/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts b/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts index 43f0f8a2d8c..527917ee531 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts @@ -4,6 +4,7 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, FetchAuthSessionOptions, KeyValueStorageInterface, defaultStorage, @@ -28,16 +29,20 @@ export class CognitoUserPoolsTokenProvider this.tokenOrchestrator.setTokenRefresher(refreshAuthTokens); } - getTokens( - { forceRefresh }: FetchAuthSessionOptions = { forceRefresh: false }, - ): Promise { - return this.tokenOrchestrator.getTokens({ forceRefresh }); + getTokens(options: FetchAuthSessionOptions = {}): Promise { + return this.tokenOrchestrator.getTokens(options); } setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void { this.authTokenStore.setKeyValueStorage(keyValueStorage); } + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void { + this.tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider); + } + setAuthConfig(authConfig: AuthConfig) { this.authTokenStore.setAuthConfig(authConfig); this.tokenOrchestrator.setAuthConfig(authConfig); diff --git a/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts b/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts index 3f8027d2596..851db00846f 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts @@ -3,6 +3,7 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, CognitoUserPoolConfig, FetchAuthSessionOptions, Hub, @@ -19,7 +20,7 @@ import { assertServiceError } from '../../../errors/utils/assertServiceError'; import { AuthError } from '../../../errors/AuthError'; import { oAuthStore } from '../utils/oauth/oAuthStore'; import { addInflightPromise } from '../utils/oauth/inflightPromise'; -import { CognitoAuthSignInDetails } from '../types'; +import { ClientMetadata, CognitoAuthSignInDetails } from '../types'; import { AuthTokenOrchestrator, @@ -32,6 +33,7 @@ import { export class TokenOrchestrator implements AuthTokenOrchestrator { private authConfig?: AuthConfig; + clientMetadataProvider?: ClientMetadataProvider; tokenStore?: AuthTokenStore; tokenRefresher?: TokenRefresher; inflightPromise: Promise | undefined; @@ -94,6 +96,12 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { return this.tokenRefresher; } + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void { + this.clientMetadataProvider = clientMetadataProvider; + } + async getTokens( options?: FetchAuthSessionOptions, ): Promise< @@ -130,6 +138,8 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { tokens = await this.refreshTokens({ tokens, username, + clientMetadata: + options?.clientMetadata ?? (await this.clientMetadataProvider?.()), }); if (tokens === null) { @@ -147,9 +157,11 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { private async refreshTokens({ tokens, username, + clientMetadata, }: { tokens: CognitoAuthTokens; username: string; + clientMetadata?: ClientMetadata; }): Promise { try { const { signInDetails } = tokens; @@ -157,6 +169,7 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { tokens, authConfig: this.authConfig, username, + clientMetadata, }); newTokens.signInDetails = signInDetails; await this.setTokens({ tokens: newTokens }); diff --git a/packages/auth/src/providers/cognito/tokenProvider/types.ts b/packages/auth/src/providers/cognito/tokenProvider/types.ts index 5f381b42016..4ac6973b60e 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/types.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/types.ts @@ -3,21 +3,24 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, FetchAuthSessionOptions, KeyValueStorageInterface, TokenProvider, } from '@aws-amplify/core'; -import { CognitoAuthSignInDetails } from '../types'; +import { ClientMetadata, CognitoAuthSignInDetails } from '../types'; export type TokenRefresher = ({ tokens, authConfig, username, + clientMetadata, }: { tokens: CognitoAuthTokens; authConfig?: AuthConfig; username: string; + clientMetadata?: ClientMetadata; }) => Promise; export type AuthKeys = Record; @@ -66,6 +69,9 @@ export interface AuthTokenOrchestrator { export interface CognitoUserPoolTokenProviderType extends TokenProvider { setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void; setAuthConfig(authConfig: AuthConfig): void; + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void; } export type CognitoAuthTokens = AuthTokens & { diff --git a/packages/auth/src/providers/cognito/types/models.ts b/packages/auth/src/providers/cognito/types/models.ts index 1b113ef1720..8bd212d117a 100644 --- a/packages/auth/src/providers/cognito/types/models.ts +++ b/packages/auth/src/providers/cognito/types/models.ts @@ -38,7 +38,7 @@ export const cognitoHostedUIIdentityProviderMap: Record = /** * Arbitrary key/value pairs that may be passed as part of certain Cognito requests */ -export type ClientMetadata = Record; +export type { ClientMetadata } from '@aws-amplify/core'; /** * Allowed values for preferredChallenge diff --git a/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts b/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts index e29bd460b19..54f30f7a796 100644 --- a/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts +++ b/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts @@ -14,15 +14,18 @@ import { assertAuthTokensWithRefreshToken } from '../utils/types'; import { AuthError } from '../../../errors/AuthError'; import { createCognitoUserPoolEndpointResolver } from '../factories'; import { createGetTokensFromRefreshTokenClient } from '../../../foundation/factories/serviceClients/cognitoIdentityProvider'; +import { ClientMetadata } from '../types'; const refreshAuthTokensFunction: TokenRefresher = async ({ tokens, authConfig, username, + clientMetadata, }: { tokens: CognitoAuthTokens; authConfig?: AuthConfig; username: string; + clientMetadata?: ClientMetadata; }): Promise => { assertTokenProviderConfig(authConfig?.Cognito); const { userPoolId, userPoolClientId, userPoolEndpoint } = authConfig.Cognito; @@ -41,6 +44,7 @@ const refreshAuthTokensFunction: TokenRefresher = async ({ ClientId: userPoolClientId, RefreshToken: tokens.refreshToken, DeviceKey: tokens.deviceMetadata?.deviceKey, + ClientMetadata: clientMetadata, }, ); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index b37829169b3..d65eac33003 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,8 @@ export { OAuthConfig, CognitoUserPoolConfig, JWT, + ClientMetadata, + ClientMetadataProvider, } from './singleton/Auth/types'; export { decodeJWT } from './singleton/Auth/utils'; export { diff --git a/packages/core/src/singleton/Auth/types.ts b/packages/core/src/singleton/Auth/types.ts index 8fa811251b1..18dafd9761b 100644 --- a/packages/core/src/singleton/Auth/types.ts +++ b/packages/core/src/singleton/Auth/types.ts @@ -4,6 +4,16 @@ import { StrictUnion } from '../../types'; import { AtLeastOne } from '../types'; +/** + * Arbitrary key/value pairs that may be passed as part of certain Cognito requests + */ +export type ClientMetadata = Record; + +/** + * Function type for providing client metadata for Cognito operations + */ +export type ClientMetadataProvider = () => Promise; + // From https://github.com/awslabs/aws-jwt-verify/blob/main/src/safe-json-parse.ts // From https://github.com/awslabs/aws-jwt-verify/blob/main/src/jwt-model.ts interface JwtPayloadStandardFields { @@ -66,6 +76,7 @@ export interface TokenProvider { export interface FetchAuthSessionOptions { forceRefresh?: boolean; + clientMetadata?: ClientMetadata; } export interface AuthTokens {