From 46628ed6318c3234513afb3006171d0c79a03571 Mon Sep 17 00:00:00 2001 From: Daniel Bedrood Date: Tue, 11 Jun 2024 15:54:03 +0200 Subject: [PATCH] feat: Replacing custom database ddtrace with dd-trace-go library --- pkg/instrumentation/database.go | 148 +-------------------------- pkg/instrumentation/database_test.go | 146 ++++---------------------- 2 files changed, 25 insertions(+), 269 deletions(-) diff --git a/pkg/instrumentation/database.go b/pkg/instrumentation/database.go index 931f217..0c21f23 100644 --- a/pkg/instrumentation/database.go +++ b/pkg/instrumentation/database.go @@ -2,159 +2,21 @@ package instrumentation import ( "context" - "fmt" - "strings" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - ddtrace "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gorm.io/gorm" ) -type spanContextKey string - -var ( - // ParentSpanGormKey is the name of the parent span key - parentSpanGormKey = spanContextKey("trancingParentSpan") - // SpanGormKey is the name of the span key - spanGormKey = spanContextKey("tracingSpan") -) - -// TraceDatabase sets span to gorm settings, returns cloned DB +// TraceDatabase fetches the span from the context and injects it to a new +// database session's statement context. func TraceDatabase(ctx context.Context, db *gorm.DB) *gorm.DB { if ctx == nil { return db } - parentSpan, _ := ddtrace.SpanFromContext(ctx) + parentSpan, _ := tracer.SpanFromContext(ctx) return db.Session(&gorm.Session{ - Context: context.WithValue(db.Statement.Context, parentSpanGormKey, parentSpan), + Context: tracer.ContextWithSpan(db.Statement.Context, parentSpan), }) } - -// InstrumentDatabase adds callbacks for tracing, call TraceDatabase to make it work -func InstrumentDatabase(db *gorm.DB, appName string) { - callbacks := newCallbacks(appName) - - registerCallbacks(db, "create", callbacks) - registerCallbacks(db, "query", callbacks) - registerCallbacks(db, "update", callbacks) - registerCallbacks(db, "delete", callbacks) - registerCallbacks(db, "row", callbacks) -} - -type callbacks struct { - serviceName string -} - -func newCallbacks(appName string) *callbacks { - return &callbacks{ - serviceName: fmt.Sprintf("%s-%s", appName, "mysql"), - } -} - -func (c *callbacks) beforeCreate(db *gorm.DB) { c.before(db, "INSERT", c.serviceName) } -func (c *callbacks) afterCreate(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeQuery(db *gorm.DB) { c.before(db, "SELECT", c.serviceName) } -func (c *callbacks) afterQuery(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeUpdate(db *gorm.DB) { c.before(db, "UPDATE", c.serviceName) } -func (c *callbacks) afterUpdate(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeDelete(db *gorm.DB) { c.before(db, "DELETE", c.serviceName) } -func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) } -func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) } -func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) } -func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) { - if db.Statement == nil || db.Statement.Context == nil { - return - } - - parentSpan, ok := db.Statement.Context.Value(parentSpanGormKey).(ddtrace.Span) - if !ok { - return - } - - spanOpts := []ddtrace.StartSpanOption{ - ddtrace.ChildOf(parentSpan.Context()), - ddtrace.SpanType(ext.SpanTypeSQL), - ddtrace.ServiceName(serviceName), - } - if operationName == "" { - operationName = strings.Split(db.Statement.SQL.String(), " ")[0] - } - sp := ddtrace.StartSpan(operationName, spanOpts...) - db.Statement.Context = context.WithValue(db.Statement.Context, spanGormKey, sp) -} - -func (c *callbacks) after(db *gorm.DB) { - if db.Statement == nil || db.Statement.Context == nil { - return - } - - sp, ok := db.Statement.Context.Value(spanGormKey).(ddtrace.Span) - if !ok { - return - } - - sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String())) - sp.SetTag("db.table", db.Statement.Table) - sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String())) - sp.SetTag("db.err", db.Error) - sp.SetTag("db.count", db.RowsAffected) - sp.Finish() -} - -func registerCallbacks(db *gorm.DB, name string, c *callbacks) { - var err error - - beforeName := fmt.Sprintf("tracing:%v_before", name) - afterName := fmt.Sprintf("tracing:%v_after", name) - gormCallbackName := fmt.Sprintf("gorm:%v", name) - // gorm does some magic, if you pass CallbackProcessor here - nothing works - switch name { - case "create": - err = db.Callback().Create().Before(gormCallbackName).Register(beforeName, c.beforeCreate) - if err != nil { - return - } - err = db.Callback().Create().After(gormCallbackName).Register(afterName, c.afterCreate) - if err != nil { - return - } - case "query": - err = db.Callback().Query().Before(gormCallbackName).Register(beforeName, c.beforeQuery) - if err != nil { - return - } - err = db.Callback().Query().After(gormCallbackName).Register(afterName, c.afterQuery) - if err != nil { - return - } - case "update": - err = db.Callback().Update().Before(gormCallbackName).Register(beforeName, c.beforeUpdate) - if err != nil { - return - } - err = db.Callback().Update().After(gormCallbackName).Register(afterName, c.afterUpdate) - if err != nil { - return - } - case "delete": - err = db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete) - if err != nil { - return - } - err = db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete) - if err != nil { - return - } - case "row": - err = db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow) - if err != nil { - return - } - err = db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow) - if err != nil { - return - } - } -} diff --git a/pkg/instrumentation/database_test.go b/pkg/instrumentation/database_test.go index 674a647..a8c38be 100644 --- a/pkg/instrumentation/database_test.go +++ b/pkg/instrumentation/database_test.go @@ -5,146 +5,40 @@ import ( "path" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gorm.io/driver/sqlite" "gorm.io/gorm" ) -type TestRecord struct { - ID int - Name string -} - -func TestInstrumentDatabase(t *testing.T) { - expectedSpans := []struct { - name string - tags map[string]interface{} - }{ - { - name: "SELECT", - tags: map[string]interface{}{ - "db.count": int64(-1), - "db.err": nil, - "db.table": "", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`, - "resource.name": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`, - }, - }, - { - name: "INSERT", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`", - "resource.name": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`", - }, - }, - { - name: "UPDATE", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`", - "resource.name": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`", - }, - }, - { - name: "SELECT", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1", - "resource.name": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1", - }, - }, - { - name: "DELETE", - tags: map[string]interface{}{ - "db.count": int64(1), - "db.err": nil, - "db.table": "test_records", - "service.name": "test_app_name-mysql", - "span.type": "sql", - "db.query": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`", - "resource.name": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`", - }, - }, - } - - mt := mocktracer.Start() - defer mt.Stop() +func TestTraceDatabase(t *testing.T) { + const ( + mockOperationName = "test" + mockDBName = "test_db" + ) - dbFile := path.Join(t.TempDir(), "test_db") + dbFile := path.Join(t.TempDir(), mockDBName) db, err := gorm.Open(sqlite.Open(dbFile)) if err != nil { t.Fatalf("Failed to open DB: %s", err) } - InstrumentDatabase(db, "test_app_name") - db = TraceDatabase(context.Background(), db) + mockTracer := mocktracer.Start() + defer mockTracer.Stop() - var ( - testRecord = TestRecord{ID: 1, Name: "test_name"} - updatedRecord = TestRecord{ID: 1, Name: "new_test_name"} - readRecord = &TestRecord{} - ) + mockSpan := tracer.StartSpan(mockOperationName) + defer mockSpan.Finish() - if err = db.AutoMigrate(TestRecord{}); err != nil { - t.Fatalf("Failed to migrate DB: %s", err) - } - - err = db.Begin().Create(&testRecord). - Save(&updatedRecord). - First(&readRecord). - Delete(&testRecord). - Commit().Error - if err != nil { - t.Fatalf("Failed to commit changes on test record: %s", err) - } + mockContext := tracer.ContextWithSpan(context.Background(), mockSpan) + db = TraceDatabase(mockContext, db) - spans := mt.FinishedSpans() - if len(spans) != 5 { - t.Fatalf("Unexpected number of spans: %d", len(spans)) + _, ok := tracer.SpanFromContext(db.Statement.Context) + if !ok { + require.True(t, ok) } - for i := range spans { - actualName := spans[i].OperationName() - actualTags := spans[i].Tags() - - expectedName := expectedSpans[i].name - expectedTags := expectedSpans[i].tags - - if actualName != expectedName { - t.Errorf("Got span: %s, expected: %s", actualName, expectedName) - } - - assert.Equal(t, expectedTags, actualTags, "database tags didn't match") - } -} - -func TestTraceDatabase(t *testing.T) { - dbFile := path.Join(t.TempDir(), "test_db") - db, err := gorm.Open(sqlite.Open(dbFile)) - if err != nil { - t.Fatalf("Failed to open DB: %s", err) - } - - InstrumentDatabase(db, "test_app_name") - db = TraceDatabase(context.Background(), db) - - if sp := db.Statement.Context.Value(parentSpanGormKey); sp == nil { - t.Error("Parent span not set on statement") - } + openSpans := mockTracer.OpenSpans() + require.Len(t, openSpans, 1) + require.Equal(t, openSpans[0].OperationName(), mockOperationName) }