Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes context override #417

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 100 additions & 16 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
grpc "google.golang.org/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

zgrpcutil "github.com/authzed/zed/internal/grpcutil"
Expand All @@ -28,20 +28,17 @@ type Client interface {
}

// NewClient defines an (overridable) means of creating a new client.
var NewClient = newGRPCClient
var (
NewClient = newClientForCurrentContext
NewClientForContext = newClientForContext
)

func newGRPCClient(cmd *cobra.Command) (Client, error) {
func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
configStore, secretStore := DefaultStorage()
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
Expand All @@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) {
return client, err
}

func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
currentToken, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}

token, err := GetTokenWithCLIOverride(cmd, currentToken)
if err != nil {
return nil, err
}

dialOpts, err := DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}

return authzed.NewClient(token.Endpoint, dialOpts...)
}

// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
token, err := storage.CurrentToken(
configStore,
secretStore,
)
if err != nil {
return storage.Token{}, err
}

return GetTokenWithCLIOverride(cmd, token)
}

// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
// flags
func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
overrideToken, err := tokenFromCli(cmd)
if err != nil {
return storage.Token{}, err
}

result, err := storage.TokenWithOverride(
overrideToken,
token,
)
if err != nil {
return storage.Token{}, err
}

log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
return result, nil
}

func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
var certBytes []byte
var err error
if certPath != "" {
certBytes, err = os.ReadFile(certPath)
if err != nil {
return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
}
}

explicitInsecure := cmd.Flags().Changed("insecure")
var notSecure *bool
if explicitInsecure {
i := cobrautil.MustGetBool(cmd, "insecure")
notSecure = &i
}

explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
var notVerifyCA *bool
if explicitNoVerifyCA {
nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
notVerifyCA = &nvc
}
overrideToken := storage.Token{
APIToken: cobrautil.MustGetString(cmd, "token"),
Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
Insecure: notSecure,
NoVerifyCA: notVerifyCA,
CACert: certBytes,
}
return overrideToken, nil
}

// DefaultStorage returns the default configured config store and secret store.
func DefaultStorage() (storage.ConfigStore, storage.SecretStore) {
var home string
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
home = filepath.Join(xdg, "zed")
} else {
homedir, _ := homedir.Dir()
home = filepath.Join(homedir, ".zed")
hmdir, _ := homedir.Dir()
home = filepath.Join(hmdir, ".zed")
}
return &storage.JSONConfigStore{ConfigPath: home},
&storage.KeychainSecretStore{ConfigPath: home}
}

func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) {
func certOption(token storage.Token) (opt grpc.DialOption, err error) {
verification := grpcutil.VerifyCA
if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() {
if token.HasNoVerifyCA() {
verification = grpcutil.SkipVerifyCA
}

if certBytes, ok := token.Certificate(); ok {
return grpcutil.WithCustomCertBytes(verification, certBytes)
}

return grpcutil.WithSystemCerts(verification)
}

Expand All @@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers),
}

if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) {
if token.IsInsecure() {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
} else {
opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
certOpt, err := certOption(cmd, token)
certOpt, err := certOption(token)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
}
Expand Down
62 changes: 62 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client_test

import (
"os"
"testing"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/storage"
zedtesting "github.com/authzed/zed/internal/testing"

"github.com/stretchr/testify/require"
)

func TestGetTokenWithCLIOverride(t *testing.T) {
testCert, err := os.CreateTemp("", "")
require.NoError(t, err)
_, err = testCert.Write([]byte("hi"))
require.NoError(t, err)
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true},
)

bTrue := true
bFalse := false

// cli args take precedence when defined
to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t1", to.APIToken)
require.Equal(t, "e1", to.Endpoint)
require.Equal(t, []byte("hi"), to.CACert)
require.Equal(t, &bTrue, to.Insecure)
require.Equal(t, &bTrue, to.NoVerifyCA)

// storage token takes precedence when defined
cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false},
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false},
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false},
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false},
)
to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{
APIToken: "t2",
Endpoint: "e2",
CACert: []byte("bye"),
Insecure: &bFalse,
NoVerifyCA: &bFalse,
})
require.NoError(t, err)
require.True(t, to.AnyValue())
require.Equal(t, "t2", to.APIToken)
require.Equal(t, "e2", to.Endpoint)
require.Equal(t, []byte("bye"), to.CACert)
require.Equal(t, &bFalse, to.Insecure)
require.Equal(t, &bFalse, to.NoVerifyCA)
}
22 changes: 3 additions & 19 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
Expand All @@ -23,7 +22,6 @@ import (
"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
Expand Down Expand Up @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{
RunE: schemaCopyCmdFunc,
}

// TODO(jschorr): support this in the client package
func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
token, err := storage.GetToken(contextName, secretStore)
if err != nil {
return nil, err
}
log.Trace().Interface("token", token).Send()

dialOpts, err := client.DialOptsFromFlags(cmd, token)
if err != nil {
return nil, err
}
return authzed.NewClient(token.Endpoint, dialOpts...)
}

func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
_, secretStore := client.DefaultStorage()
srcClient, err := clientForContext(cmd, args[0], secretStore)
srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
if err != nil {
return err
}
destClient, err := clientForContext(cmd, args[1], secretStore)

destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
if err != nil {
return err
}
Expand Down
22 changes: 2 additions & 20 deletions internal/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"github.com/gookit/color"
"github.com/jzelinskie/cobrautil/v2"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/storage"
)

func versionCmdFunc(cmd *cobra.Command, _ []string) error {
Expand All @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {

includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
hasContext := false
configStore, secretStore := client.DefaultStorage()
if includeRemoteVersion {
_, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
configStore, secretStore := client.DefaultStorage()
_, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
hasContext = err == nil
}

Expand All @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))

if hasContext && includeRemoteVersion {
token, err := storage.DefaultToken(
cobrautil.MustGetString(cmd, "endpoint"),
cobrautil.MustGetString(cmd, "token"),
configStore,
secretStore,
)
if err != nil {
return err
}
log.Trace().Interface("token", token).Send()

client, err := client.NewClient(cmd)
if err != nil {
return err
Expand Down
Loading
Loading