diff --git a/cmd/zed/import.go b/cmd/zed/import.go index b4d68514..45dbaf34 100644 --- a/cmd/zed/import.go +++ b/cmd/zed/import.go @@ -9,16 +9,10 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/authzed-go/v1" - "github.com/authzed/spicedb/pkg/schemadsl/compiler" - "github.com/authzed/spicedb/pkg/schemadsl/generator" - "github.com/authzed/spicedb/pkg/schemadsl/input" "github.com/authzed/spicedb/pkg/tuple" "github.com/jzelinskie/cobrautil" - "github.com/jzelinskie/stringz" "github.com/rs/zerolog/log" "github.com/spf13/cobra" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/authzed/zed/internal/decode" "github.com/authzed/zed/internal/storage" @@ -84,54 +78,6 @@ func importCmdFunc(cmd *cobra.Command, args []string) error { return err } - // Read the existing schema (if any) to get the prefix. - prefix := cobrautil.MustGetString(cmd, "schema-definition-prefix") - if prefix == "" { - request := &v1.ReadSchemaRequest{} - log.Trace().Interface("request", request).Msg("requesting schema read") - - resp, err := client.ReadSchema(context.Background(), request) - if err != nil { - // If the schema was not found, then just use the empty prefix. - errStatus, ok := status.FromError(err) - if !ok || errStatus.Code() != codes.NotFound { - return err - } - - log.Debug().Msg("no schema defined") - } else { - empty := "" - found, err := compiler.Compile([]compiler.InputSchema{ - {Source: input.Source("schema"), SchemaString: resp.SchemaText}, - }, &empty) - if err != nil { - return err - } - - foundPrefixes := make([]string, 0, len(found)) - for _, def := range found { - if strings.Contains(def.Name, "/") { - parts := strings.Split(def.Name, "/") - foundPrefixes = append(foundPrefixes, parts[0]) - } else { - foundPrefixes = append(foundPrefixes, "") - } - } - - prefixes := stringz.Dedup(foundPrefixes) - if len(prefixes) == 0 { - return fmt.Errorf("found no schema definition prefixes") - } - - if len(prefixes) > 1 { - return fmt.Errorf("found multiple schema definition prefixes: %v", prefixes) - } - - prefix = prefixes[0] - log.Debug().Str("prefix", prefix).Msg("found schema definition prefix") - } - } - u, err := url.Parse(args[0]) if err != nil { return err @@ -146,6 +92,11 @@ func importCmdFunc(cmd *cobra.Command, args []string) error { return err } + prefix, err := determinePrefixForSchema(cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + if cobrautil.MustGetBool(cmd, "schema") { if err := importSchema(client, p.Schema, prefix); err != nil { return err @@ -165,21 +116,11 @@ func importSchema(client *authzed.Client, schema string, definitionPrefix string log.Info().Msg("importing schema") // Recompile the schema with the specified prefix. - nsDefs, err := compiler.Compile([]compiler.InputSchema{ - {Source: input.Source("schema"), SchemaString: schema}, - }, &definitionPrefix) + schemaText, err := rewriteSchema(schema, definitionPrefix) if err != nil { return err } - objectDefs := make([]string, 0, len(nsDefs)) - for _, nsDef := range nsDefs { - objectDef, _ := generator.GenerateSource(nsDef) - objectDefs = append(objectDefs, objectDef) - } - - schemaText := strings.Join(objectDefs, "\n\n") - // Write the recompiled and regenerated schema. request := &v1.WriteSchemaRequest{Schema: schemaText} log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema") diff --git a/cmd/zed/schema.go b/cmd/zed/schema.go index efccd015..7121b8aa 100644 --- a/cmd/zed/schema.go +++ b/cmd/zed/schema.go @@ -7,15 +7,21 @@ import ( "fmt" "io/ioutil" "os" + "strings" "github.com/TylerBrock/colorjson" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/authzed-go/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" "github.com/jzelinskie/cobrautil" "github.com/jzelinskie/stringz" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "golang.org/x/term" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -30,9 +36,11 @@ func registerSchemaCmd(rootCmd *cobra.Command) { schemaCmd.AddCommand(schemaWriteCmd) schemaWriteCmd.Flags().Bool("json", false, "output as JSON") + schemaWriteCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") schemaCmd.AddCommand(schemaCopyCmd) schemaCopyCmd.Flags().Bool("json", false, "output as JSON") + schemaCopyCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") } var ( @@ -147,7 +155,17 @@ func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error { return errors.New("attempted to write empty schema") } - request := &v1.WriteSchemaRequest{Schema: string(schemaBytes)} + prefix, err := determinePrefixForSchema(cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(string(schemaBytes), prefix) + if err != nil { + return err + } + + request := &v1.WriteSchemaRequest{Schema: schemaText} log.Trace().Interface("request", request).Msg("writing schema") resp, err := client.WriteSchema(context.Background(), request) @@ -220,7 +238,17 @@ func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { } log.Trace().Interface("response", readResp).Msg("read schema") - writeRequest := &v1.WriteSchemaRequest{Schema: readResp.SchemaText} + prefix, err := determinePrefixForSchema(cobrautil.MustGetString(cmd, "schema-definition-prefix"), nil, &readResp.SchemaText) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(readResp.SchemaText, prefix) + if err != nil { + return err + } + + writeRequest := &v1.WriteSchemaRequest{Schema: schemaText} log.Trace().Interface("request", writeRequest).Msg("writing schema") resp, err := destClient.WriteSchema(context.Background(), writeRequest) @@ -241,3 +269,99 @@ func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { return nil } + +// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions. +func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, error) { + nsDefs, err := compiler.Compile([]compiler.InputSchema{ + {Source: input.Source("schema"), SchemaString: existingSchemaText}, + }, &definitionPrefix) + if err != nil { + return "", err + } + + objectDefs := make([]string, 0, len(nsDefs)) + for _, nsDef := range nsDefs { + objectDef, _ := generator.GenerateSource(nsDef) + objectDefs = append(objectDefs, objectDef) + } + + return strings.Join(objectDefs, "\n\n"), nil +} + +// readSchema calls read schema for the client and returns the schema found. +func readSchema(client *authzed.Client) (string, error) { + request := &v1.ReadSchemaRequest{} + log.Trace().Interface("request", request).Msg("requesting schema read") + + resp, err := client.ReadSchema(context.Background(), request) + if err != nil { + errStatus, ok := status.FromError(err) + if !ok || errStatus.Code() != codes.NotFound { + return "", err + } + + log.Debug().Msg("no schema defined") + return "", nil + } + + return resp.SchemaText, nil +} + +// determinePrefixForSchema determines the prefix to be applied to a schema that will be written. +// +// If specifiedPrefix is non-empty, it is returned immediately. +// If existingSchema is non-nil, it is parsed for the prefix. +// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there. +func determinePrefixForSchema(specifiedPrefix string, client *authzed.Client, existingSchema *string) (string, error) { + if specifiedPrefix != "" { + return specifiedPrefix, nil + } + + var schemaText string + if existingSchema != nil { + schemaText = *existingSchema + } else { + readSchemaText, err := readSchema(client) + if err != nil { + return "", nil + } + schemaText = readSchemaText + } + + // If there is no schema found, return the empty string. + if schemaText == "" { + return "", nil + } + + // Otherwise, compile the schema and grab the prefixes of the namespaces defined. + empty := "" + found, err := compiler.Compile([]compiler.InputSchema{ + {Source: input.Source("schema"), SchemaString: schemaText}, + }, &empty) + if err != nil { + return "", err + } + + foundPrefixes := make([]string, 0, len(found)) + for _, def := range found { + if strings.Contains(def.Name, "/") { + parts := strings.Split(def.Name, "/") + foundPrefixes = append(foundPrefixes, parts[0]) + } else { + foundPrefixes = append(foundPrefixes, "") + } + } + + prefixes := stringz.Dedup(foundPrefixes) + if len(prefixes) == 0 { + return "", fmt.Errorf("found no schema definition prefixes") + } + + if len(prefixes) > 1 { + return "", fmt.Errorf("found multiple schema definition prefixes: %v", prefixes) + } + + prefix := prefixes[0] + log.Debug().Str("prefix", prefix).Msg("found schema definition prefix") + return prefix, nil +}