From e8b6eabc46ac9cb491b110c594dacc4c93987083 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Wed, 29 Jul 2020 10:17:26 -0700 Subject: [PATCH] compiler: Support calling functions with defaults Fix a number of bugs when resolving function calls. Switch to using functions generated from a default PostgreSQL instance. Add a new test case ripped from the PostgreSQL docs. --- internal/codegen/golang/postgresql_type.go | 14 +-- internal/compiler/output_columns.go | 2 +- internal/compiler/resolve.go | 45 +++++++--- .../testdata/func_args/go/query.sql.go | 16 ++-- .../testdata/generate_series/go/query.sql.go | 6 +- .../pg_advisory_xact_lock/go/query.sql.go | 4 +- .../sql_syntax_calling_funcs/go/db.go | 29 ++++++ .../sql_syntax_calling_funcs/go/models.go | 5 ++ .../sql_syntax_calling_funcs/go/query.sql.go | 74 ++++++++++++++++ .../sql_syntax_calling_funcs/query.sql | 29 ++++++ .../sql_syntax_calling_funcs/sqlc.json | 12 +++ internal/engine/postgresql/catalog.go | 2 +- internal/sql/catalog/public.go | 88 ++++++++++++++++--- internal/sql/sqlerr/errors.go | 8 ++ internal/sql/validate/func_call.go | 31 +------ 15 files changed, 290 insertions(+), 75 deletions(-) create mode 100644 internal/endtoend/testdata/sql_syntax_calling_funcs/go/db.go create mode 100644 internal/endtoend/testdata/sql_syntax_calling_funcs/go/models.go create mode 100644 internal/endtoend/testdata/sql_syntax_calling_funcs/go/query.sql.go create mode 100644 internal/endtoend/testdata/sql_syntax_calling_funcs/query.sql create mode 100644 internal/endtoend/testdata/sql_syntax_calling_funcs/sqlc.json diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index efb028fab2..5124d5cf8c 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -13,19 +13,19 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb notNull := col.NotNull || col.IsArray switch columnType { - case "serial", "pg_catalog.serial4": + case "serial", "serial4", "pg_catalog.serial4": if notNull { return "int32" } return "sql.NullInt32" - case "bigserial", "pg_catalog.serial8": + case "bigserial", "serial8", "pg_catalog.serial8": if notNull { return "int64" } return "sql.NullInt64" - case "smallserial", "pg_catalog.serial2": + case "smallserial", "serial2", "pg_catalog.serial2": return "int16" case "integer", "int", "int4", "pg_catalog.int4": @@ -43,19 +43,19 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb case "smallint", "int2", "pg_catalog.int2": return "int16" - case "float", "double precision", "pg_catalog.float8": + case "float", "double precision", "float8", "pg_catalog.float8": if notNull { return "float64" } return "sql.NullFloat64" - case "real", "pg_catalog.float4": + case "real", "float4", "pg_catalog.float4": if notNull { return "float32" } return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file - case "pg_catalog.numeric", "money": + case "numeric", "pg_catalog.numeric", "money": // Since the Go standard library does not have a decimal type, lib/pq // returns numerics as strings. // @@ -121,7 +121,7 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb } return "sql.NullString" - case "pg_catalog.interval": + case "interval", "pg_catalog.interval": if notNull { return "int64" } diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 09164eead3..5b20fc9af8 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -149,7 +149,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { if res.Name != nil { name = *res.Name } - fun, err := qc.catalog.GetFuncN(rel, len(n.Args.Items)) + fun, err := qc.catalog.ResolveFuncCall(n) if err == nil { cols = append(cols, &Column{Name: name, DataType: dataType(fun.ReturnType), NotNull: true}) } else { diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b492ed179b..bc265e3b38 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -170,24 +170,25 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef, } case *ast.FuncCall: - fun, err := c.GetFuncN(n.Func, len(n.Args.Items)) + fun, err := c.ResolveFuncCall(n) if err != nil { + // Synthesize a function on the fly to avoid returning with an error + // for an unknown Postgres function (e.g. defined in an extension) var args []*catalog.Argument for range n.Args.Items { args = append(args, &catalog.Argument{ Type: &ast.TypeName{Name: "any"}, }) } - // Synthesize a function on the fly to avoid returning with an error - // for an unknown Postgres function (e.g. defined in an extension) - fun = catalog.Function{ + fun = &catalog.Function{ Name: n.Func.Name, Args: args, ReturnType: &ast.TypeName{Name: "any"}, } } for i, item := range n.Args.Items { - defaultName := fun.Name + funcName := fun.Name + var argName string switch inode := item.(type) { case *pg.ParamRef: if inode.Number != ref.ref.Number { @@ -210,11 +211,15 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef, continue } if inode.Name != nil { - defaultName = *inode.Name + argName = *inode.Name } } if fun.Args == nil { + defaultName := funcName + if argName != "" { + defaultName = argName + } a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -225,19 +230,31 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef, continue } - if i >= len(fun.Args) { - return nil, fmt.Errorf("incorrect number of arguments to %s", fun.Name) + var paramName string + var paramType *ast.TypeName + if argName == "" { + paramName = fun.Args[i].Name + paramType = fun.Args[i].Type + } else { + paramName = argName + for _, arg := range fun.Args { + if arg.Name == argName { + paramType = arg.Type + } + } + if paramType == nil { + panic(fmt.Sprintf("named argument %s has no type", paramName)) + } } - arg := fun.Args[i] - name := arg.Name - if name == "" { - name = defaultName + if paramName == "" { + paramName = funcName } + a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, name), - DataType: dataType(arg.Type), + Name: parameterName(ref.ref.Number, paramName), + DataType: dataType(paramType), NotNull: true, }, }) diff --git a/internal/endtoend/testdata/func_args/go/query.sql.go b/internal/endtoend/testdata/func_args/go/query.sql.go index aabb4c6bab..a2e2e83f51 100644 --- a/internal/endtoend/testdata/func_args/go/query.sql.go +++ b/internal/endtoend/testdata/func_args/go/query.sql.go @@ -11,9 +11,9 @@ const makeIntervalDays = `-- name: MakeIntervalDays :one SELECT make_interval(days => $1::int) ` -func (q *Queries) MakeIntervalDays(ctx context.Context, dollar_1 int32) (interface{}, error) { +func (q *Queries) MakeIntervalDays(ctx context.Context, dollar_1 int32) (int64, error) { row := q.db.QueryRowContext(ctx, makeIntervalDays, dollar_1) - var make_interval interface{} + var make_interval int64 err := row.Scan(&make_interval) return make_interval, err } @@ -22,9 +22,9 @@ const makeIntervalMonths = `-- name: MakeIntervalMonths :one SELECT make_interval(months => $1::int) ` -func (q *Queries) MakeIntervalMonths(ctx context.Context, months int32) (interface{}, error) { +func (q *Queries) MakeIntervalMonths(ctx context.Context, months int32) (int64, error) { row := q.db.QueryRowContext(ctx, makeIntervalMonths, months) - var make_interval interface{} + var make_interval int64 err := row.Scan(&make_interval) return make_interval, err } @@ -33,9 +33,9 @@ const makeIntervalSecs = `-- name: MakeIntervalSecs :one SELECT make_interval(secs => $1) ` -func (q *Queries) MakeIntervalSecs(ctx context.Context, secs interface{}) (interface{}, error) { +func (q *Queries) MakeIntervalSecs(ctx context.Context, secs float64) (int64, error) { row := q.db.QueryRowContext(ctx, makeIntervalSecs, secs) - var make_interval interface{} + var make_interval int64 err := row.Scan(&make_interval) return make_interval, err } @@ -45,12 +45,12 @@ SELECT plus(b => $2, a => $1) ` type PlusParams struct { - B int32 A int32 + B int32 } func (q *Queries) Plus(ctx context.Context, arg PlusParams) (int32, error) { - row := q.db.QueryRowContext(ctx, plus, arg.B, arg.A) + row := q.db.QueryRowContext(ctx, plus, arg.A, arg.B) var plus int32 err := row.Scan(&plus) return plus, err diff --git a/internal/endtoend/testdata/generate_series/go/query.sql.go b/internal/endtoend/testdata/generate_series/go/query.sql.go index 2a56e1a056..63d113dc25 100644 --- a/internal/endtoend/testdata/generate_series/go/query.sql.go +++ b/internal/endtoend/testdata/generate_series/go/query.sql.go @@ -17,15 +17,15 @@ type GenerateSeriesParams struct { Column2 time.Time `json:"column_2"` } -func (q *Queries) GenerateSeries(ctx context.Context, arg GenerateSeriesParams) ([]interface{}, error) { +func (q *Queries) GenerateSeries(ctx context.Context, arg GenerateSeriesParams) ([]string, error) { rows, err := q.db.QueryContext(ctx, generateSeries, arg.Column1, arg.Column2) if err != nil { return nil, err } defer rows.Close() - var items []interface{} + var items []string for rows.Next() { - var generate_series interface{} + var generate_series string if err := rows.Scan(&generate_series); err != nil { return nil, err } diff --git a/internal/endtoend/testdata/pg_advisory_xact_lock/go/query.sql.go b/internal/endtoend/testdata/pg_advisory_xact_lock/go/query.sql.go index a7e4cfb7ae..9959c4e03d 100644 --- a/internal/endtoend/testdata/pg_advisory_xact_lock/go/query.sql.go +++ b/internal/endtoend/testdata/pg_advisory_xact_lock/go/query.sql.go @@ -11,8 +11,8 @@ const advisoryLock = `-- name: AdvisoryLock :many SELECT pg_advisory_unlock($1) ` -func (q *Queries) AdvisoryLock(ctx context.Context, key int64) ([]bool, error) { - rows, err := q.db.QueryContext(ctx, advisoryLock, key) +func (q *Queries) AdvisoryLock(ctx context.Context, pgAdvisoryUnlock int64) ([]bool, error) { + rows, err := q.db.QueryContext(ctx, advisoryLock, pgAdvisoryUnlock) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/sql_syntax_calling_funcs/go/db.go b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sql_syntax_calling_funcs/go/models.go b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/models.go new file mode 100644 index 0000000000..4e2b892600 --- /dev/null +++ b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/models.go @@ -0,0 +1,5 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () diff --git a/internal/endtoend/testdata/sql_syntax_calling_funcs/go/query.sql.go b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/query.sql.go new file mode 100644 index 0000000000..2cc5285ddd --- /dev/null +++ b/internal/endtoend/testdata/sql_syntax_calling_funcs/go/query.sql.go @@ -0,0 +1,74 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const mixedNotation = `-- name: MixedNotation :one +SELECT concat_lower_or_upper('Hello', 'World', uppercase => true) +` + +func (q *Queries) MixedNotation(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, mixedNotation) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} + +const namedAnyOrder = `-- name: NamedAnyOrder :one +SELECT concat_lower_or_upper(a => 'Hello', b => 'World', uppercase => true) +` + +func (q *Queries) NamedAnyOrder(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, namedAnyOrder) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} + +const namedNotation = `-- name: NamedNotation :one +SELECT concat_lower_or_upper(a => 'Hello', b => 'World') +` + +func (q *Queries) NamedNotation(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, namedNotation) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} + +const namedOtherOrder = `-- name: NamedOtherOrder :one +SELECT concat_lower_or_upper(a => 'Hello', uppercase => true, b => 'World') +` + +func (q *Queries) NamedOtherOrder(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, namedOtherOrder) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} + +const positionalNoDefaault = `-- name: PositionalNoDefaault :one +SELECT concat_lower_or_upper('Hello', 'World') +` + +func (q *Queries) PositionalNoDefaault(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, positionalNoDefaault) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} + +const positionalNotation = `-- name: PositionalNotation :one +SELECT concat_lower_or_upper('Hello', 'World', true) +` + +func (q *Queries) PositionalNotation(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, positionalNotation) + var concat_lower_or_upper string + err := row.Scan(&concat_lower_or_upper) + return concat_lower_or_upper, err +} diff --git a/internal/endtoend/testdata/sql_syntax_calling_funcs/query.sql b/internal/endtoend/testdata/sql_syntax_calling_funcs/query.sql new file mode 100644 index 0000000000..59a2c35138 --- /dev/null +++ b/internal/endtoend/testdata/sql_syntax_calling_funcs/query.sql @@ -0,0 +1,29 @@ +-- https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html +CREATE FUNCTION concat_lower_or_upper(a text, b text, uppercase boolean DEFAULT false) +RETURNS text +AS +$$ + SELECT CASE + WHEN $3 THEN UPPER($1 || ' ' || $2) + ELSE LOWER($1 || ' ' || $2) + END; +$$ +LANGUAGE SQL IMMUTABLE STRICT; + +-- name: PositionalNotation :one +SELECT concat_lower_or_upper('Hello', 'World', true); + +-- name: PositionalNoDefaault :one +SELECT concat_lower_or_upper('Hello', 'World'); + +-- name: NamedNotation :one +SELECT concat_lower_or_upper(a => 'Hello', b => 'World'); + +-- name: NamedAnyOrder :one +SELECT concat_lower_or_upper(a => 'Hello', b => 'World', uppercase => true); + +-- name: NamedOtherOrder :one +SELECT concat_lower_or_upper(a => 'Hello', uppercase => true, b => 'World'); + +-- name: MixedNotation :one +SELECT concat_lower_or_upper('Hello', 'World', uppercase => true); diff --git a/internal/endtoend/testdata/sql_syntax_calling_funcs/sqlc.json b/internal/endtoend/testdata/sql_syntax_calling_funcs/sqlc.json new file mode 100644 index 0000000000..de427d069f --- /dev/null +++ b/internal/endtoend/testdata/sql_syntax_calling_funcs/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/postgresql/catalog.go b/internal/engine/postgresql/catalog.go index 0a4c5a01cd..e506cace14 100644 --- a/internal/engine/postgresql/catalog.go +++ b/internal/engine/postgresql/catalog.go @@ -5,7 +5,7 @@ import "github.com/kyleconroy/sqlc/internal/sql/catalog" func NewCatalog() *catalog.Catalog { c := catalog.New("public") c.Schemas = append(c.Schemas, pgTemp()) - c.Schemas = append(c.Schemas, pgCatalog()) + c.Schemas = append(c.Schemas, genPGCatalog()) c.SearchPath = []string{"pg_catalog"} return c } diff --git a/internal/sql/catalog/public.go b/internal/sql/catalog/public.go index 25f0a8be20..4d3d929fa9 100644 --- a/internal/sql/catalog/public.go +++ b/internal/sql/catalog/public.go @@ -1,7 +1,11 @@ package catalog import ( + "fmt" + "strings" + "github.com/kyleconroy/sqlc/internal/sql/ast" + "github.com/kyleconroy/sqlc/internal/sql/ast/pg" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) @@ -28,23 +32,83 @@ func (c *Catalog) ListFuncsByName(rel *ast.FuncName) ([]Function, error) { return funcs, nil } -func (c *Catalog) GetFuncN(rel *ast.FuncName, n int) (Function, error) { - for _, ns := range c.schemasToSearch(rel.Schema) { - s, err := c.getSchema(ns) - if err != nil { - return Function{}, err +func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) { + // Do not validate unknown functions + funs, err := c.ListFuncsByName(call.Func) + if err != nil || len(funs) == 0 { + return nil, sqlerr.FunctionNotFound(call.Func.Name) + } + + // https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html + var positional []ast.Node + var named []*pg.NamedArgExpr + + if call.Args != nil { + for _, arg := range call.Args.Items { + if narg, ok := arg.(*pg.NamedArgExpr); ok { + named = append(named, narg) + } else { + // The mixed notation combines positional and named notation. + // However, as already mentioned, named arguments cannot precede + // positional arguments. + if len(named) > 0 { + return nil, &sqlerr.Error{ + Code: "", + Message: "positional argument cannot follow named argument", + Location: call.Pos(), + } + } + positional = append(positional, arg) + } } - for i := range s.Funcs { - if s.Funcs[i].Name != rel.Name { - continue + } + + for _, fun := range funs { + args := fun.InArgs() + var defaults int + known := map[string]struct{}{} + for _, arg := range args { + if arg.HasDefault { + defaults += 1 } - args := s.Funcs[i].InArgs() - if len(args) == n { - return *s.Funcs[i], nil + if arg.Name != "" { + known[arg.Name] = struct{}{} } } + if (len(named) + len(positional)) > len(args) { + continue + } + if (len(named) + len(positional)) < (len(args) - defaults) { + continue + } + + // Validate that the provided named arguments exist in the function + var unknownArgName bool + for _, expr := range named { + if expr.Name != nil { + if _, found := known[*expr.Name]; !found { + unknownArgName = true + } + } + } + if unknownArgName { + continue + } + + return &fun, nil + } + + var sig []string + for range call.Args.Items { + sig = append(sig, "unknown") + } + + return nil, &sqlerr.Error{ + Code: "42883", + Message: fmt.Sprintf("function %s(%s) does not exist", call.Func.Name, strings.Join(sig, ", ")), + Location: call.Pos(), + // Hint: "No function matches the given name and argument types. You might need to add explicit type casts.", } - return Function{}, sqlerr.RelationNotFound(rel.Name) } func (c *Catalog) GetTable(rel *ast.TableName) (Table, error) { diff --git a/internal/sql/sqlerr/errors.go b/internal/sql/sqlerr/errors.go index a7fe204442..13c38ef1ae 100644 --- a/internal/sql/sqlerr/errors.go +++ b/internal/sql/sqlerr/errors.go @@ -95,6 +95,14 @@ func TypeNotFound(typ string) *Error { } } +func FunctionNotFound(fun string) *Error { + return &Error{ + Err: NotFound, + Code: "42704", + Message: fmt.Sprintf("function \"%s\"", fun), + } +} + func FunctionNotUnique(fn string) *Error { return &Error{ Err: NotUnique, diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 289babcdfd..06b0ddccd9 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -1,8 +1,7 @@ package validate import ( - "fmt" - "strings" + "errors" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" @@ -28,34 +27,12 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { return v } - // Do not validate unknown functions - funs, err := v.catalog.ListFuncsByName(call.Func) - if err != nil || len(funs) == 0 { + fun, err := v.catalog.ResolveFuncCall(call) + if fun != nil || errors.Is(err, sqlerr.NotFound) { return v } - var args int - if call.Args != nil { - args = len(call.Args.Items) - } - for _, fun := range funs { - if len(fun.InArgs()) == args { - return v - } - } - - var sig []string - for range call.Args.Items { - sig = append(sig, "unknown") - } - - v.err = &sqlerr.Error{ - Code: "42883", - Message: fmt.Sprintf("function %s(%s) does not exist", call.Func.Name, strings.Join(sig, ", ")), - Location: call.Pos(), - // Hint: "No function matches the given name and argument types. You might need to add explicit type casts.", - } - + v.err = err return nil }