Skip to content

Commit

Permalink
refactor(auth): decouple auth Keeper API from Msg and Query server im…
Browse files Browse the repository at this point in the history
…plementation (#15985)

Co-authored-by: unknown unknown <unknown@unknown>
  • Loading branch information
testinginprod and unknown unknown authored May 1, 2023
1 parent 7c59ead commit 3ada275
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Ref: https://keepachangelog.com/en/1.0.0/

### API Breaking Changes

* (x/auth) [#15985](https://github.com/cosmos/cosmos-sdk/pull/15985) The `AccountKeeper` does not expose the `QueryServer` and `MsgServer` APIs anymore.
* (x/authz) [#15962](https://github.com/cosmos/cosmos-sdk/issues/15962) `NewKeeper` now takes a `KVStoreService` instead of a `StoreKey`, methods in the `Keeper` now take a `context.Context` instead of a `sdk.Context`. The `Authorization` interface's `Accept` method now takes a `context.Context` instead of a `sdk.Context`.
* (x/distribution) [#15948](https://github.com/cosmos/cosmos-sdk/issues/15948) `NewKeeper` now takes a `KVStoreService` instead of a `StoreKey` and methods in the `Keeper` now take a `context.Context` instead of a `sdk.Context`. Keeper methods also now return an `error`.
* (x/bank) [#15891](https://github.com/cosmos/cosmos-sdk/issues/15891) `NewKeeper` now takes a `KVStoreService` instead of a `StoreKey` and methods in the `Keeper` now take a `context.Context` instead of a `sdk.Context`. Also `FundAccount` and `FundModuleAccount` from the `testutil` package accept a `context.Context` instead of a `sdk.Context`, and it's position was moved to the first place.
Expand Down
4 changes: 2 additions & 2 deletions x/auth/keeper/deterministic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (suite *DeterministicTestSuite) SetupTest() {
)

queryHelper := baseapp.NewQueryServerTestHelper(suite.ctx, suite.encCfg.InterfaceRegistry)
types.RegisterQueryServer(queryHelper, suite.accountKeeper)
types.RegisterQueryServer(queryHelper, keeper.NewQueryServer(suite.accountKeeper))
suite.queryClient = types.NewQueryClient(queryHelper)

suite.key = key
Expand Down Expand Up @@ -239,7 +239,7 @@ func (suite *DeterministicTestSuite) TestGRPCQueryAccountInfo() {

func (suite *DeterministicTestSuite) createAndReturnQueryClient(ak keeper.AccountKeeper) types.QueryClient {
queryHelper := baseapp.NewQueryServerTestHelper(suite.ctx, suite.encCfg.InterfaceRegistry)
types.RegisterQueryServer(queryHelper, ak)
types.RegisterQueryServer(queryHelper, keeper.NewQueryServer(ak))
return types.NewQueryClient(queryHelper)
}

Expand Down
60 changes: 33 additions & 27 deletions x/auth/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ import (
"github.com/cosmos/cosmos-sdk/x/auth/types"
)

var _ types.QueryServer = AccountKeeper{}
var _ types.QueryServer = queryServer{}

func (ak AccountKeeper) AccountAddressByID(c context.Context, req *types.QueryAccountAddressByIDRequest) (*types.QueryAccountAddressByIDResponse, error) {
func NewQueryServer(k AccountKeeper) types.QueryServer {
return queryServer{k: k}
}

type queryServer struct{ k AccountKeeper }

func (s queryServer) AccountAddressByID(c context.Context, req *types.QueryAccountAddressByIDRequest) (*types.QueryAccountAddressByIDResponse, error) {
if req == nil {
return nil, status.Errorf(codes.InvalidArgument, "empty request")
}
Expand All @@ -33,25 +39,25 @@ func (ak AccountKeeper) AccountAddressByID(c context.Context, req *types.QueryAc
accID := req.AccountId

ctx := sdk.UnwrapSDKContext(c)
address := ak.GetAccountAddressByID(ctx, accID)
address := s.k.GetAccountAddressByID(ctx, accID)
if len(address) == 0 {
return nil, status.Errorf(codes.NotFound, "account address not found with account number %d", req.Id)
}

return &types.QueryAccountAddressByIDResponse{AccountAddress: address}, nil
}

func (ak AccountKeeper) Accounts(ctx context.Context, req *types.QueryAccountsRequest) (*types.QueryAccountsResponse, error) {
func (s queryServer) Accounts(ctx context.Context, req *types.QueryAccountsRequest) (*types.QueryAccountsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}

store := ak.storeService.OpenKVStore(ctx)
store := s.k.storeService.OpenKVStore(ctx)
accountsStore := prefix.NewStore(runtime.KVStoreAdapter(store), types.AddressStoreKeyPrefix)

var accounts []*codectypes.Any
pageRes, err := query.Paginate(accountsStore, req.Pagination, func(key, value []byte) error {
account := ak.decodeAccount(value)
account := s.k.decodeAccount(value)
any, err := codectypes.NewAnyWithValue(account)
if err != nil {
return err
Expand All @@ -68,7 +74,7 @@ func (ak AccountKeeper) Accounts(ctx context.Context, req *types.QueryAccountsRe
}

// Account returns account details based on address
func (ak AccountKeeper) Account(c context.Context, req *types.QueryAccountRequest) (*types.QueryAccountResponse, error) {
func (s queryServer) Account(c context.Context, req *types.QueryAccountRequest) (*types.QueryAccountResponse, error) {
if req == nil {
return nil, status.Errorf(codes.InvalidArgument, "empty request")
}
Expand All @@ -78,11 +84,11 @@ func (ak AccountKeeper) Account(c context.Context, req *types.QueryAccountReques
}

ctx := sdk.UnwrapSDKContext(c)
addr, err := ak.StringToBytes(req.Address)
addr, err := s.k.StringToBytes(req.Address)
if err != nil {
return nil, err
}
account := ak.GetAccount(ctx, addr)
account := s.k.GetAccount(ctx, addr)
if account == nil {
return nil, status.Errorf(codes.NotFound, "account %s not found", req.Address)
}
Expand All @@ -96,35 +102,35 @@ func (ak AccountKeeper) Account(c context.Context, req *types.QueryAccountReques
}

// Params returns parameters of auth module
func (ak AccountKeeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) {
func (s queryServer) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}
ctx := sdk.UnwrapSDKContext(c)
params := ak.GetParams(ctx)
params := s.k.GetParams(ctx)

return &types.QueryParamsResponse{Params: params}, nil
}

// ModuleAccounts returns all the existing Module Accounts
func (ak AccountKeeper) ModuleAccounts(c context.Context, req *types.QueryModuleAccountsRequest) (*types.QueryModuleAccountsResponse, error) {
func (s queryServer) ModuleAccounts(c context.Context, req *types.QueryModuleAccountsRequest) (*types.QueryModuleAccountsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}

ctx := sdk.UnwrapSDKContext(c)

// For deterministic output, sort the permAddrs by module name.
sortedPermAddrs := make([]string, 0, len(ak.permAddrs))
for moduleName := range ak.permAddrs {
sortedPermAddrs := make([]string, 0, len(s.k.permAddrs))
for moduleName := range s.k.permAddrs {
sortedPermAddrs = append(sortedPermAddrs, moduleName)
}
sort.Strings(sortedPermAddrs)

modAccounts := make([]*codectypes.Any, 0, len(ak.permAddrs))
modAccounts := make([]*codectypes.Any, 0, len(s.k.permAddrs))

for _, moduleName := range sortedPermAddrs {
account := ak.GetModuleAccount(ctx, moduleName)
account := s.k.GetModuleAccount(ctx, moduleName)
if account == nil {
return nil, status.Errorf(codes.NotFound, "account %s not found", moduleName)
}
Expand All @@ -139,7 +145,7 @@ func (ak AccountKeeper) ModuleAccounts(c context.Context, req *types.QueryModule
}

// ModuleAccountByName returns module account by module name
func (ak AccountKeeper) ModuleAccountByName(c context.Context, req *types.QueryModuleAccountByNameRequest) (*types.QueryModuleAccountByNameResponse, error) {
func (s queryServer) ModuleAccountByName(c context.Context, req *types.QueryModuleAccountByNameRequest) (*types.QueryModuleAccountByNameResponse, error) {
if req == nil {
return nil, status.Errorf(codes.InvalidArgument, "empty request")
}
Expand All @@ -151,7 +157,7 @@ func (ak AccountKeeper) ModuleAccountByName(c context.Context, req *types.QueryM
ctx := sdk.UnwrapSDKContext(c)
moduleName := req.Name

account := ak.GetModuleAccount(ctx, moduleName)
account := s.k.GetModuleAccount(ctx, moduleName)
if account == nil {
return nil, status.Errorf(codes.NotFound, "account %s not found", moduleName)
}
Expand All @@ -164,8 +170,8 @@ func (ak AccountKeeper) ModuleAccountByName(c context.Context, req *types.QueryM
}

// Bech32Prefix returns the keeper internally stored bech32 prefix.
func (ak AccountKeeper) Bech32Prefix(ctx context.Context, req *types.Bech32PrefixRequest) (*types.Bech32PrefixResponse, error) {
bech32Prefix, err := ak.getBech32Prefix()
func (s queryServer) Bech32Prefix(ctx context.Context, req *types.Bech32PrefixRequest) (*types.Bech32PrefixResponse, error) {
bech32Prefix, err := s.k.getBech32Prefix()
if err != nil {
return nil, err
}
Expand All @@ -175,7 +181,7 @@ func (ak AccountKeeper) Bech32Prefix(ctx context.Context, req *types.Bech32Prefi

// AddressBytesToString converts an address from bytes to string, using the
// keeper's bech32 prefix.
func (ak AccountKeeper) AddressBytesToString(ctx context.Context, req *types.AddressBytesToStringRequest) (*types.AddressBytesToStringResponse, error) {
func (s queryServer) AddressBytesToString(ctx context.Context, req *types.AddressBytesToStringRequest) (*types.AddressBytesToStringResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}
Expand All @@ -184,7 +190,7 @@ func (ak AccountKeeper) AddressBytesToString(ctx context.Context, req *types.Add
return nil, errors.New("empty address bytes is not allowed")
}

text, err := ak.BytesToString(req.AddressBytes)
text, err := s.k.BytesToString(req.AddressBytes)
if err != nil {
return nil, err
}
Expand All @@ -194,7 +200,7 @@ func (ak AccountKeeper) AddressBytesToString(ctx context.Context, req *types.Add

// AddressStringToBytes converts an address from string to bytes, using the
// keeper's bech32 prefix.
func (ak AccountKeeper) AddressStringToBytes(ctx context.Context, req *types.AddressStringToBytesRequest) (*types.AddressStringToBytesResponse, error) {
func (s queryServer) AddressStringToBytes(ctx context.Context, req *types.AddressStringToBytesRequest) (*types.AddressStringToBytesResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}
Expand All @@ -203,7 +209,7 @@ func (ak AccountKeeper) AddressStringToBytes(ctx context.Context, req *types.Add
return nil, errors.New("empty address string is not allowed")
}

bz, err := ak.StringToBytes(req.AddressString)
bz, err := s.k.StringToBytes(req.AddressString)
if err != nil {
return nil, err
}
Expand All @@ -212,7 +218,7 @@ func (ak AccountKeeper) AddressStringToBytes(ctx context.Context, req *types.Add
}

// AccountInfo implements the AccountInfo query.
func (ak AccountKeeper) AccountInfo(goCtx context.Context, req *types.QueryAccountInfoRequest) (*types.QueryAccountInfoResponse, error) {
func (s queryServer) AccountInfo(goCtx context.Context, req *types.QueryAccountInfoRequest) (*types.QueryAccountInfoResponse, error) {
if req == nil {
return nil, status.Errorf(codes.InvalidArgument, "empty request")
}
Expand All @@ -222,12 +228,12 @@ func (ak AccountKeeper) AccountInfo(goCtx context.Context, req *types.QueryAccou
}

ctx := sdk.UnwrapSDKContext(goCtx)
addr, err := ak.StringToBytes(req.Address)
addr, err := s.k.StringToBytes(req.Address)
if err != nil {
return nil, err
}

account := ak.GetAccount(ctx, addr)
account := s.k.GetAccount(ctx, addr)
if account == nil {
return nil, status.Errorf(codes.NotFound, "account %s not found", req.Address)
}
Expand Down
8 changes: 4 additions & 4 deletions x/auth/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ type AccountKeeper struct {
authority string

// State
ParamsState collections.Item[types.Params] // NOTE: name is this because it conflicts with the Params gRPC method impl
Params collections.Item[types.Params]
AccountNumber collections.Sequence
}

Expand Down Expand Up @@ -107,7 +107,7 @@ func NewAccountKeeper(
cdc: cdc,
permAddrs: permAddrs,
authority: authority,
ParamsState: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
Params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
AccountNumber: collections.NewSequence(sb, types.GlobalAccountNumberKey, "account_number"),
}
}
Expand Down Expand Up @@ -265,12 +265,12 @@ func (ak AccountKeeper) getBech32Prefix() (string, error) {
// SetParams sets the auth module's parameters.
// CONTRACT: This method performs no validation of the parameters.
func (ak AccountKeeper) SetParams(ctx context.Context, params types.Params) error {
return ak.ParamsState.Set(ctx, params)
return ak.Params.Set(ctx, params)
}

// GetParams gets the auth module's parameters.
func (ak AccountKeeper) GetParams(ctx context.Context) (params types.Params) {
params, err := ak.ParamsState.Get(ctx)
params, err := ak.Params.Get(ctx)
if err != nil && !errors.Is(err, collections.ErrNotFound) {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion x/auth/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (suite *KeeperTestSuite) SetupTest() {
)
suite.msgServer = keeper.NewMsgServerImpl(suite.accountKeeper)
queryHelper := baseapp.NewQueryServerTestHelper(suite.ctx, suite.encCfg.InterfaceRegistry)
types.RegisterQueryServer(queryHelper, suite.accountKeeper)
types.RegisterQueryServer(queryHelper, keeper.NewQueryServer(suite.accountKeeper))
suite.queryClient = types.NewQueryClient(queryHelper)
}

Expand Down
10 changes: 5 additions & 5 deletions x/auth/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ import (
var _ types.MsgServer = msgServer{}

type msgServer struct {
AccountKeeper
ak AccountKeeper
}

// NewMsgServerImpl returns an implementation of the x/auth MsgServer interface.
func NewMsgServerImpl(ak AccountKeeper) types.MsgServer {
return &msgServer{
AccountKeeper: ak,
ak: ak,
}
}

func (ms msgServer) UpdateParams(goCtx context.Context, msg *types.MsgUpdateParams) (*types.MsgUpdateParamsResponse, error) {
if ms.authority != msg.Authority {
return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.authority, msg.Authority)
if ms.ak.authority != msg.Authority {
return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", ms.ak.authority, msg.Authority)
}

if err := msg.Params.Validate(); err != nil {
return nil, err
}

ctx := sdk.UnwrapSDKContext(goCtx)
if err := ms.SetParams(ctx, msg.Params); err != nil {
if err := ms.ak.SetParams(ctx, msg.Params); err != nil {
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion x/auth/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (AppModule) Name() string {
// module-specific GRPC queries.
func (am AppModule) RegisterServices(cfg module.Configurator) {
types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.accountKeeper))
types.RegisterQueryServer(cfg.QueryServer(), am.accountKeeper)
types.RegisterQueryServer(cfg.QueryServer(), keeper.NewQueryServer(am.accountKeeper))

m := keeper.NewMigrator(am.accountKeeper, cfg.QueryServer(), am.legacySubspace)
if err := cfg.RegisterMigration(types.ModuleName, 1, m.Migrate1to2); err != nil {
Expand Down

0 comments on commit 3ada275

Please sign in to comment.