Skip to content

Commit

Permalink
rpcenc: Test early close, add reader.MustRedirect
Browse files Browse the repository at this point in the history
  • Loading branch information
magik6k committed Jul 30, 2021
1 parent 555c402 commit e7470ed
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
79 changes: 61 additions & 18 deletions lib/rpcenc/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rpcenc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -60,8 +61,8 @@ var client = func() *http.Client {
be serialized as JSON, and sent as jsonrpc request parameter
3.1. If the reader is of type `*sealing.NullReader`, the resulting object
is `ReaderStream{ Type: "null", Info: "[base 10 number of bytes]" }`
3.2. If the reader is of type `*rpcReader`, and it wasn't read from, we
notify that rpcReader to go a different push endpoint, and return
3.2. If the reader is of type `*RpcReader`, and it wasn't read from, we
notify that RpcReader to go a different push endpoint, and return
a `ReaderStream` object like in 3.4.
3.3. In remaining cases we start a goroutine which:
3.3.1. Makes a HEAD request to the server push endpoint
Expand All @@ -73,13 +74,13 @@ var client = func() *http.Client {
`ReaderStream{ Type: "push", Info: "[UUID string]" }`
4. If the reader wasn't a NullReader, the server will receive a HEAD (or
POST in case of older clients) request to the push endpoint.
4.1. The server gets or registers an `*rpcReader` in the `readers` map.
4.1. The server gets or registers an `*RpcReader` in the `readers` map.
4.2. It waits for a request to a matching push endpoint to be opened
4.3. After the request is opened, it returns the `*rpcReader` to
4.3. After the request is opened, it returns the `*RpcReader` to
go-jsonrpc, which will pass it as the io.Reader parameter to the
rpc method implementation
4.4. If the first request made to the push endpoint was a POST, the
returned `*rpcReader` acts as a simple reader reading the POST
returned `*RpcReader` acts as a simple reader reading the POST
request body
4.5. If the first request made to the push endpoint was a HEAD
4.5.1. On the first call to Read or Close the server responds with
Expand Down Expand Up @@ -111,7 +112,7 @@ func ReaderParamEncoder(addr string) jsonrpc.Option {
}
u.Path = path.Join(u.Path, reqID.String())

rpcReader, redir := r.(*rpcReader)
rpcReader, redir := r.(*RpcReader)
if redir {
// if we have an rpc stream, redirect instead of proxying all the data
redir = rpcReader.redirect(u.String())
Expand Down Expand Up @@ -191,6 +192,7 @@ type resType int
const (
resStart resType = iota // send on first read after HEAD
resRedirect // send on redirect before first read after HEAD
resError
// done/closed = close res channel
)

Expand All @@ -199,22 +201,53 @@ type readRes struct {
meta string
}

// rpcReader watches the ReadCloser and closes the res channel when
// RpcReader watches the ReadCloser and closes the res channel when
// either: (1) the ReaderCloser fails on Read (including with a benign error
// like EOF), or (2) when Close is called.
//
// Use it be notified of terminal states, in situations where a Read failure (or
// EOF) is considered a terminal state too (besides Close).
type rpcReader struct {
postBody io.ReadCloser // nil on initial head request
next chan *rpcReader // on head will get us the postBody after sending resStart
type RpcReader struct {
postBody io.ReadCloser // nil on initial head request
next chan *RpcReader // on head will get us the postBody after sending resStart
mustRedirect bool

res chan readRes
beginOnce *sync.Once
closeOnce sync.Once
}

func (w *rpcReader) beginPost() {
var ErrHasBody = errors.New("RPCReader has body, either already read from or from a client with no redirect support")
var ErrMustRedirect = errors.New("reader can't be read directly; marked as MustRedirect")

// MustRedirect marks the reader as required to be redirected. Will make local
// calls Read fail. MUST be called before this reader is used in any goroutine.
// If the reader can't be redirected will return ErrHasBody
func (w *RpcReader) MustRedirect() error {
if w.postBody != nil {
w.closeOnce.Do(func() {
w.res <- readRes{
rt: resError,
}
close(w.res)
})

return ErrHasBody
}

w.mustRedirect = true
return nil
}

func (w *RpcReader) beginPost() {
if w.mustRedirect {
w.res <- readRes{
rt: resError,
}
close(w.res)
return
}

if w.postBody == nil {
w.res <- readRes{
rt: resStart,
Expand All @@ -228,11 +261,15 @@ func (w *rpcReader) beginPost() {
}
}

func (w *rpcReader) Read(p []byte) (int, error) {
func (w *RpcReader) Read(p []byte) (int, error) {
w.beginOnce.Do(func() {
w.beginPost()
})

if w.mustRedirect {
return 0, ErrMustRedirect
}

if w.postBody == nil {
return 0, xerrors.Errorf("reader already closed or redirected")
}
Expand All @@ -246,14 +283,18 @@ func (w *rpcReader) Read(p []byte) (int, error) {
return n, err
}

func (w *rpcReader) Close() error {
func (w *RpcReader) Close() error {
w.beginOnce.Do(func() {})
w.closeOnce.Do(func() {
close(w.res)
})
if w.postBody == nil {
return nil
}
return w.postBody.Close()
}

func (w *rpcReader) redirect(to string) bool {
func (w *RpcReader) redirect(to string) bool {
if w.postBody != nil {
return false
}
Expand All @@ -277,7 +318,7 @@ func (w *rpcReader) redirect(to string) bool {

func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
var readersLk sync.Mutex
readers := map[uuid.UUID]chan *rpcReader{}
readers := map[uuid.UUID]chan *RpcReader{}

// runs on the rpc server side, called by the client before making the jsonrpc request
hnd := func(resp http.ResponseWriter, req *http.Request) {
Expand All @@ -291,12 +332,12 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *rpcReader)
ch = make(chan *RpcReader)
readers[u] = ch
}
readersLk.Unlock()

wr := &rpcReader{
wr := &RpcReader{
res: make(chan readRes),
next: ch,
beginOnce: &sync.Once{},
Expand Down Expand Up @@ -341,6 +382,8 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
http.Redirect(resp, req, res.meta, http.StatusFound)
case resStart: // responding to HEAD, request POST with reader data
resp.WriteHeader(http.StatusOK)
case resError:
resp.WriteHeader(500)
default:
log.Errorf("unknown res.rt")
resp.WriteHeader(500)
Expand Down Expand Up @@ -378,7 +421,7 @@ func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) {
readersLk.Lock()
ch, found := readers[u]
if !found {
ch = make(chan *rpcReader)
ch = make(chan *RpcReader)
readers[u] = ch
}
readersLk.Unlock()
Expand Down
37 changes: 31 additions & 6 deletions lib/rpcenc/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,22 @@ type ReaderHandler struct {
readApi func(ctx context.Context, r io.Reader) ([]byte, error)
}

func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader) ([]byte, error) {
func (h *ReaderHandler) ReadAllApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}
return h.readApi(ctx, r)
}

func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byte, error) {
func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error) {
if mustRedir {
if err := r.(*RpcReader).MustRedirect(); err != nil {
return nil, err
}
}

n, err := r.Read([]byte{0})
if err != nil {
return nil, err
Expand All @@ -36,6 +47,10 @@ func (h *ReaderHandler) ReadStartAndApi(ctx context.Context, r io.Reader) ([]byt
return h.readApi(ctx, r)
}

func (h *ReaderHandler) CloseReader(ctx context.Context, r io.Reader) error {
return r.(io.Closer).Close()
}

func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) {
return ioutil.ReadAll(r)
}
Expand Down Expand Up @@ -133,8 +148,9 @@ func TestReaderRedirect(t *testing.T) {
}

var redirClient struct {
ReadAllApi func(ctx context.Context, r io.Reader) ([]byte, error)
ReadStartAndApi func(ctx context.Context, r io.Reader) ([]byte, error)
ReadAllApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
ReadStartAndApi func(ctx context.Context, r io.Reader, mustRedir bool) ([]byte, error)
CloseReader func(ctx context.Context, r io.Reader) error
}

{
Expand All @@ -158,12 +174,21 @@ func TestReaderRedirect(t *testing.T) {
}

// redirect
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"))
read, err := redirClient.ReadAllApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.NoError(t, err)
require.Equal(t, "rediracted pooooootato", string(read), "potatoes weren't equal")

// proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"))
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), false)
require.NoError(t, err)
require.Equal(t, "ediracted pooooootato", string(read), "otatoes weren't equal")

// check mustredir check; proxy (because we started reading locally)
read, err = redirClient.ReadStartAndApi(context.TODO(), strings.NewReader("rediracted pooooootato"), true)
require.Error(t, err)
require.Contains(t, err.Error(), ErrMustRedirect.Error())
require.Empty(t, read)

err = redirClient.CloseReader(context.TODO(), strings.NewReader("rediracted pooooootato"))
require.NoError(t, err)
}

0 comments on commit e7470ed

Please sign in to comment.