Skip to content

Commit

Permalink
feat_: check for balances after each fees update
Browse files Browse the repository at this point in the history
  • Loading branch information
saledjenic committed Sep 4, 2024
1 parent 6167508 commit 1242ab7
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 40 deletions.
89 changes: 51 additions & 38 deletions services/wallet/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ type Router struct {
pathProcessors map[string]pathprocessor.PathProcessor
scheduler *async.Scheduler

activeBalanceMap sync.Map // map[string]*big.Int

activeRoutesMutex sync.Mutex
activeRoutes *SuggestedRoutes

Expand Down Expand Up @@ -119,6 +121,12 @@ func (r *Router) GetPathProcessors() map[string]pathprocessor.PathProcessor {
return r.pathProcessors
}

func (r *Router) SetTestBalanceMap(balanceMap map[string]*big.Int) {
for k, v := range balanceMap {
r.activeBalanceMap.Store(k, v)
}
}

func newSuggestedRoutes(
uuid string,
amountIn *big.Int,
Expand Down Expand Up @@ -236,31 +244,29 @@ func (r *Router) SuggestedRoutes(ctx context.Context, input *requests.RouteInput
return nil, errors.CreateErrorResponseFromError(err)
}

balanceMap, err := r.getBalanceMapForTokenOnChains(ctx, input, selectedFromChains)
err = r.prepareBalanceMapForTokenOnChains(ctx, input, selectedFromChains)
// return only if there are no balances, otherwise try to resolve the candidates for chains we know the balances for
if len(balanceMap) == 0 {
noBalanceOnAnyChain := true
r.activeBalanceMap.Range(func(key, value interface{}) bool {
if value.(*big.Int).Cmp(pathprocessor.ZeroBigIntValue) > 0 {
noBalanceOnAnyChain = false
return false
}
return true
})
if noBalanceOnAnyChain {
if err != nil {
return nil, errors.CreateErrorResponseFromError(err)
}
} else {
noBalanceOnAnyChain := true
for _, value := range balanceMap {
if value.Cmp(pathprocessor.ZeroBigIntValue) > 0 {
noBalanceOnAnyChain = false
break
}
}
if noBalanceOnAnyChain {
return nil, ErrNoPositiveBalance
}
return nil, ErrNoPositiveBalance
}

candidates, processorErrors, err := r.resolveCandidates(ctx, input, selectedFromChains, selectedToChains, balanceMap)
candidates, processorErrors, err := r.resolveCandidates(ctx, input, selectedFromChains, selectedToChains)
if err != nil {
return nil, errors.CreateErrorResponseFromError(err)
}

suggestedRoutes, err = r.resolveRoutes(ctx, input, candidates, balanceMap)
suggestedRoutes, err = r.resolveRoutes(ctx, input, candidates)

if err == nil && (suggestedRoutes == nil || len(suggestedRoutes.Best) == 0) {
// No best route found, but no error given.
Expand Down Expand Up @@ -300,15 +306,16 @@ func (r *Router) SuggestedRoutes(ctx context.Context, input *requests.RouteInput
return suggestedRoutes, mapError(err)
}

// getBalanceMapForTokenOnChains returns the balance map for passed address, where the key is in format "chainID-tokenSymbol" and
// prepareBalanceMapForTokenOnChains prepares the balance map for passed address, where the key is in format "chainID-tokenSymbol" and
// value is the balance of the token. Native token (EHT) is always added to the balance map.
func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network) (balanceMap map[string]*big.Int, err error) {
func (r *Router) prepareBalanceMapForTokenOnChains(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network) (err error) {
if input.TestsMode {
return input.TestParams.BalanceMap, nil
for k, v := range input.TestParams.BalanceMap {
r.activeBalanceMap.Store(k, v)
}
return nil
}

balanceMap = make(map[string]*big.Int)

chainError := func(chainId uint64, token string, intErr error) {
if err == nil {
err = fmt.Errorf("chain %d, token %s: %w", chainId, token, intErr)
Expand Down Expand Up @@ -348,7 +355,7 @@ func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *reque
}
// add only if balance is not nil
if tokenBalance != nil {
balanceMap[makeBalanceKey(chain.ChainID, token.Symbol)] = tokenBalance
r.activeBalanceMap.Store(makeBalanceKey(chain.ChainID, token.Symbol), tokenBalance)
}

if token.IsNative() {
Expand All @@ -362,7 +369,7 @@ func (r *Router) getBalanceMapForTokenOnChains(ctx context.Context, input *reque
}
// add only if balance is not nil
if nativeBalance != nil {
balanceMap[makeBalanceKey(chain.ChainID, nativeToken.Symbol)] = nativeBalance
r.activeBalanceMap.Store(makeBalanceKey(chain.ChainID, nativeToken.Symbol), nativeBalance)
}
}

Expand All @@ -383,7 +390,7 @@ func (r *Router) getSelectedUnlockedChains(input *requests.RouteInputParams, pro
}

func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input *requests.RouteInputParams, amountToSplit *big.Int, processingChain *params.Network,
selectedFromChains []*params.Network, balanceMap map[string]*big.Int) map[uint64][]amountOption {
selectedFromChains []*params.Network) map[uint64][]amountOption {
selectedButNotLockedChains := r.getSelectedUnlockedChains(input, processingChain, selectedFromChains)

crossChainAmountOptions := make(map[uint64][]amountOption)
Expand All @@ -393,7 +400,12 @@ func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input
tokenBalance *big.Int
)

if tokenBalance, ok = balanceMap[makeBalanceKey(chain.ChainID, input.TokenID)]; !ok {
value, ok := r.activeBalanceMap.Load(makeBalanceKey(chain.ChainID, input.TokenID))
if !ok {
continue
}
tokenBalance, ok = value.(*big.Int)
if !ok {
continue
}

Expand All @@ -419,8 +431,7 @@ func (r *Router) getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input
return crossChainAmountOptions
}

func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network,
balanceMap map[string]*big.Int) map[uint64][]amountOption {
func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network) map[uint64][]amountOption {
// All we do in this block we're free to do, because of the validateInputData function which checks if the locked amount
// was properly set and if there is something unexpected it will return an error and we will not reach this point
finalCrossChainAmountOptions := make(map[uint64][]amountOption) // represents all possible amounts that can be sent from the "from" chain
Expand Down Expand Up @@ -471,7 +482,7 @@ func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInpu
// was properly set and if there is something unexpected it will return an error and we will not reach this point
amountToSplitAccrossChains := new(big.Int).Set(amountToSend)

crossChainAmountOptions := r.getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input, amountToSend, selectedFromChain, selectedFromChains, balanceMap)
crossChainAmountOptions := r.getOptionsForAmoutToSplitAccrossChainsForProcessingChain(input, amountToSend, selectedFromChain, selectedFromChains)

// sum up all the allocated amounts accorss all chains
allocatedAmount := big.NewInt(0)
Expand All @@ -494,10 +505,9 @@ func (r *Router) getCrossChainsOptionsForSendingAmount(input *requests.RouteInpu
return finalCrossChainAmountOptions
}

func (r *Router) findOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network,
balanceMap map[string]*big.Int) (map[uint64][]amountOption, error) {
func (r *Router) findOptionsForSendingAmount(input *requests.RouteInputParams, selectedFromChains []*params.Network) (map[uint64][]amountOption, error) {

crossChainAmountOptions := r.getCrossChainsOptionsForSendingAmount(input, selectedFromChains, balanceMap)
crossChainAmountOptions := r.getCrossChainsOptionsForSendingAmount(input, selectedFromChains)

// filter out duplicates values for the same chain
for chainID, amountOptions := range crossChainAmountOptions {
Expand Down Expand Up @@ -540,14 +550,14 @@ func (r *Router) getSelectedChains(input *requests.RouteInputParams) (selectedFr
}

func (r *Router) resolveCandidates(ctx context.Context, input *requests.RouteInputParams, selectedFromChains []*params.Network,
selectedToChains []*params.Network, balanceMap map[string]*big.Int) (candidates routes.Route, processorErrors []*ProcessorError, err error) {
selectedToChains []*params.Network) (candidates routes.Route, processorErrors []*ProcessorError, err error) {
var (
testsMode = input.TestsMode && input.TestParams != nil
group = async.NewAtomicGroup(ctx)
mu sync.Mutex
)

crossChainAmountOptions, err := r.findOptionsForSendingAmount(input, selectedFromChains, balanceMap)
crossChainAmountOptions, err := r.findOptionsForSendingAmount(input, selectedFromChains)
if err != nil {
return nil, nil, errors.CreateErrorResponseFromError(err)
}
Expand Down Expand Up @@ -774,10 +784,13 @@ func (r *Router) resolveCandidates(ctx context.Context, input *requests.RouteInp
return candidates, processorErrors, nil
}

func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute routes.Route, balanceMap map[string]*big.Int) (hasPositiveBalance bool, err error) {
balanceMapCopy := walletCommon.CopyMapGeneric(balanceMap, func(v interface{}) interface{} {
return new(big.Int).Set(v.(*big.Int))
}).(map[string]*big.Int)
func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute routes.Route) (hasPositiveBalance bool, err error) {
// make a copy of the active balance map
balanceMapCopy := make(map[string]*big.Int)
r.activeBalanceMap.Range(func(k, v interface{}) bool {
balanceMapCopy[k.(string)] = new(big.Int).Set(v.(*big.Int))
return true
})
if balanceMapCopy == nil {
return false, ErrCannotCheckBalance
}
Expand Down Expand Up @@ -830,7 +843,7 @@ func (r *Router) checkBalancesForTheBestRoute(ctx context.Context, bestRoute rou
return hasPositiveBalance, nil
}

func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputParams, candidates routes.Route, balanceMap map[string]*big.Int) (suggestedRoutes *SuggestedRoutes, err error) {
func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputParams, candidates routes.Route) (suggestedRoutes *SuggestedRoutes, err error) {
var prices map[string]float64
if input.TestsMode {
prices = input.TestParams.TokenPrices
Expand Down Expand Up @@ -866,7 +879,7 @@ func (r *Router) resolveRoutes(ctx context.Context, input *requests.RouteInputPa
for len(allRoutes) > 0 {
bestRoute = routes.FindBestRoute(allRoutes, tokenPrice, nativeTokenPrice)
var hasPositiveBalance bool
hasPositiveBalance, err = r.checkBalancesForTheBestRoute(ctx, bestRoute, balanceMap)
hasPositiveBalance, err = r.checkBalancesForTheBestRoute(ctx, bestRoute)

if err != nil {
// If it's about transfer or bridge and there is more routes, but on the best (cheapest) one there is not enugh balance
Expand Down
3 changes: 2 additions & 1 deletion services/wallet/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ func TestAmountOptions(t *testing.T) {
selectedFromChains, _, err := router.getSelectedChains(tt.input)
assert.NoError(t, err)

amountOptions, err := router.findOptionsForSendingAmount(tt.input, selectedFromChains, tt.input.TestParams.BalanceMap)
router.SetTestBalanceMap(tt.input.TestParams.BalanceMap)
amountOptions, err := router.findOptionsForSendingAmount(tt.input, selectedFromChains)
assert.NoError(t, err)

assert.Equal(t, len(tt.expectedAmountOptions), len(amountOptions))
Expand Down
4 changes: 3 additions & 1 deletion services/wallet/router/router_updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ func (r *Router) subscribeForUdates(chainID uint64) error {
}
}

sendRouterResult(uuid, r.activeRoutes, nil)
_, err = r.checkBalancesForTheBestRoute(ctx, r.activeRoutes.Best)

sendRouterResult(uuid, r.activeRoutes, err)
}
r.activeRoutesMutex.Unlock()
}
Expand Down

0 comments on commit 1242ab7

Please sign in to comment.