From 1242ab7ad4d39f925dce00f5f0fa2dac46735f41 Mon Sep 17 00:00:00 2001 From: Sale Djenic Date: Fri, 30 Aug 2024 11:03:52 +0200 Subject: [PATCH] feat_: check for balances after each fees update --- services/wallet/router/router.go | 89 ++++++++++++++---------- services/wallet/router/router_test.go | 3 +- services/wallet/router/router_updates.go | 4 +- 3 files changed, 56 insertions(+), 40 deletions(-) diff --git a/services/wallet/router/router.go b/services/wallet/router/router.go index a5d19d45b95..601f86d36a7 100644 --- a/services/wallet/router/router.go +++ b/services/wallet/router/router.go @@ -74,6 +74,8 @@ type Router struct { pathProcessors map[string]pathprocessor.PathProcessor scheduler *async.Scheduler + activeBalanceMap sync.Map // map[string]*big.Int + activeRoutesMutex sync.Mutex activeRoutes *SuggestedRoutes @@ -119,6 +121,12 @@ func (r *Router) GetPathProcessors() map[string]pathprocessor.PathProcessor { return r.pathProcessors } +func (r *Router) SetTestBalanceMap(balanceMap map[string]*big.Int) { + for k, v := range balanceMap { + r.activeBalanceMap.Store(k, v) + } +} + func newSuggestedRoutes( uuid string, amountIn *big.Int, @@ -236,31 +244,29 @@ func (r *Router) SuggestedRoutes(ctx context.Context, input *requests.RouteInput return nil, errors.CreateErrorResponseFromError(err) } - balanceMap, err := r.getBalanceMapForTokenOnChains(ctx, input, selectedFromChains) + err = r.prepareBalanceMapForTokenOnChains(ctx, input, selectedFromChains) // return only if there are no balances, otherwise try to resolve the candidates for chains we know the balances for - if len(balanceMap) == 0 { + noBalanceOnAnyChain := true + r.activeBalanceMap.Range(func(key, value interface{}) bool { + if value.(*big.Int).Cmp(pathprocessor.ZeroBigIntValue) > 0 { + noBalanceOnAnyChain = false + return false + } + return true + }) + if noBalanceOnAnyChain { if err != nil { return nil, errors.CreateErrorResponseFromError(err) } - } else { - noBalanceOnAnyChain := true - for _, value := range balanceMap { - if value.Cmp(pathprocessor.ZeroBigIntValue) > 0 { - noBalanceOnAnyChain = false - break - } - } - if noBalanceOnAnyChain { - return nil, ErrNoPositiveBalance - } + return nil, ErrNoPositiveBalance } - candidates, processorErrors, err := r.resolveCandidates(ctx, input, selectedFromChains, selectedToChains, balanceMap) + candidates, processorErrors, err := r.resolveCandidates(ctx, input, selectedFromChains, selectedToChains) if err != nil { return nil, errors.CreateErrorResponseFromError(err) } - suggestedRoutes, err = r.resolveRoutes(ctx, input, candidates, balanceMap) + suggestedRoutes, err = r.resolveRoutes(ctx, input, candidates) if err == nil && (suggestedRoutes == nil || len(suggestedRoutes.Best) == 0) { // No best route found, but no error given. @@ -300,15 +306,16 @@ func (r *Router) SuggestedRoutes(ctx context.Context, input *requests.RouteInput return suggestedRoutes, mapError(err) } -// getBalanceMapForTokenOnChains returns the balance map for passed address, where the key is in format "chainID-tokenSymbol" and +// prepareBalanceMapForTokenOnChains prepares the balance map for passed address, where the key is in format "chainID-tokenSymbol" and // value is the balance of the token. Native token (EHT) is always added to the balance map. -func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network) (balanceMap map[string]*big.Int, err error) { +func (r *Router) prepareBalanceMapForTokenOnChains(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network) (err error) { if input.TestsMode { - return input.TestParams.BalanceMap, nil + for k, v := range input.TestParams.BalanceMap { + r.activeBalanceMap.Store(k, v) + } + return nil } - balanceMap = make(map[string]*big.Int) - chainError := func(chainId uint64, token string, intErr error) { if err == nil { err = fmt.Errorf("chain %d, token %s: %w", chainId, token, intErr) @@ -348,7 +355,7 @@ func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *reque } // add only if balance is not nil if tokenBalance != nil { - balanceMap[makeBalanceKey(chain.ChainID, token.Symbol)] = tokenBalance + r.activeBalanceMap.Store(makeBalanceKey(chain.ChainID, token.Symbol), tokenBalance) } if token.IsNative() { @@ -362,7 +369,7 @@ func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *reque } // add only if balance is not nil if nativeBalance != nil { - balanceMap[makeBalanceKey(chain.ChainID, nativeToken.Symbol)] = nativeBalance + r.activeBalanceMap.Store(makeBalanceKey(chain.ChainID, nativeToken.Symbol), nativeBalance) } } @@ -383,7 +390,7 @@ func (r *Router) getSelectedUnlockedChains(input *requests.RouteInputParams, pro } func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input *requests.RouteInputParams, amountToSplit *big.Int, processingChain *params.Network, - selectedFromChains []*params.Network, balanceMap map[string]*big.Int) map[uint64][]amountOption { + selectedFromChains []*params.Network) map[uint64][]amountOption { selectedButNotLockedChains := r.getSelectedUnlockedChains(input, processingChain, selectedFromChains) crossChainAmountOptions := make(map[uint64][]amountOption) @@ -393,7 +400,12 @@ func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input tokenBalance *big.Int ) - if tokenBalance, ok = balanceMap[makeBalanceKey(chain.ChainID, input.TokenID)]; !ok { + value, ok := r.activeBalanceMap.Load(makeBalanceKey(chain.ChainID, input.TokenID)) + if !ok { + continue + } + tokenBalance, ok = value.(*big.Int) + if !ok { continue } @@ -419,8 +431,7 @@ func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input return crossChainAmountOptions } -func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network, - balanceMap map[string]*big.Int) map[uint64][]amountOption { +func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network) map[uint64][]amountOption { // All we do in this block we're free to do, because of the validateInputData function which checks if the locked amount // was properly set and if there is something unexpected it will return an error and we will not reach this point finalCrossChainAmountOptions := make(map[uint64][]amountOption) // represents all possible amounts that can be sent from the "from" chain @@ -471,7 +482,7 @@ func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInpu // was properly set and if there is something unexpected it will return an error and we will not reach this point amountToSplitAccrossChains := new(big.Int).Set(amountToSend) - crossChainAmountOptions := r.getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input, amountToSend, selectedFromChain, selectedFromChains, balanceMap) + crossChainAmountOptions := r.getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input, amountToSend, selectedFromChain, selectedFromChains) // sum up all the allocated amounts accorss all chains allocatedAmount := big.NewInt(0) @@ -494,10 +505,9 @@ func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInpu return finalCrossChainAmountOptions } -func (r *Router) findOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network, - balanceMap map[string]*big.Int) (map[uint64][]amountOption, error) { +func (r *Router) findOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network) (map[uint64][]amountOption, error) { - crossChainAmountOptions := r.getCrossChainsOptionsForSendingAmount(input, selectedFromChains, balanceMap) + crossChainAmountOptions := r.getCrossChainsOptionsForSendingAmount(input, selectedFromChains) // filter out duplicates values for the same chain for chainID, amountOptions := range crossChainAmountOptions { @@ -540,14 +550,14 @@ func (r *Router) getSelectedChains(input *requests.RouteInputParams) (selectedFr } func (r *Router) resolveCandidates(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network, - selectedToChains []*params.Network, balanceMap map[string]*big.Int) (candidates routes.Route, processorErrors []*ProcessorError, err error) { + selectedToChains []*params.Network) (candidates routes.Route, processorErrors []*ProcessorError, err error) { var ( testsMode = input.TestsMode && input.TestParams != nil group = async.NewAtomicGroup(ctx) mu sync.Mutex ) - crossChainAmountOptions, err := r.findOptionsForSendingAmount(input, selectedFromChains, balanceMap) + crossChainAmountOptions, err := r.findOptionsForSendingAmount(input, selectedFromChains) if err != nil { return nil, nil, errors.CreateErrorResponseFromError(err) } @@ -774,10 +784,13 @@ func (r *Router) resolveCandidates(ctx context.Context, input *requests.RouteInp return candidates, processorErrors, nil } -func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute routes.Route, balanceMap map[string]*big.Int) (hasPositiveBalance bool, err error) { - balanceMapCopy := walletCommon.CopyMapGeneric(balanceMap, func(v interface{}) interface{} { - return new(big.Int).Set(v.(*big.Int)) - }).(map[string]*big.Int) +func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute routes.Route) (hasPositiveBalance bool, err error) { + // make a copy of the active balance map + balanceMapCopy := make(map[string]*big.Int) + r.activeBalanceMap.Range(func(k, v interface{}) bool { + balanceMapCopy[k.(string)] = new(big.Int).Set(v.(*big.Int)) + return true + }) if balanceMapCopy == nil { return false, ErrCannotCheckBalance } @@ -830,7 +843,7 @@ func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute rou return hasPositiveBalance, nil } -func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputParams, candidates routes.Route, balanceMap map[string]*big.Int) (suggestedRoutes *SuggestedRoutes, err error) { +func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputParams, candidates routes.Route) (suggestedRoutes *SuggestedRoutes, err error) { var prices map[string]float64 if input.TestsMode { prices = input.TestParams.TokenPrices @@ -866,7 +879,7 @@ func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputPa for len(allRoutes) > 0 { bestRoute = routes.FindBestRoute(allRoutes, tokenPrice, nativeTokenPrice) var hasPositiveBalance bool - hasPositiveBalance, err = r.checkBalancesForTheBestRoute(ctx, bestRoute, balanceMap) + hasPositiveBalance, err = r.checkBalancesForTheBestRoute(ctx, bestRoute) if err != nil { // If it's about transfer or bridge and there is more routes, but on the best (cheapest) one there is not enugh balance diff --git a/services/wallet/router/router_test.go b/services/wallet/router/router_test.go index 0de7e1aa434..8990539a65e 100644 --- a/services/wallet/router/router_test.go +++ b/services/wallet/router/router_test.go @@ -268,7 +268,8 @@ func TestAmountOptions(t *testing.T) { selectedFromChains, _, err := router.getSelectedChains(tt.input) assert.NoError(t, err) - amountOptions, err := router.findOptionsForSendingAmount(tt.input, selectedFromChains, tt.input.TestParams.BalanceMap) + router.SetTestBalanceMap(tt.input.TestParams.BalanceMap) + amountOptions, err := router.findOptionsForSendingAmount(tt.input, selectedFromChains) assert.NoError(t, err) assert.Equal(t, len(tt.expectedAmountOptions), len(amountOptions)) diff --git a/services/wallet/router/router_updates.go b/services/wallet/router/router_updates.go index 207c583425d..f916c041cc3 100644 --- a/services/wallet/router/router_updates.go +++ b/services/wallet/router/router_updates.go @@ -102,7 +102,9 @@ func (r *Router) subscribeForUdates(chainID uint64) error { } } - sendRouterResult(uuid, r.activeRoutes, nil) + _, err = r.checkBalancesForTheBestRoute(ctx, r.activeRoutes.Best) + + sendRouterResult(uuid, r.activeRoutes, err) } r.activeRoutesMutex.Unlock() }