Skip to content

Commit

Permalink
fixes context override
Browse files Browse the repository at this point in the history
A customer reported flags being ignored even when
they were being provided via CLI arguments.

The behavior of overriding current context via
CLI arguments was inconsistent across the codebase.
This centralizes the logic so all the various commands
use the same logic.
  • Loading branch information
vroldanbet committed Sep 18, 2024
1 parent c03fcf6 commit 93e0023
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 79 deletions.
114 changes: 98 additions & 16 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
grpc "google.golang.org/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

zgrpcutil "github.com/authzed/zed/internal/grpcutil"
Expand All @@ -28,20 +28,15 @@ type Client interface {
}

// NewClient defines an (overridable) means of creating a new client.
var NewClient = newGRPCClient
var NewClient = newClientForCurrentContext
var NewClientForContext = newClientForContext

func newGRPCClient(cmd *cobra.Command) (Client, error) {
func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
configStore, secretStore := DefaultStorage()
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
Expand All @@ -56,28 +51,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) {
return client, err
}

func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
currentToken, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}

token, err := GetTokenWithCLIOverride(cmd, currentToken)
if err != nil {
return nil, err
}

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}

return authzed.NewClient(token.Endpoint, dialOpts...)
}

// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
token, err := storage.CurrentToken(
configStore,
secretStore,
)
if err != nil {
return storage.Token{}, err
}

return GetTokenWithCLIOverride(cmd, token)
}

// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
// flags
func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
overrideToken, err := tokenFromCli(cmd)
if err != nil {
return storage.Token{}, err
}

result, err := storage.TokenWithOverride(
overrideToken,
token,
)
if err != nil {
return storage.Token{}, err
}

log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
return result, nil
}

func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
var certBytes []byte
var err error
if certPath != "" {
certBytes, err = os.ReadFile(certPath)
if err != nil {
return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
}
}

explicitInsecure := cmd.Flags().Changed("insecure")
var notSecure *bool
if explicitInsecure {
i := cobrautil.MustGetBool(cmd, "insecure")
notSecure = &i
}

explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
var notVerifyCA *bool
if explicitNoVerifyCA {
nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
notVerifyCA = &nvc
}
overrideToken := storage.Token{
APIToken: cobrautil.MustGetString(cmd, "token"),
Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
Insecure: notSecure,
NoVerifyCA: notVerifyCA,
CACert: certBytes,
}
return overrideToken, nil
}

// DefaultStorage returns the default configured config store and secret store.
func DefaultStorage() (storage.ConfigStore, storage.SecretStore) {
var home string
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
home = filepath.Join(xdg, "zed")
} else {
homedir, _ := homedir.Dir()
home = filepath.Join(homedir, ".zed")
hmdir, _ := homedir.Dir()
home = filepath.Join(hmdir, ".zed")
}
return &storage.JSONConfigStore{ConfigPath: home},
&storage.KeychainSecretStore{ConfigPath: home}
}

func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) {
func certOption(token storage.Token) (opt grpc.DialOption, err error) {
verification := grpcutil.VerifyCA
if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() {
if token.HasNoVerifyCA() {
verification = grpcutil.SkipVerifyCA
}

if certBytes, ok := token.Certificate(); ok {
return grpcutil.WithCustomCertBytes(verification, certBytes)
}

return grpcutil.WithSystemCerts(verification)
}

Expand All @@ -96,12 +178,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers),
}

if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) {
if token.IsInsecure() {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
} else {
opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
certOpt, err := certOption(cmd, token)
certOpt, err := certOption(token)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
}
Expand Down
63 changes: 63 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package client_test

Check failure on line 1 in internal/client/client_test.go

View workflow job for this annotation

GitHub Actions / Lint Go

Please run gofumpt. diff --git a/internal/client/client_test.go b/internal/client/client_test.go index bad0f4d..e1a0d92 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -59,5 +59,4 @@ func TestGetTokenWithCLIOverride(t *testing.T) { require.Equal(t, []byte("bye"), to.CACert) require.Equal(t, &bFalse, to.Insecure) require.Equal(t, &bFalse, to.NoVerifyCA) - }

import (
"os"
"testing"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/storage"
zedtesting "github.com/authzed/zed/internal/testing"

"github.com/stretchr/testify/require"
)

func TestGetTokenWithCLIOverride(t *testing.T) {
testCert, err := os.CreateTemp("", "")
require.NoError(t, err)
_, err = testCert.Write([]byte("hi"))
require.NoError(t, err)
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true},
)

bTrue := true
bFalse := false

// cli args take precedence when defined
to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t1", to.APIToken)
require.Equal(t, "e1", to.Endpoint)
require.Equal(t, []byte("hi"), to.CACert)
require.Equal(t, &bTrue, to.Insecure)
require.Equal(t, &bTrue, to.NoVerifyCA)

// storage token takes precedence when defined
cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false},
)
to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{
APIToken: "t2",
Endpoint: "e2",
CACert: []byte("bye"),
Insecure: &bFalse,
NoVerifyCA: &bFalse,
})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t2", to.APIToken)
require.Equal(t, "e2", to.Endpoint)
require.Equal(t, []byte("bye"), to.CACert)
require.Equal(t, &bFalse, to.Insecure)
require.Equal(t, &bFalse, to.NoVerifyCA)

}
22 changes: 3 additions & 19 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
Expand All @@ -23,7 +22,6 @@ import (
"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
Expand Down Expand Up @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{
RunE: schemaCopyCmdFunc,
}

// TODO(jschorr): support this in the client package
func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
token, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := client.DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}
return authzed.NewClient(token.Endpoint, dialOpts...)
}

func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
_, secretStore := client.DefaultStorage()
srcClient, err := clientForContext(cmd, args[0], secretStore)
srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
if err != nil {
return err
}
destClient, err := clientForContext(cmd, args[1], secretStore)

destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
if err != nil {
return err
}
Expand Down
24 changes: 3 additions & 21 deletions internal/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"github.com/gookit/color"
"github.com/jzelinskie/cobrautil/v2"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func versionCmdFunc(cmd *cobra.Command, _ []string) error {
Expand All @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {

includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
hasContext := false
configStore, secretStore := client.DefaultStorage()
if includeRemoteVersion {
_, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
configStore, secretStore := client.DefaultStorage()
_, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
hasContext = err == nil
}

Expand All @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))

if hasContext && includeRemoteVersion {
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
if err != nil {
return err
}
log.Trace().Interface("token", token).Send()

client, err := client.NewClient(cmd)
if err != nil {
return err
Expand All @@ -65,7 +47,7 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
// the client being unable to connect, etc. We just treat all such cases as an unknown
// version.
var headerMD metadata.MD
_, _ = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD))
_, err = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD))
version := headerMD.Get(string(responsemeta.ServerVersion))

blue := color.FgLightBlue.Render
Expand Down
Loading

0 comments on commit 93e0023

Please sign in to comment.