From 996dcde4ebfa7b8d1d4825345aaaec645610144d Mon Sep 17 00:00:00 2001 From: James Gregory Date: Mon, 30 May 2022 19:42:03 +1000 Subject: [PATCH] fix: include groups in tokens Fixes #176 --- src/__tests__/mockUserPoolService.ts | 1 + src/bin/start.ts | 2 +- src/services/tokenGenerator.test.ts | 74 ++++++++++++++++++++++ src/services/tokenGenerator.ts | 63 ++++++++++-------- src/services/userPoolService.test.ts | 50 ++++++++++++++- src/services/userPoolService.ts | 21 ++++++ src/targets/adminInitiateAuth.test.ts | 4 ++ src/targets/adminInitiateAuth.ts | 6 ++ src/targets/initiateAuth.test.ts | 6 ++ src/targets/initiateAuth.ts | 6 ++ src/targets/respondToAuthChallenge.test.ts | 4 ++ src/targets/respondToAuthChallenge.ts | 3 + 12 files changed, 211 insertions(+), 29 deletions(-) diff --git a/src/__tests__/mockUserPoolService.ts b/src/__tests__/mockUserPoolService.ts index 56d6690e..860aded9 100644 --- a/src/__tests__/mockUserPoolService.ts +++ b/src/__tests__/mockUserPoolService.ts @@ -14,6 +14,7 @@ export const newMockUserPoolService = ( getUserByRefreshToken: jest.fn(), getUserByUsername: jest.fn(), listGroups: jest.fn(), + listUserGroupMembership: jest.fn(), listUsers: jest.fn(), options: config, removeUserFromGroup: jest.fn(), diff --git a/src/bin/start.ts b/src/bin/start.ts index fabec0b0..8547af88 100644 --- a/src/bin/start.ts +++ b/src/bin/start.ts @@ -14,7 +14,7 @@ const logger = Pino( singleLine: true, messageFormat: (log, messageKey) => `${log["reqId"] ?? "NONE"} ${log["target"] ?? "NONE"} ${log[messageKey]}`, - }) as any + }) as any // eslint-disable-line @typescript-eslint/no-explicit-any ); createDefaultServer(logger) diff --git a/src/services/tokenGenerator.test.ts b/src/services/tokenGenerator.test.ts index 5efa38ef..fa5ee92a 100644 --- a/src/services/tokenGenerator.test.ts +++ b/src/services/tokenGenerator.test.ts @@ -45,6 +45,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], TDB.appClient(), { client: "metadata" }, "RefreshTokens" @@ -80,6 +81,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], TDB.appClient(), { client: "metadata" }, "RefreshTokens" @@ -112,6 +114,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], TDB.appClient(), { client: "metadata" }, "RefreshTokens" @@ -162,6 +165,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], TDB.appClient(), { client: "metadata" }, "RefreshTokens" @@ -183,6 +187,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], userPoolClient, { client: "metadata" }, "RefreshTokens" @@ -243,6 +248,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], userPoolClient, { client: "metadata" }, "RefreshTokens" @@ -278,6 +284,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], userPoolClient, { client: "metadata" }, "RefreshTokens" @@ -309,6 +316,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], userPoolClient, { client: "metadata" }, "RefreshTokens" @@ -344,6 +352,7 @@ describe("JwtTokenGenerator", () => { const tokens = await tokenGenerator.generate( TestContext, user, + [], userPoolClient, { client: "metadata" }, "RefreshTokens" @@ -361,4 +370,69 @@ describe("JwtTokenGenerator", () => { }); }); }); + + describe("groups", () => { + it("does not include a cognito:groups claim if the user has no groups", async () => { + mockTriggers.enabled.mockReturnValue(false); + + const userPoolClient = TDB.appClient({ + AccessTokenValidity: 10, + IdTokenValidity: 20, + RefreshTokenValidity: 30, + TokenValidityUnits: { + AccessToken: "seconds", + IdToken: "minutes", + RefreshToken: "hours", + }, + }); + + const tokens = await tokenGenerator.generate( + TestContext, + user, + [], + userPoolClient, + { client: "metadata" }, + "RefreshTokens" + ); + + expect( + (jwt.decode(tokens.AccessToken) as any)["cognito:groups"] + ).toBeUndefined(); + expect( + (jwt.decode(tokens.IdToken) as any)["cognito:groups"] + ).toBeUndefined(); + }); + + it("includes a cognito:groups claim with the user's groups", async () => { + mockTriggers.enabled.mockReturnValue(false); + + const userPoolClient = TDB.appClient({ + AccessTokenValidity: 10, + IdTokenValidity: 20, + RefreshTokenValidity: 30, + TokenValidityUnits: { + AccessToken: "seconds", + IdToken: "minutes", + RefreshToken: "hours", + }, + }); + + const tokens = await tokenGenerator.generate( + TestContext, + user, + ["group1", "group2"], + userPoolClient, + { client: "metadata" }, + "RefreshTokens" + ); + + expect((jwt.decode(tokens.AccessToken) as any)["cognito:groups"]).toEqual( + ["group1", "group2"] + ); + expect((jwt.decode(tokens.IdToken) as any)["cognito:groups"]).toEqual([ + "group1", + "group2", + ]); + }); + }); }); diff --git a/src/services/tokenGenerator.ts b/src/services/tokenGenerator.ts index 6c3dcfac..a9527937 100644 --- a/src/services/tokenGenerator.ts +++ b/src/services/tokenGenerator.ts @@ -58,10 +58,15 @@ const RESERVED_CLAIMS = [ "token_use", ]; +type RawToken = Record< + string, + string | number | boolean | undefined | readonly string[] +>; + const applyTokenOverrides = ( - token: Record, + token: RawToken, overrides: TokenOverrides -): Record => { +): RawToken => { // TODO: support group overrides const claimsToSuppress = (overrides?.claimsToSuppress ?? []).filter( @@ -89,6 +94,7 @@ export interface TokenGenerator { generate( ctx: Context, user: User, + userGroups: readonly string[], userPoolClient: AppClient, clientMetadata: Record | undefined, source: @@ -124,6 +130,7 @@ export class JwtTokenGenerator implements TokenGenerator { public async generate( ctx: Context, user: User, + userGroups: readonly string[], userPoolClient: AppClient, clientMetadata: Record | undefined, source: @@ -137,7 +144,18 @@ export class JwtTokenGenerator implements TokenGenerator { const authTime = Math.floor(this.clock.get().getTime() / 1000); const sub = attributeValue("sub", user.Attributes); - let idToken: Record = { + const accessToken: RawToken = { + auth_time: authTime, + client_id: userPoolClient.ClientId, + event_id: eventId, + iat: authTime, + jti: uuid.v4(), + scope: "aws.cognito.signin.user.admin", // TODO: scopes + sub, + token_use: "access", + username: user.Username, + }; + let idToken: RawToken = { "cognito:username": user.Username, auth_time: authTime, email: attributeValue("email", user.Attributes), @@ -152,6 +170,11 @@ export class JwtTokenGenerator implements TokenGenerator { ...attributesToRecord(customAttributes(user.Attributes)), }; + if (userGroups.length) { + accessToken["cognito:groups"] = userGroups; + idToken["cognito:groups"] = userGroups; + } + if (this.triggers.enabled("PreTokenGeneration")) { const result = await this.triggers.preTokenGeneration(ctx, { clientId: userPoolClient.ClientId, @@ -174,30 +197,16 @@ export class JwtTokenGenerator implements TokenGenerator { const issuer = `${this.tokenConfig.IssuerDomain}/${userPoolClient.UserPoolId}`; return { - AccessToken: jwt.sign( - { - auth_time: authTime, - client_id: userPoolClient.ClientId, - event_id: eventId, - iat: authTime, - jti: uuid.v4(), - scope: "aws.cognito.signin.user.admin", // TODO: scopes - sub, - token_use: "access", - username: user.Username, - }, - PrivateKey.pem, - { - algorithm: "RS256", - issuer, - expiresIn: formatExpiration( - userPoolClient.AccessTokenValidity, - userPoolClient.TokenValidityUnits?.AccessToken ?? "hours", - "24h" - ), - keyid: "CognitoLocal", - } - ), + AccessToken: jwt.sign(accessToken, PrivateKey.pem, { + algorithm: "RS256", + issuer, + expiresIn: formatExpiration( + userPoolClient.AccessTokenValidity, + userPoolClient.TokenValidityUnits?.AccessToken ?? "hours", + "24h" + ), + keyid: "CognitoLocal", + }), IdToken: jwt.sign(idToken, PrivateKey.pem, { algorithm: "RS256", issuer, diff --git a/src/services/userPoolService.test.ts b/src/services/userPoolService.test.ts index 380215a4..75db5fba 100644 --- a/src/services/userPoolService.test.ts +++ b/src/services/userPoolService.test.ts @@ -204,7 +204,7 @@ describe("User Pool Service", () => { ds.get.mockImplementation((ctx, key) => { if (key === "Groups") { - return Promise.resolve([]); + return Promise.resolve({}); } return Promise.resolve(null); @@ -774,4 +774,52 @@ describe("User Pool Service", () => { ); }); }); + + describe("listUserGroupMembership", () => { + it("returns all the groups that the user is a member", async () => { + const ds = newMockDataStore(); + const userPool = new UserPoolServiceImpl( + mockClientsDataStore, + clock, + ds, + { + Id: "test", + } + ); + + const user = TDB.user(); + const group1 = TDB.group({ + GroupName: "Group1", + members: [user.Username], + }); + const group2 = TDB.group({ + GroupName: "Group2", + members: [user.Username], + }); + const group3 = TDB.group({ + GroupName: "Group3", + members: [], + }); + const groups = { + [group1.GroupName]: group1, + [group2.GroupName]: group2, + [group3.GroupName]: group3, + }; + + ds.get.mockImplementation((ctx, key) => { + if (key === "Groups") { + return Promise.resolve(groups); + } + + return Promise.resolve(null); + }); + + const groupMembership = await userPool.listUserGroupMembership( + TestContext, + user + ); + + expect(groupMembership).toEqual([group1.GroupName, group2.GroupName]); + }); + }); }); diff --git a/src/services/userPoolService.ts b/src/services/userPoolService.ts index f83e73c0..d920b450 100644 --- a/src/services/userPoolService.ts +++ b/src/services/userPoolService.ts @@ -151,6 +151,7 @@ export interface UserPoolService { ): Promise; listGroups(ctx: Context): Promise; listUsers(ctx: Context): Promise; + listUserGroupMembership(ctx: Context, user: User): Promise; updateOptions(ctx: Context, userPool: UserPool): Promise; removeUserFromGroup(ctx: Context, group: Group, user: User): Promise; saveGroup(ctx: Context, group: Group): Promise; @@ -410,6 +411,26 @@ export class UserPoolServiceImpl implements UserPoolService { await this.dataStore.set(ctx, ["Groups", group.GroupName], group); } + async listUserGroupMembership( + ctx: Context, + user: User + ): Promise { + ctx.logger.debug( + { username: user.Username }, + "UserPoolServiceImpl.listUserGroupMembership" + ); + + // could optimise this by dual-writing group membership to both the group and + // the user records, but for an initial version this is probably fine unless + // you have a lot of groups + const groups = await this.listGroups(ctx); + + return groups + .filter((x) => x.members?.includes(user.Username)) + .map((x) => x.GroupName) + .sort((a, b) => a.localeCompare(b)); + } + async storeRefreshToken( ctx: Context, refreshToken: string, diff --git a/src/targets/adminInitiateAuth.test.ts b/src/targets/adminInitiateAuth.test.ts index 4c910ff4..6164d103 100644 --- a/src/targets/adminInitiateAuth.test.ts +++ b/src/targets/adminInitiateAuth.test.ts @@ -45,6 +45,7 @@ describe("AdminInitiateAuth target", () => { const existingUser = TDB.user(); mockUserPoolService.getUserByUsername.mockResolvedValue(existingUser); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const response = await adminInitiateAuth(TestContext, { AuthFlow: "ADMIN_USER_PASSWORD_AUTH", @@ -72,6 +73,7 @@ describe("AdminInitiateAuth target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, existingUser, + [], userPoolClient, { client: "metadata", @@ -92,6 +94,7 @@ describe("AdminInitiateAuth target", () => { }); mockUserPoolService.getUserByRefreshToken.mockResolvedValue(existingUser); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const response = await adminInitiateAuth(TestContext, { AuthFlow: "REFRESH_TOKEN_AUTH", @@ -120,6 +123,7 @@ describe("AdminInitiateAuth target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, existingUser, + [], userPoolClient, { client: "metadata", diff --git a/src/targets/adminInitiateAuth.ts b/src/targets/adminInitiateAuth.ts index 52425742..d8db177a 100644 --- a/src/targets/adminInitiateAuth.ts +++ b/src/targets/adminInitiateAuth.ts @@ -71,9 +71,12 @@ const adminUserPasswordAuthFlow = async ( throw new InvalidPasswordError(); } + const userGroups = await userPool.listUserGroupMembership(ctx, user); + const tokens = await services.tokenGenerator.generate( ctx, user, + userGroups, userPoolClient, req.ClientMetadata, "Authentication" @@ -124,9 +127,12 @@ const refreshTokenAuthFlow = async ( throw new NotAuthorizedError(); } + const userGroups = await userPool.listUserGroupMembership(ctx, user); + const tokens = await services.tokenGenerator.generate( ctx, user, + userGroups, userPoolClient, req.ClientMetadata, "RefreshTokens" diff --git a/src/targets/initiateAuth.test.ts b/src/targets/initiateAuth.test.ts index 128b13a9..340b592c 100644 --- a/src/targets/initiateAuth.test.ts +++ b/src/targets/initiateAuth.test.ts @@ -365,6 +365,7 @@ describe("InitiateAuth target", () => { IdToken: "id", RefreshToken: "refresh", }); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const output = await initiateAuth(TestContext, { ClientId: userPoolClient.ClientId, @@ -389,6 +390,7 @@ describe("InitiateAuth target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, user, + [], userPoolClient, undefined, "Authentication" @@ -411,6 +413,7 @@ describe("InitiateAuth target", () => { IdToken: "id", RefreshToken: "refresh", }); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const output = await initiateAuth(TestContext, { ClientId: userPoolClient.ClientId, @@ -433,6 +436,7 @@ describe("InitiateAuth target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, user, + [], userPoolClient, undefined, "Authentication" @@ -539,6 +543,7 @@ describe("InitiateAuth target", () => { }); mockUserPoolService.getUserByRefreshToken.mockResolvedValue(existingUser); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const response = await initiateAuth(TestContext, { AuthFlow: "REFRESH_TOKEN_AUTH", @@ -560,6 +565,7 @@ describe("InitiateAuth target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, existingUser, + [], userPoolClient, undefined, "RefreshTokens" diff --git a/src/targets/initiateAuth.ts b/src/targets/initiateAuth.ts index 28427f06..498846e9 100644 --- a/src/targets/initiateAuth.ts +++ b/src/targets/initiateAuth.ts @@ -97,9 +97,12 @@ const verifyPasswordChallenge = async ( userPoolClient: AppClient, services: InitiateAuthServices ): Promise => { + const userGroups = await userPool.listUserGroupMembership(ctx, user); + const tokens = await services.tokenGenerator.generate( ctx, user, + userGroups, userPoolClient, // The docs for the pre-token generation trigger only say that the ClientMetadata is passed as part of the // AdminRespondToAuthChallenge and RespondToAuthChallenge triggers. @@ -236,9 +239,12 @@ const refreshTokenAuthFlow = async ( throw new NotAuthorizedError(); } + const userGroups = await userPool.listUserGroupMembership(ctx, user); + const tokens = await services.tokenGenerator.generate( ctx, user, + userGroups, userPoolClient, // The docs for the pre-token generation trigger only say that the ClientMetadata is passed as part of the // AdminRespondToAuthChallenge and RespondToAuthChallenge triggers. diff --git a/src/targets/respondToAuthChallenge.test.ts b/src/targets/respondToAuthChallenge.test.ts index 32d90685..274b22b0 100644 --- a/src/targets/respondToAuthChallenge.test.ts +++ b/src/targets/respondToAuthChallenge.test.ts @@ -139,6 +139,7 @@ describe("RespondToAuthChallenge target", () => { IdToken: "id", RefreshToken: "refresh", }); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const output = await respondToAuthChallenge(TestContext, { ClientId: userPoolClient.ClientId, @@ -162,6 +163,7 @@ describe("RespondToAuthChallenge target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, user, + [], userPoolClient, { client: "metadata", @@ -274,6 +276,7 @@ describe("RespondToAuthChallenge target", () => { IdToken: "id", RefreshToken: "refresh", }); + mockUserPoolService.listUserGroupMembership.mockResolvedValue([]); const output = await respondToAuthChallenge(TestContext, { ClientId: userPoolClient.ClientId, @@ -297,6 +300,7 @@ describe("RespondToAuthChallenge target", () => { expect(mockTokenGenerator.generate).toHaveBeenCalledWith( TestContext, user, + [], userPoolClient, { client: "metadata" }, "Authentication" diff --git a/src/targets/respondToAuthChallenge.ts b/src/targets/respondToAuthChallenge.ts index 43ffb8c9..485da3ea 100644 --- a/src/targets/respondToAuthChallenge.ts +++ b/src/targets/respondToAuthChallenge.ts @@ -93,11 +93,14 @@ export const RespondToAuthChallenge = }); } + const userGroups = await userPool.listUserGroupMembership(ctx, user); + return { ChallengeParameters: {}, AuthenticationResult: await tokenGenerator.generate( ctx, user, + userGroups, userPoolClient, req.ClientMetadata, "Authentication"