diff --git a/README.md b/README.md index 62bcf63e8..2256979db 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,41 @@ Install [jq](https://stedolan.github.io/jq) so that the parameters can get parse make test ``` +## customizing Logging Tags + +If you would like to ensure that certain tags are always present in the logs, `RegisterClientLogContextHook` can be used in your init function. See example below. +``` +import "github.com/snowflakedb/gosnowflake" + +func init() { + // each time the logger is used, the logs will contain a REQUEST_ID field with requestID the value extracted + // from the context + gosnowflake.RegisterClientLogContextHook("REQUEST_ID", func(ctx context.Context) interface{} { + return requestIdFromContext(ctx) + }) +} +``` + +## Setting Log Level +If you want to change the log level, `SetLogLevel` can be used in your init function like this: +``` +import "github.com/snowflakedb/gosnowflake" + +func init() { + // The following line changes the log level to debug + _ = gosnowflake.GetLogger().SetLogLevel("debug") +} +``` +The following is a list of options you can pass in to set the level from least to most verbose: +- `"OFF"` +- `"error"` +- `"warn"` +- `"print"` +- `"trace"` +- `"debug"` +- `"info"` + + ## Capturing Code Coverage Configure your testing environment as described above and run ``make cov``. The coverage percentage will be printed on the console when the testing completes. diff --git a/auth.go b/auth.go index cc839ca2c..ec70fb1dd 100644 --- a/auth.go +++ b/auth.go @@ -225,7 +225,7 @@ func postAuth( params.Add(requestGUIDKey, NewUUID().String()) fullURL := sr.getFullURL(loginRequestPath, params) - logger.Infof("full URL: %v", fullURL) + logger.WithContext(ctx).Infof("full URL: %v", fullURL) resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount) if err != nil { return nil, err @@ -235,7 +235,7 @@ func postAuth( var respd authResponse err = json.NewDecoder(resp.Body).Decode(&respd) if err != nil { - logger.Errorf("failed to decode JSON. err: %v", err) + logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) return nil, err } return &respd, nil @@ -260,11 +260,11 @@ func postAuth( } b, err := io.ReadAll(resp.Body) if err != nil { - logger.Errorf("failed to extract HTTP response body. err: %v", err) + logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return nil, err } - logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) - logger.Infof("Header: %v", resp.Header) + logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b) + logger.WithContext(ctx).Infof("Header: %v", resp.Header) return nil, &SnowflakeError{ Number: ErrFailedToAuth, SQLState: SQLStateConnectionRejected, @@ -293,7 +293,7 @@ func authenticate( proofKey []byte, ) (resp *authResponseMain, err error) { if sc.cfg.Authenticator == AuthTypeTokenAccessor { - logger.Info("Bypass authentication using existing token from token accessor") + logger.WithContext(ctx).Info("Bypass authentication using existing token from token accessor") sessionInfo := authResponseSessionInfo{ DatabaseName: sc.cfg.Database, SchemaName: sc.cfg.Schema, @@ -350,7 +350,7 @@ func authenticate( params.Add("roleName", sc.cfg.Role) } - logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v", + logger.WithContext(ctx).WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v", params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String()) respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout) @@ -358,7 +358,7 @@ func authenticate( return nil, err } if !respd.Success { - logger.Errorln("Authentication FAILED") + logger.WithContext(ctx).Errorln("Authentication FAILED") sc.rest.TokenAccessor.SetTokens("", "", -1) if sessionParameters[clientRequestMfaToken] == true { deleteCredential(sc, mfaToken) @@ -377,7 +377,7 @@ func authenticate( Message: respd.Message, }).exceptionTelemetry(sc) } - logger.Info("Authentication SUCCESS") + logger.WithContext(ctx).Info("Authentication SUCCESS") sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID) if sessionParameters[clientRequestMfaToken] == true { token := respd.Data.MfaToken @@ -439,7 +439,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface } requestMain.Token = jwtTokenString case AuthTypeSnowflake: - logger.Info("Username and password") + logger.WithContext(sc.ctx).Info("Username and password") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password switch { @@ -450,7 +450,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface requestMain.ExtAuthnDuoMethod = "passcode" } case AuthTypeUsernamePasswordMFA: - logger.Info("Username and password MFA") + logger.WithContext(sc.ctx).Info("Username and password MFA") requestMain.LoginName = sc.cfg.User requestMain.Password = sc.cfg.Password if sc.cfg.MfaToken != "" { @@ -527,7 +527,7 @@ func authenticateWithConfig(sc *snowflakeConn) error { } } - logger.Infof("Authenticating via %v", sc.cfg.Authenticator.String()) + logger.WithContext(sc.ctx).Infof("Authenticating via %v", sc.cfg.Authenticator.String()) switch sc.cfg.Authenticator { case AuthTypeExternalBrowser: if sc.cfg.IDToken == "" { diff --git a/authexternalbrowser.go b/authexternalbrowser.go index ac53c3707..cc76c8c4c 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -131,7 +131,7 @@ func getIdpURLProofKey( return "", "", err } if !respd.Success { - logger.Errorln("Authentication FAILED") + logger.WithContext(ctx).Errorln("Authentication FAILED") sr.TokenAccessor.SetTokens("", "", -1) code, err := strconv.Atoi(respd.Code) if err != nil { @@ -287,7 +287,7 @@ func doAuthenticateByExternalBrowser( n, err := c.Read(b) if err != nil { if err != io.EOF { - logger.Infof("error reading from socket. err: %v", err) + logger.WithContext(ctx).Infof("error reading from socket. err: %v", err) errAccept = &SnowflakeError{ Number: ErrFailedToGetExternalBrowserResponse, SQLState: SQLStateConnectionRejected, diff --git a/authokta.go b/authokta.go index 818753af8..d3ea90fce 100644 --- a/authokta.go +++ b/authokta.go @@ -92,7 +92,7 @@ func authenticateBySAML( return nil, err } if !respd.Success { - logger.Errorln("Authentication FAILED") + logger.WithContext(ctx).Errorln("Authentication FAILED") sr.TokenAccessor.SetTokens("", "", -1) code, err := strconv.Atoi(respd.Code) if err != nil { @@ -215,7 +215,7 @@ func postAuthSAML( params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String()) fullURL := sr.getFullURL(authenticatorRequestPath, params) - logger.Infof("fullURL: %v", fullURL) + logger.WithContext(ctx).Infof("fullURL: %v", fullURL) resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil) if err != nil { return nil, err @@ -269,7 +269,7 @@ func postAuthOKTA( fullURL string, timeout time.Duration) ( data *authOKTAResponse, err error) { - logger.Infof("fullURL: %v", fullURL) + logger.WithContext(ctx).Infof("fullURL: %v", fullURL) targetURL, err := url.Parse(fullURL) if err != nil { return nil, err @@ -290,7 +290,7 @@ func postAuthOKTA( } _, err = io.ReadAll(resp.Body) if err != nil { - logger.Errorf("failed to extract HTTP response body. err: %v", err) + logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err) return nil, err } logger.WithContext(ctx).Infof("HTTP: %v, URL: %v", resp.StatusCode, fullURL) diff --git a/bind_uploader.go b/bind_uploader.go index 740290957..09ce44fd3 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -175,7 +175,7 @@ func (bu *bindUploader) createCSVRecord(data []interface{}) []byte { if ok { b.WriteString(escapeForCSV(value)) } else if !reflect.ValueOf(data[i]).IsNil() { - logger.Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i]) + logger.WithContext(bu.ctx).Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i]) } } b.WriteString("\n") diff --git a/chunk_downloader.go b/chunk_downloader.go index c167e4e0c..77a04bedb 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -124,8 +124,8 @@ func (scd *snowflakeChunkDownloader) start() error { // start downloading chunks if exists chunkMetaLen := len(scd.ChunkMetas) if chunkMetaLen > 0 { - logger.Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers) - logger.Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize()) + logger.WithContext(scd.ctx).Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers) + logger.WithContext(scd.ctx).Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize()) scd.ChunksMutex = &sync.Mutex{} scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) scd.Chunks = make(map[int][]chunkRowType) @@ -133,7 +133,7 @@ func (scd *snowflakeChunkDownloader) start() error { scd.ChunksError = make(chan *chunkError, MaxChunkDownloadWorkers) for i := 0; i < chunkMetaLen; i++ { chunk := scd.ChunkMetas[i] - logger.Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v", + logger.WithContext(scd.ctx).Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v", i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat) scd.ChunksChan <- i } @@ -147,11 +147,11 @@ func (scd *snowflakeChunkDownloader) start() error { func (scd *snowflakeChunkDownloader) schedule() { select { case nextIdx := <-scd.ChunksChan: - logger.Infof("schedule chunk: %v", nextIdx+1) + logger.WithContext(scd.ctx).Infof("schedule chunk: %v", nextIdx+1) go scd.FuncDownload(scd.ctx, scd, nextIdx) default: // no more download - logger.Info("no more download") + logger.WithContext(scd.ctx).Info("no more download") } } @@ -164,15 +164,15 @@ func (scd *snowflakeChunkDownloader) checkErrorRetry() (err error) { // add the index to the chunks channel so that the download will be retried. go scd.FuncDownload(scd.ctx, scd, errc.Index) scd.ChunksErrorCounter++ - logger.Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...", + logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...", errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter) } else { scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc) - logger.Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error) + logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error) return errc.Error } default: - logger.Info("no error is detected.") + logger.WithContext(scd.ctx).Info("no error is detected.") } return nil } @@ -195,7 +195,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { } for scd.Chunks[scd.CurrentChunkIndex] == nil { - logger.Debugf("waiting for chunk idx: %v/%v", + logger.WithContext(scd.ctx).Debugf("waiting for chunk idx: %v/%v", scd.CurrentChunkIndex+1, len(scd.ChunkMetas)) if err := scd.checkErrorRetry(); err != nil { @@ -207,7 +207,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { // 1) one chunk download finishes or 2) an error occurs. scd.DoneDownloadCond.Wait() } - logger.Debugf("ready: chunk %v", scd.CurrentChunkIndex+1) + logger.WithContext(scd.ctx).Debugf("ready: chunk %v", scd.CurrentChunkIndex+1) scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex] scd.ChunksMutex.Unlock() scd.CurrentChunkSize = len(scd.CurrentChunk) @@ -216,7 +216,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { scd.schedule() } - logger.Debugf("no more data") + logger.WithContext(scd.ctx).Debugf("no more data") if len(scd.ChunkMetas) > 0 { close(scd.ChunksError) close(scd.ChunksChan) @@ -342,11 +342,11 @@ func (r *largeResultSetReader) Read(p []byte) (n int, err error) { } func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { - logger.Infof("download start chunk: %v", idx+1) + logger.WithContext(ctx).Infof("download start chunk: %v", idx+1) defer scd.DoneDownloadCond.Broadcast() if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil { - logger.Errorf( + logger.WithContext(ctx).Errorf( "failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err) scd.ChunksError <- &chunkError{Index: idx, Error: err} } else if scd.ctx.Err() == context.Canceled || scd.ctx.Err() == context.DeadlineExceeded { @@ -357,9 +357,9 @@ func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error { headers := make(map[string]string) if len(scd.ChunkHeader) > 0 { - logger.Debug("chunk header is provided.") + logger.WithContext(ctx).Debug("chunk header is provided.") for k, v := range scd.ChunkHeader { - logger.Debugf("adding header: %v, value: %v", k, v) + logger.WithContext(ctx).Debugf("adding header: %v, value: %v", k, v) headers[k] = v } @@ -374,14 +374,14 @@ func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx } bufStream := bufio.NewReader(resp.Body) defer resp.Body.Close() - logger.Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) + logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) if resp.StatusCode != http.StatusOK { b, err := io.ReadAll(bufStream) if err != nil { return err } - logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b) - logger.Infof("Header: %v", resp.Header) + logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b) + logger.WithContext(ctx).Infof("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToGetChunk, SQLState: SQLStateConnectionFailure, @@ -463,7 +463,7 @@ func decodeChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int, bu return err } } - logger.Debugf( + logger.WithContext(scd.ctx).Debugf( "decoded %d rows w/ %d bytes in %s (chunk %v)", scd.ChunkMetas[idx].RowCount, scd.ChunkMetas[idx].UncompressedSize, diff --git a/connection.go b/connection.go index 01d8aba83..beb198f77 100644 --- a/connection.go +++ b/connection.go @@ -98,7 +98,7 @@ func (sc *snowflakeConn) exec( queryContext, err := buildQueryContext(sc.queryContextCache) if err != nil { - logger.Errorf("error while building query context: %v", err) + logger.WithContext(ctx).Errorf("error while building query context: %v", err) } req := execRequest{ SQLText: query, @@ -163,7 +163,7 @@ func (sc *snowflakeConn) exec( if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil { queryContext, err := extractQueryContext(data) if err != nil { - logger.Errorf("error while decoding query context: ", err) + logger.WithContext(ctx).Errorf("error while decoding query context: %v", err) } else { sc.queryContextCache.add(sc, queryContext.Entries...) } @@ -272,7 +272,7 @@ func (sc *snowflakeConn) Close() (err error) { if sc.cfg != nil && !sc.cfg.KeepSessionAlive { if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { - logger.Error(err) + logger.WithContext(sc.ctx).Error(err) } } return nil @@ -350,7 +350,7 @@ func (sc *snowflakeConn) ExecContext( } return driver.ResultNoRows, nil } - logger.Debug("DDL") + logger.WithContext(ctx).Debug("DDL") if isStatementContext(ctx) { return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil } @@ -571,7 +571,7 @@ func (w *wrapReader) Close() error { func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { headers := make(map[string]string) if len(asb.scd.ChunkHeader) > 0 { - logger.Debug("chunk header is provided") + logger.WithContext(ctx).Debug("chunk header is provided") for k, v := range asb.scd.ChunkHeader { logger.Debugf("adding header: %v, value: %v", k, v) @@ -586,7 +586,7 @@ func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) erro if err != nil { return err } - logger.Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL) + logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL) if resp.StatusCode != http.StatusOK { defer resp.Body.Close() b, err := io.ReadAll(resp.Body) @@ -594,8 +594,8 @@ func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) erro return err } - logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b) - logger.Infof("Header: %v", resp.Header) + logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b) + logger.WithContext(ctx).Infof("Header: %v", resp.Header) return &SnowflakeError{ Number: ErrFailedToGetChunk, SQLState: SQLStateConnectionFailure, diff --git a/connection_util.go b/connection_util.go index 54390522a..4ac008ea8 100644 --- a/connection_util.go +++ b/connection_util.go @@ -159,7 +159,7 @@ func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParamet v = vv } } - logger.Debugf("parameter. name: %v, value: %v", param.Name, v) + logger.WithContext(sc.ctx).Debugf("parameter. name: %v, value: %v", param.Name, v) paramsMutex.Lock() sc.cfg.Params[strings.ToLower(param.Name)] = &v paramsMutex.Unlock() @@ -288,12 +288,12 @@ func populateChunkDownloader( func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error { ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host) - logger.Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) + logger.WithContext(sc.ctx).Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer) if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil { return err } ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v" - logger.Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) + logger.WithContext(sc.ctx).Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate) if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil { return err } diff --git a/converter.go b/converter.go index c00069736..a8d2c5692 100644 --- a/converter.go +++ b/converter.go @@ -137,7 +137,7 @@ func snowflakeTypeToGo(ctx context.Context, dbtype snowflakeType, scale int64, f return reflect.TypeOf("") } if len(fields) != 1 { - logger.Warn("Unexpected fields number: " + strconv.Itoa(len(fields))) + logger.WithContext(ctx).Warn("Unexpected fields number: " + strconv.Itoa(len(fields))) return reflect.TypeOf("") } switch getSnowflakeType(fields[0].Type) { @@ -173,7 +173,7 @@ func snowflakeTypeToGo(ctx context.Context, dbtype snowflakeType, scale int64, f } return reflect.TypeOf(map[any]any{}) } - logger.Errorf("unsupported dbtype is specified. %v", dbtype) + logger.WithContext(ctx).Errorf("unsupported dbtype is specified. %v", dbtype) return reflect.TypeOf("") } @@ -200,7 +200,7 @@ func snowflakeTypeToGoForMaps[K comparable](ctx context.Context, valueMetadata f case timeType, dateType, timestampTzType, timestampNtzType, timestampLtzType: return reflect.TypeOf(map[K]time.Time{}) } - logger.Errorf("unsupported dbtype is specified for map value") + logger.WithContext(ctx).Errorf("unsupported dbtype is specified for map value") return reflect.TypeOf("") } @@ -2084,7 +2084,7 @@ func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocat if col.(*array.String).IsValid(i) { stringValue := col.(*array.String).Value(i) if !utf8.ValidString(stringValue) { - logger.Error("Invalid UTF-8 characters detected while reading query response, column: ", srcColumnMeta.Name) + logger.WithContext(ctx).Error("Invalid UTF-8 characters detected while reading query response, column: ", srcColumnMeta.Name) stringValue = strings.ToValidUTF8(stringValue, "�") } tb.Append(stringValue) diff --git a/driver.go b/driver.go index 78cb69f4d..263a1394a 100644 --- a/driver.go +++ b/driver.go @@ -35,7 +35,7 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri if config.Tracing != "" { logger.SetLogLevel(config.Tracing) } - logger.Info("OpenWithConfig") + logger.WithContext(ctx).Info("OpenWithConfig") sc, err := buildSnowflakeConn(ctx, config) if err != nil { return nil, err diff --git a/errors.go b/errors.go index b41ad2723..73543ac30 100644 --- a/errors.go +++ b/errors.go @@ -75,7 +75,7 @@ func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *teleme func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError { data := se.generateTelemetryExceptionData() if err := se.sendExceptionTelemetry(sc, data); err != nil { - logger.Debugf("failed to log to telemetry: %v", data) + logger.WithContext(sc.ctx).Debugf("failed to log to telemetry: %v", data) } return se } diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 368b5c059..7d894b3d9 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -686,13 +686,13 @@ func (sfa *snowflakeFileTransferAgent) upload( } if len(smallFileMetadata) > 0 { - logger.Infof("uploading %v small files", len(smallFileMetadata)) + logger.WithContext(sfa.sc.ctx).Infof("uploading %v small files", len(smallFileMetadata)) if err = sfa.uploadFilesParallel(smallFileMetadata); err != nil { return err } } if len(largeFileMetadata) > 0 { - logger.Infof("uploading %v large files", len(largeFileMetadata)) + logger.WithContext(sfa.sc.ctx).Infof("uploading %v large files", len(largeFileMetadata)) if err = sfa.uploadFilesSequential(largeFileMetadata); err != nil { return err } diff --git a/log.go b/log.go index b48294cb6..cf1774c5f 100644 --- a/log.go +++ b/log.go @@ -20,7 +20,23 @@ const SFSessionIDKey contextKey = "LOG_SESSION_ID" // SFSessionUserKey is context key of user id of a session const SFSessionUserKey contextKey = "LOG_USER" -// LogKeys these keys in context should be included in logging messages when using logger.WithContext +// map which stores a string which will be used as a log key to the function which +// will be called to get the log value out of the context +var clientLogContextHooks = map[string]ClientLogContextHook{} + +// ClientLogContextHook is a client-defined hook that can be used to insert log +// fields based on the Context. +type ClientLogContextHook func(context.Context) string + +// RegisterLogContextHook registers a hook that can be used to extract fields +// from the Context and associated with log messages using the provided key. This +// function is not thread-safe and should only be called on startup. +func RegisterLogContextHook(contextKey string, ctxExtractor ClientLogContextHook) { + clientLogContextHooks[contextKey] = ctxExtractor +} + +// LogKeys registers string-typed context keys to be written to the logs when +// logger.WithContext is used var LogKeys = [...]contextKey{SFSessionIDKey, SFSessionUserKey} // SFLogger Snowflake logger interface to expose FieldLogger defined in logrus @@ -425,5 +441,12 @@ func context2Fields(ctx context.Context) *rlog.Fields { fields[string(LogKeys[i])] = ctx.Value(LogKeys[i]) } } + + for key, hook := range clientLogContextHooks { + if value := hook(ctx); value != "" { + fields[key] = value + } + } + return &fields } diff --git a/log_test.go b/log_test.go index e4ab43b88..727bb494b 100644 --- a/log_test.go +++ b/log_test.go @@ -4,7 +4,9 @@ package gosnowflake import ( "bytes" + "context" "errors" + "fmt" "strings" "testing" "time" @@ -251,3 +253,73 @@ func TestLogLevelFunctions(t *testing.T) { t.Fatalf("unexpected output in log: %v", strbuf) } } + +type testRequestIDCtxKey struct{} + +func TestLogKeysDefault(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + ctx := context.Background() + + // set the sessionID on the context to see if we have it in the logs + sessionIDContextValue := "sessionID" + ctx = context.WithValue(ctx, SFSessionIDKey, sessionIDContextValue) + + userContextValue := "madison" + ctx = context.WithValue(ctx, SFSessionUserKey, userContextValue) + + // base case (not using RegisterContextVariableToLog to add additional types ) + logger.WithContext(ctx).Info("test") + var strbuf = buf.String() + if !strings.Contains(strbuf, fmt.Sprintf("%s=%s", SFSessionIDKey, sessionIDContextValue)) { + t.Fatalf("expected that sfSessionIdKey would be in logs if logger.WithContext was used, but got: %v", strbuf) + } + if !strings.Contains(strbuf, fmt.Sprintf("%s=%s", SFSessionUserKey, userContextValue)) { + t.Fatalf("expected that SFSessionUserKey would be in logs if logger.WithContext was used, but got: %v", strbuf) + } +} + +func TestLogKeysWithRegisterContextVariableToLog(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + ctx := context.Background() + + // set the sessionID on the context to see if we have it in the logs + sessionIDContextValue := "sessionID" + ctx = context.WithValue(ctx, SFSessionIDKey, sessionIDContextValue) + + userContextValue := "testUser" + ctx = context.WithValue(ctx, SFSessionUserKey, userContextValue) + + // test that RegisterContextVariableToLog works with non string keys + logKey := "REQUEST_ID" + contextIntVal := 123 + ctx = context.WithValue(ctx, testRequestIDCtxKey{}, contextIntVal) + + getRequestKeyFunc := func(ctx context.Context) string { + if requestContext, ok := ctx.Value(testRequestIDCtxKey{}).(int); ok { + return fmt.Sprint(requestContext) + } + return "" + } + + RegisterLogContextHook(logKey, getRequestKeyFunc) + + // base case (not using RegisterContextVariableToLog to add additional types ) + logger.WithContext(ctx).Info("test") + var strbuf = buf.String() + + if !strings.Contains(strbuf, fmt.Sprintf("%s=%s", SFSessionIDKey, sessionIDContextValue)) { + t.Fatalf("expected that sfSessionIdKey would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) + } + if !strings.Contains(strbuf, fmt.Sprintf("%s=%s", SFSessionUserKey, userContextValue)) { + t.Fatalf("expected that SFSessionUserKey would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) + } + if !strings.Contains(strbuf, fmt.Sprintf("%s=%s", logKey, fmt.Sprint(contextIntVal))) { + t.Fatalf("expected that REQUEST_ID would be in logs if logger.WithContext and RegisterContextVariableToLog was used, but got: %v", strbuf) + } +} diff --git a/multistatement.go b/multistatement.go index ce9d9910b..51a10f6bc 100644 --- a/multistatement.go +++ b/multistatement.go @@ -51,7 +51,7 @@ func (sc *snowflakeConn) handleMultiExec( if isDml(childResultType) { childData, err := sc.getQueryResultResp(ctx, resultPath) if err != nil { - logger.Errorf("error: %v", err) + logger.WithContext(ctx).Errorf("error: %v", err) return nil, err } if childData != nil && !childData.Success { diff --git a/ocsp.go b/ocsp.go index b700a416d..89c5d8ab6 100644 --- a/ocsp.go +++ b/ocsp.go @@ -382,28 +382,28 @@ func checkOCSPCacheServer( headers := make(map[string]string) res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultMaxRetryCount, defaultTimeProvider, nil).execute() if err != nil { - logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) + logger.WithContext(ctx).Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) return nil, &ocspStatus{ code: ocspFailedSubmit, err: err, } } defer res.Body.Close() - logger.Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode) + logger.WithContext(ctx).Debugf("StatusCode from OCSP Cache Server: %v", res.StatusCode) if res.StatusCode != http.StatusOK { return nil, &ocspStatus{ code: ocspFailedResponse, err: fmt.Errorf("HTTP code is not OK. %v: %v", res.StatusCode, res.Status), } } - logger.Debugf("reading contents") + logger.WithContext(ctx).Debugf("reading contents") dec := json.NewDecoder(res.Body) for { if err := dec.Decode(&respd); err == io.EOF { break } else if err != nil { - logger.Errorf("failed to decode OCSP cache. %v", err) + logger.WithContext(ctx).Errorf("failed to decode OCSP cache. %v", err) return nil, &ocspStatus{ code: ocspFailedExtractResponse, err: err, @@ -451,7 +451,7 @@ func retryOCSP( } } defer res.Body.Close() - logger.Debugf("StatusCode from OCSP Server: %v\n", res.StatusCode) + logger.WithContext(ctx).Debugf("StatusCode from OCSP Server: %v\n", res.StatusCode) if res.StatusCode != http.StatusOK { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedResponse, @@ -467,12 +467,12 @@ func retryOCSP( } ocspRes, err = ocsp.ParseResponse(ocspResBytes, issuer) if err != nil { - logger.Warnf("error when parsing ocsp response: %v", err) - logger.Warnf("performing GET fallback request to OCSP") + logger.WithContext(ctx).Warnf("error when parsing ocsp response: %v", err) + logger.WithContext(ctx).Warnf("performing GET fallback request to OCSP") return fallbackRetryOCSPToGETRequest(ctx, client, req, ocspHost, headers, issuer, totalTimeout) } - logger.Debugf("OCSP Status from server: %v", printStatus(ocspRes)) + logger.WithContext(ctx).Debugf("OCSP Status from server: %v", printStatus(ocspRes)) return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, } @@ -504,7 +504,7 @@ func fallbackRetryOCSPToGETRequest( } } defer res.Body.Close() - logger.Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode) + logger.WithContext(ctx).Debugf("GET fallback StatusCode from OCSP Server: %v", res.StatusCode) if res.StatusCode != http.StatusOK { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedResponse, @@ -526,7 +526,7 @@ func fallbackRetryOCSPToGETRequest( } } - logger.Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes)) + logger.WithContext(ctx).Debugf("GET fallback OCSP Status from server: %v", printStatus(ocspRes)) return ocspRes, ocspResBytes, &ocspStatus{ code: ocspSuccess, } @@ -558,7 +558,7 @@ func fullOCSPURL(url *url.URL) string { // getRevocationStatus checks the certificate revocation status for subject using issuer certificate. func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) *ocspStatus { - logger.Infof("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject) + logger.WithContext(ctx).Infof("Subject: %v, Issuer: %v", subject.Subject, issuer.Subject) status, ocspReq, encodedCertID := validateWithCache(subject, issuer) if isValidOCSPStatus(status.code) { @@ -567,8 +567,8 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) if ocspReq == nil || encodedCertID == nil { return status } - logger.Infof("cache missed") - logger.Infof("OCSP Server: %v", subject.OCSPServer) + logger.WithContext(ctx).Infof("cache missed") + logger.WithContext(ctx).Infof("OCSP Server: %v", subject.OCSPServer) if len(subject.OCSPServer) == 0 || isTestNoOCSPURL() { return &ocspStatus{ code: ocspNoServer, @@ -607,8 +607,8 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate) } } - logger.Debugf("Fetching OCSP response from server: %v", u) - logger.Debugf("Host in headers: %v", hostname) + logger.WithContext(ctx).Debugf("Fetching OCSP response from server: %v", u) + logger.WithContext(ctx).Debugf("Host in headers: %v", hostname) headers := make(map[string]string) headers[httpHeaderContentType] = "application/ocsp-request" diff --git a/restful.go b/restful.go index 9b69c6700..efc3fb8af 100644 --- a/restful.go +++ b/restful.go @@ -233,7 +233,7 @@ func postRestfulQueryHelper( requestID UUID, cfg *Config) ( data *execResponse, err error) { - logger.Infof("params: %v", params) + logger.WithContext(ctx).Infof("params: %v", params) params.Add(requestIDKey, requestID.String()) params.Add(requestGUIDKey, NewUUID().String()) token, _, _ := sr.TokenAccessor.GetTokens() @@ -281,7 +281,7 @@ func postRestfulQueryHelper( fullURL = sr.getFullURL(respd.Data.GetResultURL, nil) } - logger.Info("ping pong") + logger.WithContext(ctx).Info("ping pong") token, _, _ = sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) diff --git a/retry.go b/retry.go index 80ceb56b2..59d2fdf60 100644 --- a/retry.go +++ b/retry.go @@ -306,7 +306,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { var retryReasonUpdater retryReasonUpdater for { - logger.Debugf("retry count: %v", retryCounter) + logger.WithContext(r.ctx).Debugf("retry count: %v", retryCounter) body, err := r.bodyCreator() if err != nil { return nil, err diff --git a/rows.go b/rows.go index 6637f5ae9..b70e4baeb 100644 --- a/rows.go +++ b/rows.go @@ -137,7 +137,7 @@ func (rows *snowflakeRows) Columns() []string { if err := rows.waitForAsyncQueryStatus(); err != nil { return make([]string, 0) } - logger.Debug("Rows.Columns") + logger.WithContext(rows.ctx).Debug("Rows.Columns") ret := make([]string, len(rows.ChunkDownloader.getRowType())) for i, n := 0, len(rows.ChunkDownloader.getRowType()); i < n; i++ { ret[i] = rows.ChunkDownloader.getRowType()[i].Name