diff --git a/instrumentation/logging/logger.go b/instrumentation/logging/logger.go index 20fb3c27..3f27e8eb 100644 --- a/instrumentation/logging/logger.go +++ b/instrumentation/logging/logger.go @@ -8,7 +8,6 @@ import ( "regexp" "sync" "time" - "unsafe" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/log" @@ -26,7 +25,7 @@ type ( logRecordsMutex sync.RWMutex logRecords []opentracing.LogRecord regex *regexp.Regexp - parentLogger *stdlog.Logger + ctx context.Context } logItem struct { time time.Time @@ -35,7 +34,7 @@ type ( message string } loggerPatchInfo struct { - current io.Writer + current *otWriter previous io.Writer } ) @@ -44,15 +43,12 @@ var ( patchedLoggersMutex sync.Mutex patchedLoggers = map[io.Writer]loggerPatchInfo{} stdLoggerWriter io.Writer - - loggerContextMutex sync.RWMutex - loggerContext = map[uintptr]context.Context{} ) // Patch the standard logger func PatchStandardLogger() { stdLoggerWriter := getStdLoggerWriter() - otWriter := newInstrumentedWriter(nil, stdlog.Prefix()) + otWriter := &otWriter{regex: regexp.MustCompile(fmt.Sprintf(logRegexTemplate, stdlog.Prefix()))} stdlog.SetOutput(io.MultiWriter(stdLoggerWriter, otWriter)) recorders = append(recorders, otWriter) } @@ -64,49 +60,21 @@ func UnpatchStandardLogger() { // Patch a logger func PatchLogger(logger *stdlog.Logger) { - patchedLoggersMutex.Lock() - defer patchedLoggersMutex.Unlock() - currentWriter := getLoggerWriter(logger) - if patchInfo, ok := patchedLoggers[currentWriter]; ok { - currentWriter = patchInfo.previous - } - otWriter := newInstrumentedWriter(logger, logger.Prefix()) - mWriter := io.MultiWriter(currentWriter, otWriter) - logger.SetOutput(mWriter) - recorders = append(recorders, otWriter) - patchedLoggers[mWriter] = loggerPatchInfo{ - current: otWriter, - previous: currentWriter, - } + patchLogger(logger, nil) } // Unpatch a logger func UnpatchLogger(logger *stdlog.Logger) { - patchedLoggersMutex.Lock() - defer patchedLoggersMutex.Unlock() - currentWriter := getLoggerWriter(logger) - if logInfo, ok := patchedLoggers[currentWriter]; ok { - logger.SetOutput(logInfo.previous) - delete(patchedLoggers, currentWriter) - } + unpatchLogger(logger) } // Create a new logger with a context func WithContext(logger *stdlog.Logger, ctx context.Context) *stdlog.Logger { rLogger := stdlog.New(getLoggerWriter(logger), logger.Prefix(), logger.Flags()) - setLoggerContext(rLogger, ctx) - PatchLogger(rLogger) + patchLogger(rLogger, ctx) return rLogger } -// Create a new instrumented writer for loggers -func newInstrumentedWriter(logger *stdlog.Logger, prefix string) *otWriter { - return &otWriter{ - regex: regexp.MustCompile(fmt.Sprintf(logRegexTemplate, prefix)), - parentLogger: logger, - } -} - // Write data to the channel and the base writer func (w *otWriter) Write(p []byte) (n int, err error) { w.process(p) @@ -128,6 +96,42 @@ func (w *otWriter) GetRecords() []opentracing.LogRecord { return w.logRecords } +// Patch logger with optional context +func patchLogger(logger *stdlog.Logger, ctx context.Context) { + unpatchLogger(logger) + + patchedLoggersMutex.Lock() + defer patchedLoggersMutex.Unlock() + + otWriter := &otWriter{ + regex: regexp.MustCompile(fmt.Sprintf(logRegexTemplate, logger.Prefix())), + ctx: ctx, + } + + currentWriter := getLoggerWriter(logger) + newWriter := io.MultiWriter(currentWriter, otWriter) + patchedLoggers[newWriter] = loggerPatchInfo{ + current: otWriter, + previous: currentWriter, + } + + recorders = append(recorders, otWriter) + logger.SetOutput(newWriter) +} + +// Unpatch logger +func unpatchLogger(logger *stdlog.Logger) { + patchedLoggersMutex.Lock() + defer patchedLoggersMutex.Unlock() + + currentWriter := getLoggerWriter(logger) + + if logInfo, ok := patchedLoggers[currentWriter]; ok { + logger.SetOutput(logInfo.previous) + delete(patchedLoggers, currentWriter) + } +} + // Process bytes and create new log items struct to store func (w *otWriter) process(p []byte) { if len(p) == 0 { @@ -185,33 +189,19 @@ func (w *otWriter) storeLogRecord(item *logItem) { fields = append(fields, log.String(tags.EventSource, fmt.Sprintf("%s:%s", item.file, item.lineNumber))) } - if span := opentracing.SpanFromContext(getLoggerContext(w.parentLogger)); span != nil { - span.LogFields(fields...) - return + // If context is found, we try to find the a span from the context and write the logs + if w.ctx != nil { + if span := opentracing.SpanFromContext(w.ctx); span != nil { + span.LogFields(fields...) + return + } } - w.appendLogRecords(opentracing.LogRecord{ - Timestamp: item.time, - Fields: fields, - }) -} -func (w *otWriter) appendLogRecords(record opentracing.LogRecord) { + // If no context, we store the log records for future extraction w.logRecordsMutex.Lock() defer w.logRecordsMutex.Unlock() - w.logRecords = append(w.logRecords, record) -} - -func setLoggerContext(logger *stdlog.Logger, ctx context.Context) { - loggerContextMutex.Lock() - defer loggerContextMutex.Unlock() - loggerContext[uintptr(unsafe.Pointer(logger))] = ctx -} - -func getLoggerContext(logger *stdlog.Logger) context.Context { - loggerContextMutex.RLock() - defer loggerContextMutex.RUnlock() - if ctx, ok := loggerContext[uintptr(unsafe.Pointer(logger))]; ok { - return ctx - } - return context.TODO() + w.logRecords = append(w.logRecords, opentracing.LogRecord{ + Timestamp: item.time, + Fields: fields, + }) }