diff --git a/mocks/utils/fetcher_helper.go b/mocks/utils/fetcher_helper.go index 60745d45..6269c422 100644 --- a/mocks/utils/fetcher_helper.go +++ b/mocks/utils/fetcher_helper.go @@ -16,13 +16,13 @@ type FetcherHelper struct { mock.Mock } -// AccountBalanceRetry provides a mock function with given fields: ctx, network, account, block -func (_m *FetcherHelper) AccountBalanceRetry(ctx context.Context, network *types.NetworkIdentifier, account *types.AccountIdentifier, block *types.PartialBlockIdentifier) (*types.BlockIdentifier, []*types.Amount, []*types.Coin, map[string]interface{}, *fetcher.Error) { - ret := _m.Called(ctx, network, account, block) +// AccountBalanceRetry provides a mock function with given fields: ctx, network, account, block, currencies +func (_m *FetcherHelper) AccountBalanceRetry(ctx context.Context, network *types.NetworkIdentifier, account *types.AccountIdentifier, block *types.PartialBlockIdentifier, currencies []*types.Currency) (*types.BlockIdentifier, []*types.Amount, map[string]interface{}, *fetcher.Error) { + ret := _m.Called(ctx, network, account, block, currencies) var r0 *types.BlockIdentifier - if rf, ok := ret.Get(0).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier) *types.BlockIdentifier); ok { - r0 = rf(ctx, network, account, block) + if rf, ok := ret.Get(0).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier, []*types.Currency) *types.BlockIdentifier); ok { + r0 = rf(ctx, network, account, block, currencies) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*types.BlockIdentifier) @@ -30,42 +30,33 @@ func (_m *FetcherHelper) AccountBalanceRetry(ctx context.Context, network *types } var r1 []*types.Amount - if rf, ok := ret.Get(1).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier) []*types.Amount); ok { - r1 = rf(ctx, network, account, block) + if rf, ok := ret.Get(1).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier, []*types.Currency) []*types.Amount); ok { + r1 = rf(ctx, network, account, block, currencies) } else { if ret.Get(1) != nil { r1 = ret.Get(1).([]*types.Amount) } } - var r2 []*types.Coin - if rf, ok := ret.Get(2).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier) []*types.Coin); ok { - r2 = rf(ctx, network, account, block) + var r2 map[string]interface{} + if rf, ok := ret.Get(2).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier, []*types.Currency) map[string]interface{}); ok { + r2 = rf(ctx, network, account, block, currencies) } else { if ret.Get(2) != nil { - r2 = ret.Get(2).([]*types.Coin) + r2 = ret.Get(2).(map[string]interface{}) } } - var r3 map[string]interface{} - if rf, ok := ret.Get(3).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier) map[string]interface{}); ok { - r3 = rf(ctx, network, account, block) + var r3 *fetcher.Error + if rf, ok := ret.Get(3).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier, []*types.Currency) *fetcher.Error); ok { + r3 = rf(ctx, network, account, block, currencies) } else { if ret.Get(3) != nil { - r3 = ret.Get(3).(map[string]interface{}) + r3 = ret.Get(3).(*fetcher.Error) } } - var r4 *fetcher.Error - if rf, ok := ret.Get(4).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, *types.PartialBlockIdentifier) *fetcher.Error); ok { - r4 = rf(ctx, network, account, block) - } else { - if ret.Get(4) != nil { - r4 = ret.Get(4).(*fetcher.Error) - } - } - - return r0, r1, r2, r3, r4 + return r0, r1, r2, r3 } // NetworkList provides a mock function with given fields: ctx, metadata diff --git a/utils/utils.go b/utils/utils.go index b1a2ff99..84898e39 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -183,7 +183,8 @@ type FetcherHelper interface { network *types.NetworkIdentifier, account *types.AccountIdentifier, block *types.PartialBlockIdentifier, - ) (*types.BlockIdentifier, []*types.Amount, []*types.Coin, map[string]interface{}, *fetcher.Error) + currencies []*types.Currency, + ) (*types.BlockIdentifier, []*types.Amount, map[string]interface{}, *fetcher.Error) } // CheckNetworkSupported checks if a Rosetta implementation supports a given @@ -312,7 +313,7 @@ func CurrencyBalance( account *types.AccountIdentifier, currency *types.Currency, index int64, -) (*types.Amount, *types.BlockIdentifier, []*types.Coin, error) { +) (*types.Amount, *types.BlockIdentifier, error) { var lookupBlock *types.PartialBlockIdentifier if index >= 0 { lookupBlock = &types.PartialBlockIdentifier{ @@ -320,14 +321,15 @@ func CurrencyBalance( } } - liveBlock, liveBalances, liveCoins, _, fetchErr := helper.AccountBalanceRetry( + liveBlock, liveBalances, _, fetchErr := helper.AccountBalanceRetry( ctx, network, account, lookupBlock, + []*types.Currency{currency}, ) if fetchErr != nil { - return nil, nil, nil, fetchErr.Err + return nil, nil, fetchErr.Err } liveAmount, err := types.ExtractAmount(liveBalances, currency) @@ -339,10 +341,10 @@ func CurrencyBalance( types.PrettyPrintStruct(account), ) - return nil, nil, nil, formattedError + return nil, nil, formattedError } - return liveAmount, liveBlock, liveCoins, nil + return liveAmount, liveBlock, nil } // AccountBalanceRequest defines the required information @@ -372,7 +374,7 @@ func GetAccountBalances( ) ([]*AccountBalance, error) { var accountBalances []*AccountBalance for _, balanceRequest := range balanceRequests { - amount, block, coins, err := CurrencyBalance( + amount, block, err := CurrencyBalance( ctx, balanceRequest.Network, fetcher, @@ -388,7 +390,6 @@ func GetAccountBalances( accountBalance := &AccountBalance{ Account: balanceRequest.Account, Amount: amount, - Coins: coins, Block: block, } diff --git a/utils/utils_test.go b/utils/utils_test.go index 43391c77..54ec5f6d 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -270,23 +270,6 @@ var ( Currency: currency, } - accountCoins = []*types.Coin{ - { - CoinIdentifier: &types.CoinIdentifier{Identifier: "coin1"}, - Amount: &types.Amount{ - Value: "30", - Currency: currency, - }, - }, - { - CoinIdentifier: &types.CoinIdentifier{Identifier: "coin2"}, - Amount: &types.Amount{ - Value: "30", - Currency: currency, - }, - }, - } - accountBalance = &types.AccountIdentifier{ Address: "test2", } @@ -305,7 +288,6 @@ var ( accBalanceResp1 = &AccountBalance{ Account: accountCoin, Amount: amountCoins, - Coins: accountCoins, Block: blockIdentifier, } @@ -333,10 +315,10 @@ func TestGetAccountBalances(t *testing.T) { network, accountCoin, (*types.PartialBlockIdentifier)(nil), + []*types.Currency{currency}, ).Return( blockIdentifier, []*types.Amount{amountCoins}, - accountCoins, nil, nil, ).Once() @@ -347,12 +329,12 @@ func TestGetAccountBalances(t *testing.T) { network, accountBalance, (*types.PartialBlockIdentifier)(nil), + []*types.Currency{currency}, ).Return( blockIdentifier, []*types.Amount{amountBalance}, nil, nil, - nil, ).Once() accBalances, err := GetAccountBalances( @@ -372,11 +354,11 @@ func TestGetAccountBalances(t *testing.T) { network, accountBalance, (*types.PartialBlockIdentifier)(nil), + []*types.Currency{currency}, ).Return( nil, nil, nil, - nil, &fetcher.Error{ Err: fmt.Errorf("invalid account balance"), },