diff --git a/README.md b/README.md index eb56c8d3..d40663c3 100644 --- a/README.md +++ b/README.md @@ -197,15 +197,17 @@ dynamo.ListTablesWithContext(ctx, &dynamodb.ListTablesInput{}) **SQL** -Any `db/sql` calls can be traced with X-Ray by replacing the `sql.Open` call with `xray.SQL`. It is recommended to use URLs instead of configuration strings if possible. +Any `database/sql` calls can be traced with X-Ray by replacing the `sql.Open` call with `xray.SQLContext`. It is recommended to use URLs instead of configuration strings if possible. ```go func main() { - db := xray.SQL("postgres", "postgres://user:password@host:port/db") - row, _ := db.QueryRow("SELECT 1") // Use as normal + db, err := xray.SQLContext("postgres", "postgres://user:password@host:port/db") + row, err := db.QueryRowContext(ctx, "SELECT 1") // Use as normal } ``` +Note that the `xray.SQL` are deprecated and will be remove when the SDK becomes GA. + **Lambda** ``` diff --git a/xray/sql.go b/xray/sql.go index 14125cac..adcf67f4 100644 --- a/xray/sql.go +++ b/xray/sql.go @@ -9,300 +9,128 @@ package xray import ( - "bytes" "context" "database/sql" - "database/sql/driver" - "fmt" - "net/url" - "reflect" - "strings" - "time" ) // SQL opens a normalized and traced wrapper around an *sql.DB connection. // It uses `sql.Open` internally and shares the same function signature. // To ensure passwords are filtered, it is HIGHLY RECOMMENDED that your DSN // follows the format: `://:@:/` +// +// Deprecated: SQL exists for historical compatibility. +// Use SQLContext insted of SQL, it can be used +// as a drop-in replacement of the database/sql package. func SQL(driver, dsn string) (*DB, error) { - rawDB, err := sql.Open(driver, dsn) + db, err := SQLContext(driver, dsn) if err != nil { return nil, err } - - db := &DB{db: rawDB} - - // Detect if DSN is a URL or not, set appropriate attribute - urlDsn := dsn - if !strings.Contains(dsn, "//") { - urlDsn = "//" + urlDsn - } - // Here we're trying to detect things like `host:port/database` as a URL, which is pretty hard - // So we just assume that if it's got a scheme, a user, or a query that it's probably a URL - if u, err := url.Parse(urlDsn); err == nil && (u.Scheme != "" || u.User != nil || u.RawQuery != "" || strings.Contains(u.Path, "@")) { - // Check that this isn't in the form of user/pass@host:port/db, as that will shove the host into the path - if strings.Contains(u.Path, "@") { - u, err = url.Parse(fmt.Sprintf("%s//%s%%2F%s", u.Scheme, u.Host, u.Path[1:])) - if err != nil { - return nil, err - } - } - - // Strip password from user:password pair in address - if u.User != nil { - uname := u.User.Username() - - // Some drivers use "user/pass@host:port" instead of "user:pass@host:port" - // So we must manually attempt to chop off a potential password. - // But we can skip this if we already found the password. - if _, ok := u.User.Password(); !ok { - uname = strings.Split(uname, "/")[0] - } - - u.User = url.User(uname) - } - - // Strip password from query parameters - q := u.Query() - q.Del("password") - u.RawQuery = q.Encode() - - db.url = u.String() - if !strings.Contains(dsn, "//") { - db.url = db.url[2:] - } - } else { - // We don't *think* it's a URL, so now we have to try our best to strip passwords from - // some unknown DSL. We attempt to detect whether it's space-delimited or semicolon-delimited - // then remove any keys with the name "password" or "pwd". This won't catch everything, but - // from surveying the current (Jan 2017) landscape of drivers it should catch most. - db.connectionString = stripPasswords(dsn) - } - - // Detect database type and use that to populate attributes - var detectors []func(*DB) error - switch driver { - case "postgres": - detectors = append(detectors, postgresDetector) - case "mysql": - detectors = append(detectors, mysqlDetector) - default: - detectors = append(detectors, postgresDetector, mysqlDetector, mssqlDetector, oracleDetector) - } - for _, detector := range detectors { - if detector(db) == nil { - break - } - db.databaseType = "Unknown" - db.databaseVersion = "Unknown" - db.user = "Unknown" - db.dbname = "Unknown" - } - - // There's no standard to get SQL driver version information - // So we invent an interface by which drivers can provide us this data - type versionedDriver interface { - Version() string - } - - d := db.db.Driver() - if vd, ok := d.(versionedDriver); ok { - db.driverVersion = vd.Version() - } else { - t := reflect.TypeOf(d) - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - db.driverVersion = t.PkgPath() - } - - return db, nil + return &DB{DB: db}, nil } // DB copies the interface of sql.DB but adds X-Ray tracing. // It must be created with xray.SQL. type DB struct { - db *sql.DB - - connectionString string - url string - databaseType string - databaseVersion string - driverVersion string - user string - dbname string + *sql.DB } -// Close closes a database and returns error if any. -func (db *DB) Close() error { return db.db.Close() } - -// Driver returns database's underlying driver. -func (db *DB) Driver() driver.Driver { return db.db.Driver() } - -// Stats returns database statistics. -func (db *DB) Stats() sql.DBStats { return db.db.Stats() } - -// SetConnMaxLifetime sets the maximum amount of time a connection may be reused. -func (db *DB) SetConnMaxLifetime(d time.Duration) { db.db.SetConnMaxLifetime(d) } - -// SetMaxIdleConns sets the maximum number of connections in the idle connection pool. -func (db *DB) SetMaxIdleConns(n int) { db.db.SetMaxIdleConns(n) } +// Begin starts a transaction. +func (db *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{Tx: tx}, nil +} -// SetMaxOpenConns sets the maximum number of open connections to the database. -func (db *DB) SetMaxOpenConns(n int) { db.db.SetMaxOpenConns(n) } +// Prepare creates a prepared statement for later queries or executions. +func (db *DB) Prepare(ctx context.Context, query string) (*Stmt, error) { + stmt, err := db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: stmt}, nil +} -func (db *DB) populate(ctx context.Context, query string) { - seg := GetSegment(ctx) +// Exec captures executing a query without returning any rows and +// adds corresponding information into subsegment. +func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return db.ExecContext(ctx, query, args...) +} - if seg == nil { - processNilSegment(ctx) - return - } +// Query captures executing a query that returns rows and adds corresponding information into subsegment. +func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return db.QueryContext(ctx, query, args) +} - seg.Lock() - seg.Namespace = "remote" - seg.GetSQL().ConnectionString = db.connectionString - seg.GetSQL().URL = db.url - seg.GetSQL().DatabaseType = db.databaseType - seg.GetSQL().DatabaseVersion = db.databaseVersion - seg.GetSQL().DriverVersion = db.driverVersion - seg.GetSQL().User = db.user - seg.GetSQL().SanitizedQuery = query - seg.Unlock() +// QueryRow captures executing a query that is expected to return at most one row +// and adds corresponding information into subsegment. +func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { + return db.QueryRowContext(ctx, query, args...) } // Tx copies the interface of sql.Tx but adds X-Ray tracing. // It must be created with xray.DB.Begin. type Tx struct { - db *DB - tx *sql.Tx -} - -// Commit commits the transaction. -func (tx *Tx) Commit() error { return tx.tx.Commit() } - -// Rollback aborts the transaction. -func (tx *Tx) Rollback() error { return tx.tx.Rollback() } - -// Stmt copies the interface of sql.Stmt but adds X-Ray tracing. -// It must be created with xray.DB.Prepare or xray.Tx.Stmt. -type Stmt struct { - db *DB - stmt *sql.Stmt - query string + *sql.Tx } -// Close closes the statement. -func (stmt *Stmt) Close() error { return stmt.stmt.Close() } - -func (stmt *Stmt) populate(ctx context.Context, query string) { - stmt.db.populate(ctx, query) - - seg := GetSegment(ctx) - - if seg == nil { - processNilSegment(ctx) - return +// Prepare creates a prepared statement for later queries or executions. +func (tx *Tx) Prepare(ctx context.Context, query string) (*Stmt, error) { + stmt, err := tx.Tx.PrepareContext(ctx, query) + if err != nil { + return nil, err } + return &Stmt{Stmt: stmt}, err +} - seg.Lock() - seg.GetSQL().Preparation = "statement" - seg.Unlock() +// Exec captures executing a query that doesn't return rows and adds +// corresponding information into subsegment. +func (tx *Tx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.ExecContext(ctx, query, args...) } -func postgresDetector(db *DB) error { - db.databaseType = "Postgres" - row := db.db.QueryRow("SELECT version(), current_user, current_database()") - return row.Scan(&db.databaseVersion, &db.user, &db.dbname) +// Query captures executing a query that returns rows and adds corresponding information into subsegment. +func (tx *Tx) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return tx.Tx.QueryContext(ctx, query, args...) } -func mysqlDetector(db *DB) error { - db.databaseType = "MySQL" - row := db.db.QueryRow("SELECT version(), current_user(), database()") - return row.Scan(&db.databaseVersion, &db.user, &db.dbname) +// QueryRow captures executing a query that is expected to return at most one row and adds +// corresponding information into subsegment. +func (tx *Tx) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { + return tx.QueryRowContext(ctx, query, args...) } -func mssqlDetector(db *DB) error { - db.databaseType = "MS SQL" - row := db.db.QueryRow("SELECT @@version, current_user, db_name()") - return row.Scan(&db.databaseVersion, &db.user, &db.dbname) +// Stmt returns a transaction-specific prepared statement from an existing statement. +func (tx *Tx) Stmt(ctx context.Context, stmt *Stmt) *Stmt { + return &Stmt{ + Stmt: tx.StmtContext(ctx, stmt.Stmt), + } } -func oracleDetector(db *DB) error { - db.databaseType = "Oracle" - row := db.db.QueryRow("SELECT version FROM v$instance UNION SELECT user, ora_database_name FROM dual") - return row.Scan(&db.databaseVersion, &db.user, &db.dbname) +// Stmt copies the interface of sql.Stmt but adds X-Ray tracing. +// It must be created with xray.DB.Prepare or xray.Tx.Stmt. +type Stmt struct { + *sql.Stmt } -func stripPasswords(dsn string) string { - var ( - tok bytes.Buffer - res bytes.Buffer - isPassword bool - inBraces bool - delimiter byte = ' ' - ) - flush := func() { - if inBraces { - return - } - if !isPassword { - res.Write(tok.Bytes()) - } - tok.Reset() - isPassword = false - } - if strings.Count(dsn, ";") > strings.Count(dsn, " ") { - delimiter = ';' - } +// Exec captures executing a prepared statement with the given arguments and +// returning a Result summarizing the effect of the statement and adds corresponding +// information into subsegment. +func (stmt *Stmt) Exec(ctx context.Context, args ...interface{}) (sql.Result, error) { + return stmt.ExecContext(ctx, args...) +} - buf := strings.NewReader(dsn) - for c, err := buf.ReadByte(); err == nil; c, err = buf.ReadByte() { - tok.WriteByte(c) - switch c { - case ':', delimiter: - flush() - case '=': - tokStr := strings.ToLower(tok.String()) - isPassword = `password=` == tokStr || `pwd=` == tokStr - if b, err := buf.ReadByte(); err == nil && b == '{' { - inBraces = true - } - buf.UnreadByte() - case '}': - b, err := buf.ReadByte() - if err != nil { - break - } - if b == '}' { - tok.WriteByte(b) - } else { - inBraces = false - buf.UnreadByte() - } - case '@': - if strings.Contains(res.String(), ":") { - resLen := res.Len() - if resLen > 0 && res.Bytes()[resLen-1] == ':' { - res.Truncate(resLen - 1) - } - isPassword = true - flush() - res.WriteByte(c) - } - } - } - inBraces = false - flush() - return res.String() +// Query captures executing a prepared query statement with the given arguments +// and returning the query results as a *Rows and adds corresponding information +// into subsegment. +func (stmt *Stmt) Query(ctx context.Context, args ...interface{}) (*sql.Rows, error) { + return stmt.QueryContext(ctx, args...) } -func processNilSegment(ctx context.Context) { - cfg := GetRecorder(ctx) - failedMessage := "failed to get segment from context since segment is nil" - if cfg != nil && cfg.ContextMissingStrategy != nil { - cfg.ContextMissingStrategy.ContextMissing(failedMessage) - } else { - globalCfg.ContextMissingStrategy().ContextMissing(failedMessage) - } +// QueryRow captures executing a prepared query statement with the given arguments and +// adds corresponding information into subsegment. +func (stmt *Stmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row { + return stmt.QueryRowContext(ctx, args...) } diff --git a/xray/sql_context.go b/xray/sql_context.go new file mode 100644 index 00000000..e7a3f2e7 --- /dev/null +++ b/xray/sql_context.go @@ -0,0 +1,758 @@ +// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +package xray + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "net/url" + "reflect" + "strconv" + "strings" + "sync" + "time" +) + +// we can't know that the original driver will return driver.ErrSkip in advance. +// so we add this message to the query if it returns driver.ErrSkip. +const msgErrSkip = " -- skip fast-path; continue as if unimplemented" + +// namedValueChecker is the same as driver.NamedValueChecker. +// Copied from database/sql/driver/driver.go for supporting Go 1.8. +type namedValueChecker interface { + // CheckNamedValue is called before passing arguments to the driver + // and is called in place of any ColumnConverter. CheckNamedValue must do type + // validation and conversion as appropriate for the driver. + CheckNamedValue(*driver.NamedValue) error +} + +var ( + muInitializedDrivers sync.Mutex + initializedDrivers map[string]struct{} + attrHook func(attr *dbAttribute) // for testing +) + +func initXRayDriver(driver, dsn string) error { + muInitializedDrivers.Lock() + defer muInitializedDrivers.Unlock() + + if initializedDrivers == nil { + initializedDrivers = map[string]struct{}{} + } + if _, ok := initializedDrivers[driver]; ok { + return nil + } + + db, err := sql.Open(driver, dsn) + if err != nil { + return err + } + sql.Register(driver+":xray", &driverDriver{ + Driver: db.Driver(), + baseName: driver, + }) + initializedDrivers[driver] = struct{}{} + db.Close() + return nil +} + +// SQLContext opens a normalized and traced wrapper around an *sql.DB connection. +// It uses `sql.Open` internally and shares the same function signature. +// To ensure passwords are filtered, it is HIGHLY RECOMMENDED that your DSN +// follows the format: `://:@:/` +func SQLContext(driver, dsn string) (*sql.DB, error) { + if err := initXRayDriver(driver, dsn); err != nil { + return nil, err + } + return sql.Open(driver+":xray", dsn) +} + +type driverDriver struct { + driver.Driver + baseName string // the name of the base driver +} + +func (d *driverDriver) Open(dsn string) (driver.Conn, error) { + rawConn, err := d.Driver.Open(dsn) + if err != nil { + return nil, err + } + attr, err := newDBAttribute(context.Background(), d.baseName, d.Driver, rawConn, dsn, false) + if err != nil { + rawConn.Close() + return nil, err + } + + conn := &driverConn{ + Conn: rawConn, + attr: attr, + } + return conn, nil +} + +type driverConn struct { + driver.Conn + attr *dbAttribute +} + +func (conn *driverConn) Ping(ctx context.Context) error { + return Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + conn.attr.populate(ctx, "PING") + if p, ok := conn.Conn.(driver.Pinger); ok { + return p.Ping(ctx) + } + return nil + }) +} + +func (conn *driverConn) Prepare(query string) (driver.Stmt, error) { + panic("not supported") +} + +func (conn *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + var stmt driver.Stmt + var err error + if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok { + stmt, err = connCtx.PrepareContext(ctx, query) + } else { + stmt, err = conn.Conn.Prepare(query) + if err == nil { + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + } + } + if err != nil { + return nil, err + } + return &driverStmt{ + Stmt: stmt, + attr: conn.attr, + query: query, + conn: conn, + }, nil +} + +func (conn *driverConn) Begin() (driver.Tx, error) { + panic("not supported") +} + +func (conn *driverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var tx driver.Tx + var err error + if connCtx, ok := conn.Conn.(driver.ConnBeginTx); ok { + tx, err = connCtx.BeginTx(ctx, opts) + } else { + if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) { + return nil, errors.New("xray: driver does not support non-default isolation level") + } + if opts.ReadOnly { + return nil, errors.New("xray: driver does not support read-only transactions") + } + tx, err = conn.Conn.Begin() + if err == nil { + select { + default: + case <-ctx.Done(): + tx.Rollback() + return nil, ctx.Err() + } + } + } + if err != nil { + return nil, err + } + return &driverTx{Tx: tx}, nil +} + +func (conn *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) { + panic("not supported") +} + +func (conn *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + execer, ok := conn.Conn.(driver.Execer) + if !ok { + return nil, driver.ErrSkip + } + + var err error + var result driver.Result + if execerCtx, ok := conn.Conn.(driver.ExecerContext); ok { + Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + result, err = execerCtx.ExecContext(ctx, query, args) + if err == driver.ErrSkip { + conn.attr.populate(ctx, query+msgErrSkip) + return nil + } + conn.attr.populate(ctx, query) + return err + }) + } else { + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + dargs, err0 := namedValuesToValues(args) + if err0 != nil { + return nil, err0 + } + Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + var err error + result, err = execer.Exec(query, dargs) + if err == driver.ErrSkip { + conn.attr.populate(ctx, query+msgErrSkip) + return nil + } + conn.attr.populate(ctx, query) + return err + }) + } + return result, err +} + +func (conn *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) { + panic("not supported") +} + +func (conn *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + queryer, ok := conn.Conn.(driver.Queryer) + if !ok { + return nil, driver.ErrSkip + } + + var err error + var rows driver.Rows + if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok { + Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + rows, err = queryerCtx.QueryContext(ctx, query, args) + if err == driver.ErrSkip { + conn.attr.populate(ctx, query+msgErrSkip) + return nil + } + conn.attr.populate(ctx, query) + return err + }) + } else { + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + dargs, err0 := namedValuesToValues(args) + if err0 != nil { + return nil, err0 + } + err = Capture(ctx, conn.attr.dbname, func(ctx context.Context) error { + rows, err = queryer.Query(query, dargs) + if err == driver.ErrSkip { + conn.attr.populate(ctx, query+msgErrSkip) + return nil + } + conn.attr.populate(ctx, query) + return err + }) + } + return rows, err +} + +func (conn *driverConn) Close() error { + return conn.Conn.Close() +} + +// copied from https://github.com/golang/go/blob/e6ebbe0d20fe877b111cf4ccf8349cba129d6d3a/src/database/sql/convert.go#L93-L99 +// defaultCheckNamedValue wraps the default ColumnConverter to have the same +// function signature as the CheckNamedValue in the driver.NamedValueChecker +// interface. +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + +// CheckNamedValue for implementing driver.NamedValueChecker +// This function may be unnecessary because `proxy.Stmt` already implements `NamedValueChecker`, +// but it is implemented just in case. +func (conn *driverConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + if nvc, ok := conn.Conn.(namedValueChecker); ok { + return nvc.CheckNamedValue(nv) + } + // fallback to default + return defaultCheckNamedValue(nv) +} + +type dbAttribute struct { + connectionString string + url string + databaseType string + databaseVersion string + driverVersion string + user string + dbname string +} + +func newDBAttribute(ctx context.Context, driverName string, d driver.Driver, conn driver.Conn, dsn string, filtered bool) (*dbAttribute, error) { + var attr dbAttribute + + // Detect if DSN is a URL or not, set appropriate attribute + urlDsn := dsn + if !strings.Contains(dsn, "//") { + urlDsn = "//" + urlDsn + } + // Here we're trying to detect things like `host:port/database` as a URL, which is pretty hard + // So we just assume that if it's got a scheme, a user, or a query that it's probably a URL + if u, err := url.Parse(urlDsn); err == nil && (u.Scheme != "" || u.User != nil || u.RawQuery != "" || strings.Contains(u.Path, "@")) { + // Check that this isn't in the form of user/pass@host:port/db, as that will shove the host into the path + if strings.Contains(u.Path, "@") { + u, err = url.Parse(fmt.Sprintf("%s//%s%%2F%s", u.Scheme, u.Host, u.Path[1:])) + if err != nil { + return nil, err + } + } + + // Strip password from user:password pair in address + if u.User != nil { + uname := u.User.Username() + + // Some drivers use "user/pass@host:port" instead of "user:pass@host:port" + // So we must manually attempt to chop off a potential password. + // But we can skip this if we already found the password. + if _, ok := u.User.Password(); !ok { + uname = strings.Split(uname, "/")[0] + } + + u.User = url.User(uname) + } + + // Strip password from query parameters + q := u.Query() + q.Del("password") + u.RawQuery = q.Encode() + + attr.url = u.String() + if !strings.Contains(dsn, "//") { + attr.url = attr.url[2:] + } + } else { + // We don't *think* it's a URL, so now we have to try our best to strip passwords from + // some unknown DSL. We attempt to detect whether it's space-delimited or semicolon-delimited + // then remove any keys with the name "password" or "pwd". This won't catch everything, but + // from surveying the current (Jan 2017) landscape of drivers it should catch most. + if filtered { + attr.connectionString = dsn + } else { + attr.connectionString = stripPasswords(dsn) + } + } + + // Detect database type and use that to populate attributes + var detectors []func(ctx context.Context, conn driver.Conn, attr *dbAttribute) error + switch driverName { + case "postgres": + detectors = append(detectors, postgresDetector) + case "mysql": + detectors = append(detectors, mysqlDetector) + default: + detectors = append(detectors, postgresDetector, mysqlDetector, mssqlDetector, oracleDetector) + } + for _, detector := range detectors { + if detector(ctx, conn, &attr) == nil { + break + } + attr.databaseType = "Unknown" + attr.databaseVersion = "Unknown" + attr.user = "Unknown" + attr.dbname = "Unknown" + } + + // There's no standard to get SQL driver version information + // So we invent an interface by which drivers can provide us this data + type versionedDriver interface { + Version() string + } + + if vd, ok := d.(versionedDriver); ok { + attr.driverVersion = vd.Version() + } else { + t := reflect.TypeOf(d) + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + attr.driverVersion = t.PkgPath() + } + + if attrHook != nil { + attrHook(&attr) + } + return &attr, nil +} + +func postgresDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error { + attr.databaseType = "Postgres" + return queryRow( + ctx, conn, + "SELECT version(), current_user, current_database()", + &attr.databaseVersion, &attr.user, &attr.dbname, + ) +} + +func mysqlDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error { + attr.databaseType = "MySQL" + return queryRow( + ctx, conn, + "SELECT version(), current_user(), database()", + &attr.databaseVersion, &attr.user, &attr.dbname, + ) +} + +func mssqlDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error { + attr.databaseType = "MS SQL" + return queryRow( + ctx, conn, + "SELECT @@version, current_user, db_name()", + &attr.databaseVersion, &attr.user, &attr.dbname, + ) +} + +func oracleDetector(ctx context.Context, conn driver.Conn, attr *dbAttribute) error { + attr.databaseType = "Oracle" + return queryRow( + ctx, conn, + "SELECT version FROM v$instance UNION SELECT user, ora_database_name FROM dual", + &attr.databaseVersion, &attr.user, &attr.dbname, + ) +} + +// minimum implementation of (*sql.DB).QueryRow +func queryRow(ctx context.Context, conn driver.Conn, query string, dest ...*string) error { + var err error + + // prepare + var stmt driver.Stmt + if connCtx, ok := conn.(driver.ConnPrepareContext); ok { + stmt, err = connCtx.PrepareContext(ctx, query) + } else { + stmt, err = conn.Prepare(query) + if err == nil { + select { + default: + case <-ctx.Done(): + stmt.Close() + return ctx.Err() + } + } + } + if err != nil { + return err + } + defer stmt.Close() + + // execute query + var rows driver.Rows + if queryCtx, ok := stmt.(driver.StmtQueryContext); ok { + rows, err = queryCtx.QueryContext(ctx, []driver.NamedValue{}) + } else { + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + rows, err = stmt.Query([]driver.Value{}) + } + if err != nil { + return err + } + defer rows.Close() + + // scan + if len(dest) != len(rows.Columns()) { + return fmt.Errorf("xray: expected %d destination arguments in Scan, not %d", len(rows.Columns()), len(dest)) + } + cols := make([]driver.Value, len(rows.Columns())) + if err := rows.Next(cols); err != nil { + return err + } + for i, src := range cols { + d := dest[i] + switch s := src.(type) { + case string: + *d = s + case []byte: + *d = string(s) + case time.Time: + *d = s.Format(time.RFC3339Nano) + case int64: + *d = strconv.FormatInt(s, 10) + case float64: + *d = strconv.FormatFloat(s, 'g', -1, 64) + case bool: + *d = strconv.FormatBool(s) + default: + return fmt.Errorf("sql: Scan error on column index %d, name %q: type missmatch", i, rows.Columns()[i]) + } + } + + return nil +} + +func (attr *dbAttribute) populate(ctx context.Context, query string) { + seg := GetSegment(ctx) + + if seg == nil { + processNilSegment(ctx) + return + } + + seg.Lock() + seg.Namespace = "remote" + seg.GetSQL().ConnectionString = attr.connectionString + seg.GetSQL().URL = attr.url + seg.GetSQL().DatabaseType = attr.databaseType + seg.GetSQL().DatabaseVersion = attr.databaseVersion + seg.GetSQL().DriverVersion = attr.driverVersion + seg.GetSQL().User = attr.user + seg.GetSQL().SanitizedQuery = query + seg.Unlock() +} + +type driverTx struct { + driver.Tx +} + +func (tx *driverTx) Commit() error { + return tx.Tx.Commit() +} + +func (tx *driverTx) Rollback() error { + return tx.Tx.Rollback() +} + +type driverStmt struct { + driver.Stmt + conn *driverConn + attr *dbAttribute + query string +} + +func (stmt *driverStmt) Close() error { + return stmt.Stmt.Close() +} + +func (stmt *driverStmt) NumInput() int { + return stmt.Stmt.NumInput() +} + +func (stmt *driverStmt) Exec(args []driver.Value) (driver.Result, error) { + panic("not supported") +} + +func (stmt *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + var result driver.Result + var err error + if execerContext, ok := stmt.Stmt.(driver.StmtExecContext); ok { + err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + stmt.populate(ctx) + var err error + result, err = execerContext.ExecContext(ctx, args) + return err + }) + } else { + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + dargs, err0 := namedValuesToValues(args) + if err0 != nil { + return nil, err0 + } + err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + stmt.populate(ctx) + var err error + result, err = stmt.Stmt.Exec(dargs) + return err + }) + } + if err != nil { + return nil, err + } + return result, nil +} + +func (stmt *driverStmt) Query(args []driver.Value) (driver.Rows, error) { + panic("not supported") +} + +func (stmt *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + var result driver.Rows + var err error + if queryCtx, ok := stmt.Stmt.(driver.StmtQueryContext); ok { + err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + stmt.populate(ctx) + var err error + result, err = queryCtx.QueryContext(ctx, args) + return err + }) + } else { + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + dargs, err0 := namedValuesToValues(args) + if err0 != nil { + return nil, err0 + } + err = Capture(ctx, stmt.attr.dbname, func(ctx context.Context) error { + stmt.populate(ctx) + var err error + result, err = stmt.Stmt.Query(dargs) + return err + }) + } + if err != nil { + return nil, err + } + return result, nil +} + +func (stmt *driverStmt) ColumnConverter(idx int) driver.ValueConverter { + if conv, ok := stmt.Stmt.(driver.ColumnConverter); ok { + return conv.ColumnConverter(idx) + } + return driver.DefaultParameterConverter +} + +func (stmt *driverStmt) populate(ctx context.Context) { + stmt.attr.populate(ctx, stmt.query) + + seg := GetSegment(ctx) + + if seg == nil { + processNilSegment(ctx) + return + } + + seg.Lock() + seg.GetSQL().Preparation = "statement" + seg.Unlock() +} + +// CheckNamedValue for implementing NamedValueChecker +func (stmt *driverStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { + if nvc, ok := stmt.Stmt.(namedValueChecker); ok { + return nvc.CheckNamedValue(nv) + } + // When converting data in sql/driver/convert.go, it is checked first whether the `stmt` + // implements `NamedValueChecker`, and then checks if `conn` implements NamedValueChecker. + // In the case of "go-sql-proxy", the `proxy.Stmt` "implements" `CheckNamedValue` here, + // so we also check both `stmt` and `conn` inside here. + if nvc, ok := stmt.conn.Conn.(namedValueChecker); ok { + return nvc.CheckNamedValue(nv) + } + // fallback to default + return defaultCheckNamedValue(nv) +} + +func namedValuesToValues(args []driver.NamedValue) ([]driver.Value, error) { + var err error + ret := make([]driver.Value, len(args)) + for _, arg := range args { + if len(arg.Name) > 0 { + err = errors.New("xray: driver does not support the use of Named Parameters") + } + ret[arg.Ordinal-1] = arg.Value + } + return ret, err +} + +func stripPasswords(dsn string) string { + var ( + tok bytes.Buffer + res bytes.Buffer + isPassword bool + inBraces bool + delimiter byte = ' ' + ) + flush := func() { + if inBraces { + return + } + if !isPassword { + res.Write(tok.Bytes()) + } + tok.Reset() + isPassword = false + } + if strings.Count(dsn, ";") > strings.Count(dsn, " ") { + delimiter = ';' + } + + buf := strings.NewReader(dsn) + for c, err := buf.ReadByte(); err == nil; c, err = buf.ReadByte() { + tok.WriteByte(c) + switch c { + case ':', delimiter: + flush() + case '=': + tokStr := strings.ToLower(tok.String()) + isPassword = `password=` == tokStr || `pwd=` == tokStr + if b, err := buf.ReadByte(); err != nil { + break + } else { + inBraces = b == '{' + } + if err := buf.UnreadByte(); err != nil { + panic(err) + } + case '}': + b, err := buf.ReadByte() + if err != nil { + break + } + if b == '}' { + tok.WriteByte(b) + } else { + inBraces = false + if err := buf.UnreadByte(); err != nil { + panic(err) + } + } + case '@': + if strings.Contains(res.String(), ":") { + resLen := res.Len() + if resLen > 0 && res.Bytes()[resLen-1] == ':' { + res.Truncate(resLen - 1) + } + isPassword = true + flush() + res.WriteByte(c) + } + } + } + inBraces = false + flush() + return res.String() +} + +func processNilSegment(ctx context.Context) { + cfg := GetRecorder(ctx) + failedMessage := "failed to get segment from context since segment is nil" + if cfg != nil && cfg.ContextMissingStrategy != nil { + cfg.ContextMissingStrategy.ContextMissing(failedMessage) + } else { + globalCfg.ContextMissingStrategy().ContextMissing(failedMessage) + } +} diff --git a/xray/sql_go110.go b/xray/sql_go110.go new file mode 100644 index 00000000..a034b80d --- /dev/null +++ b/xray/sql_go110.go @@ -0,0 +1,152 @@ +// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +// +build go1.10 + +package xray + +import ( + "context" + "database/sql/driver" + "sync" +) + +// SQLConnector wraps the connector, and traces SQL executions. +// Unlike SQLContext, SQLConnector doesn't filter the password of the dsn. +// So, you have to filter the password before passing the dsn to SQLConnector. +func SQLConnector(dsn string, connector driver.Connector) driver.Connector { + d := &driverDriver{ + Driver: connector.Driver(), + baseName: "unknown", + } + return &driverConnector{ + Connector: connector, + driver: d, + name: dsn, + filtered: true, + // initialized attr lazy because we have no context here. + } +} + +func (conn *driverConn) ResetSession(ctx context.Context) error { + if sr, ok := conn.Conn.(driver.SessionResetter); ok { + return sr.ResetSession(ctx) + } + return nil +} + +type driverConnector struct { + driver.Connector + driver *driverDriver + filtered bool + name string + + mu sync.RWMutex + attr *dbAttribute +} + +func (c *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { + var rawConn driver.Conn + attr, err := c.getAttr(ctx) + if err != nil { + return nil, err + } + err = Capture(ctx, attr.dbname, func(ctx context.Context) error { + attr.populate(ctx, "CONNECT") + var err error + rawConn, err = c.Connector.Connect(ctx) + return err + }) + if err != nil { + return nil, err + } + + conn := &driverConn{ + Conn: rawConn, + attr: attr, + } + return conn, nil +} + +func (c *driverConnector) getAttr(ctx context.Context) (*dbAttribute, error) { + c.mu.RLock() + attr := c.attr + c.mu.RUnlock() + if attr != nil { + return attr, nil + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.attr != nil { + return c.attr, nil + } + conn, err := c.Connector.Connect(ctx) + if err != nil { + return nil, err + } + defer conn.Close() + + attr, err = newDBAttribute(ctx, c.driver.baseName, c.driver.Driver, conn, c.name, c.filtered) + if err != nil { + return nil, err + } + c.attr = attr + return attr, nil +} + +func (c *driverConnector) Driver() driver.Driver { + return c.driver +} + +type fallbackConnector struct { + driver driver.Driver + name string +} + +func (c *fallbackConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.driver.Open(c.name) + if err != nil { + return nil, err + } + select { + default: + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + } + return conn, nil +} + +func (c *fallbackConnector) Driver() driver.Driver { + return c.driver +} + +func (d *driverDriver) OpenConnector(name string) (driver.Connector, error) { + var c driver.Connector + if dctx, ok := d.Driver.(driver.DriverContext); ok { + var err error + c, err = dctx.OpenConnector(name) + if err != nil { + return nil, err + } + } else { + c = &fallbackConnector{ + driver: d.Driver, + name: name, + } + } + c = &driverConnector{ + Connector: c, + driver: d, + filtered: false, + name: name, + // initialized attr lazy because we have no context here. + } + return c, nil +} diff --git a/xray/sql_go110_test.go b/xray/sql_go110_test.go new file mode 100644 index 00000000..a72ec597 --- /dev/null +++ b/xray/sql_go110_test.go @@ -0,0 +1,76 @@ +// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +// +build go1.10 + +package xray + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +type versionedDriver struct { + driver.Driver + version string +} + +func (d *versionedDriver) Version() string { + return d.version +} + +func TestDriverVersion(t *testing.T) { + dsn := "test-versioned-driver" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, nil) + + // implement versionedDriver + driver := &versionedDriver{ + Driver: db.Driver(), + version: "3.1415926535", + } + connector := &fallbackConnector{ + driver: driver, + name: dsn, + } + sqlConnector := SQLConnector("sanitized-dsn", connector) + db = sql.OpenDB(sqlConnector) + defer db.Close() + + ctx, td := NewTestDaemon() + defer td.Close() + + // Execute SQL + ctx, root := BeginSegment(ctx, "test") + if err := db.PingContext(ctx); err != nil { + t.Fatal(err) + } + root.Close(nil) + assert.NoError(t, mock.ExpectationsWereMet()) + + // assertion + seg, err := td.Recv() + if err != nil { + t.Fatal(err) + } + var subseg *Segment + if err := json.Unmarshal(seg.Subsegments[0], &subseg); err != nil { + t.Fatal(err) + } + assert.Equal(t, "sanitized-dsn", subseg.SQL.ConnectionString) + assert.Equal(t, "3.1415926535", subseg.SQL.DriverVersion) +} diff --git a/xray/sql_go111_test.go b/xray/sql_go111_test.go index ea81e011..e2b990a8 100644 --- a/xray/sql_go111_test.go +++ b/xray/sql_go111_test.go @@ -10,34 +10,55 @@ package xray -func (s *sqlTestSuite) TestMySQLPasswordConnectionString() { - s.mockDB("username:password@protocol(address:1234)/dbname?param=value") - s.mockMySQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - if s.db.connectionString != "" { - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.connectionString) - s.Equal("", s.db.url) - } - if s.db.url != "" { - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.url) - s.Equal("", s.db.connectionString) - } -} +import ( + "testing" -func (s *sqlTestSuite) TestMySQLPasswordlessConnectionString() { - s.mockDB("username@protocol(address:1234)/dbname?param=value") - s.mockMySQL(nil) - s.connect() + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) - s.Require().NoError(s.mock.ExpectationsWereMet()) - if s.db.connectionString != "" { - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.connectionString) - s.Equal("", s.db.url) +func TestMySQLPasswordConnectionString(t *testing.T) { + tc := []struct { + dsn string + url string + str string + }{ + { + dsn: "username:password@protocol(address:1234)/dbname?param=value", + url: "username@protocol(address:1234)/dbname?param=value", + str: "username@protocol(address:1234)/dbname?param=value", + }, + { + dsn: "username@protocol(address:1234)/dbname?param=value", + url: "username@protocol(address:1234)/dbname?param=value", + str: "username@protocol(address:1234)/dbname?param=value", + }, } - if s.db.url != "" { - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.url) - s.Equal("", s.db.connectionString) + + for _, tt := range tc { + tt := tt + t.Run(tt.dsn, func(t *testing.T) { + db, mock, err := sqlmock.NewWithDSN(tt.dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockMySQL(mock, nil) + + subseg, err := capturePing(tt.dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "MySQL", subseg.SQL.DatabaseType) + if subseg.SQL.URL != "" { + assert.Equal(t, tt.url, subseg.SQL.URL) + } + if subseg.SQL.ConnectionString != "" { + assert.Equal(t, tt.str, subseg.SQL.ConnectionString) + } + }) } } diff --git a/xray/sql_go18.go b/xray/sql_go18.go deleted file mode 100644 index 0067207a..00000000 --- a/xray/sql_go18.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - -// +build go1.8 - -package xray - -import ( - "context" - "database/sql" -) - -// Begin starts a transaction. -func (db *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - tx, err := db.db.BeginTx(ctx, opts) - return &Tx{db, tx}, err -} - -// Prepare creates a prepared statement for later queries or executions. -func (db *DB) Prepare(ctx context.Context, query string) (*Stmt, error) { - stmt, err := db.db.PrepareContext(ctx, query) - return &Stmt{db, stmt, query}, err -} - -// Ping traces verifying a connection to the database is still alive, -// establishing a connection if necessary and adds corresponding information into subsegment. -func (db *DB) Ping(ctx context.Context) error { - return Capture(ctx, db.dbname, func(ctx context.Context) error { - db.populate(ctx, "PING") - return db.db.PingContext(ctx) - }) -} - -// Exec captures executing a query without returning any rows and -// adds corresponding information into subsegment. -func (db *DB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - var res sql.Result - - err := Capture(ctx, db.dbname, func(ctx context.Context) error { - db.populate(ctx, query) - - var err error - res, err = db.db.ExecContext(ctx, query, args...) - return err - }) - - return res, err -} - -// Query captures executing a query that returns rows and adds corresponding information into subsegment. -func (db *DB) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - var res *sql.Rows - - err := Capture(ctx, db.dbname, func(ctx context.Context) error { - db.populate(ctx, query) - - var err error - res, err = db.db.QueryContext(ctx, query, args...) - return err - }) - - return res, err -} - -// QueryRow captures executing a query that is expected to return at most one row -// and adds corresponding information into subsegment. -func (db *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { - var res *sql.Row - - Capture(ctx, db.dbname, func(ctx context.Context) error { - db.populate(ctx, query) - - res = db.db.QueryRowContext(ctx, query, args...) - return nil - }) - - return res -} - -// Prepare creates a prepared statement for later queries or executions. -func (tx *Tx) Prepare(ctx context.Context, query string) (*Stmt, error) { - stmt, err := tx.tx.PrepareContext(ctx, query) - return &Stmt{tx.db, stmt, query}, err -} - -// Exec captures executing a query that doesn't return rows and adds -// corresponding information into subsegment. -func (tx *Tx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - var res sql.Result - - err := Capture(ctx, tx.db.dbname, func(ctx context.Context) error { - tx.db.populate(ctx, query) - - var err error - res, err = tx.tx.ExecContext(ctx, query, args...) - return err - }) - - return res, err -} - -// Query captures executing a query that returns rows and adds corresponding information into subsegment. -func (tx *Tx) Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - var res *sql.Rows - - err := Capture(ctx, tx.db.dbname, func(ctx context.Context) error { - tx.db.populate(ctx, query) - - var err error - res, err = tx.tx.QueryContext(ctx, query, args...) - return err - }) - - return res, err -} - -// QueryRow captures executing a query that is expected to return at most one row and adds -// corresponding information into subsegment. -func (tx *Tx) QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row { - var res *sql.Row - - Capture(ctx, tx.db.dbname, func(ctx context.Context) error { - tx.db.populate(ctx, query) - - res = tx.tx.QueryRowContext(ctx, query, args...) - return nil - }) - - return res -} - -// Stmt returns a transaction-specific prepared statement from an existing statement. -func (tx *Tx) Stmt(ctx context.Context, stmt *Stmt) *Stmt { - return &Stmt{stmt.db, tx.tx.StmtContext(ctx, stmt.stmt), stmt.query} -} - -// Exec captures executing a prepared statement with the given arguments and -// returning a Result summarizing the effect of the statement and adds corresponding -// information into subsegment. -func (stmt *Stmt) Exec(ctx context.Context, args ...interface{}) (sql.Result, error) { - var res sql.Result - - err := Capture(ctx, stmt.db.dbname, func(ctx context.Context) error { - stmt.populate(ctx, stmt.query) - - var err error - res, err = stmt.stmt.ExecContext(ctx, args...) - return err - }) - - return res, err -} - -// Query captures executing a prepared query statement with the given arguments -// and returning the query results as a *Rows and adds corresponding information -// into subsegment. -func (stmt *Stmt) Query(ctx context.Context, args ...interface{}) (*sql.Rows, error) { - var res *sql.Rows - - err := Capture(ctx, stmt.db.dbname, func(ctx context.Context) error { - stmt.populate(ctx, stmt.query) - - var err error - res, err = stmt.stmt.QueryContext(ctx, args...) - return err - }) - - return res, err -} - -// QueryRow captures executing a prepared query statement with the given arguments and -// adds corresponding information into subsegment. -func (stmt *Stmt) QueryRow(ctx context.Context, args ...interface{}) *sql.Row { - var res *sql.Row - - Capture(ctx, stmt.db.dbname, func(ctx context.Context) error { - stmt.populate(ctx, stmt.query) - - res = stmt.stmt.QueryRowContext(ctx, args...) - return nil - }) - - return res -} diff --git a/xray/sql_prego111_test.go b/xray/sql_prego111_test.go index 13d20e55..3d934550 100644 --- a/xray/sql_prego111_test.go +++ b/xray/sql_prego111_test.go @@ -10,22 +10,49 @@ package xray -func (s *sqlTestSuite) TestMySQLPasswordConnectionStringPreGo111() { - s.mockDB("username:password@protocol(address:1234)/dbname?param=value") - s.mockMySQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.url) -} +import ( + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +func TestMySQLPasswordConnectionString(t *testing.T) { + tc := []struct { + dsn string + url string + str string + }{ + { + dsn: "username:password@protocol(address:1234)/dbname?param=value", + url: "username@protocol(address:1234)/dbname?param=value", + }, + { + dsn: "username@protocol(address:1234)/dbname?param=value", + url: "username@protocol(address:1234)/dbname?param=value", + }, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.dsn, func(t *testing.T) { + db, mock, err := sqlmock.NewWithDSN(tt.dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockMySQL(mock, nil) -func (s *sqlTestSuite) TestMySQLPasswordlessConnectionStringPreGo111() { - s.mockDB("username@protocol(address:1234)/dbname?param=value") - s.mockMySQL(nil) - s.connect() + subseg, err := capturePing(tt.dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("username@protocol(address:1234)/dbname?param=value", s.db.url) + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "MySQL", subseg.SQL.DatabaseType) + assert.Equal(t, tt.url, subseg.SQL.URL) + assert.Equal(t, tt.str, subseg.SQL.ConnectionString) + }) + } } diff --git a/xray/sql_test.go b/xray/sql_test.go deleted file mode 100644 index 4680900e..00000000 --- a/xray/sql_test.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - -package xray - -import ( - "crypto/rand" - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/suite" -) - -func TestSQL(t *testing.T) { - suite.Run(t, &sqlTestSuite{ - dbs: map[string]sqlmock.Sqlmock{}, - }) -} - -type sqlTestSuite struct { - suite.Suite - - dbs map[string]sqlmock.Sqlmock - - dsn string - db *DB - mock sqlmock.Sqlmock -} - -func (s *sqlTestSuite) mockDB(dsn string) { - if dsn == "" { - b := make([]byte, 32) - rand.Read(b) - dsn = string(b) - } - - var err error - s.dsn = dsn - if mock, ok := s.dbs[dsn]; ok { - s.mock = mock - } else { - _, s.mock, err = sqlmock.NewWithDSN(dsn) - s.Require().NoError(err) - s.dbs[dsn] = s.mock - } -} - -func (s *sqlTestSuite) connect() { - var err error - s.db, err = SQL("sqlmock", s.dsn) - s.Require().NoError(err) -} - -func (s *sqlTestSuite) mockPSQL(err error) { - row := sqlmock.NewRows([]string{"version()", "current_user", "current_database()"}). - AddRow("test version", "test user", "test database"). - RowError(0, err) - s.mock.ExpectQuery(`SELECT version\(\), current_user, current_database\(\)`).WillReturnRows(row) -} -func (s *sqlTestSuite) mockMySQL(err error) { - row := sqlmock.NewRows([]string{"version()", "current_user()", "database()"}). - AddRow("test version", "test user", "test database"). - RowError(0, err) - s.mock.ExpectQuery(`SELECT version\(\), current_user\(\), database\(\)`).WillReturnRows(row) -} -func (s *sqlTestSuite) mockMSSQL(err error) { - row := sqlmock.NewRows([]string{"@@version", "current_user", "db_name()"}). - AddRow("test version", "test user", "test database"). - RowError(0, err) - s.mock.ExpectQuery(`SELECT @@version, current_user, db_name\(\)`).WillReturnRows(row) -} -func (s *sqlTestSuite) mockOracle(err error) { - row := sqlmock.NewRows([]string{"version", "user", "ora_database_name"}). - AddRow("test version", "test user", "test database"). - RowError(0, err) - s.mock.ExpectQuery(`SELECT version FROM v\$instance UNION SELECT user, ora_database_name FROM dual`).WillReturnRows(row) -} - -func (s *sqlTestSuite) TestPasswordlessURL() { - s.mockDB("postgres://user@host:1234/database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("postgres://user@host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordURL() { - s.mockDB("postgres://user:password@host:1234/database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("postgres://user@host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordURLQuery() { - s.mockDB("postgres://host:1234/database?password=password") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("postgres://host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordURLSchemaless() { - s.mockDB("user:password@host:1234/database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("user@host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordURLSchemalessUserlessQuery() { - s.mockDB("host:1234/database?password=password") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestWeirdPasswordURL() { - s.mockDB("user%2Fpassword@host:1234/database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("user@host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestWeirderPasswordURL() { - s.mockDB("user/password@host:1234/database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("", s.db.connectionString) - s.Equal("user@host:1234/database", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordlessConnectionString() { - s.mockDB("user=user database=database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("user=user database=database", s.db.connectionString) - s.Equal("", s.db.url) -} - -func (s *sqlTestSuite) TestPasswordConnectionString() { - s.mockDB("user=user password=password database=database") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("user=user database=database", s.db.connectionString) - s.Equal("", s.db.url) -} - -func (s *sqlTestSuite) TestSemicolonPasswordConnectionString() { - s.mockDB("odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("odbc:server=localhost;user id=sa;otherthing=thing", s.db.connectionString) - s.Equal("", s.db.url) -} - -func (s *sqlTestSuite) TestPSQL() { - s.mockDB("") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("Postgres", s.db.databaseType) - s.Equal("test version", s.db.databaseVersion) - s.Equal("test user", s.db.user) - s.Equal("test database", s.db.dbname) -} - -func (s *sqlTestSuite) TestMySQL() { - s.mockDB("") - s.mockPSQL(errors.New("")) - s.mockMySQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("MySQL", s.db.databaseType) - s.Equal("test version", s.db.databaseVersion) - s.Equal("test user", s.db.user) - s.Equal("test database", s.db.dbname) -} - -func (s *sqlTestSuite) TestMSSQL() { - s.mockDB("") - s.mockPSQL(errors.New("")) - s.mockMySQL(errors.New("")) - s.mockMSSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("MS SQL", s.db.databaseType) - s.Equal("test version", s.db.databaseVersion) - s.Equal("test user", s.db.user) - s.Equal("test database", s.db.dbname) -} - -func (s *sqlTestSuite) TestOracle() { - s.mockDB("") - s.mockPSQL(errors.New("")) - s.mockMySQL(errors.New("")) - s.mockMSSQL(errors.New("")) - s.mockOracle(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("Oracle", s.db.databaseType) - s.Equal("test version", s.db.databaseVersion) - s.Equal("test user", s.db.user) - s.Equal("test database", s.db.dbname) -} - -func (s *sqlTestSuite) TestUnknownDatabase() { - s.mockDB("") - s.mockPSQL(errors.New("")) - s.mockMySQL(errors.New("")) - s.mockMSSQL(errors.New("")) - s.mockOracle(errors.New("")) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - s.Equal("Unknown", s.db.databaseType) - s.Equal("Unknown", s.db.databaseVersion) - s.Equal("Unknown", s.db.user) - s.Equal("Unknown", s.db.dbname) -} - -func (s *sqlTestSuite) TestDriverVersionPackage() { - s.mockDB("") - s.mockPSQL(nil) - s.connect() - - s.Require().NoError(s.mock.ExpectationsWereMet()) - //s.Equal("gopkg.in/DATA-DOG/go-sqlmock.v1", s.db.driverVersion) -} diff --git a/xray/sqlcontext_test.go b/xray/sqlcontext_test.go new file mode 100644 index 00000000..389719cd --- /dev/null +++ b/xray/sqlcontext_test.go @@ -0,0 +1,352 @@ +// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +package xray + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +// utility functions for testing SQL + +func mockPostgreSQL(mock sqlmock.Sqlmock, err error) { + row := sqlmock.NewRows([]string{"version()", "current_user", "current_database()"}). + AddRow("test version", "test user", "test database"). + RowError(0, err) + mock.ExpectPrepare(`SELECT version\(\), current_user, current_database\(\)`).ExpectQuery().WillReturnRows(row) +} +func mockMySQL(mock sqlmock.Sqlmock, err error) { + row := sqlmock.NewRows([]string{"version()", "current_user()", "database()"}). + AddRow("test version", "test user", "test database"). + RowError(0, err) + mock.ExpectPrepare(`SELECT version\(\), current_user\(\), database\(\)`).ExpectQuery().WillReturnRows(row) +} +func mockMSSQL(mock sqlmock.Sqlmock, err error) { + row := sqlmock.NewRows([]string{"@@version", "current_user", "db_name()"}). + AddRow("test version", "test user", "test database"). + RowError(0, err) + mock.ExpectPrepare(`SELECT @@version, current_user, db_name\(\)`).ExpectQuery().WillReturnRows(row) +} +func mockOracle(mock sqlmock.Sqlmock, err error) { + row := sqlmock.NewRows([]string{"version", "user", "ora_database_name"}). + AddRow("test version", "test user", "test database"). + RowError(0, err) + mock.ExpectPrepare(`SELECT version FROM v\$instance UNION SELECT user, ora_database_name FROM dual`).ExpectQuery().WillReturnRows(row) +} + +func capturePing(dsn string) (*Segment, error) { + ctx, td := NewTestDaemon() + defer td.Close() + + db, err := SQLContext("sqlmock", dsn) + if err != nil { + return nil, err + } + defer db.Close() + + ctx, root := BeginSegment(ctx, "test") + if err := db.PingContext(ctx); err != nil { + return nil, err + } + root.Close(nil) + + seg, err := td.Recv() + if err != nil { + return nil, err + } + var subseg *Segment + if err := json.Unmarshal(seg.Subsegments[0], &subseg); err != nil { + return nil, err + } + + return subseg, nil +} + +func TestDSN(t *testing.T) { + tc := []struct { + dsn string + url string + str string + }{ + { + dsn: "postgres://user@host:5432/database", + url: "postgres://user@host:5432/database", + }, + { + dsn: "postgres://user:password@host:5432/database", + url: "postgres://user@host:5432/database", + }, + { + dsn: "postgres://host:5432/database?password=password", + url: "postgres://host:5432/database", + }, + { + dsn: "user:password@host:5432/database", + url: "user@host:5432/database", + }, + { + dsn: "host:5432/database?password=password", + url: "host:5432/database", + }, + { + dsn: "user%2Fpassword@host:5432/database", + url: "user@host:5432/database", + }, + { + dsn: "user/password@host:5432/database", + url: "user@host:5432/database", + }, + { + dsn: "user=user database=database", + str: "user=user database=database", + }, + { + dsn: "user=user password=password database=database", + str: "user=user database=database", + }, + { + dsn: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing", + str: "odbc:server=localhost;user id=sa;otherthing=thing", + }, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.dsn, func(t *testing.T) { + db, mock, err := sqlmock.NewWithDSN(tt.dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, nil) + + subseg, err := capturePing(tt.dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "Postgres", subseg.SQL.DatabaseType) + assert.Equal(t, tt.url, subseg.SQL.URL) + assert.Equal(t, tt.str, subseg.SQL.ConnectionString) + assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) + assert.Equal(t, "test user", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) + }) + } +} + +func TestPostgreSQL(t *testing.T) { + dsn := "test-postgre" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, nil) + + subseg, err := capturePing(dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "Postgres", subseg.SQL.DatabaseType) + assert.Equal(t, "", subseg.SQL.URL) + assert.Equal(t, dsn, subseg.SQL.ConnectionString) + assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) + assert.Equal(t, "test user", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) +} + +func TestMySQL(t *testing.T) { + dsn := "test-mysql" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, errors.New("syntax error")) + mockMySQL(mock, nil) + + subseg, err := capturePing(dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "MySQL", subseg.SQL.DatabaseType) + assert.Equal(t, "", subseg.SQL.URL) + assert.Equal(t, dsn, subseg.SQL.ConnectionString) + assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) + assert.Equal(t, "test user", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) +} + +func TestMSSQL(t *testing.T) { + dsn := "test-mssql" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, errors.New("syntax error")) + mockMySQL(mock, errors.New("syntax error")) + mockMSSQL(mock, nil) + + subseg, err := capturePing(dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "MS SQL", subseg.SQL.DatabaseType) + assert.Equal(t, "", subseg.SQL.URL) + assert.Equal(t, dsn, subseg.SQL.ConnectionString) + assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) + assert.Equal(t, "test user", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) +} + +func TestOracle(t *testing.T) { + dsn := "test-oracle" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, errors.New("syntax error")) + mockMySQL(mock, errors.New("syntax error")) + mockMSSQL(mock, errors.New("syntax error")) + mockOracle(mock, nil) + + subseg, err := capturePing(dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "Oracle", subseg.SQL.DatabaseType) + assert.Equal(t, "", subseg.SQL.URL) + assert.Equal(t, dsn, subseg.SQL.ConnectionString) + assert.Equal(t, "test version", subseg.SQL.DatabaseVersion) + assert.Equal(t, "test user", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) +} + +func TestUnknownDatabase(t *testing.T) { + dsn := "test-unknown" + db, mock, err := sqlmock.NewWithDSN(dsn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + mockPostgreSQL(mock, errors.New("syntax error")) + mockMySQL(mock, errors.New("syntax error")) + mockMSSQL(mock, errors.New("syntax error")) + mockOracle(mock, errors.New("syntax error")) + + subseg, err := capturePing(dsn) + if err != nil { + t.Fatal(err) + } + assert.NoError(t, mock.ExpectationsWereMet()) + + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, "Unknown", subseg.SQL.DatabaseType) + assert.Equal(t, "", subseg.SQL.URL) + assert.Equal(t, dsn, subseg.SQL.ConnectionString) + assert.Equal(t, "Unknown", subseg.SQL.DatabaseVersion) + assert.Equal(t, "Unknown", subseg.SQL.User) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) +} + +func TestStripPasswords(t *testing.T) { + tc := []struct { + in string + want string + }{ + { + in: "user=user database=database", + want: "user=user database=database", + }, + { + in: "user=user password=password database=database", + want: "user=user database=database", + }, + { + in: "odbc:server=localhost;user id=sa;password={foo}};bar};otherthing=thing", + want: "odbc:server=localhost;user id=sa;otherthing=thing", + }, + + // see https://github.com/aws/aws-xray-sdk-go/issues/181 + { + in: "password=", + want: "", + }, + { + in: "pwd=", + want: "", + }, + + // test cases for https://github.com/go-sql-driver/mysql + { + in: "user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true", + want: "user@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true", + }, + + { + in: "user@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true", + want: "user@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true", + }, + + { + in: "user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci", + want: "user@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci", + }, + + { + in: "user@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci", + want: "user@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci", + }, + + { + in: "user:password@/", + want: "user@/", + }, + } + + for _, tt := range tc { + got := stripPasswords(tt.in) + if got != tt.want { + t.Errorf("%s: want %s, got %s", tt.in, tt.want, got) + } + } +}