Skip to content

Commit

Permalink
Fixed HTTP response not adjusted based on request
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaokangwang committed Jun 3, 2020
1 parent 38e89bd commit 087a62e
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 45 deletions.
122 changes: 83 additions & 39 deletions transport/internet/headers/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
//go:generate errorgen

import (
"bufio"
"bytes"
"context"
"io"
Expand All @@ -28,6 +29,8 @@ const (

var (
ErrHeaderToLong = newError("Header too long.")

ErrHeaderMisMatch = newError("Header Mismatch.")
)

type Reader interface {
Expand All @@ -51,39 +54,83 @@ func (NoOpWriter) Write(io.Writer) error {
}

type HeaderReader struct {
req *http.Request
expectedHeader *RequestConfig
}

func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader {
h.expectedHeader = expectedHeader
return h
}

func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
buffer := buf.New()
totalBytes := int32(0)
endingDetected := false

var headerBuf bytes.Buffer

for totalBytes < maxHeaderLength {
_, err := buffer.ReadFrom(reader)
if err != nil {
buffer.Release()
return nil, err
}
if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 {
headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING))))
buffer.Advance(int32(n + len(ENDING)))
endingDetected = true
break
}
lenEnding := int32(len(ENDING))
if buffer.Len() >= lenEnding {
totalBytes += buffer.Len() - lenEnding
headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding))
leftover := buffer.BytesFrom(-lenEnding)
buffer.Clear()
copy(buffer.Extend(lenEnding), leftover)
}
}
if buffer.IsEmpty() {
buffer.Release()
return nil, nil
}

if !endingDetected {
buffer.Release()
return nil, ErrHeaderToLong
}

if h.expectedHeader == nil {
if buffer.IsEmpty() {
buffer.Release()
return nil, nil
}
return buffer, nil
}

//Parse the request

if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes())), false); err != nil {
return nil, err
} else {
h.req = req
}

//Check req
path := h.req.URL.Path
hasThisUri := false
for _, u := range h.expectedHeader.Uri {
if u == path {
hasThisUri = true
}
}

if hasThisUri == false {
return nil, ErrHeaderMisMatch
}

if buffer.IsEmpty() {
buffer.Release()
return nil, nil
}

return buffer, nil
}

Expand All @@ -110,25 +157,32 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
type HttpConn struct {
net.Conn

readBuffer *buf.Buffer
oneTimeReader Reader
oneTimeWriter Writer
errorWriter Writer
readBuffer *buf.Buffer
oneTimeReader Reader
oneTimeWriter Writer
errorWriter Writer
errorMismatchWriter Writer
errorTooLongWriter Writer

errReason error
}

func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer) *HttpConn {
func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *HttpConn {
return &HttpConn{
Conn: conn,
oneTimeReader: reader,
oneTimeWriter: writer,
errorWriter: errorWriter,
Conn: conn,
oneTimeReader: reader,
oneTimeWriter: writer,
errorWriter: errorWriter,
errorMismatchWriter: errorMismatchWriter,
errorTooLongWriter: errorTooLongWriter,
}
}

func (c *HttpConn) Read(b []byte) (int, error) {
if c.oneTimeReader != nil {
buffer, err := c.oneTimeReader.Read(c.Conn)
if err != nil {
c.errReason = err
return 0, err
}
c.readBuffer = buffer
Expand Down Expand Up @@ -165,7 +219,16 @@ func (c *HttpConn) Close() error {
if c.oneTimeWriter != nil && c.errorWriter != nil {
// Connection is being closed but header wasn't sent. This means the client request
// is probably not valid. Sending back a server error header in this case.
c.errorWriter.Write(c.Conn)

//Write response based on error reason

if c.errReason == ErrHeaderMisMatch {
c.errorMismatchWriter.Write(c.Conn)
} else if c.errReason == ErrHeaderToLong {
c.errorTooLongWriter.Write(c.Conn)
} else {
c.errorWriter.Write(c.Conn)
}
}

return c.Conn.Close()
Expand Down Expand Up @@ -230,36 +293,17 @@ func (a HttpAuthenticator) Client(conn net.Conn) net.Conn {
if a.config.Response != nil {
writer = a.GetClientWriter()
}
return NewHttpConn(conn, reader, writer, NoOpWriter{})
return NewHttpConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{})
}

func (a HttpAuthenticator) Server(conn net.Conn) net.Conn {
if a.config.Request == nil && a.config.Response == nil {
return conn
}
return NewHttpConn(conn, new(HeaderReader), a.GetServerWriter(), formResponseHeader(&ResponseConfig{
Version: &Version{
Value: "1.1",
},
Status: &Status{
Code: "500",
Reason: "Internal Server Error",
},
Header: []*Header{
{
Name: "Connection",
Value: []string{"close"},
},
{
Name: "Cache-Control",
Value: []string{"private"},
},
{
Name: "Content-Length",
Value: []string{"0"},
},
},
}))
return NewHttpConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(),
formResponseHeader(resp400),
formResponseHeader(resp404),
formResponseHeader(resp400))
}

func NewHttpAuthenticator(ctx context.Context, config *Config) (HttpAuthenticator, error) {
Expand Down
Loading

0 comments on commit 087a62e

Please sign in to comment.