From 55f10fa186515d58f0cd3625961f287b7a7d50cb Mon Sep 17 00:00:00 2001 From: Daniel Bedrood Date: Tue, 11 Jun 2024 15:54:03 +0200 Subject: [PATCH] feat: Replacing custom database ddtrace with dd-trace-go library --- go.mod | 2 +- go.sum | 18 ++++ pkg/database/gorm.go | 69 ++++++++++--- pkg/instrumentation/database.go | 148 +-------------------------- pkg/instrumentation/database_test.go | 146 ++++---------------------- 5 files changed, 98 insertions(+), 285 deletions(-) diff --git a/go.mod b/go.mod index 907163b..aeb2997 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 github.com/getsentry/sentry-go v0.12.0 github.com/go-kit/kit v0.9.0 + github.com/go-sql-driver/mysql v1.7.0 github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 github.com/magefile/mage v1.15.0 @@ -70,7 +71,6 @@ require ( github.com/ebitengine/purego v0.6.0-alpha.5 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-stack/stack v1.8.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect diff --git a/go.sum b/go.sum index 093b999..5aa7835 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.11.0 h1:9rHa233rhdOyrz2GcP9NM+gi2psgJZ4GWDpL/7ND8HI= +github.com/denisenkom/go-mssqldb v0.11.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= @@ -149,6 +151,10 @@ github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/E github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -193,6 +199,12 @@ github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/ github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.2 h1:u1gmGDwbdRUZiwisBm/Ky2M14uQyUP65bG8+20nnyrg= +github.com/jackc/pgx/v5 v5.4.2/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -248,6 +260,8 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= +github.com/microsoft/go-mssqldb v0.21.0 h1:p2rpHIL7TlSv1QrbXJUAcbyRKnIT0C9rRkH2E4OjLn8= +github.com/microsoft/go-mssqldb v0.21.0/go.mod h1:+4wZTUnz/SV6nffv+RRRB/ss8jPng5Sho2SmM1l2ts4= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -508,8 +522,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.4.6 h1:5zS3vIKcyb46byXZNcYxaT9EWNIhXzu0gPuvvVrwZ8s= gorm.io/driver/mysql v1.4.6/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= +gorm.io/driver/postgres v1.4.6 h1:1FPESNXqIKG5JmraaH2bfCVlMQ7paLoCreFxDtqzwdc= +gorm.io/driver/postgres v1.4.6/go.mod h1:UJChCNLFKeBqQRE+HrkFUbKbq9idPXmTOk2u4Wok8S4= gorm.io/driver/sqlite v1.4.4 h1:gIufGoR0dQzjkyqDyYSCvsYR6fba1Gw5YKDqKeChxFc= gorm.io/driver/sqlite v1.4.4/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/driver/sqlserver v1.4.2 h1:nMtEeKqv2R/vv9FoHUFWfXfP6SskAgRar0TPlZV1stk= +gorm.io/driver/sqlserver v1.4.2/go.mod h1:XHwBuB4Tlh7DqO0x7Ema8dmyWsQW7wi38VQOAFkrbXY= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.3 h1:zi4rHZj1anhZS2EuEODMhDisGy+Daq9jtPrNGgbQYD8= diff --git a/pkg/database/gorm.go b/pkg/database/gorm.go index 8ff05a6..f3775c4 100644 --- a/pkg/database/gorm.go +++ b/pkg/database/gorm.go @@ -1,18 +1,60 @@ package database import ( + "fmt" "strconv" "time" "github.com/DATA-DOG/go-txdb" - "gorm.io/driver/mysql" + "github.com/go-sql-driver/mysql" + sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" + gormtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorm.io/gorm.v1" + mysqlgorm "gorm.io/driver/mysql" "gorm.io/gorm" ) const testEnv = "test" -// NewConnection returns a new Gorm database connection. +// NewInstrumentedConnection returns a new instrumented Gorm database connection. +func NewInstrumentedConnection(config *Config, environment, appName string) (*gorm.DB, error) { + return newConnection(config, environment, appName, true) +} + +// NewConnection returns a new uninstrumented Gorm database connection. func NewConnection(config *Config, environment string) (*gorm.DB, error) { + return newConnection(config, environment, "", false) +} + +func newConnection( + config *Config, environment, appName string, instrumented bool, +) (*gorm.DB, error) { + dialector := getDialectorFromConfig(config, environment) + + var db *gorm.DB + var err error + if instrumented { + serviceName := fmt.Sprintf("%s-mysql", appName) + + // Register augments the provided driver with tracing, enabling it to be loaded by gormtrace.Open. + sqltrace.Register("mysql", &mysql.MySQLDriver{}, sqltrace.WithServiceName(serviceName)) + + if db, err = gormtrace.Open(dialector, nil, gormtrace.WithServiceName(serviceName)); err != nil { + return nil, err + } + } else { + if db, err = gorm.Open(dialector); err != nil { + return nil, err + } + } + + if err := databasePoolSettings(db, config); err != nil { + return nil, err + } + + return db, nil +} + +func getDialectorFromConfig(config *Config, environment string) gorm.Dialector { connectionDetails := NewConnectionDetails(config) connectionString := connectionDetails.String() @@ -29,25 +71,22 @@ func NewConnection(config *Config, environment string) (*gorm.DB, error) { connectionString = testDriverName } - dialector := mysql.New(mysql.Config{ + return mysqlgorm.New(mysqlgorm.Config{ DSN: connectionString, DriverName: driverName, }) +} - db, err := gorm.Open(dialector) - if err != nil { - return nil, err - } - - sqlDB, err := db.DB() +func databasePoolSettings(gormDB *gorm.DB, config *Config) error { + db, err := gormDB.DB() if err != nil { - return nil, err + return err } - sqlDB.SetMaxIdleConns(config.Pool) - sqlDB.SetMaxOpenConns(config.MaxOpenConnections) - sqlDB.SetConnMaxIdleTime(config.ConnectionMaxIdleTime) - sqlDB.SetConnMaxLifetime(config.ConnectionMaxLifetime) + db.SetMaxIdleConns(config.Pool) + db.SetMaxOpenConns(config.MaxOpenConnections) + db.SetConnMaxIdleTime(config.ConnectionMaxIdleTime) + db.SetConnMaxLifetime(config.ConnectionMaxLifetime) - return db, nil + return nil } diff --git a/pkg/instrumentation/database.go b/pkg/instrumentation/database.go index 931f217..0c21f23 100644 --- a/pkg/instrumentation/database.go +++ b/pkg/instrumentation/database.go @@ -2,159 +2,21 @@ package instrumentation import ( "context" - "fmt" - "strings" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - ddtrace "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gorm.io/gorm" ) -type spanContextKey string - -var ( - // ParentSpanGormKey is the name of the parent span key - parentSpanGormKey = spanContextKey("trancingParentSpan") - // SpanGormKey is the name of the span key - spanGormKey = spanContextKey("tracingSpan") -) - -// TraceDatabase sets span to gorm settings, returns cloned DB +// TraceDatabase fetches the span from the context and injects it to a new +// database session's statement context. func TraceDatabase(ctx context.Context, db *gorm.DB) *gorm.DB { if ctx == nil { return db } - parentSpan, _ := ddtrace.SpanFromContext(ctx) + parentSpan, _ := tracer.SpanFromContext(ctx) return db.Session(&gorm.Session{ - Context: context.WithValue(db.Statement.Context, parentSpanGormKey, parentSpan), + Context: tracer.ContextWithSpan(db.Statement.Context, parentSpan), }) } - -// InstrumentDatabase adds callbacks for tracing, call TraceDatabase to make it work -func InstrumentDatabase(db *gorm.DB, appName string) { - callbacks := newCallbacks(appName) - - registerCallbacks(db, "create", callbacks) - registerCallbacks(db, "query", callbacks) - registerCallbacks(db, "update", callbacks) - registerCallbacks(db, "delete", callbacks) - registerCallbacks(db, "row", callbacks) -} - -type callbacks struct { - serviceName string -} - -func newCallbacks(appName string) *callbacks { - return &callbacks{ - serviceName: fmt.Sprintf("%s-%s", appName, "mysql"), - } -} - -func (c *callbacks) beforeCreate(db *gorm.DB) { c.before(db, "INSERT", c.serviceName) } -func (c *callbacks) afterCreate(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeQuery(db *gorm.DB) { c.before(db, "SELECT", c.serviceName) } -func (c *callbacks) afterQuery(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeUpdate(db *gorm.DB) { c.before(db, "UPDATE", c.serviceName) } -func (c *callbacks) afterUpdate(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeDelete(db *gorm.DB) { c.before(db, "DELETE", c.serviceName) } -func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) } -func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) } -func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) { - if db.Statement == nil || db.Statement.Context == nil { - return - } - - parentSpan, ok := db.Statement.Context.Value(parentSpanGormKey).(ddtrace.Span) - if !ok { - return - } - - spanOpts := []ddtrace.StartSpanOption{ - ddtrace.ChildOf(parentSpan.Context()), - ddtrace.SpanType(ext.SpanTypeSQL), - ddtrace.ServiceName(serviceName), - } - if operationName == "" { - operationName = strings.Split(db.Statement.SQL.String(), " ")[0] - } - sp := ddtrace.StartSpan(operationName, spanOpts...) - db.Statement.Context = context.WithValue(db.Statement.Context, spanGormKey, sp) -} - -func (c *callbacks) after(db *gorm.DB) { - if db.Statement == nil || db.Statement.Context == nil { - return - } - - sp, ok := db.Statement.Context.Value(spanGormKey).(ddtrace.Span) - if !ok { - return - } - - sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String())) - sp.SetTag("db.table", db.Statement.Table) - sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String())) - sp.SetTag("db.err", db.Error) - sp.SetTag("db.count", db.RowsAffected) - sp.Finish() -} - -func registerCallbacks(db *gorm.DB, name string, c *callbacks) { - var err error - - beforeName := fmt.Sprintf("tracing:%v_before", name) - afterName := fmt.Sprintf("tracing:%v_after", name) - gormCallbackName := fmt.Sprintf("gorm:%v", name) - // gorm does some magic, if you pass CallbackProcessor here - nothing works - switch name { - case "create": - err = db.Callback().Create().Before(gormCallbackName).Register(beforeName, c.beforeCreate) - if err != nil { - return - } - err = db.Callback().Create().After(gormCallbackName).Register(afterName, c.afterCreate) - if err != nil { - return - } - case "query": - err = db.Callback().Query().Before(gormCallbackName).Register(beforeName, c.beforeQuery) - if err != nil { - return - } - err = db.Callback().Query().After(gormCallbackName).Register(afterName, c.afterQuery) - if err != nil { - return - } - case "update": - err = db.Callback().Update().Before(gormCallbackName).Register(beforeName, c.beforeUpdate) - if err != nil { - return - } - err = db.Callback().Update().After(gormCallbackName).Register(afterName, c.afterUpdate) - if err != nil { - return - } - case "delete": - err = db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete) - if err != nil { - return - } - err = db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete) - if err != nil { - return - } - case "row": - err = db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow) - if err != nil { - return - } - err = db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow) - if err != nil { - return - } - } -} diff --git a/pkg/instrumentation/database_test.go b/pkg/instrumentation/database_test.go index 674a647..a8c38be 100644 --- a/pkg/instrumentation/database_test.go +++ b/pkg/instrumentation/database_test.go @@ -5,146 +5,40 @@ import ( "path" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gorm.io/driver/sqlite" "gorm.io/gorm" ) -type TestRecord struct { - ID int - Name string -} - -func TestInstrumentDatabase(t *testing.T) { - expectedSpans := []struct { - name string - tags map[string]interface{} - }{ - { - name: "SELECT", - tags: map[string]interface{}{ - "db.count": int64(-1), - "db.err": nil, - "db.table": "", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`, - "resource.name": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`, - }, - }, - { - name: "INSERT", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`", - "resource.name": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`", - }, - }, - { - name: "UPDATE", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`", - "resource.name": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`", - }, - }, - { - name: "SELECT", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1", - "resource.name": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1", - }, - }, - { - name: "DELETE", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`", - "resource.name": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`", - }, - }, - } - - mt := mocktracer.Start() - defer mt.Stop() +func TestTraceDatabase(t *testing.T) { + const ( + mockOperationName = "test" + mockDBName = "test_db" + ) - dbFile := path.Join(t.TempDir(), "test_db") + dbFile := path.Join(t.TempDir(), mockDBName) db, err := gorm.Open(sqlite.Open(dbFile)) if err != nil { t.Fatalf("Failed to open DB: %s", err) } - InstrumentDatabase(db, "test_app_name") - db = TraceDatabase(context.Background(), db) + mockTracer := mocktracer.Start() + defer mockTracer.Stop() - var ( - testRecord = TestRecord{ID: 1, Name: "test_name"} - updatedRecord = TestRecord{ID: 1, Name: "new_test_name"} - readRecord = &TestRecord{} - ) + mockSpan := tracer.StartSpan(mockOperationName) + defer mockSpan.Finish() - if err = db.AutoMigrate(TestRecord{}); err != nil { - t.Fatalf("Failed to migrate DB: %s", err) - } - - err = db.Begin().Create(&testRecord). - Save(&updatedRecord). - First(&readRecord). - Delete(&testRecord). - Commit().Error - if err != nil { - t.Fatalf("Failed to commit changes on test record: %s", err) - } + mockContext := tracer.ContextWithSpan(context.Background(), mockSpan) + db = TraceDatabase(mockContext, db) - spans := mt.FinishedSpans() - if len(spans) != 5 { - t.Fatalf("Unexpected number of spans: %d", len(spans)) + _, ok := tracer.SpanFromContext(db.Statement.Context) + if !ok { + require.True(t, ok) } - for i := range spans { - actualName := spans[i].OperationName() - actualTags := spans[i].Tags() - - expectedName := expectedSpans[i].name - expectedTags := expectedSpans[i].tags - - if actualName != expectedName { - t.Errorf("Got span: %s, expected: %s", actualName, expectedName) - } - - assert.Equal(t, expectedTags, actualTags, "database tags didn't match") - } -} - -func TestTraceDatabase(t *testing.T) { - dbFile := path.Join(t.TempDir(), "test_db") - db, err := gorm.Open(sqlite.Open(dbFile)) - if err != nil { - t.Fatalf("Failed to open DB: %s", err) - } - - InstrumentDatabase(db, "test_app_name") - db = TraceDatabase(context.Background(), db) - - if sp := db.Statement.Context.Value(parentSpanGormKey); sp == nil { - t.Error("Parent span not set on statement") - } + openSpans := mockTracer.OpenSpans() + require.Len(t, openSpans, 1) + require.Equal(t, openSpans[0].OperationName(), mockOperationName) }