diff --git a/framework/src/modules/fee/cc_method.ts b/framework/src/modules/fee/cc_method.ts index c88eabc8ba4..986b5455637 100644 --- a/framework/src/modules/fee/cc_method.ts +++ b/framework/src/modules/fee/cc_method.ts @@ -14,7 +14,7 @@ import { BaseCCMethod } from '../interoperability/base_cc_method'; import { CrossChainMessageContext } from '../interoperability/types'; -import { InteroperabilityMethod, TokenMethod } from './types'; +import { InteroperabilityMethod, ModuleConfig, TokenMethod } from './types'; import { NamedRegistry } from '../named_registry'; import { CONTEXT_STORE_KEY_AVAILABLE_CCM_FEE } from './constants'; import { getContextStoreBigInt } from '../../state_machine'; @@ -27,12 +27,17 @@ export class FeeInteroperableMethod extends BaseCCMethod { private _interopMethod!: InteroperabilityMethod; private _tokenMethod!: TokenMethod; + private _feePoolAddress?: Buffer; public constructor(stores: NamedRegistry, events: NamedRegistry, moduleName: string) { super(stores, events); this._moduleName = moduleName; } + public init(config: ModuleConfig): void { + this._feePoolAddress = config.feePoolAddress; + } + public addDependencies(interoperabilityMethod: InteroperabilityMethod, tokenMethod: TokenMethod) { this._interopMethod = interoperabilityMethod; this._tokenMethod = tokenMethod; @@ -71,12 +76,27 @@ export class FeeInteroperableMethod extends BaseCCMethod { CONTEXT_STORE_KEY_AVAILABLE_CCM_FEE, ); const burntAmount = ctx.ccm.fee - availableFee; - await this._tokenMethod.burn( - ctx.getMethodContext(), - ctx.transaction.senderAddress, - messageTokenID, - burntAmount, - ); + + if ( + this._feePoolAddress && + (await this._tokenMethod.userAccountExists(ctx, this._feePoolAddress, messageTokenID)) + ) { + await this._tokenMethod.transfer( + ctx.getMethodContext(), + ctx.transaction.senderAddress, + this._feePoolAddress, + messageTokenID, + burntAmount, + ); + } else { + await this._tokenMethod.burn( + ctx.getMethodContext(), + ctx.transaction.senderAddress, + messageTokenID, + burntAmount, + ); + } + const { ccmID } = getEncodedCCMAndID(ctx.ccm); this.events.get(RelayerFeeProcessedEvent).log(ctx, { diff --git a/framework/src/modules/fee/module.ts b/framework/src/modules/fee/module.ts index c4cf27478e0..937c16e21f2 100644 --- a/framework/src/modules/fee/module.ts +++ b/framework/src/modules/fee/module.ts @@ -96,6 +96,7 @@ export class FeeModule extends BaseInteroperableModule { }; this.method.init(moduleConfig); this.endpoint.init(moduleConfig); + this.crossChainMethod.init(moduleConfig); this._tokenID = moduleConfig.feeTokenID; this._minFeePerByte = moduleConfig.minFeePerByte; diff --git a/framework/test/unit/modules/fee/cc_method.spec.ts b/framework/test/unit/modules/fee/cc_method.spec.ts index 99bd46d21a9..ba6c96a622f 100644 --- a/framework/test/unit/modules/fee/cc_method.spec.ts +++ b/framework/test/unit/modules/fee/cc_method.spec.ts @@ -33,12 +33,14 @@ describe('FeeInteroperableMethod', () => { status: 1, }; const messageFeeTokenID = Buffer.from('0000000000000002', 'hex'); + const feePoolAddress = utils.getRandomBytes(20); let feeMethod: FeeInteroperableMethod; let context: CrossChainMessageContext; beforeEach(() => { feeMethod = new FeeInteroperableMethod(feeModule.stores, feeModule.events, feeModule.name); + feeMethod.init({ feePoolAddress } as any); feeMethod.addDependencies( { getMessageFeeTokenID: jest.fn().mockResolvedValue(messageFeeTokenID), @@ -83,10 +85,11 @@ describe('FeeInteroperableMethod', () => { beforeEach(async () => { jest.spyOn(feeMethod['events'].get(RelayerFeeProcessedEvent), 'log'); context.contextStore.set(CONTEXT_STORE_KEY_AVAILABLE_CCM_FEE, availableFee); - await feeMethod.afterCrossChainCommandExecute(context); }); - it('should unlock ccm fee from sender', () => { + it('should unlock ccm fee from sender', async () => { + await feeMethod.afterCrossChainCommandExecute(context); + expect(feeMethod['_tokenMethod'].unlock).toHaveBeenCalledWith( expect.anything(), context.transaction.senderAddress, @@ -96,7 +99,23 @@ describe('FeeInteroperableMethod', () => { ); }); - it('should burn the used fee', () => { + it('should transfer the used fee to fee pool address if it exists and is initialized for the cross-chain message fee token', async () => { + feeMethod['_tokenMethod'].userAccountExists = jest.fn().mockResolvedValue(true); + await feeMethod.afterCrossChainCommandExecute(context); + + expect(feeMethod['_tokenMethod'].transfer).toHaveBeenCalledWith( + expect.anything(), + context.transaction.senderAddress, + feePoolAddress, + messageFeeTokenID, + ccm.fee - availableFee, + ); + }); + + it('should burn the used fee', async () => { + feeMethod['_tokenMethod'].userAccountExists = jest.fn().mockResolvedValue(false); + await feeMethod.afterCrossChainCommandExecute(context); + expect(feeMethod['_tokenMethod'].burn).toHaveBeenCalledWith( expect.anything(), context.transaction.senderAddress, @@ -105,7 +124,9 @@ describe('FeeInteroperableMethod', () => { ); }); - it('should log event', () => { + it('should log event', async () => { + await feeMethod.afterCrossChainCommandExecute(context); + expect(context.eventQueue.getEvents()).toHaveLength(1); expect(feeMethod['events'].get(RelayerFeeProcessedEvent).log).toHaveBeenCalledWith( expect.anything(), @@ -118,7 +139,9 @@ describe('FeeInteroperableMethod', () => { ); }); - it('should reset the context store', () => { + it('should reset the context store', async () => { + await feeMethod.afterCrossChainCommandExecute(context); + expect(context.contextStore.get(CONTEXT_STORE_KEY_AVAILABLE_CCM_FEE)).toBeUndefined(); }); });