Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local queries #641

Merged
merged 11 commits into from
Feb 2, 2023
2 changes: 2 additions & 0 deletions data/lc_example1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- pgweb: host="localhost"
select 'foo'
5 changes: 5 additions & 0 deletions data/lc_example2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- pgweb: host="localhost"
-- some comment
-- pgweb: user="foo"

select 'foo'
2 changes: 2 additions & 0 deletions data/lc_invalid_meta.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- pgweb: host="localhost" mode="foo"
select 'foo'
1 change: 1 addition & 0 deletions data/lc_no_meta.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 'foo'
80 changes: 80 additions & 0 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries"
"github.com/sosedoff/pgweb/pkg/shared"
"github.com/sosedoff/pgweb/static"
)
Expand All @@ -27,6 +28,9 @@ var (

// DbSessions represents the mapping for client connections
DbSessions *SessionManager

// QueryStore reads the SQL queries stores in the home directory
QueryStore *queries.Store
)

// DB returns a database connection from the client context
Expand Down Expand Up @@ -555,6 +559,7 @@ func GetInfo(c *gin.Context) {
"features": gin.H{
"session_lock": command.Opts.LockSession,
"query_timeout": command.Opts.QueryTimeout,
"local_queries": QueryStore != nil,
},
})
}
Expand Down Expand Up @@ -606,3 +611,78 @@ func GetFunction(c *gin.Context) {
res, err := DB(c).Function(c.Param("id"))
serveResult(c, res, err)
}

func GetLocalQueries(c *gin.Context) {
connCtx, err := DB(c).GetConnContext()
if err != nil {
badRequest(c, err)
return
}

storeQueries, err := QueryStore.ReadAll()
if err != nil {
badRequest(c, err)
return
}

queries := []localQuery{}
for _, q := range storeQueries {
if !q.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
continue
}

queries = append(queries, localQuery{
ID: q.ID,
Title: q.Meta.Title,
Description: q.Meta.Description,
Query: cleanQuery(q.Data),
})
}

successResponse(c, queries)
}

func RunLocalQuery(c *gin.Context) {
query, err := QueryStore.Read(c.Param("id"))
if err != nil {
if err == queries.ErrQueryFileNotExist {
query = nil
} else {
badRequest(c, err)
return
}
}
if query == nil {
errorResponse(c, 404, "query not found")
return
}

connCtx, err := DB(c).GetConnContext()
if err != nil {
badRequest(c, err)
return
}

if !query.IsPermitted(connCtx.Host, connCtx.User, connCtx.Database, connCtx.Mode) {
errorResponse(c, 404, "query not found")
return
}

if c.Request.Method == http.MethodGet {
successResponse(c, localQuery{
ID: query.ID,
Title: query.Meta.Title,
Description: query.Meta.Description,
Query: query.Data,
})
return
}

statement := cleanQuery(query.Data)
if statement == "" {
badRequest(c, errQueryRequired)
return
}

HandleQuery(statement, c)
}
11 changes: 11 additions & 0 deletions pkg/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", command.Opts.CorsOrigin)
}
}

func requireLocalQueries() gin.HandlerFunc {
return func(c *gin.Context) {
if QueryStore == nil {
badRequest(c, "local queries are disabled")
return
}

c.Next()
}
}
3 changes: 3 additions & 0 deletions pkg/api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ func SetupRoutes(router *gin.Engine) {
api.GET("/history", GetHistory)
api.GET("/bookmarks", GetBookmarks)
api.GET("/export", DataExport)
api.GET("/local_queries", requireLocalQueries(), GetLocalQueries)
api.GET("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
api.POST("/local_queries/:id", requireLocalQueries(), RunLocalQuery)
}

func SetupMetrics(engine *gin.Engine) {
Expand Down
8 changes: 8 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package api

type localQuery struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Query string `json:"query"`
}
36 changes: 31 additions & 5 deletions pkg/cli/cli.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"errors"
"fmt"
"os"
"os/exec"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/metrics"
"github.com/sosedoff/pgweb/pkg/queries"
"github.com/sosedoff/pgweb/pkg/util"
)

Expand All @@ -28,11 +30,11 @@ var (
options command.Options

readonlyWarning = `
------------------------------------------------------
SECURITY WARNING: You are running pgweb in read-only mode.
This mode is designed for environments where users could potentially delete / change data.
For proper read-only access please follow postgresql role management documentation.
------------------------------------------------------`
--------------------------------------------------------------------------------
SECURITY WARNING: You are running Pgweb in read-only mode.
This mode is designed for environments where users could potentially delete or change data.
For proper read-only access please follow PostgreSQL role management documentation.
--------------------------------------------------------------------------------`

regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
Expand Down Expand Up @@ -157,9 +159,33 @@ func initOptions() {
}
}

configureLocalQueryStore()
printVersion()
}

func configureLocalQueryStore() {
if options.Sessions || options.QueriesDir == "" {
return
}

stat, err := os.Stat(options.QueriesDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
logger.Debugf("local queries directory %q does not exist, disabling feature", options.QueriesDir)
} else {
logger.Debugf("local queries feature disabled due to error: %v", err)
}
return
}

if !stat.IsDir() {
logger.Debugf("local queries path %q is not a directory", options.QueriesDir)
return
}

api.QueryStore = queries.NewStore(options.QueriesDir)
}

func configureLogger(opts command.Options) error {
if options.Debug {
logger.SetLevel(logrus.DebugLevel)
Expand Down
41 changes: 41 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,44 @@ func (client *Client) hasHistoryRecord(query string) bool {

return result
}

type ConnContext struct {
Host string
User string
Database string
Mode string
}

func (c ConnContext) String() string {
return fmt.Sprintf(
"host=%q user=%q database=%q mode=%q",
c.Host, c.User, c.Database, c.Mode,
)
}

// ConnContext returns information about current database connection
func (client *Client) GetConnContext() (*ConnContext, error) {
url, err := neturl.Parse(client.ConnectionString)
if err != nil {
return nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

connCtx := ConnContext{
Host: url.Hostname(),
Mode: "default",
}

if command.Opts.ReadOnly {
connCtx.Mode = "readonly"
}

row := client.db.QueryRowContext(ctx, "SELECT current_user, current_database()")
if err := row.Scan(&connCtx.User, &connCtx.Database); err != nil {
return nil, err
}

return &connCtx, nil
}
10 changes: 10 additions & 0 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,15 @@ func testTablesStats(t *testing.T) {
assert.Equal(t, columns, result.Columns)
}

func testConnContext(t *testing.T) {
result, err := testClient.GetConnContext()
assert.NoError(t, err)
assert.Equal(t, "localhost", result.Host)
assert.Equal(t, "postgres", result.User)
assert.Equal(t, "booktown", result.Database)
assert.Equal(t, "default", result.Mode)
}

func TestAll(t *testing.T) {
if onWindows() {
t.Log("Unit testing on Windows platform is not supported.")
Expand Down Expand Up @@ -698,6 +707,7 @@ func TestAll(t *testing.T) {
testReadOnlyMode(t)
testDumpExport(t)
testTablesStats(t)
testConnContext(t)

teardownClient()
teardown(t, true)
Expand Down
18 changes: 14 additions & 4 deletions pkg/command/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Options struct {
LockSession bool `long:"lock-session" description:"Lock session to a single database connection"`
Bookmark string `short:"b" long:"bookmark" description:"Bookmark to use for connection. Bookmark files are stored under $HOME/.pgweb/bookmarks/*.toml" default:""`
BookmarksDir string `long:"bookmarks-dir" description:"Overrides default directory for bookmark files to search" default:""`
QueriesDir string `long:"queries-dir" description:"Overrides default directory for local queries"`
DisablePrettyJSON bool `long:"no-pretty-json" description:"Disable JSON formatting feature for result export"`
DisableSSH bool `long:"no-ssh" description:"Disable database connections via SSH"`
ConnectBackend string `long:"connect-backend" description:"Enable database authentication through a third party backend"`
Expand Down Expand Up @@ -159,10 +160,19 @@ func ParseOptions(args []string) (Options, error) {
}
}

if opts.BookmarksDir == "" {
path, err := homedir.Dir()
if err == nil {
opts.BookmarksDir = filepath.Join(path, ".pgweb/bookmarks")
homePath, err := homedir.Dir()
if err != nil {
fmt.Fprintf(os.Stderr, "[WARN] cant detect home dir: %v", err)
homePath = os.Getenv("HOME")
}

if homePath != "" {
if opts.BookmarksDir == "" {
opts.BookmarksDir = filepath.Join(homePath, ".pgweb/bookmarks")
}

if opts.QueriesDir == "" {
opts.QueriesDir = filepath.Join(homePath, ".pgweb/queries")
}
}

Expand Down
43 changes: 43 additions & 0 deletions pkg/queries/field.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package queries

import (
"fmt"
"regexp"
"strings"
)

type field struct {
value string
re *regexp.Regexp
}

func (f field) String() string {
return f.value
}

func (f field) matches(input string) bool {
if f.re != nil {
return f.re.MatchString(input)
}
return f.value == input
}

func newField(value string) (field, error) {
f := field{value: value}

if value == "*" { // match everything
f.re = reMatchAll
} else if reExpression.MatchString(value) { // match by given expression
// Make writing expressions easier for values like "foo_*"
if strings.Count(value, "*") == 1 {
value = strings.Replace(value, "*", "(.+)", 1)
}
re, err := regexp.Compile(fmt.Sprintf("^%s$", value))
if err != nil {
return f, err
}
f.re = re
}

return f, nil
}
Loading