From 207a38df70d69be52b461ac124b818d81601c601 Mon Sep 17 00:00:00 2001 From: "M. Ryan Rigdon" Date: Wed, 10 Jan 2024 09:19:09 -0500 Subject: [PATCH] Fix issue in AcceptedError handling for UploadSarif (#3047) Fixes: #3036. --- github/code-scanning.go | 18 +++++++++++++++--- github/code-scanning_test.go | 14 ++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/github/code-scanning.go b/github/code-scanning.go index 74a7b6c9b47..a8fca98a92d 100644 --- a/github/code-scanning.go +++ b/github/code-scanning.go @@ -7,6 +7,8 @@ package github import ( "context" + "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -389,11 +391,21 @@ func (s *CodeScanningService) UploadSarif(ctx context.Context, owner, repo strin return nil, nil, err } - sarifID := new(SarifID) - resp, err := s.client.Do(ctx, req, sarifID) - if err != nil { + // This will always return an error without unmarshalling the data + resp, err := s.client.Do(ctx, req, nil) + // Even though there was an error, we still return the response + // in case the caller wants to inspect it further. + // However, if the error is AcceptedError, decode it below before + // returning from this function and closing the response body. + var acceptedError *AcceptedError + if !errors.As(err, &acceptedError) { return nil, resp, err } + sarifID := new(SarifID) + decErr := json.Unmarshal(acceptedError.Raw, sarifID) + if decErr != nil { + return nil, resp, decErr + } return sarifID, resp, nil } diff --git a/github/code-scanning_test.go b/github/code-scanning_test.go index 4fbe45e7813..a081a6ba5f1 100644 --- a/github/code-scanning_test.go +++ b/github/code-scanning_test.go @@ -58,6 +58,11 @@ func TestCodeScanningService_UploadSarif(t *testing.T) { client, mux, _, teardown := setup() defer teardown() + expectedSarifID := &SarifID{ + ID: String("testid"), + URL: String("https://example.com/testurl"), + } + mux.HandleFunc("/repos/o/r/code-scanning/sarifs", func(w http.ResponseWriter, r *http.Request) { v := new(SarifAnalysis) assertNilError(t, json.NewDecoder(r.Body).Decode(v)) @@ -67,15 +72,20 @@ func TestCodeScanningService_UploadSarif(t *testing.T) { t.Errorf("Request body = %+v, want %+v", v, want) } - fmt.Fprint(w, `{"commit_sha":"abc","ref":"ref/head/main","sarif":"abc"}`) + w.WriteHeader(http.StatusAccepted) + respBody, _ := json.Marshal(expectedSarifID) + _, _ = w.Write(respBody) }) ctx := context.Background() sarifAnalysis := &SarifAnalysis{CommitSHA: String("abc"), Ref: String("ref/head/main"), Sarif: String("abc"), CheckoutURI: String("uri"), StartedAt: &Timestamp{time.Date(2006, time.January, 02, 15, 04, 05, 0, time.UTC)}, ToolName: String("codeql-cli")} - _, _, err := client.CodeScanning.UploadSarif(ctx, "o", "r", sarifAnalysis) + respSarifID, _, err := client.CodeScanning.UploadSarif(ctx, "o", "r", sarifAnalysis) if err != nil { t.Errorf("CodeScanning.UploadSarif returned error: %v", err) } + if !cmp.Equal(expectedSarifID, respSarifID) { + t.Errorf("Sarif response = %+v, want %+v", respSarifID, expectedSarifID) + } const methodName = "UploadSarif" testBadOptions(t, methodName, func() (err error) {