diff --git a/__tests__/rpcProvider.test.ts b/__tests__/rpcProvider.test.ts index 3da84184c..f54069776 100644 --- a/__tests__/rpcProvider.test.ts +++ b/__tests__/rpcProvider.test.ts @@ -6,6 +6,7 @@ import { CallData, Contract, RPC, + RpcProvider, TransactionExecutionStatus, stark, waitForTransactionOptions, @@ -334,4 +335,25 @@ describeIfRpc('RPCProvider', () => { expect(syncingStats).toMatchSchemaRef('GetSyncingStatsResponse'); }); }); + + describeIfRpc('Fallback node', () => { + beforeAll(() => {}); + test('Ensure fallback node is used when base node fails', async () => { + const fallbackProvider: RpcProvider = new RpcProvider({ + nodeUrl: 'Incorrect URL', + fallbackNodeUrls: [process.env.TEST_RPC_URL!], + }); + const blockNumber = await fallbackProvider.getBlockNumber(); + expect(typeof blockNumber).toBe('number'); + }); + }); + + test('Ensure fallback nodes are run until any of them succeeds', async () => { + const fallbackProvider: RpcProvider = new RpcProvider({ + nodeUrl: 'Incorrect URL', + fallbackNodeUrls: ['Another incorrect URL', process.env.TEST_RPC_URL!], + }); + const blockNumber = await fallbackProvider.getBlockNumber(); + expect(typeof blockNumber).toBe('number'); + }); }); diff --git a/src/channel/rpc_0_6.ts b/src/channel/rpc_0_6.ts index 05952e2bf..d1872b2ae 100644 --- a/src/channel/rpc_0_6.ts +++ b/src/channel/rpc_0_6.ts @@ -19,6 +19,7 @@ import { waitForTransactionOptions, } from '../types'; import { ETransactionVersion } from '../types/api'; +import assert from '../utils/assert'; import { CallData } from '../utils/calldata'; import { isSierra } from '../utils/contract'; import fetch from '../utils/fetchPonyfill'; @@ -36,8 +37,6 @@ const defaultOptions = { }; export class RpcChannel { - public nodeUrl: string; - public headers: object; readonly retries: number; @@ -52,15 +51,18 @@ export class RpcChannel { readonly waitMode: Boolean; // behave like web2 rpc and return when tx is processed + public nodeUrls: string[]; + constructor(optionsOrProvider?: RpcProviderOptions) { - const { nodeUrl, retries, headers, blockIdentifier, chainId, waitMode } = + const { nodeUrl, retries, headers, blockIdentifier, chainId, waitMode, fallbackNodeUrls } = optionsOrProvider || {}; + let primaryNode; if (Object.values(NetworkName).includes(nodeUrl as NetworkName)) { - this.nodeUrl = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); + primaryNode = getDefaultNodeUrl(nodeUrl as NetworkName, optionsOrProvider?.default); } else if (nodeUrl) { - this.nodeUrl = nodeUrl; + primaryNode = nodeUrl; } else { - this.nodeUrl = getDefaultNodeUrl(undefined, optionsOrProvider?.default); + primaryNode = getDefaultNodeUrl(undefined, optionsOrProvider?.default); } this.retries = retries || defaultOptions.retries; this.headers = { ...defaultOptions.headers, ...headers }; @@ -68,22 +70,69 @@ export class RpcChannel { this.chainId = chainId; this.waitMode = waitMode || false; this.requestId = 0; + this.nodeUrls = [primaryNode, ...(fallbackNodeUrls || [])]; + } + + get nodeUrl() { + return this.nodeUrls[0]; + } + + set nodeUrl(url) { + this.nodeUrls[0] = url; } - public fetch(method: string, params?: object, id: string | number = 0) { + public fetch(url: string, method: string, params?: object, id: string | number = 0) { const rpcRequestBody: RPC.JRPC.RequestBody = { id, jsonrpc: '2.0', method, ...(params && { params }), }; - return fetch(this.nodeUrl, { + return fetch(url, { method: 'POST', body: stringify(rpcRequestBody), headers: this.headers as Record, }); } + protected async setPrimaryNode(node: string, index: number) { + // eslint-disable-next-line prefer-destructuring + this.nodeUrls[index] = this.nodeUrls[0]; + this.nodeUrls[0] = node; + } + + protected async fetchResponse(method: string, params?: object) { + const nodes = [...this.nodeUrls]; + const lastNode = nodes.pop(); + assert(lastNode !== undefined); + let response; + for (let i = 0; i < nodes.length - 1; i += 1) { + try { + // eslint-disable-next-line no-await-in-loop + response = await this.fetch(nodes[i], method, params); + + if (response.ok) { + this.setPrimaryNode(nodes[i], i); + return response; + } + } catch (error: any) { + /* empty */ + } + } + + // If all nodes fail return anything the last one returned + try { + response = await this.fetch(lastNode, method, params); + if (response.ok) { + this.setPrimaryNode(lastNode, this.nodeUrls.length - 1); + } + return response; + } catch (error: any) { + this.errorHandler(method, params, error?.response?.data, error); + throw error; + } + } + protected errorHandler(method: string, params: any, rpcError?: RPC.JRPC.Error, otherError?: any) { if (rpcError) { const { code, message, data } = rpcError; @@ -104,9 +153,10 @@ export class RpcChannel { method: T, params?: RPC.Methods[T]['params'] ): Promise { + const response = await this.fetchResponse(method, params); + try { - const rawResult = await this.fetch(method, params, (this.requestId += 1)); - const { error, result } = await rawResult.json(); + const { error, result } = await response.json(); this.errorHandler(method, params, error); return result as RPC.Methods[T]['result']; } catch (error: any) { diff --git a/src/provider/rpc.ts b/src/provider/rpc.ts index f71ee8f48..8cc82311f 100644 --- a/src/provider/rpc.ts +++ b/src/provider/rpc.ts @@ -43,7 +43,7 @@ export class RpcProvider implements ProviderInterface { } public fetch(method: string, params?: object, id: string | number = 0) { - return this.channel.fetch(method, params, id); + return this.channel.fetch(this.channel.nodeUrl, method, params, id); } public async getChainId() { diff --git a/src/types/provider/configuration.ts b/src/types/provider/configuration.ts index b4d614f6d..ce0bf46c8 100644 --- a/src/types/provider/configuration.ts +++ b/src/types/provider/configuration.ts @@ -11,4 +11,5 @@ export type RpcProviderOptions = { chainId?: StarknetChainId; default?: boolean; waitMode?: boolean; + fallbackNodeUrls?: string[]; };