From 83747c5325cded7d220aa5bd556f83ff3034d16f Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 13:59:32 +0000 Subject: [PATCH] feat(client): overwrite client context instead of setting new one (backport #20356) (#20383) Co-authored-by: Shude Li Co-authored-by: Julien Robert --- CHANGELOG.md | 1 + client/cmd.go | 13 ++++++++----- client/cmd_test.go | 26 +++++++++++++++++++++++--- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fcc161c0a906..5c29d5c23d1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (debug) [#20328](https://github.com/cosmos/cosmos-sdk/pull/20328) Add consensus address for debug cmd. * (runtime) [#20264](https://github.com/cosmos/cosmos-sdk/pull/20264) Expose grpc query router via depinject. +* (client) [#20356](https://github.com/cosmos/cosmos-sdk/pull/20356) Overwrite client context when available in `SetCmdClientContext`. ### Bug Fixes diff --git a/client/cmd.go b/client/cmd.go index 947df1ceaac1..34b433ec5f3a 100644 --- a/client/cmd.go +++ b/client/cmd.go @@ -359,14 +359,17 @@ func GetClientContextFromCmd(cmd *cobra.Command) Context { // SetCmdClientContext sets a command's Context value to the provided argument. // If the context has not been set, set the given context as the default. func SetCmdClientContext(cmd *cobra.Command, clientCtx Context) error { - var cmdCtx context.Context - - if cmd.Context() == nil { + cmdCtx := cmd.Context() + if cmdCtx == nil { cmdCtx = context.Background() + } + + v := cmd.Context().Value(ClientContextKey) + if clientCtxPtr, ok := v.(*Context); ok { + *clientCtxPtr = clientCtx } else { - cmdCtx = cmd.Context() + cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) } - cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) return nil } diff --git a/client/cmd_test.go b/client/cmd_test.go index 81d5719ccfb6..559d31d39f40 100644 --- a/client/cmd_test.go +++ b/client/cmd_test.go @@ -79,11 +79,13 @@ func TestSetCmdClientContextHandler(t *testing.T) { name string expectedContext client.Context args []string + ctx context.Context }{ { "no flags set", initClientCtx, []string{}, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set", @@ -91,6 +93,7 @@ func TestSetCmdClientContextHandler(t *testing.T) { []string{ fmt.Sprintf("--%s=new-chain-id", flags.FlagChainID), }, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set with space", @@ -99,6 +102,25 @@ func TestSetCmdClientContextHandler(t *testing.T) { fmt.Sprintf("--%s", flags.FlagHome), "/tmp/dir", }, + context.Background(), + }, + { + "no context provided", + initClientCtx.WithHomeDir("/tmp/noctx"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/noctx", + }, + nil, + }, + { + "with invalid client value in the context", + initClientCtx.WithHomeDir("/tmp/invalid"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/invalid", + }, + context.WithValue(context.Background(), client.ClientContextKey, "invalid"), }, } @@ -106,13 +128,11 @@ func TestSetCmdClientContextHandler(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}) - cmd := newCmd() _ = testutil.ApplyMockIODiscardOutErr(cmd) cmd.SetArgs(tc.args) - require.NoError(t, cmd.ExecuteContext(ctx)) + require.NoError(t, cmd.ExecuteContext(tc.ctx)) clientCtx := client.GetClientContextFromCmd(cmd) require.Equal(t, tc.expectedContext, clientCtx)