Skip to content

Commit

Permalink
Allow caller to set log fields (#1094)
Browse files Browse the repository at this point in the history
Co-authored-by: Piotr Fus <[email protected]>
  • Loading branch information
madisonchamberlain and sfc-gh-pfus authored Apr 19, 2024
1 parent 088150c commit 9207458
Show file tree
Hide file tree
Showing 19 changed files with 208 additions and 78 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,41 @@ Install [jq](https://stedolan.github.io/jq) so that the parameters can get parse
make test
```

## customizing Logging Tags

If you would like to ensure that certain tags are always present in the logs, `RegisterClientLogContextHook` can be used in your init function. See example below.
```
import "github.com/snowflakedb/gosnowflake"
func init() {
// each time the logger is used, the logs will contain a REQUEST_ID field with requestID the value extracted
// from the context
gosnowflake.RegisterClientLogContextHook("REQUEST_ID", func(ctx context.Context) interface{} {
return requestIdFromContext(ctx)
})
}
```

## Setting Log Level
If you want to change the log level, `SetLogLevel` can be used in your init function like this:
```
import "github.com/snowflakedb/gosnowflake"
func init() {
// The following line changes the log level to debug
_ = gosnowflake.GetLogger().SetLogLevel("debug")
}
```
The following is a list of options you can pass in to set the level from least to most verbose:
- `"OFF"`
- `"error"`
- `"warn"`
- `"print"`
- `"trace"`
- `"debug"`
- `"info"`


## Capturing Code Coverage

Configure your testing environment as described above and run ``make cov``. The coverage percentage will be printed on the console when the testing completes.
Expand Down
24 changes: 12 additions & 12 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func postAuth(
params.Add(requestGUIDKey, NewUUID().String())

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
logger.WithContext(ctx).Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount)
if err != nil {
return nil, err
Expand All @@ -235,7 +235,7 @@ func postAuth(
var respd authResponse
err = json.NewDecoder(resp.Body).Decode(&respd)
if err != nil {
logger.Errorf("failed to decode JSON. err: %v", err)
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
return nil, err
}
return &respd, nil
Expand All @@ -260,11 +260,11 @@ func postAuth(
}
b, err := io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("failed to extract HTTP response body. err: %v", err)
logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
return nil, err
}
logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
logger.Infof("Header: %v", resp.Header)
logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, fullURL, b)
logger.WithContext(ctx).Infof("Header: %v", resp.Header)
return nil, &SnowflakeError{
Number: ErrFailedToAuth,
SQLState: SQLStateConnectionRejected,
Expand Down Expand Up @@ -293,7 +293,7 @@ func authenticate(
proofKey []byte,
) (resp *authResponseMain, err error) {
if sc.cfg.Authenticator == AuthTypeTokenAccessor {
logger.Info("Bypass authentication using existing token from token accessor")
logger.WithContext(ctx).Info("Bypass authentication using existing token from token accessor")
sessionInfo := authResponseSessionInfo{
DatabaseName: sc.cfg.Database,
SchemaName: sc.cfg.Schema,
Expand Down Expand Up @@ -350,15 +350,15 @@ func authenticate(
params.Add("roleName", sc.cfg.Role)
}

logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v",
logger.WithContext(ctx).WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v",
params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String())

respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout)
if err != nil {
return nil, err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
logger.WithContext(ctx).Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
deleteCredential(sc, mfaToken)
Expand All @@ -377,7 +377,7 @@ func authenticate(
Message: respd.Message,
}).exceptionTelemetry(sc)
}
logger.Info("Authentication SUCCESS")
logger.WithContext(ctx).Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
Expand Down Expand Up @@ -439,7 +439,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
}
requestMain.Token = jwtTokenString
case AuthTypeSnowflake:
logger.Info("Username and password")
logger.WithContext(sc.ctx).Info("Username and password")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
switch {
Expand All @@ -450,7 +450,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeUsernamePasswordMFA:
logger.Info("Username and password MFA")
logger.WithContext(sc.ctx).Info("Username and password MFA")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
if sc.cfg.MfaToken != "" {
Expand Down Expand Up @@ -527,7 +527,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
}
}

logger.Infof("Authenticating via %v", sc.cfg.Authenticator.String())
logger.WithContext(sc.ctx).Infof("Authenticating via %v", sc.cfg.Authenticator.String())
switch sc.cfg.Authenticator {
case AuthTypeExternalBrowser:
if sc.cfg.IDToken == "" {
Expand Down
4 changes: 2 additions & 2 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func getIdpURLProofKey(
return "", "", err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
logger.WithContext(ctx).Errorln("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
Expand Down Expand Up @@ -287,7 +287,7 @@ func doAuthenticateByExternalBrowser(
n, err := c.Read(b)
if err != nil {
if err != io.EOF {
logger.Infof("error reading from socket. err: %v", err)
logger.WithContext(ctx).Infof("error reading from socket. err: %v", err)
errAccept = &SnowflakeError{
Number: ErrFailedToGetExternalBrowserResponse,
SQLState: SQLStateConnectionRejected,
Expand Down
8 changes: 4 additions & 4 deletions authokta.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func authenticateBySAML(
return nil, err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
logger.WithContext(ctx).Errorln("Authentication FAILED")
sr.TokenAccessor.SetTokens("", "", -1)
code, err := strconv.Atoi(respd.Code)
if err != nil {
Expand Down Expand Up @@ -215,7 +215,7 @@ func postAuthSAML(
params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
fullURL := sr.getFullURL(authenticatorRequestPath, params)

logger.Infof("fullURL: %v", fullURL)
logger.WithContext(ctx).Infof("fullURL: %v", fullURL)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -269,7 +269,7 @@ func postAuthOKTA(
fullURL string,
timeout time.Duration) (
data *authOKTAResponse, err error) {
logger.Infof("fullURL: %v", fullURL)
logger.WithContext(ctx).Infof("fullURL: %v", fullURL)
targetURL, err := url.Parse(fullURL)
if err != nil {
return nil, err
Expand All @@ -290,7 +290,7 @@ func postAuthOKTA(
}
_, err = io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("failed to extract HTTP response body. err: %v", err)
logger.WithContext(ctx).Errorf("failed to extract HTTP response body. err: %v", err)
return nil, err
}
logger.WithContext(ctx).Infof("HTTP: %v, URL: %v", resp.StatusCode, fullURL)
Expand Down
2 changes: 1 addition & 1 deletion bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (bu *bindUploader) createCSVRecord(data []interface{}) []byte {
if ok {
b.WriteString(escapeForCSV(value))
} else if !reflect.ValueOf(data[i]).IsNil() {
logger.Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i])
logger.WithContext(bu.ctx).Debugf("Cannot convert value to string in createCSVRecord. value: %v", data[i])
}
}
b.WriteString("\n")
Expand Down
38 changes: 19 additions & 19 deletions chunk_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ func (scd *snowflakeChunkDownloader) start() error {
// start downloading chunks if exists
chunkMetaLen := len(scd.ChunkMetas)
if chunkMetaLen > 0 {
logger.Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers)
logger.Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize())
logger.WithContext(scd.ctx).Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers)
logger.WithContext(scd.ctx).Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize())
scd.ChunksMutex = &sync.Mutex{}
scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex)
scd.Chunks = make(map[int][]chunkRowType)
scd.ChunksChan = make(chan int, chunkMetaLen)
scd.ChunksError = make(chan *chunkError, MaxChunkDownloadWorkers)
for i := 0; i < chunkMetaLen; i++ {
chunk := scd.ChunkMetas[i]
logger.Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v",
logger.WithContext(scd.ctx).Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v",
i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat)
scd.ChunksChan <- i
}
Expand All @@ -147,11 +147,11 @@ func (scd *snowflakeChunkDownloader) start() error {
func (scd *snowflakeChunkDownloader) schedule() {
select {
case nextIdx := <-scd.ChunksChan:
logger.Infof("schedule chunk: %v", nextIdx+1)
logger.WithContext(scd.ctx).Infof("schedule chunk: %v", nextIdx+1)
go scd.FuncDownload(scd.ctx, scd, nextIdx)
default:
// no more download
logger.Info("no more download")
logger.WithContext(scd.ctx).Info("no more download")
}
}

Expand All @@ -164,15 +164,15 @@ func (scd *snowflakeChunkDownloader) checkErrorRetry() (err error) {
// add the index to the chunks channel so that the download will be retried.
go scd.FuncDownload(scd.ctx, scd, errc.Index)
scd.ChunksErrorCounter++
logger.Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...",
logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...",
errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter)
} else {
scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc)
logger.Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error)
logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error)
return errc.Error
}
default:
logger.Info("no error is detected.")
logger.WithContext(scd.ctx).Info("no error is detected.")
}
return nil
}
Expand All @@ -195,7 +195,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) {
}

for scd.Chunks[scd.CurrentChunkIndex] == nil {
logger.Debugf("waiting for chunk idx: %v/%v",
logger.WithContext(scd.ctx).Debugf("waiting for chunk idx: %v/%v",
scd.CurrentChunkIndex+1, len(scd.ChunkMetas))

if err := scd.checkErrorRetry(); err != nil {
Expand All @@ -207,7 +207,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) {
// 1) one chunk download finishes or 2) an error occurs.
scd.DoneDownloadCond.Wait()
}
logger.Debugf("ready: chunk %v", scd.CurrentChunkIndex+1)
logger.WithContext(scd.ctx).Debugf("ready: chunk %v", scd.CurrentChunkIndex+1)
scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex]
scd.ChunksMutex.Unlock()
scd.CurrentChunkSize = len(scd.CurrentChunk)
Expand All @@ -216,7 +216,7 @@ func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) {
scd.schedule()
}

logger.Debugf("no more data")
logger.WithContext(scd.ctx).Debugf("no more data")
if len(scd.ChunkMetas) > 0 {
close(scd.ChunksError)
close(scd.ChunksChan)
Expand Down Expand Up @@ -342,11 +342,11 @@ func (r *largeResultSetReader) Read(p []byte) (n int, err error) {
}

func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) {
logger.Infof("download start chunk: %v", idx+1)
logger.WithContext(ctx).Infof("download start chunk: %v", idx+1)
defer scd.DoneDownloadCond.Broadcast()

if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil {
logger.Errorf(
logger.WithContext(ctx).Errorf(
"failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err)
scd.ChunksError <- &chunkError{Index: idx, Error: err}
} else if scd.ctx.Err() == context.Canceled || scd.ctx.Err() == context.DeadlineExceeded {
Expand All @@ -357,9 +357,9 @@ func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int)
func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error {
headers := make(map[string]string)
if len(scd.ChunkHeader) > 0 {
logger.Debug("chunk header is provided.")
logger.WithContext(ctx).Debug("chunk header is provided.")
for k, v := range scd.ChunkHeader {
logger.Debugf("adding header: %v, value: %v", k, v)
logger.WithContext(ctx).Debugf("adding header: %v, value: %v", k, v)

headers[k] = v
}
Expand All @@ -374,14 +374,14 @@ func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx
}
bufStream := bufio.NewReader(resp.Body)
defer resp.Body.Close()
logger.Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL)
logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL)
if resp.StatusCode != http.StatusOK {
b, err := io.ReadAll(bufStream)
if err != nil {
return err
}
logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b)
logger.Infof("Header: %v", resp.Header)
logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b)
logger.WithContext(ctx).Infof("Header: %v", resp.Header)
return &SnowflakeError{
Number: ErrFailedToGetChunk,
SQLState: SQLStateConnectionFailure,
Expand Down Expand Up @@ -463,7 +463,7 @@ func decodeChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int, bu
return err
}
}
logger.Debugf(
logger.WithContext(scd.ctx).Debugf(
"decoded %d rows w/ %d bytes in %s (chunk %v)",
scd.ChunkMetas[idx].RowCount,
scd.ChunkMetas[idx].UncompressedSize,
Expand Down
16 changes: 8 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (sc *snowflakeConn) exec(

queryContext, err := buildQueryContext(sc.queryContextCache)
if err != nil {
logger.Errorf("error while building query context: %v", err)
logger.WithContext(ctx).Errorf("error while building query context: %v", err)
}
req := execRequest{
SQLText: query,
Expand Down Expand Up @@ -163,7 +163,7 @@ func (sc *snowflakeConn) exec(
if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil {
queryContext, err := extractQueryContext(data)
if err != nil {
logger.Errorf("error while decoding query context: ", err)
logger.WithContext(ctx).Errorf("error while decoding query context: %v", err)
} else {
sc.queryContextCache.add(sc, queryContext.Entries...)
}
Expand Down Expand Up @@ -272,7 +272,7 @@ func (sc *snowflakeConn) Close() (err error) {

if sc.cfg != nil && !sc.cfg.KeepSessionAlive {
if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil {
logger.Error(err)
logger.WithContext(sc.ctx).Error(err)
}
}
return nil
Expand Down Expand Up @@ -350,7 +350,7 @@ func (sc *snowflakeConn) ExecContext(
}
return driver.ResultNoRows, nil
}
logger.Debug("DDL")
logger.WithContext(ctx).Debug("DDL")
if isStatementContext(ctx) {
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
}
Expand Down Expand Up @@ -571,7 +571,7 @@ func (w *wrapReader) Close() error {
func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error {
headers := make(map[string]string)
if len(asb.scd.ChunkHeader) > 0 {
logger.Debug("chunk header is provided")
logger.WithContext(ctx).Debug("chunk header is provided")
for k, v := range asb.scd.ChunkHeader {
logger.Debugf("adding header: %v, value: %v", k, v)

Expand All @@ -586,16 +586,16 @@ func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) erro
if err != nil {
return err
}
logger.Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL)
logger.WithContext(ctx).Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL)
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b)
logger.Infof("Header: %v", resp.Header)
logger.WithContext(ctx).Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b)
logger.WithContext(ctx).Infof("Header: %v", resp.Header)
return &SnowflakeError{
Number: ErrFailedToGetChunk,
SQLState: SQLStateConnectionFailure,
Expand Down
Loading

0 comments on commit 9207458

Please sign in to comment.