Skip to content

Commit

Permalink
fix(wallet): prevent cointype mismatch in getTokenBalancesRegistry (u…
Browse files Browse the repository at this point in the history
…plift to 1.60.x) (#20608)

fix(wallet): prevent cointype mismatch in getTokenBalancesRegistry (#20589)

* fix(wallet): prevent cointype mismatch in getTokenBalancesRegistry

* review(JL): fixes
  • Loading branch information
onyb authored Oct 25, 2023
1 parent 477eeed commit 6b4bcf2
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 192 deletions.
42 changes: 11 additions & 31 deletions components/brave_wallet_ui/common/hooks/use-balances-fetcher.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,40 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
import { useMemo } from 'react'
import { skipToken } from '@reduxjs/toolkit/query/react'

// hooks
import { useSafeWalletSelector } from './use-safe-selector'
import { useGetTokenBalancesRegistryQuery } from '../slices/api.slice'
import { querySubscriptionOptions60s } from '../slices/constants'

// Types / constants
import { BraveWallet, CoinTypes } from '../../constants/types'
// Types
import { BraveWallet } from '../../constants/types'
import { WalletSelectors } from '../selectors'

// Utils
import { networkSupportsAccount } from '../../utils/network-utils'

interface Arg {
networks: BraveWallet.NetworkInfo[]
accounts: BraveWallet.AccountInfo[]
}

const coinTypesMapping = {
[BraveWallet.CoinType.SOL]: CoinTypes.SOL,
[BraveWallet.CoinType.ETH]: CoinTypes.ETH,
[BraveWallet.CoinType.FIL]: CoinTypes.FIL,
[BraveWallet.CoinType.BTC]: CoinTypes.BTC,
}

export const useBalancesFetcher = (arg: Arg | typeof skipToken) => {
// redux
const isWalletLocked = useSafeWalletSelector(WalletSelectors.isWalletLocked)
const isWalletCreated = useSafeWalletSelector(WalletSelectors.isWalletCreated)
const hasInitialized = useSafeWalletSelector(WalletSelectors.hasInitialized)

const args = useMemo(() => arg !== skipToken && arg.accounts && arg.networks
? arg.accounts.flatMap(account =>
arg.networks
.filter(network => networkSupportsAccount(network, account.accountId))
.map(network => ({
accountId: account.accountId,
chainId: network.chainId,
coin: coinTypesMapping[network.coin]
}))
.filter(({ coin }) => coin !== undefined)
)
: skipToken,
[arg]
)

return useGetTokenBalancesRegistryQuery(
args !== skipToken &&
arg !== skipToken &&
!isWalletLocked &&
isWalletCreated &&
hasInitialized
? args
hasInitialized &&
arg.accounts.length &&
arg.networks.length
? {
accounts: arg.accounts.map(account => account.accountId),
networks: arg.networks
.map(({ chainId, coin }) => ({ chainId, coin }))
}
: skipToken,
querySubscriptionOptions60s
)
Expand Down
1 change: 0 additions & 1 deletion components/brave_wallet_ui/common/slices/api-base.slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export function createWalletApiBase () {
...cacher.defaultTags,
'AccountInfos',
'AccountTokenCurrentBalance',
'CombinedTokenBalanceForAllAccounts',
'TokenBalancesForChainId',
'TokenBalances',
'HardwareAccountDiscoveryBalance',
Expand Down
236 changes: 107 additions & 129 deletions components/brave_wallet_ui/common/slices/api.slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,15 @@ import { coingeckoEndpoints } from './endpoints/coingecko-endpoints'
import {
tokenSuggestionsEndpoints //
} from './endpoints/token_suggestions.endpoints'
import {
coinTypesMapping //
} from './constants'

type GetAccountTokenCurrentBalanceArg = {
accountId: BraveWallet.AccountId
token: GetBlockchainTokenIdArg
}

type GetCombinedTokenBalanceForAllAccountsArg =
GetAccountTokenCurrentBalanceArg['token'] &
Pick<BraveWallet.BlockchainToken, 'coin'>

type GetSPLTokenBalancesForAddressAndChainIdArg = {
accountId: BraveWallet.AccountId
chainId: string
Expand All @@ -145,13 +144,14 @@ type GetTokenBalancesForChainIdArg =
| GetTokenBalancesForAddressAndChainIdArg

type GetTokenBalancesRegistryArg = {
accountId: BraveWallet.AccountId,
chainId: string,
coin:
| typeof CoinTypes.ETH
| typeof CoinTypes.SOL
| typeof CoinTypes.FIL
| typeof CoinTypes.BTC
accounts: BraveWallet.AccountId[]
networks: Array<
Pick<
BraveWallet.NetworkInfo,
| 'chainId'
| 'coin'
>
>
}

type GetHardwareAccountDiscoveryBalanceArg = {
Expand Down Expand Up @@ -914,7 +914,6 @@ export function createWalletApi () {
token.chainId,
accountId
)

if (errorMessage || balance === null) {
console.log(`getBalance error: ${errorMessage}`)
return {
Expand Down Expand Up @@ -990,64 +989,6 @@ export function createWalletApi () {
token
)
}),
getCombinedTokenBalanceForAllAccounts: query<
string,
GetCombinedTokenBalanceForAllAccountsArg
>({
queryFn: async (asset, { dispatch }, extraOptions, baseQuery) => {
const { cache } = baseQuery(undefined)
const { accounts } = await cache.getAllAccounts()

const accountsForAssetCoinType = accounts.filter(
(account) => account.accountId.coin === asset.coin
)

const accountTokenBalancesForChainId: string[] = await mapLimit(
accountsForAssetCoinType,
10,
async (account: BraveWallet.AccountInfo) => {
const balance = await dispatch(
walletApi.endpoints.getAccountTokenCurrentBalance.initiate({
accountId: account.accountId,
token: {
chainId: asset.chainId,
coin: asset.coin,
contractAddress: asset.contractAddress,
isErc721: asset.isErc721,
isNft: asset.isNft,
tokenId: asset.tokenId
}
})
).unwrap()

return balance ?? ''
}
)

// return a '0' balance until user has created a FIL or SOL account
if (accountTokenBalancesForChainId.length === 0) {
return {
data: '0'
}
}

const aggregatedAmount = accountTokenBalancesForChainId.reduce(
function (totalBalance, itemBalance) {
return itemBalance !== ''
? new Amount(totalBalance).plus(itemBalance).format()
: itemBalance ?? '0'
},
'0'
)

return {
data: aggregatedAmount
}
},
providesTags: cacher.cacheByBlockchainTokenArg(
'CombinedTokenBalanceForAllAccounts'
)
}),
getTokenBalancesForChainId: query<
TokenBalancesRegistry,
GetTokenBalancesForChainIdArg[]
Expand Down Expand Up @@ -1283,7 +1224,7 @@ export function createWalletApi () {
}),
getTokenBalancesRegistry: query<
TokenBalancesRegistry,
GetTokenBalancesRegistryArg[]
GetTokenBalancesRegistryArg
>({
queryFn: async (
args,
Expand All @@ -1292,64 +1233,100 @@ export function createWalletApi () {
baseQuery
) => {
try {
const { getUserTokensRegistry, getTokenBalancesForChainId } =
walletApi.endpoints

const userTokens = await dispatch(
walletApi.endpoints.getUserTokensRegistry.initiate()
getUserTokensRegistry.initiate()
).unwrap()

const registryArray = await mapLimit(
args,
10,
async (arg: GetTokenBalancesRegistryArg) => {
const partialRegistry: TokenBalancesRegistry = await dispatch(
walletApi.endpoints.getTokenBalancesForChainId.initiate(
arg.coin === CoinTypes.SOL
? [
{
accountId: arg.accountId,
coin: arg.coin,
chainId: arg.chainId
}
]
: [
{
accountId: arg.accountId,
coin: arg.coin,
chainId: arg.chainId,
tokens: getEntitiesListFromEntityState(
userTokens,
userTokens.idsByChainId[
networkEntityAdapter.selectId({
coin: arg.coin,
chainId: arg.chainId
})
]
)
}
],
{
forceRefetch: true
const tokenBalancesRegistryArray = await mapLimit(
args.accounts,
3,
async (accountId: BraveWallet.AccountId) => {
const networks = args.networks.filter(
(network) => network.coin === accountId.coin)

if (networks.length === 0) {
return {}
}

const registryArray = await mapLimit(
networks,
3,
async (
network: Pick<
BraveWallet.NetworkInfo,
'coin' | 'chainId'
>
) => {
const partialRegistryQuery = dispatch(
getTokenBalancesForChainId.initiate(
network.coin === CoinTypes.SOL
? [
{
accountId,
coin: CoinTypes.SOL,
chainId: network.chainId
}
]
: coinTypesMapping[network.coin] ? [
{
accountId,
coin: coinTypesMapping[network.coin],
chainId: network.chainId,
tokens:
getEntitiesListFromEntityState(
userTokens,
userTokens.idsByChainId[
networkEntityAdapter.selectId({
coin: network.coin,
chainId: network.chainId
})
]
)
}
]: [],
{
forceRefetch: true
}
)
)

try {
const partialRegistry: TokenBalancesRegistry =
await partialRegistryQuery.unwrap()
return partialRegistry
} catch (error) {
console.error(error)
return {}
}
)
).unwrap()
}
)

return partialRegistry
return registryArray.reduce((acc, curr) => {
for (const [uniqueKey, chainIds] of Object.entries(
curr
)) {
if (!acc.hasOwnProperty(uniqueKey)) {
acc[uniqueKey] = chainIds
} else {
acc[uniqueKey] = {
...acc[uniqueKey],
...chainIds
}
}
}
return acc
}, {})
}
)

return {
data: registryArray.reduce((acc, curr) => {
for (const [uniqueKey, chainIds] of Object.entries(curr)) {
if (!acc.hasOwnProperty(uniqueKey)) {
acc[uniqueKey] = chainIds
} else {
acc[uniqueKey] = {
...acc[uniqueKey],
...chainIds
}
}
}
return acc
}, {})
data: tokenBalancesRegistryArray.reduce(
(acc, curr) => ({ ...acc, ...curr }),
{}
)
}
} catch (error) {
return handleEndpointError(
Expand All @@ -1359,15 +1336,18 @@ export function createWalletApi () {
)
}
},
providesTags: (result, err, args) =>
err
? ['TokenBalances', 'UNKNOWN_ERROR']
: args.map((arg) => ({
providesTags: (result, err, args) => err
? ['TokenBalances', 'UNKNOWN_ERROR']
: args.accounts.flatMap((accountId) => {
const networkKeys = args.networks
.filter((network) => accountId.coin === network.coin)
.map((network) => network.chainId)
const accountKey = getAccountBalancesKey(accountId)
return networkKeys.map((networkKey) => ({
type: 'TokenBalances',
id: `${getAccountBalancesKey(arg.accountId)}-${arg.coin}-${
arg.chainId
}`
id: `${accountKey}-${networkKey}`
}))
})
}),
getHardwareAccountDiscoveryBalance: query<
string,
Expand Down Expand Up @@ -3095,7 +3075,6 @@ export const {
useGetAutopinEnabledQuery,
useGetBuyUrlQuery,
useGetCoingeckoIdQuery,
useGetCombinedTokenBalanceForAllAccountsQuery,
useGetDefaultFiatCurrencyQuery,
useGetEthTokenDecimalsQuery,
useGetEthTokenSymbolQuery,
Expand Down Expand Up @@ -3141,7 +3120,6 @@ export const {
useLazyGetAccountTokenCurrentBalanceQuery,
useLazyGetAddressByteCodeQuery,
useLazyGetBuyUrlQuery,
useLazyGetCombinedTokenBalanceForAllAccountsQuery,
useLazyGetDefaultFiatCurrencyQuery,
useLazyGetERC721MetadataQuery,
useLazyGetEVMTransactionSimulationQuery,
Expand Down
Loading

0 comments on commit 6b4bcf2

Please sign in to comment.