Skip to content

Commit

Permalink
lib: add HTTPWithContext cel.EnvOption
Browse files Browse the repository at this point in the history
This passes the provided context.Context to all network requests. This
is an addition to the HEAD, GET and POST methods and a completion of the
intended behaviour for do_request.
  • Loading branch information
efd6 committed Dec 7, 2022
1 parent fa5b7a0 commit 2f0f287
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions lib/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ import (
// line=25&page=2"
//
func HTTP(client *http.Client, limit *rate.Limiter) cel.EnvOption {
return HTTPWithContext(context.Background(), client, limit)
}

// HTTP returns a cel.EnvOption to configure extended functions for HTTP
// requests that include a context.Context in network requests.
func HTTPWithContext(ctx context.Context, client *http.Client, limit *rate.Limiter) cel.EnvOption {
if client == nil {
client = http.DefaultClient
}
Expand All @@ -260,12 +266,14 @@ func HTTP(client *http.Client, limit *rate.Limiter) cel.EnvOption {
return cel.Lib(httpLib{
client: client,
limit: limit,
ctx: ctx,
})
}

type httpLib struct {
client *http.Client
limit *rate.Limiter
ctx context.Context
}

func (httpLib) CompileOptions() []cel.EnvOption {
Expand Down Expand Up @@ -468,7 +476,7 @@ func (l httpLib) doHead(arg ref.Val) ref.Val {
if err != nil {
return types.NewErr("%s", err)
}
resp, err := l.client.Head(string(url))
resp, err := l.head(url)
if err != nil {
return types.NewErr("%s", err)
}
Expand All @@ -479,6 +487,14 @@ func (l httpLib) doHead(arg ref.Val) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(rm)
}

func (l httpLib) head(url types.String) (*http.Response, error) {
req, err := http.NewRequestWithContext(l.ctx, http.MethodHead, string(url), nil)
if err != nil {
return nil, err
}
return l.client.Do(req)
}

func (l httpLib) doGet(arg ref.Val) ref.Val {
url, ok := arg.(types.String)
if !ok {
Expand All @@ -488,7 +504,7 @@ func (l httpLib) doGet(arg ref.Val) ref.Val {
if err != nil {
return types.NewErr("%s", err)
}
resp, err := l.client.Get(string(url))
resp, err := l.get(url)
if err != nil {
return types.NewErr("%s", err)
}
Expand All @@ -499,6 +515,14 @@ func (l httpLib) doGet(arg ref.Val) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(rm)
}

func (l httpLib) get(url types.String) (*http.Response, error) {
req, err := http.NewRequestWithContext(l.ctx, http.MethodGet, string(url), nil)
if err != nil {
return nil, err
}
return l.client.Do(req)
}

func newGetRequest(url ref.Val) ref.Val {
return newRequestBody(types.String("GET"), url)
}
Expand Down Expand Up @@ -532,7 +556,7 @@ func (l httpLib) doPost(args ...ref.Val) ref.Val {
if err != nil {
return types.NewErr("%s", err)
}
resp, err := l.client.Post(string(url), string(content), body)
resp, err := l.post(url, content, body)
if err != nil {
return types.NewErr("%s", err)
}
Expand All @@ -543,6 +567,15 @@ func (l httpLib) doPost(args ...ref.Val) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(rm)
}

func (l httpLib) post(url, content types.String, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, string(url), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", string(content))
return l.client.Do(req)
}

func newPostRequest(args ...ref.Val) ref.Val {
if len(args) != 3 {
return types.NewErr("no such overload for post request")
Expand Down Expand Up @@ -703,8 +736,8 @@ func (l httpLib) doRequest(arg ref.Val) ref.Val {
return types.NewErr("%s", err)
}
// Recover the context lost during serialisation to JSON.
req = req.WithContext(context.Background())
err = l.limit.Wait(context.TODO())
req = req.WithContext(l.ctx)
err = l.limit.Wait(l.ctx)
if err != nil {
return types.NewErr("%s", err)
}
Expand Down

0 comments on commit 2f0f287

Please sign in to comment.