Skip to content

Commit

Permalink
feat: Replacing custom database ddtrace with dd-trace-go library
Browse files Browse the repository at this point in the history
  • Loading branch information
laynax committed Jun 11, 2024
1 parent bc49b52 commit 06382a5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 272 deletions.
148 changes: 5 additions & 143 deletions pkg/instrumentation/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,159 +2,21 @@ package instrumentation

import (
"context"
"fmt"
"strings"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
ddtrace "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gorm.io/gorm"
)

type spanContextKey string

var (
// ParentSpanGormKey is the name of the parent span key
parentSpanGormKey = spanContextKey("trancingParentSpan")
// SpanGormKey is the name of the span key
spanGormKey = spanContextKey("tracingSpan")
)

// TraceDatabase sets span to gorm settings, returns cloned DB
// TraceDatabase fetches the span from the context and injects it to a new
// database session's statement context.
func TraceDatabase(ctx context.Context, db *gorm.DB) *gorm.DB {
if ctx == nil {
return db
}

parentSpan, _ := ddtrace.SpanFromContext(ctx)
parentSpan, _ := tracer.SpanFromContext(ctx)

return db.Session(&gorm.Session{
Context: context.WithValue(db.Statement.Context, parentSpanGormKey, parentSpan),
Context: tracer.ContextWithSpan(db.Statement.Context, parentSpan),
})
}

// InstrumentDatabase adds callbacks for tracing, call TraceDatabase to make it work
func InstrumentDatabase(db *gorm.DB, appName string) {
callbacks := newCallbacks(appName)

registerCallbacks(db, "create", callbacks)
registerCallbacks(db, "query", callbacks)
registerCallbacks(db, "update", callbacks)
registerCallbacks(db, "delete", callbacks)
registerCallbacks(db, "row", callbacks)
}

type callbacks struct {
serviceName string
}

func newCallbacks(appName string) *callbacks {
return &callbacks{
serviceName: fmt.Sprintf("%s-%s", appName, "mysql"),
}
}

func (c *callbacks) beforeCreate(db *gorm.DB) { c.before(db, "INSERT", c.serviceName) }
func (c *callbacks) afterCreate(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeQuery(db *gorm.DB) { c.before(db, "SELECT", c.serviceName) }
func (c *callbacks) afterQuery(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeUpdate(db *gorm.DB) { c.before(db, "UPDATE", c.serviceName) }
func (c *callbacks) afterUpdate(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeDelete(db *gorm.DB) { c.before(db, "DELETE", c.serviceName) }
func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) }
func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) }
func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) }
func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) {
if db.Statement == nil || db.Statement.Context == nil {
return
}

parentSpan, ok := db.Statement.Context.Value(parentSpanGormKey).(ddtrace.Span)
if !ok {
return
}

spanOpts := []ddtrace.StartSpanOption{
ddtrace.ChildOf(parentSpan.Context()),
ddtrace.SpanType(ext.SpanTypeSQL),
ddtrace.ServiceName(serviceName),
}
if operationName == "" {
operationName = strings.Split(db.Statement.SQL.String(), " ")[0]
}
sp := ddtrace.StartSpan(operationName, spanOpts...)
db.Statement.Context = context.WithValue(db.Statement.Context, spanGormKey, sp)
}

func (c *callbacks) after(db *gorm.DB) {
if db.Statement == nil || db.Statement.Context == nil {
return
}

sp, ok := db.Statement.Context.Value(spanGormKey).(ddtrace.Span)
if !ok {
return
}

sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String()))
sp.SetTag("db.table", db.Statement.Table)
sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String()))
sp.SetTag("db.err", db.Error)
sp.SetTag("db.count", db.RowsAffected)
sp.Finish()
}

func registerCallbacks(db *gorm.DB, name string, c *callbacks) {
var err error

beforeName := fmt.Sprintf("tracing:%v_before", name)
afterName := fmt.Sprintf("tracing:%v_after", name)
gormCallbackName := fmt.Sprintf("gorm:%v", name)
// gorm does some magic, if you pass CallbackProcessor here - nothing works
switch name {
case "create":
err = db.Callback().Create().Before(gormCallbackName).Register(beforeName, c.beforeCreate)
if err != nil {
return
}
err = db.Callback().Create().After(gormCallbackName).Register(afterName, c.afterCreate)
if err != nil {
return
}
case "query":
err = db.Callback().Query().Before(gormCallbackName).Register(beforeName, c.beforeQuery)
if err != nil {
return
}
err = db.Callback().Query().After(gormCallbackName).Register(afterName, c.afterQuery)
if err != nil {
return
}
case "update":
err = db.Callback().Update().Before(gormCallbackName).Register(beforeName, c.beforeUpdate)
if err != nil {
return
}
err = db.Callback().Update().After(gormCallbackName).Register(afterName, c.afterUpdate)
if err != nil {
return
}
case "delete":
err = db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete)
if err != nil {
return
}
err = db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete)
if err != nil {
return
}
case "row":
err = db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow)
if err != nil {
return
}
err = db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow)
if err != nil {
return
}
}
}
146 changes: 20 additions & 126 deletions pkg/instrumentation/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,146 +5,40 @@ import (
"path"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

type TestRecord struct {
ID int
Name string
}

func TestInstrumentDatabase(t *testing.T) {
expectedSpans := []struct {
name string
tags map[string]interface{}
}{
{
name: "SELECT",
tags: map[string]interface{}{
"db.count": int64(-1),
"db.err": nil,
"db.table": "",
"service.name": "test_app_name-mysql",
"span.type": "sql",
"db.query": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`,
"resource.name": `SELECT COUNT(*) FROM SQLITE_MASTER WHERE TYPE='TABLE' AND NAME=?`,
},
},
{
name: "INSERT",
tags: map[string]interface{}{
"db.count": int64(1),
"db.err": nil,
"db.table": "test_records",
"service.name": "test_app_name-mysql",
"span.type": "sql",
"db.query": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`",
"resource.name": "INSERT INTO `TEST_RECORDS` (`NAME`,`ID`) VALUES (?,?) RETURNING `ID`",
},
},
{
name: "UPDATE",
tags: map[string]interface{}{
"db.count": int64(1),
"db.err": nil,
"db.table": "test_records",
"service.name": "test_app_name-mysql",
"span.type": "sql",
"db.query": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`",
"resource.name": "UPDATE `TEST_RECORDS` SET `ID`=?,`NAME`=? WHERE `ID` = ? RETURNING `ID`",
},
},
{
name: "SELECT",
tags: map[string]interface{}{
"db.count": int64(1),
"db.err": nil,
"db.table": "test_records",
"service.name": "test_app_name-mysql",
"span.type": "sql",
"db.query": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1",
"resource.name": "SELECT * FROM `TEST_RECORDS` WHERE `ID` = ? ORDER BY `TEST_RECORDS`.`ID` LIMIT 1",
},
},
{
name: "DELETE",
tags: map[string]interface{}{
"db.count": int64(1),
"db.err": nil,
"db.table": "test_records",
"service.name": "test_app_name-mysql",
"span.type": "sql",
"db.query": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`",
"resource.name": "DELETE FROM `TEST_RECORDS` WHERE `ID` = ? AND `TEST_RECORDS`.`ID` = ? RETURNING `ID`",
},
},
}

mt := mocktracer.Start()
defer mt.Stop()
func TestTraceDatabase(t *testing.T) {
const (
mockOperationName = "test"
mockDBName = "test_db"
)

dbFile := path.Join(t.TempDir(), "test_db")
dbFile := path.Join(t.TempDir(), mockDBName)
db, err := gorm.Open(sqlite.Open(dbFile))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}

InstrumentDatabase(db, "test_app_name")
db = TraceDatabase(context.Background(), db)
mockTracer := mocktracer.Start()
defer mockTracer.Stop()

var (
testRecord = TestRecord{ID: 1, Name: "test_name"}
updatedRecord = TestRecord{ID: 1, Name: "new_test_name"}
readRecord = &TestRecord{}
)
mockSpan := tracer.StartSpan(mockOperationName)
defer mockSpan.Finish()

if err = db.AutoMigrate(TestRecord{}); err != nil {
t.Fatalf("Failed to migrate DB: %s", err)
}

err = db.Begin().Create(&testRecord).
Save(&updatedRecord).
First(&readRecord).
Delete(&testRecord).
Commit().Error
if err != nil {
t.Fatalf("Failed to commit changes on test record: %s", err)
}
mockContext := tracer.ContextWithSpan(context.Background(), mockSpan)
db = TraceDatabase(mockContext, db)

spans := mt.FinishedSpans()
if len(spans) != 5 {
t.Fatalf("Unexpected number of spans: %d", len(spans))
_, ok := tracer.SpanFromContext(db.Statement.Context)
if !ok {
require.True(t, ok)
}

for i := range spans {
actualName := spans[i].OperationName()
actualTags := spans[i].Tags()

expectedName := expectedSpans[i].name
expectedTags := expectedSpans[i].tags

if actualName != expectedName {
t.Errorf("Got span: %s, expected: %s", actualName, expectedName)
}

assert.Equal(t, expectedTags, actualTags, "database tags didn't match")
}
}

func TestTraceDatabase(t *testing.T) {
dbFile := path.Join(t.TempDir(), "test_db")
db, err := gorm.Open(sqlite.Open(dbFile))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}

InstrumentDatabase(db, "test_app_name")
db = TraceDatabase(context.Background(), db)

if sp := db.Statement.Context.Value(parentSpanGormKey); sp == nil {
t.Error("Parent span not set on statement")
}
openSpans := mockTracer.OpenSpans()
require.Len(t, openSpans, 1)
require.Equal(t, openSpans[0].OperationName(), mockOperationName)
}
4 changes: 1 addition & 3 deletions pkg/middleware/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"net/http"

sdkdatabasecontext "github.com/scribd/go-sdk/pkg/context/database"
sdkinstrumentation "github.com/scribd/go-sdk/pkg/instrumentation"

"gorm.io/gorm"
)
Expand All @@ -27,8 +26,7 @@ func NewDatabaseMiddleware(d *gorm.DB) DatabaseMiddleware {
// connection pool to the request context.
func (dm DatabaseMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
db := sdkinstrumentation.TraceDatabase(r.Context(), dm.Database)
ctx := sdkdatabasecontext.ToContext(r.Context(), db)
ctx := sdkdatabasecontext.ToContext(r.Context(), dm.Database)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

0 comments on commit 06382a5

Please sign in to comment.