Skip to content

Commit

Permalink
Add options lambdaurl.WithDetectContentType and lambda.WithContextVal…
Browse files Browse the repository at this point in the history
…ue (#516)
  • Loading branch information
bmoffatt authored Dec 1, 2023
1 parent 1dca084 commit 752114b
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ jobs:
name: run tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go:
- "1.21"
Expand Down
22 changes: 22 additions & 0 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Handler interface {
type handlerOptions struct {
handlerFunc
baseContext context.Context
contextValues map[interface{}]interface{}
jsonRequestUseNumber bool
jsonRequestDisallowUnknownFields bool
jsonResponseEscapeHTML bool
Expand Down Expand Up @@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option {
})
}

// WithContextValue adds a value to the handler context.
// If a base context was set using WithContext, that base is used as the parent.
//
// Usage:
//
// lambda.StartWithOptions(
// func (ctx context.Context) (string, error) {
// return ctx.Value("foo"), nil
// },
// lambda.WithContextValue("foo", "bar")
// )
func WithContextValue(key interface{}, value interface{}) Option {
return Option(func(h *handlerOptions) {
h.contextValues[key] = value
})
}

// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
//
// Usage:
Expand Down Expand Up @@ -211,13 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
}
h := &handlerOptions{
baseContext: context.Background(),
contextValues: map[interface{}]interface{}{},
jsonResponseEscapeHTML: false,
jsonResponseIndentPrefix: "",
jsonResponseIndentValue: "",
}
for _, option := range options {
option(h)
}
for k, v := range h.contextValues {
h.baseContext = context.WithValue(h.baseContext, k, v)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
Expand Down
12 changes: 7 additions & 5 deletions lambda/sigterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -17,10 +18,6 @@ import (
"github.com/stretchr/testify/require"
)

const (
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
)

func TestEnableSigterm(t *testing.T) {
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
Expand All @@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) {
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

portI := 0
for name, opts := range map[string]struct {
envVars []string
assertLogs func(t *testing.T, logs string)
Expand All @@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
portI += 1
addr1 := "localhost:" + strconv.Itoa(8000+portI)
addr2 := "localhost:" + strconv.Itoa(9000+portI)
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler")
cmd.Env = append([]string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
Expand Down
91 changes: 76 additions & 15 deletions lambdaurl/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,76 @@ import (
"github.com/aws/aws-lambda-go/lambda"
)

type detectContentTypeContextKey struct{}

// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided.
// When true, the first Write call will pass the intial bytes to http.DetectContentType.
// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda,
// and the Lambda Function URL will fallback to it's default.
//
// Note: The http.ResponseWriter passed to the handler is unbuffered.
// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe.
//
// Usage:
//
// lambdaurl.Start(
// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
// w.Write("<!DOCTYPE html><html></html>")
// }),
// lambdaurl.WithDetectContentType(true)
// )
func WithDetectContentType(detectContentType bool) lambda.Option {
return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType)
}

type httpResponseWriter struct {
detectContentType bool
header http.Header
writer io.Writer
once sync.Once
ready chan<- header
}

type header struct {
code int
header http.Header
writer io.Writer
once sync.Once
status chan<- int
}

func (w *httpResponseWriter) Header() http.Header {
if w.header == nil {
w.header = http.Header{}
}
return w.header
}

func (w *httpResponseWriter) Write(p []byte) (int, error) {
w.once.Do(func() { w.status <- http.StatusOK })
w.writeHeader(http.StatusOK, p)
return w.writer.Write(p)
}

func (w *httpResponseWriter) WriteHeader(statusCode int) {
w.once.Do(func() { w.status <- statusCode })
w.writeHeader(statusCode, nil)
}

func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) {
w.once.Do(func() {
if w.detectContentType {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", detectContentType(initialPayload))
}
}
w.ready <- header{code: statusCode, header: w.header}
})
}

func detectContentType(p []byte) string {
// http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices.
// This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that.
// This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty.
if len(p) == 0 {
return "application/octet-stream"
}
return http.DetectContentType(p)
}

type requestContextKey struct{}
Expand All @@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest,
return req, ok
}

// Wrap converts an http.Handler into a lambda request handler.
// Wrap converts an http.Handler into a Lambda request handler.
//
// Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler.
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`.
func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {

var body io.Reader = strings.NewReader(request.Body)
if request.IsBase64Encoded {
body = base64.NewDecoder(base64.StdEncoding, body)
Expand All @@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR
for k, v := range request.Headers {
httpRequest.Header.Add(k, v)
}
status := make(chan int) // Signals when it's OK to start returning the response body to Lambda
header := http.Header{}

ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda
r, w := io.Pipe()
responseWriter := &httpResponseWriter{writer: w, ready: ready}
if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok {
responseWriter.detectContentType = detectContentType
}
go func() {
defer close(status)
defer close(ready)
defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader
handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest)
//nolint:errcheck
defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler
handler.ServeHTTP(responseWriter, httpRequest)
}()
header := <-ready
response := &events.LambdaFunctionURLStreamingResponse{
Body: r,
StatusCode: <-status,
StatusCode: header.code,
}
if len(header) > 0 {
response.Headers = make(map[string]string, len(header))
for k, v := range header {
if len(header.header) > 0 {
response.Headers = make(map[string]string, len(header.header))
for k, v := range header.header {
if k == "Set-Cookie" {
response.Cookies = v
} else {
Expand Down
Loading

0 comments on commit 752114b

Please sign in to comment.