Skip to content

Commit

Permalink
feat: include user attributes in user migration lambda call
Browse files Browse the repository at this point in the history
  • Loading branch information
jagregory committed Apr 13, 2020
1 parent bc27b86 commit dabed92
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 57 deletions.
78 changes: 44 additions & 34 deletions src/services/lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,46 +40,54 @@ describe("Lambda function invoker", () => {
triggerSource: "UserMigration_Authentication",
username: "username",
userPoolId: "userPoolId",
userAttributes: {},
})
).rejects.toEqual(new Error("UserMigration trigger not configured"));
});

it("invokes the lambda", async () => {
const response = Promise.resolve({
StatusCode: 200,
Payload: '{ "some": "json" }',
});
mockLambdaClient.invoke.mockReturnValue({
promise: () => response,
} as any);
const lambda = createLambda(
{
UserMigration: "MyLambdaName",
},
mockLambdaClient
);

await lambda.invoke("UserMigration", {
clientId: "clientId",
password: "password",
triggerSource: "UserMigration_Authentication",
username: "username",
userPoolId: "userPoolId",
});
describe("UserMigration_Authentication", () => {
it("invokes the lambda", async () => {
const response = Promise.resolve({
StatusCode: 200,
Payload: '{ "some": "json" }',
});
mockLambdaClient.invoke.mockReturnValue({
promise: () => response,
} as any);
const lambda = createLambda(
{
UserMigration: "MyLambdaName",
},
mockLambdaClient
);

expect(mockLambdaClient.invoke).toHaveBeenCalledWith({
FunctionName: "MyLambdaName",
InvocationType: "RequestResponse",
Payload: JSON.stringify({
version: 0,
userName: "username",
callerContext: { awsSdkVersion: "2.656.0", clientId: "clientId" },
region: "local",
userPoolId: "userPoolId",
await lambda.invoke("UserMigration", {
clientId: "clientId",
password: "password",
triggerSource: "UserMigration_Authentication",
request: { userAttributes: {} },
response: {},
}),
username: "username",
userPoolId: "userPoolId",
userAttributes: {},
});

expect(mockLambdaClient.invoke).toHaveBeenCalledWith({
FunctionName: "MyLambdaName",
InvocationType: "RequestResponse",
Payload: JSON.stringify({
version: 0,
userName: "username",
callerContext: { awsSdkVersion: "2.656.0", clientId: "clientId" },
region: "local",
userPoolId: "userPoolId",
triggerSource: "UserMigration_Authentication",
request: {
userAttributes: {},
password: "password",
validationData: {},
},
response: {},
}),
});
});
});

Expand All @@ -105,6 +113,7 @@ describe("Lambda function invoker", () => {
triggerSource: "UserMigration_Authentication",
username: "username",
userPoolId: "userPoolId",
userAttributes: {},
});

expect(result).toEqual("value");
Expand All @@ -131,6 +140,7 @@ describe("Lambda function invoker", () => {
triggerSource: "UserMigration_Authentication",
username: "username",
userPoolId: "userPoolId",
userAttributes: {},
});

expect(result).toEqual("value");
Expand Down
8 changes: 7 additions & 1 deletion src/services/lambda.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ interface UserMigrationEvent {
clientId: string;
username: string;
password: string;
userAttributes: Record<string, string>;
triggerSource: "UserMigration_Authentication";
}

Expand Down Expand Up @@ -49,11 +50,16 @@ export const createLambda: CreateLambda = (config, lambdaClient) => ({
userPoolId: event.userPoolId,
triggerSource: event.triggerSource,
request: {
userAttributes: {},
userAttributes: event.userAttributes,
},
response: {},
};

if (event.triggerSource === "UserMigration_Authentication") {
lambdaEvent.request.password = event.password;
lambdaEvent.request.validationData = {};
}

console.log(
`Invoking "${lambdaName}" with event`,
JSON.stringify(lambdaEvent, undefined, 2)
Expand Down
12 changes: 12 additions & 0 deletions src/services/triggers/userMigration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ describe("UserMigration trigger", () => {
clientId: "clientId",
username: "username",
password: "password",
userAttributes: [],
})
).rejects.toBeInstanceOf(NotAuthorizedError);
});
Expand All @@ -51,6 +52,16 @@ describe("UserMigration trigger", () => {
clientId: "clientId",
username: "[email protected]",
password: "password",
userAttributes: [{ Name: "email", Value: "[email protected]" }],
});

expect(mockLambda.invoke).toHaveBeenCalledWith("UserMigration", {
clientId: "clientId",
password: "password",
triggerSource: "UserMigration_Authentication",
userAttributes: { email: "[email protected]" },
userPoolId: "userPoolId",
username: "[email protected]",
});

expect(user).not.toBeNull();
Expand All @@ -73,6 +84,7 @@ describe("UserMigration trigger", () => {
clientId: "clientId",
username: "[email protected]",
password: "password",
userAttributes: [],
});

expect(user).not.toBeNull();
Expand Down
7 changes: 5 additions & 2 deletions src/services/triggers/userMigration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ import * as uuid from "uuid";
import { NotAuthorizedError } from "../../errors";
import { UserPool } from "../index";
import { CognitoUserPoolResponse, Lambda } from "../lambda";
import { User } from "../userPool";
import { attributesToRecord, User, UserAttribute } from "../userPool";

export type UserMigrationTrigger = (params: {
userPoolId: string;
clientId: string;
username: string;
password: string;
userAttributes: readonly UserAttribute[];
}) => Promise<User>;

export const UserMigration = ({
Expand All @@ -22,6 +23,7 @@ export const UserMigration = ({
clientId,
username,
password,
userAttributes,
}): Promise<User> => {
let result: CognitoUserPoolResponse;

Expand All @@ -32,13 +34,14 @@ export const UserMigration = ({
username,
password,
triggerSource: "UserMigration_Authentication",
userAttributes: attributesToRecord(userAttributes),
});
} catch (ex) {
throw new NotAuthorizedError();
}

const user: User = {
Attributes: [{ Name: "email", Value: username }],
Attributes: userAttributes,
Enabled: true,
Password: password,
UserCreateDate: new Date().getTime(),
Expand Down
55 changes: 54 additions & 1 deletion src/services/userPool.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import { CreateDataStore, DataStore } from "./dataStore";
import { createUserPool, UserPool } from "./userPool";
import {
attributesInclude,
attributesIncludeMatch,
attributesToRecord,
createUserPool,
UserAttribute,
UserPool,
} from "./userPool";

describe("User Pool", () => {
let mockDataStore: jest.Mocked<DataStore>;
Expand Down Expand Up @@ -154,4 +161,50 @@ describe("User Pool", () => {
}
);
});

describe("attributes", () => {
const attributes: readonly UserAttribute[] = [
{ Name: "sub", Value: "uuid" },
{ Name: "email", Value: "[email protected]" },
];

describe("attributesIncludeMatch", () => {
it("returns true if attribute exists in collection with matching name and value", () => {
expect(
attributesIncludeMatch("email", "[email protected]", attributes)
).toBe(true);
});

it("returns false if attribute exists in collection with matching name but not matching value", () => {
expect(attributesIncludeMatch("email", "invalid", attributes)).toBe(
false
);
});

it("returns false if attribute does not exist in collection", () => {
expect(attributesIncludeMatch("invalid", "invalid", attributes)).toBe(
false
);
});
});

describe("attributesInclude", () => {
it("returns true if attribute exists in collection with matching name", () => {
expect(attributesInclude("email", attributes)).toBe(true);
});

it("returns false if attribute does not exist in collection", () => {
expect(attributesInclude("invalid", attributes)).toBe(false);
});
});

describe("attributesToRecord", () => {
it("converts the attributes to a record", () => {
expect(attributesToRecord(attributes)).toEqual({
email: "[email protected]",
sub: "uuid",
});
});
});
});
});
51 changes: 32 additions & 19 deletions src/services/userPool.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
import { CreateDataStore } from "./dataStore";

export interface UserAttribute {
Name: "sub" | "email" | "phone_number" | "preferred_username" | string;
Value: string;
}

export const attributesIncludeMatch = (
attributeName: string,
attributeValue: string,
attributes: readonly UserAttribute[]
) =>
!!attributes.find(
(x) => x.Name === attributeName && x.Value === attributeValue
);

export const attributesInclude = (
attributeName: string,
attributes: readonly UserAttribute[]
) => !!attributes.find((x) => x.Name === attributeName);

export const attributesToRecord = (
attributes: readonly UserAttribute[]
): Record<string, string> =>
attributes.reduce((acc, attr) => ({ ...acc, [attr.Name]: attr.Value }), {});

export interface User {
Username: string;
UserCreateDate: number;
UserLastModifiedDate: number;
Enabled: boolean;
UserStatus: "CONFIRMED" | "UNCONFIRMED" | "RESET_REQUIRED";
Attributes: readonly {
Name: "sub" | "email" | "phone_number" | "preferred_username" | string;
Value: string;
}[];
Attributes: readonly UserAttribute[];

// extra attributes for Cognito Local
Password: string;
Expand Down Expand Up @@ -38,17 +59,6 @@ export const createUserPool = async (
Options: options,
});

const attributeEquals = (
attributeName: string,
attributeValue: string,
user: User
) =>
!!user.Attributes.find(
(x) => x.Name === attributeName && x.Value === attributeValue
);
const hasAttribute = (attributeName: string, user: User) =>
!!user.Attributes.find((x) => x.Name === attributeName);

return {
async getUserPoolIdForClientId() {
// TODO: support user pool to client mapping
Expand All @@ -69,17 +79,20 @@ export const createUserPool = async (
const users = await dataStore.get<Record<string, User>>("Users");

for (const user of Object.values(users ?? {})) {
if (attributeEquals("sub", username, user)) {
if (attributesIncludeMatch("sub", username, user.Attributes)) {
return user;
}

if (aliasEmailEnabled && attributeEquals("email", username, user)) {
if (
aliasEmailEnabled &&
attributesIncludeMatch("email", username, user.Attributes)
) {
return user;
}

if (
aliasPhoneNumberEnabled &&
attributeEquals("phone_number", username, user)
attributesIncludeMatch("phone_number", username, user.Attributes)
) {
return user;
}
Expand All @@ -91,7 +104,7 @@ export const createUserPool = async (
async saveUser(user) {
console.log("saveUser", user);

const attributes = hasAttribute("sub", user)
const attributes = attributesInclude("sub", user.Attributes)
? user.Attributes
: [{ Name: "sub", Value: user.Username }, ...user.Attributes];

Expand Down

0 comments on commit dabed92

Please sign in to comment.