From 78af9f7ee72e9c008c26841c57b4dfd2ae4afe27 Mon Sep 17 00:00:00 2001 From: Michael Hudak Date: Thu, 23 Dec 2021 21:49:06 +0100 Subject: [PATCH] handling of dollar signs in function names by removing the char from codegen output adding a e2e case --- internal/codegen/golang/result.go | 6 +++- internal/compiler/output_columns.go | 13 +++++++-- internal/compiler/query.go | 1 + .../testdata/identifier_dollar_sign/db/db.go | 29 +++++++++++++++++++ .../identifier_dollar_sign/db/models.go | 5 ++++ .../identifier_dollar_sign/db/query.sql.go | 19 ++++++++++++ .../testdata/identifier_dollar_sign/query.sql | 4 +++ .../testdata/identifier_dollar_sign/sqlc.json | 11 +++++++ 8 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 internal/endtoend/testdata/identifier_dollar_sign/db/db.go create mode 100644 internal/endtoend/testdata/identifier_dollar_sign/db/models.go create mode 100644 internal/endtoend/testdata/identifier_dollar_sign/db/query.sql.go create mode 100644 internal/endtoend/testdata/identifier_dollar_sign/query.sql create mode 100644 internal/endtoend/testdata/identifier_dollar_sign/sqlc.json diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index b213826eb9..e4f65443ba 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -193,8 +193,12 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs if len(query.Columns) == 1 { c := query.Columns[0] + name := columnName(c, 0) + if c.IsFuncCall { + name = strings.Replace(name, "$", "_", -1) + } gq.Ret = QueryValue{ - Name: columnName(c, 0), + Name: name, Typ: goType(r, c, settings), SQLPackage: sqlpkg, } diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 16a96c40a8..0b0e770025 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -210,9 +210,18 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { } fun, err := qc.catalog.ResolveFuncCall(n) if err == nil { - cols = append(cols, &Column{Name: name, DataType: dataType(fun.ReturnType), NotNull: !fun.ReturnTypeNullable}) + cols = append(cols, &Column{ + Name: name, + DataType: dataType(fun.ReturnType), + NotNull: !fun.ReturnTypeNullable, + IsFuncCall: true, + }) } else { - cols = append(cols, &Column{Name: name, DataType: "any"}) + cols = append(cols, &Column{ + Name: name, + DataType: "any", + IsFuncCall: true, + }) } case *ast.SubLink: diff --git a/internal/compiler/query.go b/internal/compiler/query.go index d2eb1d2fd7..46abd01e9b 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -22,6 +22,7 @@ type Column struct { Comment string Length *int IsNamedParam bool + IsFuncCall bool // XXX: Figure out what PostgreSQL calls `foo.id` Scope string diff --git a/internal/endtoend/testdata/identifier_dollar_sign/db/db.go b/internal/endtoend/testdata/identifier_dollar_sign/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/identifier_dollar_sign/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/identifier_dollar_sign/db/models.go b/internal/endtoend/testdata/identifier_dollar_sign/db/models.go new file mode 100644 index 0000000000..7822a4a49b --- /dev/null +++ b/internal/endtoend/testdata/identifier_dollar_sign/db/models.go @@ -0,0 +1,5 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import () diff --git a/internal/endtoend/testdata/identifier_dollar_sign/db/query.sql.go b/internal/endtoend/testdata/identifier_dollar_sign/db/query.sql.go new file mode 100644 index 0000000000..0227595966 --- /dev/null +++ b/internal/endtoend/testdata/identifier_dollar_sign/db/query.sql.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" +) + +const fn = `-- name: Fn :one +SELECT f$n() +` + +func (q *Queries) Fn(ctx context.Context) (int32, error) { + row := q.db.QueryRowContext(ctx, fn) + var f_n int32 + err := row.Scan(&f_n) + return f_n, err +} diff --git a/internal/endtoend/testdata/identifier_dollar_sign/query.sql b/internal/endtoend/testdata/identifier_dollar_sign/query.sql new file mode 100644 index 0000000000..ee805d9459 --- /dev/null +++ b/internal/endtoend/testdata/identifier_dollar_sign/query.sql @@ -0,0 +1,4 @@ +CREATE FUNCTION f$n() RETURNS integer AS 'SELECT 1'; + +-- name: Fn :one +SELECT f$n(); diff --git a/internal/endtoend/testdata/identifier_dollar_sign/sqlc.json b/internal/endtoend/testdata/identifier_dollar_sign/sqlc.json new file mode 100644 index 0000000000..ff443fe1b9 --- /dev/null +++ b/internal/endtoend/testdata/identifier_dollar_sign/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "postgresql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +}