diff --git a/lib/msal-node/src/cache/distributed/DistributedCachePlugin.ts b/lib/msal-node/src/cache/distributed/DistributedCachePlugin.ts index 9157a5b816..9ab6c4817f 100644 --- a/lib/msal-node/src/cache/distributed/DistributedCachePlugin.ts +++ b/lib/msal-node/src/cache/distributed/DistributedCachePlugin.ts @@ -52,17 +52,21 @@ export class DistributedCachePlugin implements ICachePlugin { AccountEntity.isAccountEntity(value as object) ); + let partitionKey: string; + if (accountEntities.length > 0) { const accountEntity = accountEntities[0] as AccountEntity; - const partitionKey = await this.partitionManager.extractKey( + partitionKey = await this.partitionManager.extractKey( accountEntity ); - - await this.client.set( - partitionKey, - cacheContext.tokenCache.serialize() - ); + } else { + partitionKey = await this.partitionManager.getKey(); } + + await this.client.set( + partitionKey, + cacheContext.tokenCache.serialize() + ); } } } diff --git a/lib/msal-node/test/cache/cacheConstants.ts b/lib/msal-node/test/cache/cacheConstants.ts index afef211915..fbf17099d2 100644 --- a/lib/msal-node/test/cache/cacheConstants.ts +++ b/lib/msal-node/test/cache/cacheConstants.ts @@ -7,9 +7,6 @@ import { CacheHelpers, } from "@azure/msal-common"; -export const MOCK_PARTITION_KEY = "mock_partition_key"; -export const MOCK_CACHE_STRING = "mock_cache_string"; - // mock tokens export const mockAccessTokenEntity_1: AccessTokenEntity = { homeAccountId: "uid.utid", @@ -140,3 +137,30 @@ export const MockCache = { amdt: mockCache.createMockAmdt(), amdtKey: CacheHelpers.generateAppMetadataKey(mockCache.createMockAmdt()), }; + +// mock cache storage +export const MOCK_PARTITION_KEY = MockCache.acc.homeAccountId; +export const MOCK_CACHE_STORAGE = { + [MOCK_PARTITION_KEY]: { + Account: { + [`${MOCK_PARTITION_KEY}-login.windows.net`]: mockAccountEntity, + }, + IdToken: { + [`${MOCK_PARTITION_KEY}-login.windows.net-idtoken`]: + mockIdTokenEntity, + }, + AccessToken: { + [`${MOCK_PARTITION_KEY}-login.windows.net-accesstoken`]: + mockAccessTokenEntity_1, + }, + RefreshToken: { + [`${MOCK_PARTITION_KEY}-login.windows.net-refreshtoken`]: + mockRefreshTokenEntity, + }, + AppMetadata: { + [`${MOCK_PARTITION_KEY}-login.windows.net-appmetadata`]: + mockAppMetaDataEntity, + }, + }, +}; +export const MOCK_CACHE_STRING = () => JSON.stringify(MOCK_CACHE_STORAGE); diff --git a/lib/msal-node/test/cache/distributed/DistributedCachePlugin.spec.ts b/lib/msal-node/test/cache/distributed/DistributedCachePlugin.spec.ts index 2fbb405e8e..fc492a494d 100644 --- a/lib/msal-node/test/cache/distributed/DistributedCachePlugin.spec.ts +++ b/lib/msal-node/test/cache/distributed/DistributedCachePlugin.spec.ts @@ -1,32 +1,50 @@ -import { DistributedCachePlugin } from "../../../src/cache/distributed/DistributedCachePlugin"; +import { DistributedCachePlugin } from "../../../src/cache/distributed/DistributedCachePlugin.js"; import { AccountEntity, ICachePlugin, TokenCacheContext, } from "@azure/msal-common"; -import { TokenCache } from "../../../src/cache/TokenCache"; +import { TokenCache } from "../../../src/cache/TokenCache.js"; import { MockCache, MOCK_CACHE_STRING, MOCK_PARTITION_KEY, -} from "../cacheConstants"; -import { IPartitionManager } from "../../../src/cache/distributed/IPartitionManager"; -import { ICacheClient } from "../../../src/cache/distributed/ICacheClient"; + MOCK_CACHE_STORAGE, +} from "../cacheConstants.js"; +import { IPartitionManager } from "../../../src/cache/distributed/IPartitionManager.js"; +import { ICacheClient } from "../../../src/cache/distributed/ICacheClient.js"; describe("Distributed Cache Plugin Tests for msal-node", () => { let distributedCachePluginInstance: ICachePlugin; + let cacheHasChanged = true; const tokenCache = { - serialize: jest - .fn() - .mockImplementation((): string => MOCK_CACHE_STRING), + serialize: jest.fn().mockImplementation((): string => { + cacheHasChanged = false; + return MOCK_CACHE_STRING(); + }), deserialize: jest.fn(), getKVStore: jest.fn().mockImplementation(() => ({ [MockCache.idTKey]: MockCache.idT, [MockCache.accKey]: MockCache.acc, })), + getAllAccounts: jest + .fn() + .mockImplementation(async () => [MockCache.acc]), + removeAccount: jest.fn().mockImplementation(async () => { + const cacheStorage = MOCK_CACHE_STORAGE; + + (cacheStorage[MOCK_PARTITION_KEY].Account as any) = {}; + (cacheStorage[MOCK_PARTITION_KEY].IdToken as any) = {}; + (cacheStorage[MOCK_PARTITION_KEY].AccessToken as any) = {}; + (cacheStorage[MOCK_PARTITION_KEY].RefreshToken as any) = {}; + (cacheStorage[MOCK_PARTITION_KEY].AppMetadata as any) = {}; + + cacheHasChanged = true; + }), + hasChanged: jest.fn().mockImplementation(() => cacheHasChanged), } as unknown as TokenCache; const tokenCacheContext = { - cacheHasChanged: true, + cacheHasChanged, tokenCache, } as unknown as TokenCacheContext; const partitionManager = { @@ -46,7 +64,7 @@ describe("Distributed Cache Plugin Tests for msal-node", () => { get: jest .fn() .mockImplementation( - async (_: string): Promise => MOCK_CACHE_STRING + async (_: string): Promise => MOCK_CACHE_STRING() ), set: jest .fn() @@ -75,7 +93,9 @@ describe("Distributed Cache Plugin Tests for msal-node", () => { // Confirm the intended effects expect(partitionManager.getKey).toHaveBeenCalled(); expect(cacheClient.get).toHaveBeenCalledWith(MOCK_PARTITION_KEY); - expect(tokenCache.deserialize).toHaveBeenCalledWith(MOCK_CACHE_STRING); + expect(tokenCache.deserialize).toHaveBeenCalledWith( + MOCK_CACHE_STRING() + ); }); it("properly handles afterCacheAccess", async () => { @@ -90,7 +110,29 @@ describe("Distributed Cache Plugin Tests for msal-node", () => { expect(tokenCache.serialize).toHaveBeenCalled(); expect(cacheClient.set).toHaveBeenCalledWith( MockCache.acc.homeAccountId, - MOCK_CACHE_STRING + MOCK_CACHE_STRING() ); }); + + it("removes the specified account from the cache", async () => { + const accounts = await tokenCache.getAllAccounts(); + await tokenCache.removeAccount(accounts[0]); + expect(tokenCache.hasChanged()).toEqual(true); + + const tokenCacheAfterSerialization = JSON.parse(tokenCache.serialize()); + + expect(tokenCache.hasChanged()).toEqual(false); + expect( + tokenCacheAfterSerialization[MOCK_PARTITION_KEY].Account + ).toEqual({}); + expect( + tokenCacheAfterSerialization[MOCK_PARTITION_KEY].RefreshToken + ).toEqual({}); + expect( + tokenCacheAfterSerialization[MOCK_PARTITION_KEY].AccessToken + ).toEqual({}); + expect( + tokenCacheAfterSerialization[MOCK_PARTITION_KEY].IdToken + ).toEqual({}); + }); });