diff --git a/http/agent.go b/http/agent.go index 674f081..c2dd76f 100644 --- a/http/agent.go +++ b/http/agent.go @@ -121,7 +121,7 @@ func (a *Agent) Get(url string) (content []byte, err error) { } defer request.Body.Close() - return a.readResponse(request) + return a.readResponseToByteArray(request) } // GetRequest sends a GET request to a URL and returns the request and response @@ -156,7 +156,7 @@ func (a *Agent) Post(url string, postData []byte) (content []byte, err error) { } defer response.Body.Close() - return a.readResponse(response) + return a.readResponseToByteArray(response) } // PostRequest sends the postData in a POST request to a URL and returns the request object @@ -209,26 +209,54 @@ func (impl *defaultAgentImplementation) SendGetRequest(client *http.Client, url return response, nil } -// readResponse read an dinterpret the http request -func (a *Agent) readResponse(response *http.Response) (body []byte, err error) { +// readResponseToByteArray returns the contents of an http response as a byte array +func (a *Agent) readResponseToByteArray(response *http.Response) ([]byte, error) { + var b bytes.Buffer + if err := a.readResponse(response, &b); err != nil { + return nil, fmt.Errorf("reading") + } + return b.Bytes(), nil +} + +// readResponse reads and interprets the response to an HTTP request to an io.Writer. +// If the response status is < 200 or >= 300 and FailOnHTTPError is set, the function +// will return an error. +// +// This function will close the response body reader. +func (a *Agent) readResponse(response *http.Response, w io.Writer) (err error) { // Read the response body defer response.Body.Close() - body, err = io.ReadAll(response.Body) - if err != nil { - return nil, fmt.Errorf( - "reading the response body from %s: %w", - response.Request.URL, err, - ) + if _, err := io.Copy(w, response.Body); err != nil { + return fmt.Errorf("reading response: %w", err) } // Check the https response code if response.StatusCode < 200 || response.StatusCode >= 300 { if a.options.FailOnHTTPError { - return nil, fmt.Errorf( + return fmt.Errorf( "HTTP error %s for %s", response.Status, response.Request.URL, ) } logrus.Warnf("Got HTTP error but FailOnHTTPError not set: %s", response.Status) } - return body, err + return err +} + +// GetToWriter sends a get request and writes the response to an io.Writer +func (a *Agent) GetToWriter(w io.Writer, url string) error { + resp, err := a.AgentImplementation.SendGetRequest(a.Client(), url) + if err != nil { + return fmt.Errorf("sending GET request: %w", err) + } + + return a.readResponse(resp, w) +} + +// PostToWriter sends a request to a url and writes the response to an io.Writer +func (a *Agent) PostToWriter(w io.Writer, url string, postData []byte) error { + resp, err := a.AgentImplementation.SendPostRequest(a.Client(), url, postData, a.options.PostContentType) + if err != nil { + return fmt.Errorf("sending POST request: %w", err) + } + return a.readResponse(resp, w) } diff --git a/http/http_test.go b/http/http_test.go index ddd9224..778f532 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -88,15 +88,7 @@ func NewTestAgent() *khttp.Agent { func TestAgentPost(t *testing.T) { agent := NewTestAgent() - - resp := &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte("hello sig-release!"))), - ContentLength: 18, - Close: true, - Request: &http.Request{}, - } + resp := getTestResponse() defer resp.Body.Close() // First simulate a successful request @@ -118,14 +110,7 @@ func TestAgentPost(t *testing.T) { func TestAgentGet(t *testing.T) { agent := NewTestAgent() - resp := &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte("hello sig-release!"))), - ContentLength: 18, - Close: true, - Request: &http.Request{}, - } + resp := getTestResponse() defer resp.Body.Close() // First simulate a successful request @@ -144,6 +129,98 @@ func TestAgentGet(t *testing.T) { require.NotNil(t, err) } +func TestAgentGetToWriter(t *testing.T) { + agent := NewTestAgent() + for _, tc := range []struct { + n string + prepare func(*httpfakes.FakeAgentImplementation, *http.Response) + mustErr bool + }{ + { + n: "success", + prepare: func(fake *httpfakes.FakeAgentImplementation, resp *http.Response) { + fake.SendGetRequestReturns(resp, nil) + }, + }, + { + n: "fail", + prepare: func(fake *httpfakes.FakeAgentImplementation, resp *http.Response) { + fake.SendGetRequestReturns(resp, errors.New("HTTP Post error")) + }, + mustErr: true, + }, + } { + t.Run(tc.n, func(t *testing.T) { + // First simulate a successful request + fake := &httpfakes.FakeAgentImplementation{} + resp := getTestResponse() + defer resp.Body.Close() + tc.prepare(fake, resp) + var buf bytes.Buffer + + agent.SetImplementation(fake) + err := agent.GetToWriter(&buf, "http://www.example.com/") + if tc.mustErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, buf.Bytes(), []byte("hello sig-release!")) + }) + } +} + +func getTestResponse() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte("hello sig-release!"))), + ContentLength: 18, + Close: true, + Request: &http.Request{}, + } +} + +func TestAgentPostToWriter(t *testing.T) { + for _, tc := range []struct { + n string + prepare func(*httpfakes.FakeAgentImplementation, *http.Response) + mustErr bool + }{ + { + n: "success", + prepare: func(fake *httpfakes.FakeAgentImplementation, resp *http.Response) { + fake.SendPostRequestReturns(resp, nil) + }, + }, + { + n: "fail", + prepare: func(fake *httpfakes.FakeAgentImplementation, resp *http.Response) { + fake.SendPostRequestReturns(resp, errors.New("HTTP Post error")) + }, + mustErr: true, + }, + } { + t.Run(tc.n, func(t *testing.T) { + agent := NewTestAgent() + // First simulate a successful request + fake := &httpfakes.FakeAgentImplementation{} + resp := getTestResponse() + defer resp.Body.Close() + tc.prepare(fake, resp) + var buf bytes.Buffer + agent.SetImplementation(fake) + err := agent.PostToWriter(&buf, "http://www.example.com/", []byte{}) + if tc.mustErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, buf.Bytes(), []byte("hello sig-release!")) + }) + } +} + func TestAgentOptions(t *testing.T) { agent := NewTestAgent() fake := &httpfakes.FakeAgentImplementation{}