Skip to content

Commit

Permalink
fixes context override
Browse files Browse the repository at this point in the history
A customer reported flags being ignored even when
they were being provided via CLI arguments.

The behavior of overriding current context via
CLI arguments was inconsistent across the codebase.
This centralizes the logic so all the various commands
use the same logic.
  • Loading branch information
vroldanbet committed Sep 18, 2024
1 parent c03fcf6 commit f5626a8
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 78 deletions.
116 changes: 100 additions & 16 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
grpc "google.golang.org/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

zgrpcutil "github.com/authzed/zed/internal/grpcutil"
Expand All @@ -28,20 +28,17 @@ type Client interface {
}

// NewClient defines an (overridable) means of creating a new client.
var NewClient = newGRPCClient
var (
NewClient = newClientForCurrentContext
NewClientForContext = newClientForContext
)

func newGRPCClient(cmd *cobra.Command) (Client, error) {
func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
configStore, secretStore := DefaultStorage()
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
Expand All @@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) {
return client, err
}

func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
currentToken, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}

token, err := GetTokenWithCLIOverride(cmd, currentToken)
if err != nil {
return nil, err
}

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}

return authzed.NewClient(token.Endpoint, dialOpts...)
}

// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
token, err := storage.CurrentToken(
configStore,
secretStore,
)
if err != nil {
return storage.Token{}, err
}

return GetTokenWithCLIOverride(cmd, token)
}

// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
// flags
func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
overrideToken, err := tokenFromCli(cmd)
if err != nil {
return storage.Token{}, err
}

result, err := storage.TokenWithOverride(
overrideToken,
token,
)
if err != nil {
return storage.Token{}, err
}

log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
return result, nil
}

func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
var certBytes []byte
var err error
if certPath != "" {
certBytes, err = os.ReadFile(certPath)
if err != nil {
return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
}
}

explicitInsecure := cmd.Flags().Changed("insecure")
var notSecure *bool
if explicitInsecure {
i := cobrautil.MustGetBool(cmd, "insecure")
notSecure = &i
}

explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
var notVerifyCA *bool
if explicitNoVerifyCA {
nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
notVerifyCA = &nvc
}
overrideToken := storage.Token{
APIToken: cobrautil.MustGetString(cmd, "token"),
Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
Insecure: notSecure,
NoVerifyCA: notVerifyCA,
CACert: certBytes,
}
return overrideToken, nil
}

// DefaultStorage returns the default configured config store and secret store.
func DefaultStorage() (storage.ConfigStore, storage.SecretStore) {
var home string
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
home = filepath.Join(xdg, "zed")
} else {
homedir, _ := homedir.Dir()
home = filepath.Join(homedir, ".zed")
hmdir, _ := homedir.Dir()
home = filepath.Join(hmdir, ".zed")
}
return &storage.JSONConfigStore{ConfigPath: home},
&storage.KeychainSecretStore{ConfigPath: home}
}

func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) {
func certOption(token storage.Token) (opt grpc.DialOption, err error) {
verification := grpcutil.VerifyCA
if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() {
if token.HasNoVerifyCA() {
verification = grpcutil.SkipVerifyCA
}

if certBytes, ok := token.Certificate(); ok {
return grpcutil.WithCustomCertBytes(verification, certBytes)
}

return grpcutil.WithSystemCerts(verification)
}

Expand All @@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers),
}

if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) {
if token.IsInsecure() {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
} else {
opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
certOpt, err := certOption(cmd, token)
certOpt, err := certOption(token)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
}
Expand Down
62 changes: 62 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client_test

import (
"os"
"testing"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/storage"
zedtesting "github.com/authzed/zed/internal/testing"

"github.com/stretchr/testify/require"
)

func TestGetTokenWithCLIOverride(t *testing.T) {
testCert, err := os.CreateTemp("", "")
require.NoError(t, err)
_, err = testCert.Write([]byte("hi"))
require.NoError(t, err)
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true},
)

bTrue := true
bFalse := false

// cli args take precedence when defined
to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t1", to.APIToken)
require.Equal(t, "e1", to.Endpoint)
require.Equal(t, []byte("hi"), to.CACert)
require.Equal(t, &bTrue, to.Insecure)
require.Equal(t, &bTrue, to.NoVerifyCA)

// storage token takes precedence when defined
cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false},
)
to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{
APIToken: "t2",
Endpoint: "e2",
CACert: []byte("bye"),
Insecure: &bFalse,
NoVerifyCA: &bFalse,
})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t2", to.APIToken)
require.Equal(t, "e2", to.Endpoint)
require.Equal(t, []byte("bye"), to.CACert)
require.Equal(t, &bFalse, to.Insecure)
require.Equal(t, &bFalse, to.NoVerifyCA)
}
22 changes: 3 additions & 19 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
Expand All @@ -23,7 +22,6 @@ import (
"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
Expand Down Expand Up @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{
RunE: schemaCopyCmdFunc,
}

// TODO(jschorr): support this in the client package
func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
token, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := client.DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}
return authzed.NewClient(token.Endpoint, dialOpts...)
}

func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
_, secretStore := client.DefaultStorage()
srcClient, err := clientForContext(cmd, args[0], secretStore)
srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
if err != nil {
return err
}
destClient, err := clientForContext(cmd, args[1], secretStore)

destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
if err != nil {
return err
}
Expand Down
22 changes: 2 additions & 20 deletions internal/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"github.com/gookit/color"
"github.com/jzelinskie/cobrautil/v2"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func versionCmdFunc(cmd *cobra.Command, _ []string) error {
Expand All @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {

includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
hasContext := false
configStore, secretStore := client.DefaultStorage()
if includeRemoteVersion {
_, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
configStore, secretStore := client.DefaultStorage()
_, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
hasContext = err == nil
}

Expand All @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))

if hasContext && includeRemoteVersion {
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
if err != nil {
return err
}
log.Trace().Interface("token", token).Send()

client, err := client.NewClient(cmd)
if err != nil {
return err
Expand Down
Loading

0 comments on commit f5626a8

Please sign in to comment.