Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ describe('fetchAuthSession behavior for IdentityPools only', () => {
});

describe('fetchAuthSession behavior for UserPools only', () => {
let getTokensSpy: jest.SpyInstance;

beforeAll(() => {
jest
getTokensSpy = jest
.spyOn(cognitoUserPoolsTokenProvider, 'getTokens')
.mockImplementation(async () => {
return {
Expand Down Expand Up @@ -136,4 +138,28 @@ describe('fetchAuthSession behavior for UserPools only', () => {
userSub: '1234567890',
});
});

test('should pass clientMetadata option to token provider', async () => {
Amplify.configure(
{
Auth: {
Cognito: {
userPoolClientId: 'userPoolCliendIdValue',
userPoolId: 'userpoolIdvalue',
},
},
},
{
Auth: {
credentialsProvider: cognitoCredentialsProvider,
tokenProvider: cognitoUserPoolsTokenProvider,
},
},
);

const clientMetadata = { 'app-version': '1.0.0' };
await fetchAuthSession({ clientMetadata });

expect(getTokensSpy).toHaveBeenCalledWith({ clientMetadata });
});
});
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import { decodeJWT } from '@aws-amplify/core/internals/utils';

import { refreshAuthTokens } from '../../../src/providers/cognito/utils/refreshAuthTokens';
Expand Down Expand Up @@ -60,6 +63,7 @@ describe('refreshToken', () => {
});

it('should refresh token', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const expectedOutput = {
accessToken: decodeJWT(mockAccessToken),
idToken: decodeJWT(mockAccessToken),
Expand All @@ -82,6 +86,7 @@ describe('refreshToken', () => {
},
},
username: mockedUsername,
clientMetadata,
});

// stringify and re-parse for JWT equality
Expand All @@ -93,6 +98,7 @@ describe('refreshToken', () => {
expect.objectContaining({
ClientId: 'aaaaaaaaaaaa',
RefreshToken: mockedRefreshToken,
ClientMetadata: clientMetadata,
}),
);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,81 @@ describe('TokenOrchestrator', () => {
expect(tokens?.accessToken).toEqual(validAuthTokens.accessToken);
});
});

describe('setClientMetadataProvider', () => {
it('should use clientMetadataProvider for token refresh', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const clientMetadataProvider = () => Promise.resolve(clientMetadata);

mockTokenRefresher.mockResolvedValue({
accessToken: { payload: {} },
idToken: { payload: {} },
clockDrift: 0,
refreshToken: 'newRefreshToken',
username: 'testuser',
});

tokenOrchestrator.setTokenRefresher(mockTokenRefresher);
tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore);
tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);

mockAuthTokenStore.loadTokens.mockResolvedValue({
accessToken: { payload: { exp: 1 } },
idToken: { payload: { exp: 1 } },
clockDrift: 0,
refreshToken: 'refreshToken',
username: 'testuser',
});
mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser');

await tokenOrchestrator.getTokens({ forceRefresh: true });

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
clientMetadata,
}),
);
});

it('should prioritize clientMetadata from options over clientMetadataProvider', async () => {
const providerMetadata = { 'app-version': '1.0.0' };
const optionsMetadata = {
'app-version': '2.0.0',
'device-id': 'test-device',
};
const clientMetadataProvider = () => Promise.resolve(providerMetadata);

mockTokenRefresher.mockResolvedValue({
accessToken: { payload: {} },
idToken: { payload: {} },
clockDrift: 0,
refreshToken: 'newRefreshToken',
username: 'testuser',
});

tokenOrchestrator.setTokenRefresher(mockTokenRefresher);
tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore);
tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);

mockAuthTokenStore.loadTokens.mockResolvedValue({
accessToken: { payload: { exp: 1 } },
idToken: { payload: { exp: 1 } },
clockDrift: 0,
refreshToken: 'refreshToken',
username: 'testuser',
});
mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser');

await tokenOrchestrator.getTokens({
forceRefresh: true,
clientMetadata: optionsMetadata,
});

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
clientMetadata: optionsMetadata,
}),
);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
defaultStorage,
Expand All @@ -28,16 +29,20 @@ export class CognitoUserPoolsTokenProvider
this.tokenOrchestrator.setTokenRefresher(refreshAuthTokens);
}

getTokens(
{ forceRefresh }: FetchAuthSessionOptions = { forceRefresh: false },
): Promise<AuthTokens | null> {
return this.tokenOrchestrator.getTokens({ forceRefresh });
getTokens(options: FetchAuthSessionOptions = {}): Promise<AuthTokens | null> {
return this.tokenOrchestrator.getTokens(options);
}

setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void {
this.authTokenStore.setKeyValueStorage(keyValueStorage);
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);
}

setAuthConfig(authConfig: AuthConfig) {
this.authTokenStore.setAuthConfig(authConfig);
this.tokenOrchestrator.setAuthConfig(authConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
CognitoUserPoolConfig,
FetchAuthSessionOptions,
Hub,
Expand All @@ -19,7 +20,7 @@ import { assertServiceError } from '../../../errors/utils/assertServiceError';
import { AuthError } from '../../../errors/AuthError';
import { oAuthStore } from '../utils/oauth/oAuthStore';
import { addInflightPromise } from '../utils/oauth/inflightPromise';
import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

import {
AuthTokenOrchestrator,
Expand All @@ -32,6 +33,7 @@ import {

export class TokenOrchestrator implements AuthTokenOrchestrator {
private authConfig?: AuthConfig;
clientMetadataProvider?: ClientMetadataProvider;
tokenStore?: AuthTokenStore;
tokenRefresher?: TokenRefresher;
inflightPromise: Promise<void> | undefined;
Expand Down Expand Up @@ -94,6 +96,12 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
return this.tokenRefresher;
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.clientMetadataProvider = clientMetadataProvider;
}

async getTokens(
options?: FetchAuthSessionOptions,
): Promise<
Expand Down Expand Up @@ -130,6 +138,8 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
tokens = await this.refreshTokens({
tokens,
username,
clientMetadata:
options?.clientMetadata ?? (await this.clientMetadataProvider?.()),
});

if (tokens === null) {
Expand All @@ -147,16 +157,19 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
private async refreshTokens({
tokens,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens | null> {
try {
const { signInDetails } = tokens;
const newTokens = await this.getTokenRefresher()({
tokens,
authConfig: this.authConfig,
username,
clientMetadata,
});
newTokens.signInDetails = signInDetails;
await this.setTokens({ tokens: newTokens });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
TokenProvider,
} from '@aws-amplify/core';

import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

export type TokenRefresher = ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}) => Promise<CognitoAuthTokens>;

export type AuthKeys<AuthKey extends string> = Record<AuthKey, string>;
Expand Down Expand Up @@ -66,6 +69,9 @@ export interface AuthTokenOrchestrator {
export interface CognitoUserPoolTokenProviderType extends TokenProvider {
setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void;
setAuthConfig(authConfig: AuthConfig): void;
setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void;
}

export type CognitoAuthTokens = AuthTokens & {
Expand Down
2 changes: 1 addition & 1 deletion packages/auth/src/providers/cognito/types/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const cognitoHostedUIIdentityProviderMap: Record<AuthProvider, string> =
/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;
export type { ClientMetadata } from '@aws-amplify/core';

/**
* Allowed values for preferredChallenge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ import { assertAuthTokensWithRefreshToken } from '../utils/types';
import { AuthError } from '../../../errors/AuthError';
import { createCognitoUserPoolEndpointResolver } from '../factories';
import { createGetTokensFromRefreshTokenClient } from '../../../foundation/factories/serviceClients/cognitoIdentityProvider';
import { ClientMetadata } from '../types';

const refreshAuthTokensFunction: TokenRefresher = async ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens> => {
assertTokenProviderConfig(authConfig?.Cognito);
const { userPoolId, userPoolClientId, userPoolEndpoint } = authConfig.Cognito;
Expand All @@ -41,6 +44,7 @@ const refreshAuthTokensFunction: TokenRefresher = async ({
ClientId: userPoolClientId,
RefreshToken: tokens.refreshToken,
DeviceKey: tokens.deviceMetadata?.deviceKey,
ClientMetadata: clientMetadata,
},
);

Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export {
OAuthConfig,
CognitoUserPoolConfig,
JWT,
ClientMetadata,
ClientMetadataProvider,
} from './singleton/Auth/types';
export { decodeJWT } from './singleton/Auth/utils';
export {
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/singleton/Auth/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import { StrictUnion } from '../../types';
import { AtLeastOne } from '../types';

/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;

/**
* Function type for providing client metadata for Cognito operations
*/
export type ClientMetadataProvider = () => Promise<ClientMetadata>;

// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/safe-json-parse.ts
// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/jwt-model.ts
interface JwtPayloadStandardFields {
Expand Down Expand Up @@ -66,6 +76,7 @@ export interface TokenProvider {

export interface FetchAuthSessionOptions {
forceRefresh?: boolean;
clientMetadata?: ClientMetadata;
}

export interface AuthTokens {
Expand Down
Loading