From 463a07a6969c07e52ed1f67c8a5e8aea1f5c2900 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 4 Apr 2020 19:15:56 -0700 Subject: [PATCH] Implement ALTER TYPE RENAME VALUE / ADD VALUE (#433) * postgresql: Implement ALTER TYPE RENAME/ADD VALUE * catalog: Implement ALTER TYPE RENAME/ADD VALUE This is a backport of the implementation from sql/catalog. --- internal/catalog/build.go | 50 +++++++++++++++++++ internal/catalog/build_test.go | 58 ++++++++++++++++++++++ internal/postgresql/catalog_test.go | 67 +++++++++++++++++++++++++- internal/postgresql/parse.go | 19 ++++++++ internal/sql/ast/ast.go | 20 ++++++++ internal/sql/catalog/catalog.go | 4 ++ internal/sql/catalog/types.go | 74 +++++++++++++++++++++++++++++ 7 files changed, 290 insertions(+), 2 deletions(-) diff --git a/internal/catalog/build.go b/internal/catalog/build.go index 62359c1fc9..ab03eab7a6 100644 --- a/internal/catalog/build.go +++ b/internal/catalog/build.go @@ -99,6 +99,56 @@ func Update(c *pg.Catalog, stmt nodes.Node) error { switch n := raw.Stmt.(type) { + case nodes.AlterEnumStmt: + fqn, err := ParseList(n.TypeName) + if err != nil { + return err + } + schema, exists := c.Schemas[fqn.Schema] + if !exists { + return wrap(pg.ErrorSchemaDoesNotExist(fqn.Schema), raw.StmtLocation) + } + typ, exists := schema.Types[fqn.Rel] + if !exists { + return wrap(pg.ErrorRelationDoesNotExist(fqn.Rel), raw.StmtLocation) + } + enum, ok := typ.(pg.Enum) + if !ok { + return wrap(pg.ErrorRelationDoesNotExist(fqn.Rel), raw.StmtLocation) + } + oldIndex := -1 + newIndex := -1 + for i, val := range enum.Vals { + if n.OldVal != nil && val == *n.OldVal { + oldIndex = i + } + if n.NewVal != nil && val == *n.NewVal { + newIndex = i + } + } + if n.OldVal != nil { + // RENAME TYPE + if oldIndex < 0 { + return fmt.Errorf("type %s does not have value %s", fqn.Rel, *n.OldVal) + } + if newIndex >= 0 { + return fmt.Errorf("type %s already has value %s", fqn.Rel, *n.NewVal) + } + enum.Vals[oldIndex] = *n.NewVal + schema.Types[fqn.Rel] = enum + } else { + // ADD VALUE + if newIndex >= 0 { + if !n.SkipIfNewValExists { + return fmt.Errorf("type %s already has value %s", fqn.Rel, *n.NewVal) + } else { + return nil + } + } + enum.Vals = append(enum.Vals, *n.NewVal) + schema.Types[fqn.Rel] = enum + } + case nodes.AlterObjectSchemaStmt: switch n.ObjectType { diff --git a/internal/catalog/build_test.go b/internal/catalog/build_test.go index f2df039d7f..da3bd628f0 100644 --- a/internal/catalog/build_test.go +++ b/internal/catalog/build_test.go @@ -45,6 +45,43 @@ func TestUpdate(t *testing.T) { }, }, }, + { + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + ALTER TYPE status RENAME VALUE 'closed' TO 'shut'; + `, + pg.Catalog{ + Schemas: map[string]pg.Schema{ + "public": { + Types: map[string]pg.Type{ + "status": pg.Enum{ + Name: "status", + Vals: []string{"open", "shut"}, + }, + }, + }, + }, + }, + }, + { + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + ALTER TYPE status ADD VALUE 'unknown'; + ALTER TYPE status ADD VALUE IF NOT EXISTS 'unknown'; + `, + pg.Catalog{ + Schemas: map[string]pg.Schema{ + "public": { + Types: map[string]pg.Type{ + "status": pg.Enum{ + Name: "status", + Vals: []string{"open", "closed", "unknown"}, + }, + }, + }, + }, + }, + }, { "CREATE TABLE venues ();", pg.Catalog{ @@ -240,6 +277,27 @@ func TestUpdate(t *testing.T) { }, }, }, + /* + { + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + ALTER TYPE status RENAME VALUE 'closed' TO 'shut'; + `, + pg.Catalog{ + Schemas: map[string]pg.Schema{ + "public": { + Types: map[string]pg.Type{ + "status": pg.Enum{ + Name: "status", + Vals: []string{"open", "shut"}, + }, + }, + Tables: map[string]pg.Table{}, + }, + }, + }, + }, + */ { ` CREATE TABLE venues (); diff --git a/internal/postgresql/catalog_test.go b/internal/postgresql/catalog_test.go index 61619794fa..abc1a26477 100644 --- a/internal/postgresql/catalog_test.go +++ b/internal/postgresql/catalog_test.go @@ -17,11 +17,13 @@ import ( func TestUpdate(t *testing.T) { p := NewParser() - for i, tc := range []struct { + for _, tc := range []struct { + name string stmt string s *catalog.Schema }{ { + "create-enum", "CREATE TYPE status AS ENUM ('open', 'closed');", &catalog.Schema{ Name: "public", @@ -34,6 +36,40 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-type-rename-value", + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + ALTER TYPE status RENAME VALUE 'closed' TO 'shut'; + `, + &catalog.Schema{ + Name: "public", + Types: []catalog.Type{ + &catalog.Enum{ + Name: "status", + Vals: []string{"open", "shut"}, + }, + }, + }, + }, + { + "alter-type-add-value", + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + ALTER TYPE status ADD VALUE 'unknown'; + ALTER TYPE status ADD VALUE IF NOT EXISTS 'unknown'; + `, + &catalog.Schema{ + Name: "public", + Types: []catalog.Type{ + &catalog.Enum{ + Name: "status", + Vals: []string{"open", "closed", "unknown"}, + }, + }, + }, + }, + { + "create-table", "CREATE TABLE venues ();", &catalog.Schema{ Name: "public", @@ -45,6 +81,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-drop-column", ` CREATE TABLE foo (); ALTER TABLE foo ADD COLUMN bar text; @@ -60,6 +97,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-drop-column-if-exists", ` CREATE TABLE foo (); ALTER TABLE foo DROP COLUMN IF EXISTS bar; @@ -74,6 +112,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-set-not-null", ` CREATE TABLE foo (bar text); ALTER TABLE foo ALTER bar SET NOT NULL; @@ -116,6 +155,7 @@ func TestUpdate(t *testing.T) { }, */ { + "alter-table-drop-not-null", ` CREATE TABLE foo (bar text NOT NULL); ALTER TABLE foo ALTER bar DROP NOT NULL; @@ -136,6 +176,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-column-drop-not-null", ` CREATE TABLE foo (bar text NOT NULL); ALTER TABLE foo ALTER COLUMN bar DROP NOT NULL; @@ -156,6 +197,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-rename-column", ` CREATE TABLE foo (bar text); ALTER TABLE foo RENAME bar TO baz; @@ -176,6 +218,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-set-data-type", ` CREATE TABLE foo (bar text); ALTER TABLE foo ALTER bar SET DATA TYPE bool; @@ -196,6 +239,7 @@ func TestUpdate(t *testing.T) { }, }, { + "alter-table-set-schema", ` CREATE SCHEMA foo; CREATE TABLE bar (); @@ -211,6 +255,7 @@ func TestUpdate(t *testing.T) { }, }, { + "drop-table", ` CREATE TABLE venues (); DROP TABLE venues; @@ -218,6 +263,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-table-if-exists", ` CREATE TABLE venues (); DROP TABLE IF EXISTS venues; @@ -226,6 +272,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "alter-table-rename", ` CREATE TABLE venues (); ALTER TABLE venues RENAME TO arenas; @@ -240,6 +287,7 @@ func TestUpdate(t *testing.T) { }, }, { + "drop-type", ` CREATE TYPE status AS ENUM ('open', 'closed'); DROP TYPE status; @@ -247,6 +295,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-type-if-exists", ` CREATE TYPE status AS ENUM ('open', 'closed'); DROP TYPE IF EXISTS status; @@ -255,6 +304,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-table-in-schema", ` CREATE TABLE venues (); DROP TABLE public.venues; @@ -262,6 +312,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-type-in-schema", ` CREATE TYPE status AS ENUM ('open', 'closed'); DROP TYPE public.status; @@ -269,6 +320,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-schema", ` CREATE SCHEMA foo; DROP SCHEMA foo; @@ -276,12 +328,14 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-schema-if-exists", ` DROP SCHEMA IF EXISTS foo; `, nil, }, { + "drop-function-if-exists", ` DROP FUNCTION IF EXISTS bar; DROP FUNCTION IF EXISTS bar(); @@ -289,6 +343,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "alter-table-drop-constraint", ` CREATE TABLE venues (id SERIAL PRIMARY KEY); ALTER TABLE venues DROP CONSTRAINT venues_id_pkey; @@ -310,6 +365,7 @@ func TestUpdate(t *testing.T) { }, }, { + "create-function", ` CREATE FUNCTION foo(TEXT) RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; `, @@ -329,6 +385,7 @@ func TestUpdate(t *testing.T) { }, }, { + "create-function-args", ` CREATE FUNCTION foo(bar TEXT) RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; CREATE FUNCTION foo(bar TEXT, baz TEXT) RETURNS TEXT AS $$ SELECT "baz" $$ LANGUAGE sql; @@ -364,6 +421,7 @@ func TestUpdate(t *testing.T) { }, }, { + "create-function-types", ` CREATE FUNCTION foo(bar TEXT) RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; CREATE FUNCTION foo(bar INTEGER) RETURNS TEXT AS $$ SELECT "baz" $$ LANGUAGE sql; @@ -395,6 +453,7 @@ func TestUpdate(t *testing.T) { }, }, { + "create-function-return", ` CREATE FUNCTION foo(bar TEXT, baz TEXT="baz") RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; `, @@ -420,6 +479,7 @@ func TestUpdate(t *testing.T) { }, }, { + "drop-function-args", ` CREATE FUNCTION foo(bar text) RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; DROP FUNCTION foo(text); @@ -427,6 +487,7 @@ func TestUpdate(t *testing.T) { nil, }, { + "drop-function", ` CREATE FUNCTION foo(bar text) RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; DROP FUNCTION foo; @@ -437,6 +498,7 @@ func TestUpdate(t *testing.T) { // CREATE FUNCTION foo() RETURNS bool AS $$ SELECT true $$ LANGUAGE sql; // DROP FUNCTION foo -- FAIL { + "pg_temp", ` CREATE TABLE pg_temp.migrate (val SERIAL); INSERT INTO pg_temp.migrate (val) SELECT val FROM old; @@ -458,6 +520,7 @@ func TestUpdate(t *testing.T) { }, }, { + "comment", ` CREATE SCHEMA foo; CREATE TABLE foo.bar (baz text); @@ -494,7 +557,7 @@ func TestUpdate(t *testing.T) { }, } { test := tc - t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { stmts, err := p.Parse(strings.NewReader(test.stmt)) if err != nil { t.Log(test.stmt) diff --git a/internal/postgresql/parse.go b/internal/postgresql/parse.go index 89951be550..9502217cfe 100644 --- a/internal/postgresql/parse.go +++ b/internal/postgresql/parse.go @@ -188,6 +188,25 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { func translate(node nodes.Node) (ast.Node, error) { switch n := node.(type) { + case nodes.AlterEnumStmt: + name, err := parseTypeName(n.TypeName) + if err != nil { + return nil, err + } + if n.OldVal != nil { + return &ast.AlterTypeRenameValueStmt{ + Type: name, + OldValue: n.OldVal, + NewValue: n.NewVal, + }, nil + } else { + return &ast.AlterTypeAddValueStmt{ + Type: name, + NewValue: n.NewVal, + SkipIfNewValExists: n.SkipIfNewValExists, + }, nil + } + case nodes.AlterObjectSchemaStmt: switch n.ObjectType { diff --git a/internal/sql/ast/ast.go b/internal/sql/ast/ast.go index 3262fc49c4..c8562e1f84 100644 --- a/internal/sql/ast/ast.go +++ b/internal/sql/ast/ast.go @@ -230,6 +230,26 @@ func (n *RenameColumnStmt) Pos() int { return 0 } +type AlterTypeRenameValueStmt struct { + Type *TypeName + OldValue *string + NewValue *string +} + +func (n *AlterTypeRenameValueStmt) Pos() int { + return 0 +} + +type AlterTypeAddValueStmt struct { + Type *TypeName + NewValue *string + SkipIfNewValExists bool +} + +func (n *AlterTypeAddValueStmt) Pos() int { + return 0 +} + type RenameTableStmt struct { Table *TableName NewName *string diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 38b839b89a..aaac0f43b4 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -226,6 +226,10 @@ func (c *Catalog) Build(stmts []ast.Statement) error { err = c.alterTable(n) case *ast.AlterTableSetSchemaStmt: err = c.alterTableSetSchema(n) + case *ast.AlterTypeAddValueStmt: + err = c.alterTypeAddValue(n) + case *ast.AlterTypeRenameValueStmt: + err = c.alterTypeRenameValue(n) case *ast.CommentOnColumnStmt: err = c.commentOnColumn(n) case *ast.CommentOnSchemaStmt: diff --git a/internal/sql/catalog/types.go b/internal/sql/catalog/types.go index 2f5be4d9b8..359a64768c 100644 --- a/internal/sql/catalog/types.go +++ b/internal/sql/catalog/types.go @@ -2,6 +2,7 @@ package catalog import ( "errors" + "fmt" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" @@ -36,6 +37,79 @@ func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { return nil } +func (c *Catalog) alterTypeRenameValue(stmt *ast.AlterTypeRenameValueStmt) error { + ns := stmt.Type.Schema + if ns == "" { + ns = c.DefaultSchema + } + schema, err := c.getSchema(ns) + if err != nil { + return err + } + typ, _, err := schema.getType(stmt.Type) + if err != nil { + return err + } + enum, ok := typ.(*Enum) + if !ok { + return fmt.Errorf("type is not an enum: %s", stmt.Type) + } + + oldIndex := -1 + newIndex := -1 + for i, val := range enum.Vals { + if val == *stmt.OldValue { + oldIndex = i + } + if val == *stmt.NewValue { + newIndex = i + } + } + if oldIndex < 0 { + return fmt.Errorf("type %s does not have value %s", stmt.Type, *stmt.OldValue) + } + if newIndex >= 0 { + return fmt.Errorf("type %s already has value %s", stmt.Type, *stmt.NewValue) + } + enum.Vals[oldIndex] = *stmt.NewValue + return nil +} + +func (c *Catalog) alterTypeAddValue(stmt *ast.AlterTypeAddValueStmt) error { + ns := stmt.Type.Schema + if ns == "" { + ns = c.DefaultSchema + } + schema, err := c.getSchema(ns) + if err != nil { + return err + } + typ, _, err := schema.getType(stmt.Type) + if err != nil { + return err + } + enum, ok := typ.(*Enum) + if !ok { + return fmt.Errorf("type is not an enum: %s", stmt.Type) + } + + newIndex := -1 + for i, val := range enum.Vals { + if val == *stmt.NewValue { + newIndex = i + } + } + if newIndex >= 0 { + if !stmt.SkipIfNewValExists { + return fmt.Errorf("type %s already has value %s", stmt.Type, *stmt.NewValue) + } else { + return nil + } + } + enum.Vals = append(enum.Vals, *stmt.NewValue) + return nil +} + func (c *Catalog) dropType(stmt *ast.DropTypeStmt) error { for _, name := range stmt.Types { ns := name.Schema