From 41bf189e6b00f43b8a8d201ac7cef1f1368af47e Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Thu, 2 Feb 2023 16:13:14 -0600 Subject: [PATCH] Local queries (#641) * Read local queries from pgweb home directory * Refactor local query functionality * Allow picking local query in the query tab * WIP * Disable local query dropdown during execution * Only allow local queries running in a single session mode * Add middleware to enforce local query endpoint availability * Fix query check * Add query store tests * Make query store errors portable * Skip building specific tests on windows --- data/lc_example1.sql | 2 + data/lc_example2.sql | 5 ++ data/lc_invalid_meta.sql | 2 + data/lc_no_meta.sql | 1 + pkg/api/api.go | 80 +++++++++++++++++++ pkg/api/middleware.go | 11 +++ pkg/api/routes.go | 3 + pkg/api/types.go | 8 ++ pkg/cli/cli.go | 36 +++++++-- pkg/client/client.go | 41 ++++++++++ pkg/client/client_test.go | 10 +++ pkg/command/options.go | 18 ++++- pkg/queries/field.go | 43 ++++++++++ pkg/queries/field_test.go | 41 ++++++++++ pkg/queries/metadata.go | 148 +++++++++++++++++++++++++++++++++++ pkg/queries/metadata_test.go | 146 ++++++++++++++++++++++++++++++++++ pkg/queries/query.go | 23 ++++++ pkg/queries/query_test.go | 77 ++++++++++++++++++ pkg/queries/store.go | 88 +++++++++++++++++++++ pkg/queries/store_test.go | 71 +++++++++++++++++ static/css/app.css | 1 + static/index.html | 9 ++- static/js/app.js | 32 +++++++- 23 files changed, 884 insertions(+), 12 deletions(-) create mode 100644 data/lc_example1.sql create mode 100644 data/lc_example2.sql create mode 100644 data/lc_invalid_meta.sql create mode 100644 data/lc_no_meta.sql create mode 100644 pkg/api/types.go create mode 100644 pkg/queries/field.go create mode 100644 pkg/queries/field_test.go create mode 100644 pkg/queries/metadata.go create mode 100644 pkg/queries/metadata_test.go create mode 100644 pkg/queries/query.go create mode 100644 pkg/queries/query_test.go create mode 100644 pkg/queries/store.go create mode 100644 pkg/queries/store_test.go diff --git a/data/lc_example1.sql b/data/lc_example1.sql new file mode 100644 index 000000000..64f981f38 --- /dev/null +++ b/data/lc_example1.sql @@ -0,0 +1,2 @@ +-- pgweb: host="localhost" +select 'foo' diff --git a/data/lc_example2.sql b/data/lc_example2.sql new file mode 100644 index 000000000..92a16d99d --- /dev/null +++ b/data/lc_example2.sql @@ -0,0 +1,5 @@ +-- pgweb: host="localhost" +-- some comment +-- pgweb: user="foo" + +select 'foo' diff --git a/data/lc_invalid_meta.sql b/data/lc_invalid_meta.sql new file mode 100644 index 000000000..5bf1217f2 --- /dev/null +++ b/data/lc_invalid_meta.sql @@ -0,0 +1,2 @@ +-- pgweb: host="localhost" mode="foo" +select 'foo' diff --git a/data/lc_no_meta.sql b/data/lc_no_meta.sql new file mode 100644 index 000000000..43bd89112 --- /dev/null +++ b/data/lc_no_meta.sql @@ -0,0 +1 @@ +select 'foo' diff --git a/pkg/api/api.go b/pkg/api/api.go index 364a2a0a3..7e5a68495 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -17,6 +17,7 @@ import ( "github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/metrics" + "github.com/sosedoff/pgweb/pkg/queries" "github.com/sosedoff/pgweb/pkg/shared" "github.com/sosedoff/pgweb/static" ) @@ -27,6 +28,9 @@ var ( // DbSessions represents the mapping for client connections DbSessions *SessionManager + + // QueryStore reads the SQL queries stores in the home directory + QueryStore *queries.Store ) // DB returns a database connection from the client context @@ -555,6 +559,7 @@ func GetInfo(c *gin.Context) { "features": gin.H{ "session_lock": command.Opts.LockSession, "query_timeout": command.Opts.QueryTimeout, + "local_queries": QueryStore != nil, }, }) } @@ -606,3 +611,78 @@ func GetFunction(c *gin.Context) { res, err := DB(c).Function(c.Param("id")) serveResult(c, res, err) } + +func GetLocalQueries(c *gin.Context) { + connCtx, err := DB(c).GetConnContext() + if err != nil { + badRequest(c, err) + return + } + + storeQueries, err := QueryStore.ReadAll() + if err != nil { + badRequest(c, err) + return + } + + queries := []localQuery{} + for _, q := range storeQueries { + if !q.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) { + continue + } + + queries = append(queries, localQuery{ + ID: q.ID, + Title: q.Meta.Title, + Description: q.Meta.Description, + Query: cleanQuery(q.Data), + }) + } + + successResponse(c, queries) +} + +func RunLocalQuery(c *gin.Context) { + query, err := QueryStore.Read(c.Param("id")) + if err != nil { + if err == queries.ErrQueryFileNotExist { + query = nil + } else { + badRequest(c, err) + return + } + } + if query == nil { + errorResponse(c, 404, "query not found") + return + } + + connCtx, err := DB(c).GetConnContext() + if err != nil { + badRequest(c, err) + return + } + + if !query.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) { + errorResponse(c, 404, "query not found") + return + } + + if c.Request.Method == http.MethodGet { + successResponse(c, localQuery{ + ID: query.ID, + Title: query.Meta.Title, + Description: query.Meta.Description, + Query: query.Data, + }) + return + } + + statement := cleanQuery(query.Data) + if statement == "" { + badRequest(c, errQueryRequired) + return + } + + HandleQuery(statement, c) +} diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go index 83a33ad7c..14b3c2872 100644 --- a/pkg/api/middleware.go +++ b/pkg/api/middleware.go @@ -56,3 +56,14 @@ func corsMiddleware() gin.HandlerFunc { c.Header("Access-Control-Allow-Origin", command.Opts.CorsOrigin) } } + +func requireLocalQueries() gin.HandlerFunc { + return func(c *gin.Context) { + if QueryStore == nil { + badRequest(c, "local queries are disabled") + return + } + + c.Next() + } +} diff --git a/pkg/api/routes.go b/pkg/api/routes.go index f8ef344df..5b22695d6 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -54,6 +54,9 @@ func SetupRoutes(router *gin.Engine) { api.GET("/history", GetHistory) api.GET("/bookmarks", GetBookmarks) api.GET("/export", DataExport) + api.GET("/local_queries", requireLocalQueries(), GetLocalQueries) + api.GET("/local_queries/:id", requireLocalQueries(), RunLocalQuery) + api.POST("/local_queries/:id", requireLocalQueries(), RunLocalQuery) } func SetupMetrics(engine *gin.Engine) { diff --git a/pkg/api/types.go b/pkg/api/types.go new file mode 100644 index 000000000..143ec6e61 --- /dev/null +++ b/pkg/api/types.go @@ -0,0 +1,8 @@ +package api + +type localQuery struct { + ID string `json:"id"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Query string `json:"query"` +} diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 9bdbda3e6..9c61820ab 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "fmt" "os" "os/exec" @@ -20,6 +21,7 @@ import ( "github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/metrics" + "github.com/sosedoff/pgweb/pkg/queries" "github.com/sosedoff/pgweb/pkg/util" ) @@ -28,11 +30,11 @@ var ( options command.Options readonlyWarning = ` ------------------------------------------------------- -SECURITY WARNING: You are running pgweb in read-only mode. -This mode is designed for environments where users could potentially delete / change data. -For proper read-only access please follow postgresql role management documentation. -------------------------------------------------------` +-------------------------------------------------------------------------------- +SECURITY WARNING: You are running Pgweb in read-only mode. +This mode is designed for environments where users could potentially delete or change data. +For proper read-only access please follow PostgreSQL role management documentation. +--------------------------------------------------------------------------------` regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`) regexErrAuthFailed = regexp.MustCompile(`authentication failed`) @@ -157,9 +159,33 @@ func initOptions() { } } + configureLocalQueryStore() printVersion() } +func configureLocalQueryStore() { + if options.Sessions || options.QueriesDir == "" { + return + } + + stat, err := os.Stat(options.QueriesDir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + logger.Debugf("local queries directory %q does not exist, disabling feature", options.QueriesDir) + } else { + logger.Debugf("local queries feature disabled due to error: %v", err) + } + return + } + + if !stat.IsDir() { + logger.Debugf("local queries path %q is not a directory", options.QueriesDir) + return + } + + api.QueryStore = queries.NewStore(options.QueriesDir) +} + func configureLogger(opts command.Options) error { if options.Debug { logger.SetLevel(logrus.DebugLevel) diff --git a/pkg/client/client.go b/pkg/client/client.go index 10e2256b2..e7157404c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -585,3 +585,44 @@ func (client *Client) hasHistoryRecord(query string) bool { return result } + +type ConnContext struct { + Host string + User string + Database string + Mode string +} + +func (c ConnContext) String() string { + return fmt.Sprintf( + "host=%q user=%q database=%q mode=%q", + c.Host, c.User, c.Database, c.Mode, + ) +} + +// ConnContext returns information about current database connection +func (client *Client) GetConnContext() (*ConnContext, error) { + url, err := neturl.Parse(client.ConnectionString) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + connCtx := ConnContext{ + Host: url.Hostname(), + Mode: "default", + } + + if command.Opts.ReadOnly { + connCtx.Mode = "readonly" + } + + row := client.db.QueryRowContext(ctx, "SELECT current_user, current_database()") + if err := row.Scan(&connCtx.User, &connCtx.Database); err != nil { + return nil, err + } + + return &connCtx, nil +} diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 3583be9fd..a139f778e 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -660,6 +660,15 @@ func testTablesStats(t *testing.T) { assert.Equal(t, columns, result.Columns) } +func testConnContext(t *testing.T) { + result, err := testClient.GetConnContext() + assert.NoError(t, err) + assert.Equal(t, "localhost", result.Host) + assert.Equal(t, "postgres", result.User) + assert.Equal(t, "booktown", result.Database) + assert.Equal(t, "default", result.Mode) +} + func TestAll(t *testing.T) { if onWindows() { t.Log("Unit testing on Windows platform is not supported.") @@ -698,6 +707,7 @@ func TestAll(t *testing.T) { testReadOnlyMode(t) testDumpExport(t) testTablesStats(t) + testConnContext(t) teardownClient() teardown(t, true) diff --git a/pkg/command/options.go b/pkg/command/options.go index f0a578681..5d87cbebc 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -48,6 +48,7 @@ type Options struct { LockSession bool `long:"lock-session" description:"Lock session to a single database connection"` Bookmark string `short:"b" long:"bookmark" description:"Bookmark to use for connection. Bookmark files are stored under $HOME/.pgweb/bookmarks/*.toml" default:""` BookmarksDir string `long:"bookmarks-dir" description:"Overrides default directory for bookmark files to search" default:""` + QueriesDir string `long:"queries-dir" description:"Overrides default directory for local queries"` DisablePrettyJSON bool `long:"no-pretty-json" description:"Disable JSON formatting feature for result export"` DisableSSH bool `long:"no-ssh" description:"Disable database connections via SSH"` ConnectBackend string `long:"connect-backend" description:"Enable database authentication through a third party backend"` @@ -159,10 +160,19 @@ func ParseOptions(args []string) (Options, error) { } } - if opts.BookmarksDir == "" { - path, err := homedir.Dir() - if err == nil { - opts.BookmarksDir = filepath.Join(path, ".pgweb/bookmarks") + homePath, err := homedir.Dir() + if err != nil { + fmt.Fprintf(os.Stderr, "[WARN] cant detect home dir: %v", err) + homePath = os.Getenv("HOME") + } + + if homePath != "" { + if opts.BookmarksDir == "" { + opts.BookmarksDir = filepath.Join(homePath, ".pgweb/bookmarks") + } + + if opts.QueriesDir == "" { + opts.QueriesDir = filepath.Join(homePath, ".pgweb/queries") } } diff --git a/pkg/queries/field.go b/pkg/queries/field.go new file mode 100644 index 000000000..abd3e5a51 --- /dev/null +++ b/pkg/queries/field.go @@ -0,0 +1,43 @@ +package queries + +import ( + "fmt" + "regexp" + "strings" +) + +type field struct { + value string + re *regexp.Regexp +} + +func (f field) String() string { + return f.value +} + +func (f field) matches(input string) bool { + if f.re != nil { + return f.re.MatchString(input) + } + return f.value == input +} + +func newField(value string) (field, error) { + f := field{value: value} + + if value == "*" { // match everything + f.re = reMatchAll + } else if reExpression.MatchString(value) { // match by given expression + // Make writing expressions easier for values like "foo_*" + if strings.Count(value, "*") == 1 { + value = strings.Replace(value, "*", "(.+)", 1) + } + re, err := regexp.Compile(fmt.Sprintf("^%s$", value)) + if err != nil { + return f, err + } + f.re = re + } + + return f, nil +} diff --git a/pkg/queries/field_test.go b/pkg/queries/field_test.go new file mode 100644 index 000000000..544c999be --- /dev/null +++ b/pkg/queries/field_test.go @@ -0,0 +1,41 @@ +package queries + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_field(t *testing.T) { + field, err := newField("val") + assert.NoError(t, err) + assert.Equal(t, "val", field.value) + assert.Equal(t, true, field.matches("val")) + assert.Equal(t, false, field.matches("value")) + + field, err = newField("*") + assert.NoError(t, err) + assert.Equal(t, "*", field.value) + assert.NotNil(t, field.re) + assert.Equal(t, true, field.matches("val")) + assert.Equal(t, true, field.matches("value")) + + field, err = newField("(.+") + assert.EqualError(t, err, "error parsing regexp: missing closing ): `^(.+$`") + assert.NotNil(t, field) + + field, err = newField("foo_*") + assert.NoError(t, err) + assert.Equal(t, "foo_*", field.value) + assert.NotNil(t, field.re) + assert.Equal(t, false, field.matches("foo")) + assert.Equal(t, true, field.matches("foo_bar")) + assert.Equal(t, true, field.matches("foo_bar_widget")) + +} + +func Test_fieldString(t *testing.T) { + field, err := newField("val") + assert.NoError(t, err) + assert.Equal(t, "val", field.String()) +} diff --git a/pkg/queries/metadata.go b/pkg/queries/metadata.go new file mode 100644 index 000000000..811bb9554 --- /dev/null +++ b/pkg/queries/metadata.go @@ -0,0 +1,148 @@ +package queries + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" +) + +var ( + reMetaPrefix = regexp.MustCompile(`(?m)^\s*--\s*pgweb:\s*(.+)`) + reMetaContent = regexp.MustCompile(`([\w]+)\s*=\s*"([^"]+)"`) + reMatchAll = regexp.MustCompile(`^(.+)$`) + reExpression = regexp.MustCompile(`[\[\]\(\)\+\*]+`) + + allowedKeys = []string{"title", "description", "host", "user", "database", "mode", "timeout"} + allowedModes = map[string]bool{"readonly": true, "*": true} +) + +type Metadata struct { + Title string + Description string + Host field + User field + Database field + Mode field + Timeout *time.Duration +} + +func parseMetadata(input string) (*Metadata, error) { + fields, err := parseFields(input) + if err != nil { + return nil, err + } + if fields == nil { + return nil, nil + } + + // Host must be set to limit queries availability + if fields["host"] == "" { + return nil, fmt.Errorf("host field must be set") + } + + // Allow matching for any user, database and mode by default + if fields["user"] == "" { + fields["user"] = "*" + } + if fields["database"] == "" { + fields["database"] = "*" + } + if fields["mode"] == "" { + fields["mode"] = "*" + } + + hostField, err := newField(fields["host"]) + if err != nil { + return nil, fmt.Errorf(`error initializing "host" field: %w`, err) + } + + userField, err := newField(fields["user"]) + if err != nil { + return nil, fmt.Errorf(`error initializing "user" field: %w`, err) + } + + dbField, err := newField(fields["database"]) + if err != nil { + return nil, fmt.Errorf(`error initializing "database" field: %w`, err) + } + + if !allowedModes[fields["mode"]] { + return nil, fmt.Errorf(`invalid "mode" field value: %q`, fields["mode"]) + } + modeField, err := newField(fields["mode"]) + if err != nil { + return nil, fmt.Errorf(`error initializing "mode" field: %w`, err) + } + + var timeout *time.Duration + if fields["timeout"] != "" { + timeoutSec, err := strconv.Atoi(fields["timeout"]) + if err != nil { + return nil, fmt.Errorf(`error initializing "timeout" field: %w`, err) + } + timeoutVal := time.Duration(timeoutSec) * time.Second + timeout = &timeoutVal + } + + return &Metadata{ + Title: fields["title"], + Description: fields["description"], + Host: hostField, + User: userField, + Database: dbField, + Mode: modeField, + Timeout: timeout, + }, nil +} + +func parseFields(input string) (map[string]string, error) { + result := map[string]string{} + seenKeys := map[string]bool{} + + allowed := map[string]bool{} + for _, key := range allowedKeys { + allowed[key] = true + } + + matches := reMetaPrefix.FindAllStringSubmatch(input, -1) + if len(matches) == 0 { + return nil, nil + } + + for _, match := range matches { + content := reMetaContent.FindAllStringSubmatch(match[1], -1) + if len(content) == 0 { + continue + } + + for _, field := range content { + key := field[1] + value := field[2] + + if !allowed[key] { + return result, fmt.Errorf("unknown key: %q", key) + } + if seenKeys[key] { + return result, fmt.Errorf("duplicate key: %q", key) + } + + seenKeys[key] = true + result[key] = value + } + } + + return result, nil +} + +func sanitizeMetadata(input string) string { + lines := []string{} + for _, line := range strings.Split(input, "\n") { + line = reMetaPrefix.ReplaceAllString(line, "") + if len(line) > 0 { + lines = append(lines, line) + } + } + return strings.Join(lines, "\n") +} diff --git a/pkg/queries/metadata_test.go b/pkg/queries/metadata_test.go new file mode 100644 index 000000000..0b55fec55 --- /dev/null +++ b/pkg/queries/metadata_test.go @@ -0,0 +1,146 @@ +package queries + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseFields(t *testing.T) { + examples := []struct { + input string + err error + vals map[string]string + }{ + {input: "", err: nil, vals: nil}, + {input: "foobar", err: nil, vals: nil}, + {input: "-- no pgweb meta", err: nil, vals: nil}, + { + input: `--pgweb: foo=bar`, + err: nil, + vals: map[string]string{}, + }, + { + input: `--pgweb: host="localhost"`, + err: nil, + vals: map[string]string{"host": "localhost"}, + }, + { + input: `--pgweb: host="*" user="admin" database ="mydb"; mode = "readonly"`, + err: nil, + vals: map[string]string{ + "host": "*", + "database": "mydb", + "user": "admin", + "mode": "readonly", + }, + }, + } + + for _, ex := range examples { + t.Run(ex.input, func(t *testing.T) { + fields, err := parseFields(ex.input) + assert.Equal(t, ex.err, err) + assert.Equal(t, ex.vals, fields) + }) + } +} + +func Test_parseMetadata(t *testing.T) { + examples := []struct { + input string + err string + check func(meta *Metadata) bool + }{ + { + input: `--pgweb: `, + err: `host field must be set`, + }, + { + input: `--pgweb: hello="world"`, + err: `unknown key: "hello"`, + }, + { + input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="foo"`, + err: `invalid "mode" field value: "foo"`, + }, + { + input: "--pgweb2:", + check: func(m *Metadata) bool { + return m == nil + }, + }, + { + input: `--pgweb: host="localhost"`, + check: func(m *Metadata) bool { + return m.Host.value == "localhost" && + m.User.value == "*" && + m.Database.value == "*" && + m.Mode.value == "*" && + m.Timeout == nil + }, + }, + { + input: `--pgweb: host="localhost" user="anyuser" database="anydb" mode="*"`, + check: func(m *Metadata) bool { + return m.Host.value == "localhost" && + m.Host.re == nil && + m.User.value == "anyuser" && + m.Database.value == "anydb" && + m.Mode.value == "*" && + m.Timeout == nil + }, + }, + { + input: `--pgweb: host="localhost" timeout="foo"`, + err: `error initializing "timeout" field: strconv.Atoi: parsing "foo": invalid syntax`, + }, + { + input: `-- pgweb: host="local(host|dev)"`, + check: func(m *Metadata) bool { + return m.Host.value == "local(host|dev)" && m.Host.re != nil && + m.Host.matches("localhost") && m.Host.matches("localdev") && + !m.Host.matches("localfoo") && !m.Host.matches("superlocaldev") + }, + }, + } + + for _, ex := range examples { + t.Run(ex.input, func(t *testing.T) { + meta, err := parseMetadata(ex.input) + if ex.err != "" { + assert.Contains(t, err.Error(), ex.err) + } + if ex.check != nil { + assert.Equal(t, true, ex.check(meta)) + } + }) + } +} + +func Test_sanitizeMetadata(t *testing.T) { + examples := []struct { + input string + output string + }{ + {input: "", output: ""}, + {input: "foo", output: "foo"}, + { + input: ` +-- pgweb: metadata +query1 +-- pgweb: more metadata + +query2 + +`, + output: "query1\nquery2", + }, + } + + for _, ex := range examples { + t.Run(ex.input, func(t *testing.T) { + assert.Equal(t, ex.output, sanitizeMetadata(ex.input)) + }) + } +} diff --git a/pkg/queries/query.go b/pkg/queries/query.go new file mode 100644 index 000000000..fe9295658 --- /dev/null +++ b/pkg/queries/query.go @@ -0,0 +1,23 @@ +package queries + +type Query struct { + ID string + Path string + Meta *Metadata + Data string +} + +// IsPermitted returns true if a query is allowed to execute for a given db context +func (q Query) IsPermitted(host, user, database, mode string) bool { + // All fields must be provided for matching + if q.Meta == nil || host == "" || user == "" || database == "" || mode == "" { + return false + } + + meta := q.Meta + + return meta.Host.matches(host) && + meta.User.matches(user) && + meta.Database.matches(database) && + meta.Mode.matches(mode) +} diff --git a/pkg/queries/query_test.go b/pkg/queries/query_test.go new file mode 100644 index 000000000..245e7a978 --- /dev/null +++ b/pkg/queries/query_test.go @@ -0,0 +1,77 @@ +package queries + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQueryIsPermitted(t *testing.T) { + examples := []struct { + name string + query Query + args []string + expected bool + }{ + { + name: "no input provided", + query: makeQuery("localhost", "someuser", "somedb", "default"), + args: makeArgs("", "", "", ""), + expected: false, + }, + { + name: "match on host", + query: makeQuery("localhost", "*", "*", "*"), + args: makeArgs("localhost", "user", "db", "default"), + expected: true, + }, + { + name: "match on full set", + query: makeQuery("localhost", "user", "database", "mode"), + args: makeArgs("localhost", "someuser", "somedb", "default"), + expected: false, + }, + { + name: "match on partial database", + query: makeQuery("localhost", "*", "myapp_*", "*"), + args: makeArgs("localhost", "user", "myapp_development", "default"), + expected: true, + }, + { + name: "match on full set but not mode", + query: makeQuery("localhost", "*", "*", "readonly"), + args: makeArgs("localhost", "user", "db", "default"), + expected: false, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + result := ex.query.IsPermitted(ex.args[0], ex.args[1], ex.args[2], ex.args[3]) + assert.Equal(t, ex.expected, result) + }) + } +} + +func makeArgs(vals ...string) []string { + return vals +} + +func makeQuery(host, user, database, mode string) Query { + mustfield := func(input string) field { + f, err := newField(input) + if err != nil { + panic(err) + } + return f + } + + return Query{ + Meta: &Metadata{ + Host: mustfield(host), + User: mustfield(user), + Database: mustfield(database), + Mode: mustfield(mode), + }, + } +} diff --git a/pkg/queries/store.go b/pkg/queries/store.go new file mode 100644 index 000000000..bfcfd215a --- /dev/null +++ b/pkg/queries/store.go @@ -0,0 +1,88 @@ +package queries + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +var ( + ErrQueryDirNotExist = errors.New("queries directory does not exist") + ErrQueryFileNotExist = errors.New("query file does not exist") +) + +type Store struct { + dir string +} + +func NewStore(dir string) *Store { + return &Store{ + dir: dir, + } +} + +func (s Store) Read(id string) (*Query, error) { + path := filepath.Join(s.dir, fmt.Sprintf("%s.sql", id)) + return readQuery(path) +} + +func (s Store) ReadAll() ([]Query, error) { + entries, err := os.ReadDir(s.dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + err = ErrQueryDirNotExist + } + return nil, err + } + + queries := []Query{} + + for _, entry := range entries { + name := entry.Name() + if filepath.Ext(name) != ".sql" { + continue + } + + path := filepath.Join(s.dir, name) + query, err := readQuery(path) + if err != nil { + fmt.Fprintf(os.Stderr, "[WARN] skipping %q query file due to error: %v\n", name, err) + continue + } + if query == nil { + continue + } + + queries = append(queries, *query) + } + + return queries, nil +} + +func readQuery(path string) (*Query, error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, ErrQueryFileNotExist + } + return nil, err + } + dataStr := string(data) + + meta, err := parseMetadata(dataStr) + if err != nil { + return nil, err + } + if meta == nil { + return nil, nil + } + + return &Query{ + ID: strings.Replace(filepath.Base(path), ".sql", "", 1), + Path: path, + Meta: meta, + Data: sanitizeMetadata(dataStr), + }, nil +} diff --git a/pkg/queries/store_test.go b/pkg/queries/store_test.go new file mode 100644 index 000000000..635a3ee5f --- /dev/null +++ b/pkg/queries/store_test.go @@ -0,0 +1,71 @@ +//go:build !windows + +package queries + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStoreReadAll(t *testing.T) { + t.Run("valid dir", func(t *testing.T) { + queries, err := NewStore("../../data").ReadAll() + assert.NoError(t, err) + assert.Equal(t, 2, len(queries)) + }) + + t.Run("invalid dir", func(t *testing.T) { + queries, err := NewStore("../../data2").ReadAll() + assert.Equal(t, err.Error(), "queries directory does not exist") + assert.Equal(t, 0, len(queries)) + }) +} + +func TestStoreRead(t *testing.T) { + examples := []struct { + id string + err string + check func(*testing.T, *Query) + }{ + {id: "foo", err: "query file does not exist"}, + {id: "lc_no_meta"}, + {id: "lc_invalid_meta", err: `invalid "mode" field value: "foo"`}, + { + id: "lc_example1", + check: func(t *testing.T, q *Query) { + assert.Equal(t, "lc_example1", q.ID) + assert.Equal(t, "../../data/lc_example1.sql", q.Path) + assert.Equal(t, "select 'foo'", q.Data) + assert.Equal(t, "localhost", q.Meta.Host.String()) + assert.Equal(t, "*", q.Meta.User.String()) + assert.Equal(t, "*", q.Meta.Database.String()) + }, + }, + { + id: "lc_example2", + check: func(t *testing.T, q *Query) { + assert.Equal(t, "lc_example2", q.ID) + assert.Equal(t, "../../data/lc_example2.sql", q.Path) + assert.Equal(t, "-- some comment\nselect 'foo'", q.Data) + assert.Equal(t, "localhost", q.Meta.Host.String()) + assert.Equal(t, "foo", q.Meta.User.String()) + assert.Equal(t, "*", q.Meta.Database.String()) + }, + }, + } + + store := NewStore("../../data") + + for _, ex := range examples { + t.Run(ex.id, func(t *testing.T) { + query, err := store.Read(ex.id) + if ex.err != "" || err != nil { + assert.Equal(t, ex.err, err.Error()) + } + if ex.check != nil { + ex.check(t, query) + } + }) + } +} diff --git a/static/css/app.css b/static/css/app.css index d4b565ccb..b3aaebdfe 100644 --- a/static/css/app.css +++ b/static/css/app.css @@ -357,6 +357,7 @@ #input .actions #query_progress { display: none; float: left; + font-size: 12px; line-height: 30px; height: 30px; color: #aaa; diff --git a/static/index.html b/static/index.html index 406ed7bae..13fca7fbb 100644 --- a/static/index.html +++ b/static/index.html @@ -17,7 +17,7 @@ - + @@ -87,6 +87,13 @@
  • Analyze Query
  • +
    Please wait, query is executing...
    diff --git a/static/js/app.js b/static/js/app.js index fa6782919..37bb68e31 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -178,6 +178,33 @@ function buildSchemaSection(name, objects) { return section; } +function loadLocalQueries() { + if (!appFeatures.local_queries) return; + + $("body").on("click", "a.load-local-query", function(e) { + var id = $(this).data("id"); + + apiCall("get", "/local_queries/" + id, {}, function(resp) { + editor.setValue(resp.query); + editor.clearSelection(); + }); + }); + + apiCall("get", "/local_queries", {}, function(resp) { + if (resp.error) return; + + var container = $("#load-query-dropdown").find(".dropdown-menu"); + + resp.forEach(function(item) { + var title = item.title || item.id; + $("
  • " + title + "
  • ").appendTo(container); + }); + + if (resp.length > 0) $("#load-local-query").prop("disabled", ""); + $("#load-query-dropdown").show(); + }); +} + function loadSchemas() { $("#objects").html(""); @@ -738,13 +765,13 @@ function showActivityPanel() { } function showQueryProgressMessage() { - $("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", true); + $("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", true); $("#explain-dropdown").removeClass("open"); $("#query_progress").show(); } function hideQueryProgressMessage() { - $("#run, #explain-dropdown-toggle, #csv, #json, #xml").prop("disabled", false); + $("#run, #explain-dropdown-toggle, #csv, #json, #xml, #load-local-query").prop("disabled", false); $("#query_progress").hide(); } @@ -1810,6 +1837,7 @@ $(document).ready(function() { connected = true; loadSchemas(); + loadLocalQueries(); $("#current_database").text(resp.current_database); $("#main").show();