diff --git a/go.mod b/go.mod index 0556056e..2cc14aae 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ipfs/go-graphsync go 1.12 require ( + github.com/filecoin-project/go-data-transfer v0.0.0-20200408061858-82c58b423ca6 github.com/filecoin-project/go-fil-markets v0.0.0-20200408062434-d92f329a6428 github.com/gogo/protobuf v1.3.1 github.com/ipfs/go-block-format v0.0.2 diff --git a/go.sum b/go.sum index 1ec6e071..0b3c0bea 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,7 @@ github.com/filecoin-project/go-amt-ipld/v2 v2.0.1-0.20200131012142-05d80eeccc5e github.com/filecoin-project/go-amt-ipld/v2 v2.0.1-0.20200131012142-05d80eeccc5e/go.mod h1:boRtQhzmxNocrMxOXo1NYn4oUc1NGvR8tEa79wApNXg= github.com/filecoin-project/go-cbor-util v0.0.0-20191219014500-08c40a1e63a2/go.mod h1:pqTiPHobNkOVM5thSRsHYjyQfq7O5QSCMhvuu9JoDlg= github.com/filecoin-project/go-crypto v0.0.0-20191218222705-effae4ea9f03/go.mod h1:+viYnvGtUTgJRdy6oaeF4MTFKAfatX071MPDPBL11EQ= +github.com/filecoin-project/go-data-transfer v0.0.0-20200408061858-82c58b423ca6 h1:CIQ7RlW7I3E+JBxfKiK0ZWO9HPSBqlI5aeA/sdwyVTc= github.com/filecoin-project/go-data-transfer v0.0.0-20200408061858-82c58b423ca6/go.mod h1:7b5/sG9Jj33aWqft8XsH8yIdxZBACqS5tx9hv4uj2Ck= github.com/filecoin-project/go-fil-commcid v0.0.0-20200208005934-2b8bd03caca5/go.mod h1:JbkIgFF/Z9BDlvrJO1FuKkaWsH673/UdFaiVS6uIHlA= github.com/filecoin-project/go-fil-markets v0.0.0-20200408062434-d92f329a6428 h1:y8P10ZwfmsKMVHrqcU1L8Bgj2q42O6LzaySI+XaogXE= diff --git a/graphsync.go b/graphsync.go index 0e35f236..12ca81e3 100644 --- a/graphsync.go +++ b/graphsync.go @@ -59,6 +59,9 @@ const ( // PartialResponse may include blocks and metadata about the in progress response // in extra. PartialResponse = ResponseStatusCode(14) + // RequestPaused indicates a request is paused and will not send any more data + // until unpaused + RequestPaused = ResponseStatusCode(15) // Success Response Codes (request terminated) @@ -157,19 +160,34 @@ type IncomingRequestHookActions interface { ValidateRequest() } +// OutgoingBlockHookActions are actions that an outgoing block hook can take to +// change the execution of a request +type OutgoingBlockHookActions interface { + SendExtensionData(ExtensionData) + TerminateWithError(error) + PauseResponse() +} + // OutgoingRequestHookActions are actions that an outgoing request hook can take -// to change the execution of this request +// to change the execution of a request type OutgoingRequestHookActions interface { UsePersistenceOption(name string) UseNodeBuilderChooser(traversal.NodeBuilderChooser) } -// OutgoingBlockHookActions are actions that an outgoing block hook can take to -// change the execution of this request -type OutgoingBlockHookActions interface { - SendExtensionData(ExtensionData) +// IncomingResponseHookActions are actions that incoming response hook can take +// to change the execution of a request +type IncomingResponseHookActions interface { TerminateWithError(error) - PauseResponse() + UpdateRequestWithExtensions(...ExtensionData) +} + +// RequestUpdatedHookActions are actions that can be taken in a request updated hook to +// change execution of the response +type RequestUpdatedHookActions interface { + TerminateWithError(error) + SendExtensionData(ExtensionData) + UnpauseResponse() } // OnIncomingRequestHook is a hook that runs each time a new request is received. @@ -180,7 +198,7 @@ type OnIncomingRequestHook func(p peer.ID, request RequestData, hookActions Inco // OnIncomingResponseHook is a hook that runs each time a new response is received. // It receives the peer that sent the response and all data about the response. // If it returns an error processing is halted and the original request is cancelled. -type OnIncomingResponseHook func(p peer.ID, responseData ResponseData) error +type OnIncomingResponseHook func(p peer.ID, responseData ResponseData, hookActions IncomingResponseHookActions) // OnOutgoingRequestHook is a hook that runs immediately prior to sending a request // It receives the peer we're sending a request to and all the data aobut the request @@ -194,6 +212,11 @@ type OnOutgoingRequestHook func(p peer.ID, request RequestData, hookActions Outg // It receives an interface for taking further action on the response type OnOutgoingBlockHook func(p peer.ID, request RequestData, block BlockData, hookActions OutgoingBlockHookActions) +// OnRequestUpdatedHook is a hook that runs when an update to a request is received +// It receives the peer we're sending to, the original request, the request update +// It receives an interface to taking further action on the response +type OnRequestUpdatedHook func(p peer.ID, request RequestData, updateRequest RequestData, hookActions RequestUpdatedHookActions) + // UnregisterHookFunc is a function call to unregister a hook that was previously registered type UnregisterHookFunc func() @@ -217,6 +240,9 @@ type GraphExchange interface { // RegisterOutgoingBlockHook adds a hook that runs every time a block is sent from a responder RegisterOutgoingBlockHook(hook OnOutgoingBlockHook) UnregisterHookFunc + // RegisterRequestUpdatedHook adds a hook that runs every time an update to a request is received + RegisterRequestUpdatedHook(hook OnRequestUpdatedHook) UnregisterHookFunc + // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID UnpauseResponse(peer.ID, RequestID) error } diff --git a/impl/graphsync.go b/impl/graphsync.go index 992bfc8f..742057da 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -10,11 +10,11 @@ import ( "github.com/ipfs/go-graphsync/peermanager" "github.com/ipfs/go-graphsync/requestmanager" "github.com/ipfs/go-graphsync/requestmanager/asyncloader" + requestorhooks "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/responsemanager" - "github.com/ipfs/go-graphsync/responsemanager/blockhooks" + responderhooks "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" "github.com/ipfs/go-graphsync/responsemanager/persistenceoptions" - "github.com/ipfs/go-graphsync/responsemanager/requesthooks" "github.com/ipfs/go-graphsync/selectorvalidator" logging "github.com/ipfs/go-log" "github.com/ipfs/go-peertaskqueue" @@ -38,8 +38,11 @@ type GraphSync struct { peerResponseManager *peerresponsemanager.PeerResponseManager peerTaskQueue *peertaskqueue.PeerTaskQueue peerManager *peermanager.PeerMessageManager - incomingRequestHooks *requesthooks.IncomingRequestHooks - outgoingBlockHooks *blockhooks.OutgoingBlockHooks + incomingRequestHooks *responderhooks.IncomingRequestHooks + outgoingBlockHooks *responderhooks.OutgoingBlockHooks + requestUpdatedHooks *responderhooks.RequestUpdatedHooks + incomingResponseHooks *requestorhooks.IncomingResponseHooks + outgoingRequestHooks *requestorhooks.OutgoingRequestHooks persistenceOptions *persistenceoptions.PersistenceOptions ctx context.Context cancel context.CancelFunc @@ -69,16 +72,19 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, } peerManager := peermanager.NewMessageManager(ctx, createMessageQueue) asyncLoader := asyncloader.New(ctx, loader, storer) - requestManager := requestmanager.New(ctx, asyncLoader) + incomingResponseHooks := requestorhooks.NewResponseHooks() + outgoingRequestHooks := requestorhooks.NewRequestHooks() + requestManager := requestmanager.New(ctx, asyncLoader, outgoingRequestHooks, incomingResponseHooks) peerTaskQueue := peertaskqueue.New() createdResponseQueue := func(ctx context.Context, p peer.ID) peerresponsemanager.PeerResponseSender { return peerresponsemanager.NewResponseSender(ctx, p, peerManager) } peerResponseManager := peerresponsemanager.New(ctx, createdResponseQueue) persistenceOptions := persistenceoptions.New() - incomingRequestHooks := requesthooks.New(persistenceOptions) - outgoingBlockHooks := blockhooks.New() - responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue, incomingRequestHooks, outgoingBlockHooks) + incomingRequestHooks := responderhooks.NewRequestHooks(persistenceOptions) + outgoingBlockHooks := responderhooks.NewBlockHooks() + requestUpdatedHooks := responderhooks.NewUpdateHooks() + responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks) unregisterDefaultValidator := incomingRequestHooks.Register(selectorvalidator.SelectorValidator(maxRecursionDepth)) graphSync := &GraphSync{ network: network, @@ -90,6 +96,9 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, persistenceOptions: persistenceOptions, incomingRequestHooks: incomingRequestHooks, outgoingBlockHooks: outgoingBlockHooks, + requestUpdatedHooks: requestUpdatedHooks, + incomingResponseHooks: incomingResponseHooks, + outgoingRequestHooks: outgoingRequestHooks, peerTaskQueue: peerTaskQueue, peerResponseManager: peerResponseManager, responseManager: responseManager, @@ -125,12 +134,12 @@ func (gs *GraphSync) RegisterIncomingRequestHook(hook graphsync.OnIncomingReques // RegisterIncomingResponseHook adds a hook that runs when a response is received func (gs *GraphSync) RegisterIncomingResponseHook(hook graphsync.OnIncomingResponseHook) graphsync.UnregisterHookFunc { - return gs.requestManager.RegisterResponseHook(hook) + return gs.incomingResponseHooks.Register(hook) } // RegisterOutgoingRequestHook adds a hook that runs immediately prior to sending a new request func (gs *GraphSync) RegisterOutgoingRequestHook(hook graphsync.OnOutgoingRequestHook) graphsync.UnregisterHookFunc { - return gs.requestManager.RegisterRequestHook(hook) + return gs.outgoingRequestHooks.Register(hook) } // RegisterPersistenceOption registers an alternate loader/storer combo that can be substituted for the default @@ -147,6 +156,11 @@ func (gs *GraphSync) RegisterOutgoingBlockHook(hook graphsync.OnOutgoingBlockHoo return gs.outgoingBlockHooks.Register(hook) } +// RegisterRequestUpdatedHook registers a hook that runs when an update to a request is received +func (gs *GraphSync) RegisterRequestUpdatedHook(hook graphsync.OnRequestUpdatedHook) graphsync.UnregisterHookFunc { + return gs.requestUpdatedHooks.Register(hook) +} + // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { return gs.responseManager.UnpauseResponse(p, requestID) diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index d70d5656..0e1b8543 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -191,12 +191,11 @@ func TestGraphsyncRoundTrip(t *testing.T) { var receivedRequestData []byte requestor.RegisterIncomingResponseHook( - func(p peer.ID, responseData graphsync.ResponseData) error { + func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { data, has := responseData.Extension(td.extensionName) if has { receivedResponseData = data } - return nil }) responder.RegisterIncomingRequestHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { @@ -272,6 +271,65 @@ func TestPauseResume(t *testing.T) { } +func TestPauseResumeViaUpdate(t *testing.T) { + // create network + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + td := newGsTestData(ctx, t) + + var receivedReponseData []byte + var receivedUpdateData []byte + // initialize graphsync on first node to make requests + requestor := td.GraphSyncHost1() + + requestor.RegisterIncomingResponseHook(func(p peer.ID, response graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + if response.Status() == graphsync.RequestPaused { + var has bool + receivedReponseData, has = response.Extension(td.extensionName) + if has { + hookActions.UpdateRequestWithExtensions(td.extensionUpdate) + } + } + }) + + // setup receiving peer to just record message coming in + blockChainLength := 100 + blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 100, blockChainLength) + + // initialize graphsync on second node to response to requests + responder := td.GraphSyncHost2() + stopPoint := 50 + blocksSent := 0 + responder.RegisterOutgoingBlockHook(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + _, has := requestData.Extension(td.extensionName) + if has { + blocksSent++ + if blocksSent == stopPoint { + hookActions.SendExtensionData(td.extensionResponse) + hookActions.PauseResponse() + } + } else { + hookActions.TerminateWithError(errors.New("should have sent extension")) + } + }) + responder.RegisterRequestUpdatedHook(func(p peer.ID, request graphsync.RequestData, update graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + var has bool + receivedUpdateData, has = update.Extension(td.extensionName) + if has { + hookActions.UnpauseResponse() + } + }) + progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension) + + blockChain.VerifyWholeChain(ctx, progressChan) + testutil.VerifyEmptyErrors(ctx, t, errChan) + require.Len(t, td.blockStore1, blockChainLength, "did not store all blocks") + + require.Equal(t, td.extensionResponseData, receivedReponseData, "did not receive correct extension response data") + require.Equal(t, td.extensionUpdateData, receivedUpdateData, "did not receive correct extension update data") +} + func TestGraphsyncRoundTripAlternatePersistenceAndNodes(t *testing.T) { // create network ctx := context.Background() @@ -539,6 +597,8 @@ type gsTestData struct { extension graphsync.ExtensionData extensionResponseData []byte extensionResponse graphsync.ExtensionData + extensionUpdateData []byte + extensionUpdate graphsync.ExtensionData } func newGsTestData(ctx context.Context, t *testing.T) *gsTestData { @@ -571,6 +631,11 @@ func newGsTestData(ctx context.Context, t *testing.T) *gsTestData { Name: td.extensionName, Data: td.extensionResponseData, } + td.extensionUpdateData = testutil.RandomBytes(100) + td.extensionUpdate = graphsync.ExtensionData{ + Name: td.extensionName, + Data: td.extensionUpdateData, + } return td } diff --git a/message/message.go b/message/message.go index d3ddb375..d62675c8 100644 --- a/message/message.go +++ b/message/message.go @@ -74,6 +74,7 @@ type GraphSyncRequest struct { id graphsync.RequestID extensions map[string][]byte isCancel bool + isUpdate bool } // GraphSyncResponse is an struct to capture data on a response sent back @@ -110,12 +111,17 @@ func NewRequest(id graphsync.RequestID, priority graphsync.Priority, extensions ...graphsync.ExtensionData) GraphSyncRequest { - return newRequest(id, root, selector, priority, false, toExtensionsMap(extensions)) + return newRequest(id, root, selector, priority, false, false, toExtensionsMap(extensions)) } // CancelRequest request generates a request to cancel an in progress request func CancelRequest(id graphsync.RequestID) GraphSyncRequest { - return newRequest(id, cid.Cid{}, nil, 0, true, nil) + return newRequest(id, cid.Cid{}, nil, 0, true, false, nil) +} + +// UpdateRequest generates a new request to update an in progress request with the given extensions +func UpdateRequest(id graphsync.RequestID, extensions ...graphsync.ExtensionData) GraphSyncRequest { + return newRequest(id, cid.Cid{}, nil, 0, false, true, toExtensionsMap(extensions)) } func toExtensionsMap(extensions []graphsync.ExtensionData) (extensionsMap map[string][]byte) { @@ -133,6 +139,7 @@ func newRequest(id graphsync.RequestID, selector ipld.Node, priority graphsync.Priority, isCancel bool, + isUpdate bool, extensions map[string][]byte) GraphSyncRequest { return GraphSyncRequest{ id: id, @@ -140,6 +147,7 @@ func newRequest(id graphsync.RequestID, selector: selector, priority: priority, isCancel: isCancel, + isUpdate: isUpdate, extensions: extensions, } } @@ -162,15 +170,23 @@ func newResponse(requestID graphsync.RequestID, func newMessageFromProto(pbm pb.Message) (GraphSyncMessage, error) { gsm := newMsg() for _, req := range pbm.Requests { - root, err := cid.Cast(req.Root) - if err != nil { - return nil, err + var root cid.Cid + var err error + if !req.Cancel && !req.Update { + root, err = cid.Cast(req.Root) + if err != nil { + return nil, err + } } - selector, err := ipldutil.DecodeNode(req.Selector) - if err != nil { - return nil, err + + var selector ipld.Node + if !req.Cancel && !req.Update { + selector, err = ipldutil.DecodeNode(req.Selector) + if err != nil { + return nil, err + } } - gsm.AddRequest(newRequest(graphsync.RequestID(req.Id), root, selector, graphsync.Priority(req.Priority), req.Cancel, req.GetExtensions())) + gsm.AddRequest(newRequest(graphsync.RequestID(req.Id), root, selector, graphsync.Priority(req.Priority), req.Cancel, req.Update, req.GetExtensions())) } for _, res := range pbm.Responses { @@ -273,6 +289,7 @@ func (gsm *graphSyncMessage) ToProto() (*pb.Message, error) { Selector: selector, Priority: int32(request.priority), Cancel: request.isCancel, + Update: request.isUpdate, Extensions: request.extensions, }) } @@ -349,6 +366,9 @@ func (gsr GraphSyncRequest) Extension(name graphsync.ExtensionName) ([]byte, boo // IsCancel returns true if this particular request is being cancelled func (gsr GraphSyncRequest) IsCancel() bool { return gsr.isCancel } +// IsUpdate returns true if this particular request is being updated +func (gsr GraphSyncRequest) IsUpdate() bool { return gsr.isUpdate } + // RequestID returns the request ID for this response func (gsr GraphSyncResponse) RequestID() graphsync.RequestID { return gsr.requestID } diff --git a/message/message_test.go b/message/message_test.go index b47cc6d1..3cb6d0aa 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -51,6 +51,7 @@ func TestAppendingRequests(t *testing.T) { require.Equal(t, int32(id), pbRequest.Id) require.Equal(t, int32(priority), pbRequest.Priority) require.False(t, pbRequest.Cancel) + require.False(t, pbRequest.Update) require.Equal(t, root.Bytes(), pbRequest.Root) require.Equal(t, selectorEncoded, pbRequest.Selector) require.Equal(t, map[string][]byte{"graphsync/awesome": extension.Data}, pbRequest.Extensions) @@ -64,6 +65,7 @@ func TestAppendingRequests(t *testing.T) { extensionData, found = deserializedRequest.Extension(extensionName) require.Equal(t, id, deserializedRequest.ID()) require.False(t, deserializedRequest.IsCancel()) + require.False(t, deserializedRequest.IsUpdate()) require.Equal(t, priority, deserializedRequest.Priority()) require.Equal(t, root.String(), deserializedRequest.Root().String()) require.Equal(t, selector, deserializedRequest.Selector()) @@ -158,6 +160,59 @@ func TestRequestCancel(t *testing.T) { request := requests[0] require.Equal(t, id, request.ID()) require.True(t, request.IsCancel()) + + buf := new(bytes.Buffer) + err := gsm.ToNet(buf) + require.NoError(t, err, "did not serialize protobuf message") + deserialized, err := FromNet(buf) + require.NoError(t, err, "did not deserialize protobuf message") + deserializedRequests := deserialized.Requests() + require.Len(t, deserializedRequests, 1, "did not add request to deserialized message") + deserializedRequest := deserializedRequests[0] + require.Equal(t, request.ID(), deserializedRequest.ID()) + require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) +} + +func TestRequestUpdate(t *testing.T) { + + id := graphsync.RequestID(rand.Int31()) + extensionName := graphsync.ExtensionName("graphsync/awesome") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: testutil.RandomBytes(100), + } + + gsm := New() + gsm.AddRequest(UpdateRequest(id, extension)) + + requests := gsm.Requests() + require.Len(t, requests, 1, "did not add cancel request") + request := requests[0] + require.Equal(t, id, request.ID()) + require.True(t, request.IsUpdate()) + require.False(t, request.IsCancel()) + extensionData, found := request.Extension(extensionName) + require.True(t, found) + require.Equal(t, extension.Data, extensionData) + + buf := new(bytes.Buffer) + err := gsm.ToNet(buf) + require.NoError(t, err, "did not serialize protobuf message") + deserialized, err := FromNet(buf) + require.NoError(t, err, "did not deserialize protobuf message") + + deserializedRequests := deserialized.Requests() + require.Len(t, deserializedRequests, 1, "did not add request to deserialized message") + deserializedRequest := deserializedRequests[0] + extensionData, found = deserializedRequest.Extension(extensionName) + require.Equal(t, request.ID(), deserializedRequest.ID()) + require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) + require.Equal(t, request.IsUpdate(), deserializedRequest.IsUpdate()) + require.Equal(t, request.Priority(), deserializedRequest.Priority()) + require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) + require.Equal(t, request.Selector(), deserializedRequest.Selector()) + require.True(t, found) + require.Equal(t, extension.Data, extensionData) } func TestToNetFromNetEquivalency(t *testing.T) { @@ -197,6 +252,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { extensionData, found := deserializedRequest.Extension(extensionName) require.Equal(t, request.ID(), deserializedRequest.ID()) require.False(t, deserializedRequest.IsCancel()) + require.False(t, deserializedRequest.IsUpdate()) require.Equal(t, request.Priority(), deserializedRequest.Priority()) require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) require.Equal(t, request.Selector(), deserializedRequest.Selector()) @@ -221,7 +277,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { } for _, b := range gsm.Blocks() { - _, ok := keys[b.Cid()]; + _, ok := keys[b.Cid()] require.True(t, ok) } } diff --git a/message/pb/message.pb.go b/message/pb/message.pb.go index 257fc2bf..50dd15aa 100644 --- a/message/pb/message.pb.go +++ b/message/pb/message.pb.go @@ -33,7 +33,7 @@ func (m *Message) Reset() { *m = Message{} } func (m *Message) String() string { return proto.CompactTextString(m) } func (*Message) ProtoMessage() {} func (*Message) Descriptor() ([]byte, []int) { - return fileDescriptor_message_c5788c4e9f6c17be, []int{0} + return fileDescriptor_message_5de5dd65106cd0db, []int{0} } func (m *Message) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -97,13 +97,14 @@ type Message_Request struct { Extensions map[string][]byte `protobuf:"bytes,4,rep,name=extensions" json:"extensions,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` Priority int32 `protobuf:"varint,5,opt,name=priority,proto3" json:"priority,omitempty"` Cancel bool `protobuf:"varint,6,opt,name=cancel,proto3" json:"cancel,omitempty"` + Update bool `protobuf:"varint,7,opt,name=update,proto3" json:"update,omitempty"` } func (m *Message_Request) Reset() { *m = Message_Request{} } func (m *Message_Request) String() string { return proto.CompactTextString(m) } func (*Message_Request) ProtoMessage() {} func (*Message_Request) Descriptor() ([]byte, []int) { - return fileDescriptor_message_c5788c4e9f6c17be, []int{0, 0} + return fileDescriptor_message_5de5dd65106cd0db, []int{0, 0} } func (m *Message_Request) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -174,6 +175,13 @@ func (m *Message_Request) GetCancel() bool { return false } +func (m *Message_Request) GetUpdate() bool { + if m != nil { + return m.Update + } + return false +} + type Message_Response struct { Id int32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` Status int32 `protobuf:"varint,2,opt,name=status,proto3" json:"status,omitempty"` @@ -184,7 +192,7 @@ func (m *Message_Response) Reset() { *m = Message_Response{} } func (m *Message_Response) String() string { return proto.CompactTextString(m) } func (*Message_Response) ProtoMessage() {} func (*Message_Response) Descriptor() ([]byte, []int) { - return fileDescriptor_message_c5788c4e9f6c17be, []int{0, 1} + return fileDescriptor_message_5de5dd65106cd0db, []int{0, 1} } func (m *Message_Response) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -243,7 +251,7 @@ func (m *Message_Block) Reset() { *m = Message_Block{} } func (m *Message_Block) String() string { return proto.CompactTextString(m) } func (*Message_Block) ProtoMessage() {} func (*Message_Block) Descriptor() ([]byte, []int) { - return fileDescriptor_message_c5788c4e9f6c17be, []int{0, 2} + return fileDescriptor_message_5de5dd65106cd0db, []int{0, 2} } func (m *Message_Block) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,6 +436,16 @@ func (m *Message_Request) MarshalTo(dAtA []byte) (int, error) { } i++ } + if m.Update { + dAtA[i] = 0x38 + i++ + if m.Update { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } return i, nil } @@ -586,6 +604,9 @@ func (m *Message_Request) Size() (n int) { if m.Cancel { n += 2 } + if m.Update { + n += 2 + } return n } @@ -1077,6 +1098,26 @@ func (m *Message_Request) Unmarshal(dAtA []byte) error { } } m.Cancel = bool(v != 0) + case 7: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Update", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowMessage + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Update = bool(v != 0) default: iNdEx = preIndex skippy, err := skipMessage(dAtA[iNdEx:]) @@ -1522,36 +1563,37 @@ var ( ErrIntOverflowMessage = fmt.Errorf("proto: integer overflow") ) -func init() { proto.RegisterFile("message.proto", fileDescriptor_message_c5788c4e9f6c17be) } +func init() { proto.RegisterFile("message.proto", fileDescriptor_message_5de5dd65106cd0db) } -var fileDescriptor_message_c5788c4e9f6c17be = []byte{ - // 447 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0x5d, 0x6b, 0xd4, 0x40, - 0x14, 0xdd, 0x64, 0x37, 0xe9, 0xf6, 0x5a, 0x3f, 0x18, 0x4b, 0x19, 0xf2, 0x10, 0x17, 0x45, 0xd9, - 0x17, 0x53, 0xb1, 0x28, 0x22, 0xf4, 0x65, 0xa1, 0x08, 0xa2, 0x2f, 0x03, 0xfa, 0x9e, 0xcd, 0xde, - 0xa6, 0x43, 0xb3, 0x99, 0x38, 0x33, 0x91, 0xe6, 0x5f, 0x08, 0xfe, 0x07, 0x7f, 0x4b, 0x7d, 0xeb, - 0xa3, 0x4f, 0x22, 0xbb, 0x7f, 0x44, 0x72, 0x33, 0xc6, 0xaf, 0x52, 0x17, 0x7c, 0xbb, 0x67, 0x67, - 0xce, 0x99, 0x73, 0xee, 0xd9, 0xc0, 0xf5, 0x25, 0x1a, 0x93, 0xe6, 0x98, 0x54, 0x5a, 0x59, 0xc5, - 0x76, 0x73, 0x9d, 0x56, 0x27, 0xa6, 0x29, 0xb3, 0xa4, 0x3f, 0x98, 0x47, 0x0f, 0x73, 0x69, 0x4f, - 0xea, 0x79, 0x92, 0xa9, 0xe5, 0x7e, 0xae, 0x72, 0xb5, 0x4f, 0x97, 0xe7, 0xf5, 0x31, 0x21, 0x02, - 0x34, 0x75, 0x22, 0x77, 0x3f, 0x85, 0xb0, 0xf5, 0xba, 0x63, 0xb3, 0x47, 0x70, 0x3b, 0x53, 0xcb, - 0xaa, 0x40, 0x8b, 0x02, 0xdf, 0xd5, 0x68, 0xec, 0x2b, 0x69, 0x2c, 0xf7, 0x26, 0xde, 0x74, 0x2c, - 0x2e, 0x3b, 0x62, 0x2f, 0x60, 0xac, 0x3b, 0x68, 0xb8, 0x3f, 0x19, 0x4e, 0xaf, 0x3d, 0xbe, 0x9f, - 0x5c, 0xe6, 0x2a, 0x71, 0x4f, 0x24, 0x8e, 0x3c, 0x1b, 0x9d, 0x7f, 0xbd, 0x33, 0x10, 0x3d, 0x99, - 0xbd, 0x84, 0x6d, 0x8d, 0xa6, 0x52, 0xa5, 0x41, 0xc3, 0x87, 0xa4, 0xf4, 0xe0, 0x5f, 0x4a, 0xdd, - 0x75, 0x27, 0xf5, 0x93, 0xce, 0x0e, 0x61, 0xb4, 0x48, 0x6d, 0xca, 0x47, 0x24, 0x73, 0xef, 0x6a, - 0x99, 0x59, 0xa1, 0xb2, 0x53, 0xa7, 0x41, 0xb4, 0xe8, 0xa3, 0x0f, 0x5b, 0xce, 0x26, 0xbb, 0x01, - 0xbe, 0x5c, 0xd0, 0x02, 0x02, 0xe1, 0xcb, 0x05, 0x63, 0x30, 0xd2, 0x4a, 0x59, 0xee, 0x4f, 0xbc, - 0xe9, 0x8e, 0xa0, 0x99, 0x45, 0x30, 0x36, 0x58, 0x60, 0x66, 0x95, 0xe6, 0x43, 0xfa, 0xbd, 0xc7, - 0xec, 0x0d, 0x00, 0x9e, 0x59, 0x2c, 0x8d, 0x54, 0xa5, 0x71, 0x86, 0x9e, 0x6c, 0xb4, 0xa1, 0xe4, - 0xa8, 0xe7, 0x1d, 0x95, 0x56, 0x37, 0xe2, 0x17, 0xa1, 0xf6, 0xc9, 0x4a, 0x4b, 0xa5, 0xa5, 0x6d, - 0x78, 0x40, 0xe6, 0x7a, 0xcc, 0xf6, 0x20, 0xcc, 0xd2, 0x32, 0xc3, 0x82, 0x87, 0xd4, 0x9b, 0x43, - 0xd1, 0x21, 0xdc, 0xfc, 0x43, 0x92, 0xdd, 0x82, 0xe1, 0x29, 0x36, 0x14, 0x6f, 0x5b, 0xb4, 0x23, - 0xdb, 0x85, 0xe0, 0x7d, 0x5a, 0xd4, 0xe8, 0x02, 0x76, 0xe0, 0xb9, 0xff, 0xcc, 0x8b, 0x3e, 0x7b, - 0x30, 0xfe, 0xb1, 0xf2, 0xbf, 0xd6, 0xb2, 0x07, 0xa1, 0xb1, 0xa9, 0xad, 0x0d, 0xf1, 0x02, 0xe1, - 0x10, 0x7b, 0xfb, 0x5b, 0xfc, 0xae, 0xd6, 0xa7, 0x9b, 0xd5, 0x7a, 0x55, 0xfe, 0xff, 0xcd, 0x72, - 0x00, 0x01, 0xd5, 0xde, 0xfa, 0xae, 0x34, 0x1e, 0xcb, 0x33, 0xe2, 0xed, 0x08, 0x87, 0xda, 0x9a, - 0xe9, 0x1f, 0xe4, 0x6a, 0x6e, 0xe7, 0x19, 0x3f, 0x5f, 0xc5, 0xde, 0xc5, 0x2a, 0xf6, 0xbe, 0xad, - 0x62, 0xef, 0xc3, 0x3a, 0x1e, 0x5c, 0xac, 0xe3, 0xc1, 0x97, 0x75, 0x3c, 0x98, 0x87, 0xf4, 0x25, - 0x1d, 0x7c, 0x0f, 0x00, 0x00, 0xff, 0xff, 0x8a, 0x3b, 0xe8, 0x33, 0x9f, 0x03, 0x00, 0x00, +var fileDescriptor_message_5de5dd65106cd0db = []byte{ + // 458 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x53, 0xcb, 0x6e, 0xd3, 0x40, + 0x14, 0x8d, 0x9d, 0xd8, 0x49, 0x2f, 0xe5, 0xa1, 0xa1, 0xaa, 0x46, 0x5e, 0x98, 0x08, 0x04, 0xca, + 0x06, 0x17, 0x51, 0x81, 0x10, 0x52, 0x37, 0x91, 0x2a, 0x24, 0x04, 0x9b, 0x91, 0x60, 0xef, 0x38, + 0xb7, 0xae, 0x55, 0xc7, 0x63, 0x66, 0xc6, 0xa8, 0xfe, 0x0b, 0xfe, 0x83, 0x7f, 0x60, 0x5d, 0x76, + 0x5d, 0xb2, 0x42, 0x28, 0xf9, 0x11, 0xe4, 0x3b, 0x83, 0x79, 0x55, 0xa5, 0x12, 0xbb, 0x7b, 0xee, + 0xcc, 0x39, 0xf7, 0x71, 0x66, 0xe0, 0xfa, 0x0a, 0xb5, 0x4e, 0x73, 0x4c, 0x6a, 0x25, 0x8d, 0x64, + 0x3b, 0xb9, 0x4a, 0xeb, 0x63, 0xdd, 0x56, 0x59, 0xd2, 0x1f, 0x2c, 0xa2, 0x87, 0x79, 0x61, 0x8e, + 0x9b, 0x45, 0x92, 0xc9, 0xd5, 0x5e, 0x2e, 0x73, 0xb9, 0x47, 0x97, 0x17, 0xcd, 0x11, 0x21, 0x02, + 0x14, 0x59, 0x91, 0xbb, 0x9f, 0x42, 0x18, 0xbf, 0xb6, 0x6c, 0xf6, 0x08, 0x6e, 0x67, 0x72, 0x55, + 0x97, 0x68, 0x50, 0xe0, 0xbb, 0x06, 0xb5, 0x79, 0x55, 0x68, 0xc3, 0xbd, 0xa9, 0x37, 0x9b, 0x88, + 0x8b, 0x8e, 0xd8, 0x0b, 0x98, 0x28, 0x0b, 0x35, 0xf7, 0xa7, 0xc3, 0xd9, 0xb5, 0xc7, 0xf7, 0x93, + 0x8b, 0xba, 0x4a, 0x5c, 0x89, 0xc4, 0x91, 0xe7, 0xa3, 0xb3, 0xaf, 0x77, 0x06, 0xa2, 0x27, 0xb3, + 0x97, 0xb0, 0xa5, 0x50, 0xd7, 0xb2, 0xd2, 0xa8, 0xf9, 0x90, 0x94, 0x1e, 0xfc, 0x4b, 0xc9, 0x5e, + 0x77, 0x52, 0x3f, 0xe9, 0xec, 0x00, 0x46, 0xcb, 0xd4, 0xa4, 0x7c, 0x44, 0x32, 0xf7, 0x2e, 0x97, + 0x99, 0x97, 0x32, 0x3b, 0x71, 0x1a, 0x44, 0x8b, 0x3e, 0xfa, 0x30, 0x76, 0x6d, 0xb2, 0x1b, 0xe0, + 0x17, 0x4b, 0x5a, 0x40, 0x20, 0xfc, 0x62, 0xc9, 0x18, 0x8c, 0x94, 0x94, 0x86, 0xfb, 0x53, 0x6f, + 0xb6, 0x2d, 0x28, 0x66, 0x11, 0x4c, 0x34, 0x96, 0x98, 0x19, 0xa9, 0xf8, 0x90, 0xf2, 0x3d, 0x66, + 0x6f, 0x00, 0xf0, 0xd4, 0x60, 0xa5, 0x0b, 0x59, 0x69, 0xd7, 0xd0, 0x93, 0x2b, 0x6d, 0x28, 0x39, + 0xec, 0x79, 0x87, 0x95, 0x51, 0xad, 0xf8, 0x45, 0xa8, 0x2b, 0x59, 0xab, 0x42, 0xaa, 0xc2, 0xb4, + 0x3c, 0xa0, 0xe6, 0x7a, 0xcc, 0x76, 0x21, 0xcc, 0xd2, 0x2a, 0xc3, 0x92, 0x87, 0xe4, 0x9b, 0x43, + 0x5d, 0xbe, 0xa9, 0x97, 0xa9, 0x41, 0x3e, 0xb6, 0x79, 0x8b, 0xa2, 0x03, 0xb8, 0xf9, 0x47, 0x29, + 0x76, 0x0b, 0x86, 0x27, 0xd8, 0xd2, 0xd8, 0x5b, 0xa2, 0x0b, 0xd9, 0x0e, 0x04, 0xef, 0xd3, 0xb2, + 0x41, 0x37, 0xb8, 0x05, 0xcf, 0xfd, 0x67, 0x5e, 0xf4, 0xd9, 0x83, 0xc9, 0x0f, 0x2b, 0xfe, 0x5a, + 0xd7, 0x2e, 0x84, 0xda, 0xa4, 0xa6, 0xd1, 0xc4, 0x0b, 0x84, 0x43, 0xec, 0xed, 0x6f, 0x6b, 0xb1, + 0x76, 0x3f, 0xbd, 0x9a, 0xdd, 0x97, 0xed, 0xe5, 0x7f, 0x67, 0xd9, 0x87, 0x80, 0x9e, 0x43, 0xd7, + 0x77, 0xad, 0xf0, 0xa8, 0x38, 0x25, 0xde, 0xb6, 0x70, 0xa8, 0xb3, 0x9f, 0x5e, 0x96, 0xb3, 0xbf, + 0x8b, 0xe7, 0xfc, 0x6c, 0x1d, 0x7b, 0xe7, 0xeb, 0xd8, 0xfb, 0xb6, 0x8e, 0xbd, 0x0f, 0x9b, 0x78, + 0x70, 0xbe, 0x89, 0x07, 0x5f, 0x36, 0xf1, 0x60, 0x11, 0xd2, 0x0f, 0xdb, 0xff, 0x1e, 0x00, 0x00, + 0xff, 0xff, 0xaa, 0x5e, 0x01, 0x7e, 0xb7, 0x03, 0x00, 0x00, } diff --git a/message/pb/message.proto b/message/pb/message.proto index 173db223..7aba3c13 100644 --- a/message/pb/message.proto +++ b/message/pb/message.proto @@ -13,6 +13,7 @@ message Message { map extensions = 4; // aux information. useful for other protocols int32 priority = 5; // the priority (normalized). default to 1 bool cancel = 6; // whether this cancels a request + bool update = 7; // whether this requests resumes a previous request } message Response { diff --git a/requestmanager/hooks/hooks_test.go b/requestmanager/hooks/hooks_test.go new file mode 100644 index 00000000..d36d65f7 --- /dev/null +++ b/requestmanager/hooks/hooks_test.go @@ -0,0 +1,182 @@ +package hooks_test + +import ( + "errors" + "math/rand" + "testing" + + "github.com/ipfs/go-graphsync" + gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/requestmanager/hooks" + "github.com/ipfs/go-graphsync/testutil" + "github.com/ipld/go-ipld-prime" + ipldfree "github.com/ipld/go-ipld-prime/impl/free" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" + peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +func TestRequestHookProcessing(t *testing.T) { + fakeChooser := func(ipld.Link, ipld.LinkContext) (ipld.NodeBuilder, error) { + return ipldfree.NodeBuilder(), nil + } + extensionData := testutil.RandomBytes(100) + extensionName := graphsync.ExtensionName("AppleSauce/McGee") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionData, + } + + root := testutil.GenerateCids(1)[0] + requestID := graphsync.RequestID(rand.Int31()) + ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) + request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) + p := testutil.GeneratePeers(1)[0] + testCases := map[string]struct { + configure func(t *testing.T, hooks *hooks.OutgoingRequestHooks) + assert func(t *testing.T, result hooks.RequestResult) + }{ + "no hooks": { + assert: func(t *testing.T, result hooks.RequestResult) { + require.Nil(t, result.CustomChooser) + require.Empty(t, result.PersistenceOption) + }, + }, + "hooks alter chooser": { + configure: func(t *testing.T, hooks *hooks.OutgoingRequestHooks) { + hooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UseNodeBuilderChooser(fakeChooser) + } + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.NotNil(t, result.CustomChooser) + require.Empty(t, result.PersistenceOption) + }, + }, + "hooks alter persistence option": { + configure: func(t *testing.T, hooks *hooks.OutgoingRequestHooks) { + hooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UsePersistenceOption("chainstore") + } + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.Nil(t, result.CustomChooser) + require.Equal(t, "chainstore", result.PersistenceOption) + }, + }, + "hooks unregistered": { + configure: func(t *testing.T, hooks *hooks.OutgoingRequestHooks) { + unregister := hooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UsePersistenceOption("chainstore") + } + }) + unregister() + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.Nil(t, result.CustomChooser) + require.Empty(t, result.PersistenceOption) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + hooks := hooks.NewRequestHooks() + if data.configure != nil { + data.configure(t, hooks) + } + result := hooks.ProcessRequestHooks(p, request) + if data.assert != nil { + data.assert(t, result) + } + }) + } +} + +func TestResponseHookProcessing(t *testing.T) { + + extensionResponseData := testutil.RandomBytes(100) + extensionName := graphsync.ExtensionName("AppleSauce/McGee") + extensionResponse := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionResponseData, + } + extensionUpdateData := testutil.RandomBytes(100) + extensionUpdate := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionUpdateData, + } + requestID := graphsync.RequestID(rand.Int31()) + response := gsmsg.NewResponse(requestID, graphsync.PartialResponse, extensionResponse) + + p := testutil.GeneratePeers(1)[0] + testCases := map[string]struct { + configure func(t *testing.T, hooks *hooks.IncomingResponseHooks) + assert func(t *testing.T, result hooks.ResponseResult) + }{ + "no hooks": { + assert: func(t *testing.T, result hooks.ResponseResult) { + require.Empty(t, result.Extensions) + require.NoError(t, result.Err) + }, + }, + "short circuit on error": { + configure: func(t *testing.T, hooks *hooks.IncomingResponseHooks) { + hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + hookActions.TerminateWithError(errors.New("something went wrong")) + }) + hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + hookActions.UpdateRequestWithExtensions(extensionUpdate) + }) + }, + assert: func(t *testing.T, result hooks.ResponseResult) { + require.Empty(t, result.Extensions) + require.EqualError(t, result.Err, "something went wrong") + }, + }, + "hooks update with extensions": { + configure: func(t *testing.T, hooks *hooks.IncomingResponseHooks) { + hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + if _, found := responseData.Extension(extensionName); found { + hookActions.UpdateRequestWithExtensions(extensionUpdate) + } + }) + }, + assert: func(t *testing.T, result hooks.ResponseResult) { + require.Len(t, result.Extensions, 1) + require.Equal(t, extensionUpdate, result.Extensions[0]) + require.NoError(t, result.Err) + }, + }, + "hooks unregistered": { + configure: func(t *testing.T, hooks *hooks.IncomingResponseHooks) { + unregister := hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { + if _, found := responseData.Extension(extensionName); found { + hookActions.UpdateRequestWithExtensions(extensionUpdate) + } + }) + unregister() + }, + assert: func(t *testing.T, result hooks.ResponseResult) { + require.Empty(t, result.Extensions) + require.NoError(t, result.Err) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + hooks := hooks.NewResponseHooks() + if data.configure != nil { + data.configure(t, hooks) + } + result := hooks.ProcessResponseHooks(p, response) + if data.assert != nil { + data.assert(t, result) + } + }) + } +} diff --git a/requestmanager/hooks/requesthooks.go b/requestmanager/hooks/requesthooks.go new file mode 100644 index 00000000..e9533f6b --- /dev/null +++ b/requestmanager/hooks/requesthooks.go @@ -0,0 +1,82 @@ +package hooks + +import ( + "sync" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/traversal" + peer "github.com/libp2p/go-libp2p-core/peer" +) + +type requestHook struct { + key uint64 + hook graphsync.OnOutgoingRequestHook +} + +// OutgoingRequestHooks is a set of incoming request hooks that can be processed +type OutgoingRequestHooks struct { + nextKey uint64 + hooksLk sync.RWMutex + hooks []requestHook +} + +// NewRequestHooks returns a new list of incoming request hooks +func NewRequestHooks() *OutgoingRequestHooks { + return &OutgoingRequestHooks{} +} + +// Register registers an extension to process outgoing requests +func (orh *OutgoingRequestHooks) Register(hook graphsync.OnOutgoingRequestHook) graphsync.UnregisterHookFunc { + orh.hooksLk.Lock() + rh := requestHook{orh.nextKey, hook} + orh.nextKey++ + orh.hooks = append(orh.hooks, rh) + orh.hooksLk.Unlock() + return func() { + orh.hooksLk.Lock() + defer orh.hooksLk.Unlock() + for i, matchHook := range orh.hooks { + if rh.key == matchHook.key { + orh.hooks = append(orh.hooks[:i], orh.hooks[i+1:]...) + return + } + } + } +} + +// RequestResult is the outcome of running requesthooks +type RequestResult struct { + PersistenceOption string + CustomChooser traversal.NodeBuilderChooser +} + +// ProcessRequestHooks runs request hooks against an outgoing request +func (orh *OutgoingRequestHooks) ProcessRequestHooks(p peer.ID, request graphsync.RequestData) RequestResult { + orh.hooksLk.RLock() + defer orh.hooksLk.RUnlock() + rha := &requestHookActions{} + for _, requestHook := range orh.hooks { + requestHook.hook(p, request, rha) + } + return rha.result() +} + +type requestHookActions struct { + persistenceOption string + nodeBuilderChooser traversal.NodeBuilderChooser +} + +func (rha *requestHookActions) result() RequestResult { + return RequestResult{ + PersistenceOption: rha.persistenceOption, + CustomChooser: rha.nodeBuilderChooser, + } +} + +func (rha *requestHookActions) UsePersistenceOption(name string) { + rha.persistenceOption = name +} + +func (rha *requestHookActions) UseNodeBuilderChooser(nodeBuilderChooser traversal.NodeBuilderChooser) { + rha.nodeBuilderChooser = nodeBuilderChooser +} diff --git a/requestmanager/hooks/responsehooks.go b/requestmanager/hooks/responsehooks.go new file mode 100644 index 00000000..fad43068 --- /dev/null +++ b/requestmanager/hooks/responsehooks.go @@ -0,0 +1,89 @@ +package hooks + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/peer" + + "github.com/ipfs/go-graphsync" +) + +type responseHook struct { + key uint64 + hook graphsync.OnIncomingResponseHook +} + +// IncomingResponseHooks is a set of incoming response hooks that can be processed +type IncomingResponseHooks struct { + nextKey uint64 + hooksLk sync.RWMutex + hooks []responseHook +} + +// NewResponseHooks returns a new list of incoming request hooks +func NewResponseHooks() *IncomingResponseHooks { + return &IncomingResponseHooks{} +} + +// Register registers an extension to process incoming responses +func (irh *IncomingResponseHooks) Register(hook graphsync.OnIncomingResponseHook) graphsync.UnregisterHookFunc { + irh.hooksLk.Lock() + rh := responseHook{irh.nextKey, hook} + irh.nextKey++ + irh.hooks = append(irh.hooks, rh) + irh.hooksLk.Unlock() + return func() { + irh.hooksLk.Lock() + defer irh.hooksLk.Unlock() + for i, matchHook := range irh.hooks { + if rh.key == matchHook.key { + irh.hooks = append(irh.hooks[:i], irh.hooks[i+1:]...) + return + } + } + } +} + +// ResponseResult is the outcome of running response hooks +type ResponseResult struct { + Err error + Extensions []graphsync.ExtensionData +} + +// ProcessResponseHooks runs response hooks against an incoming response +func (irh *IncomingResponseHooks) ProcessResponseHooks(p peer.ID, response graphsync.ResponseData) ResponseResult { + irh.hooksLk.Lock() + defer irh.hooksLk.Unlock() + rha := &responseHookActions{} + for _, responseHooks := range irh.hooks { + responseHooks.hook(p, response, rha) + if rha.hasError() { + break + } + } + return rha.result() +} + +type responseHookActions struct { + err error + extensions []graphsync.ExtensionData +} + +func (rha *responseHookActions) result() ResponseResult { + return ResponseResult{ + Err: rha.err, + Extensions: rha.extensions, + } +} + +func (rha *responseHookActions) hasError() bool { + return rha.err != nil +} + +func (rha *responseHookActions) TerminateWithError(err error) { + rha.err = err +} + +func (rha *responseHookActions) UpdateRequestWithExtensions(extensions ...graphsync.ExtensionData) { + rha.extensions = append(rha.extensions, extensions...) +} diff --git a/requestmanager/requestmanager.go b/requestmanager/requestmanager.go index 4bb82b83..110ca2aa 100644 --- a/requestmanager/requestmanager.go +++ b/requestmanager/requestmanager.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/ipfs/go-graphsync/requestmanager/hooks" + blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-graphsync" ipldutil "github.com/ipfs/go-graphsync/ipldutil" @@ -33,16 +35,6 @@ type inProgressRequestStatus struct { networkError chan error } -type responseHook struct { - key uint64 - hook graphsync.OnIncomingResponseHook -} - -type requestHook struct { - key uint64 - hook graphsync.OnOutgoingRequestHook -} - // PeerHandler is an interface that can send requests to peers type PeerHandler interface { SendRequest(p peer.ID, graphSyncRequest gsmsg.GraphSyncRequest) @@ -71,9 +63,8 @@ type RequestManager struct { // dont touch out side of run loop nextRequestID graphsync.RequestID inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus - hooksNextKey uint64 - responseHooks []responseHook - requestHooks []requestHook + requestHooks *hooks.OutgoingRequestHooks + responseHooks *hooks.IncomingResponseHooks } type requestManagerMessage interface { @@ -81,7 +72,10 @@ type requestManagerMessage interface { } // New generates a new request manager from a context, network, and selectorQuerier -func New(ctx context.Context, asyncLoader AsyncLoader) *RequestManager { +func New(ctx context.Context, + asyncLoader AsyncLoader, + requestHooks *hooks.OutgoingRequestHooks, + responseHooks *hooks.IncomingResponseHooks) *RequestManager { ctx, cancel := context.WithCancel(ctx) return &RequestManager{ ctx: ctx, @@ -90,6 +84,8 @@ func New(ctx context.Context, asyncLoader AsyncLoader) *RequestManager { rc: newResponseCollector(ctx), messages: make(chan requestManagerMessage, 16), inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus), + requestHooks: requestHooks, + responseHooks: responseHooks, } } @@ -209,49 +205,6 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn } } -type registerRequestHookMessage struct { - hook graphsync.OnOutgoingRequestHook - unregisterHookChan chan graphsync.UnregisterHookFunc -} - -type registerResponseHookMessage struct { - hook graphsync.OnIncomingResponseHook - unregisterHookChan chan graphsync.UnregisterHookFunc -} - -// RegisterRequestHook registers an extension to process outgoing requests -func (rm *RequestManager) RegisterRequestHook(hook graphsync.OnOutgoingRequestHook) graphsync.UnregisterHookFunc { - response := make(chan graphsync.UnregisterHookFunc) - select { - case rm.messages <- ®isterRequestHookMessage{hook, response}: - case <-rm.ctx.Done(): - return nil - } - select { - case unregister := <-response: - return unregister - case <-rm.ctx.Done(): - return nil - } -} - -// RegisterResponseHook registers an extension to process incoming responses -func (rm *RequestManager) RegisterResponseHook( - hook graphsync.OnIncomingResponseHook) graphsync.UnregisterHookFunc { - response := make(chan graphsync.UnregisterHookFunc) - select { - case rm.messages <- ®isterResponseHookMessage{hook, response}: - case <-rm.ctx.Done(): - return nil - } - select { - case unregister := <-response: - return unregister - case <-rm.ctx.Done(): - return nil - } -} - // Startup starts processing for the WantManager. func (rm *RequestManager) Startup() { go rm.run() @@ -327,40 +280,6 @@ func (prm *processResponseMessage) handle(rm *RequestManager) { rm.processTerminations(filteredResponses) } -func (rhm *registerRequestHookMessage) handle(rm *RequestManager) { - rh := requestHook{rm.hooksNextKey, rhm.hook} - rm.hooksNextKey++ - rm.requestHooks = append(rm.requestHooks, rh) - select { - case rhm.unregisterHookChan <- func() { - for i, matchHook := range rm.requestHooks { - if rh.key == matchHook.key { - rm.requestHooks = append(rm.requestHooks[:i], rm.requestHooks[i+1:]...) - return - } - } - }: - case <-rm.ctx.Done(): - } -} - -func (rhm *registerResponseHookMessage) handle(rm *RequestManager) { - rh := responseHook{rm.hooksNextKey, rhm.hook} - rm.hooksNextKey++ - rm.responseHooks = append(rm.responseHooks, rh) - select { - case rhm.unregisterHookChan <- func() { - for i, matchHook := range rm.responseHooks { - if rh.key == matchHook.key { - rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...) - return - } - } - }: - case <-rm.ctx.Done(): - } -} - func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse { responsesForPeer := make([]gsmsg.GraphSyncResponse, 0, len(responses)) for _, response := range responses { @@ -385,18 +304,20 @@ func (rm *RequestManager) processExtensions(responses []gsmsg.GraphSyncResponse, } func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg.GraphSyncResponse) bool { - for _, responseHook := range rm.responseHooks { - err := responseHook.hook(p, response) - if err != nil { - requestStatus := rm.inProgressRequestStatuses[response.RequestID()] - responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown) - select { - case requestStatus.networkError <- responseError: - case <-requestStatus.ctx.Done(): - } - requestStatus.cancelFn() - return false + result := rm.responseHooks.ProcessResponseHooks(p, response) + if len(result.Extensions) > 0 { + updateRequest := gsmsg.UpdateRequest(response.RequestID(), result.Extensions...) + rm.peerHandler.SendRequest(p, updateRequest) + } + if result.Err != nil { + requestStatus := rm.inProgressRequestStatuses[response.RequestID()] + responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown) + select { + case requestStatus.networkError <- responseError: + case <-requestStatus.ctx.Done(): } + requestStatus.cancelFn() + return false } return true } @@ -434,19 +355,6 @@ func (rm *RequestManager) generateResponseErrorFromStatus(status graphsync.Respo } } -type hookActions struct { - persistenceOption string - nodeBuilderChooser traversal.NodeBuilderChooser -} - -func (ha *hookActions) UsePersistenceOption(name string) { - ha.persistenceOption = name -} - -func (ha *hookActions) UseNodeBuilderChooser(nodeBuilderChooser traversal.NodeBuilderChooser) { - ha.nodeBuilderChooser = nodeBuilderChooser -} - func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node, extensions []graphsync.ExtensionData) (chan graphsync.ResponseProgress, chan error) { _, err := ipldutil.EncodeNode(selectorSpec) if err != nil { @@ -466,16 +374,13 @@ func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, ctx, cancel, p, networkErrorChan, } request := gsmsg.NewRequest(requestID, asCidLink.Cid, selectorSpec, defaultPriority, extensions...) - ha := &hookActions{} - for _, hook := range rm.requestHooks { - hook.hook(p, request, ha) - } - err = rm.asyncLoader.StartRequest(requestID, ha.persistenceOption) + hooksResult := rm.requestHooks.ProcessRequestHooks(p, request) + err = rm.asyncLoader.StartRequest(requestID, hooksResult.PersistenceOption) if err != nil { return rm.singleErrorResponse(err) } rm.peerHandler.SendRequest(p, request) - return rm.executeTraversal(ctx, requestID, root, selector, ha.nodeBuilderChooser, networkErrorChan) + return rm.executeTraversal(ctx, requestID, root, selector, hooksResult.CustomChooser, networkErrorChan) } func (rm *RequestManager) executeTraversal( diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 85328a4f..053c9e5e 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/requestmanager/types" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" @@ -185,7 +186,9 @@ func TestNormalSimultaneousFetch(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -272,7 +275,9 @@ func TestCancelRequestInProgress(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() requestCtx, cancel := context.WithTimeout(ctx, time.Second) @@ -331,7 +336,9 @@ func TestCancelManagerExitsGracefully(t *testing.T) { ctx := context.Background() managerCtx, managerCancel := context.WithCancel(ctx) fal := newFakeAsyncLoader() - requestManager := New(managerCtx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(managerCtx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() requestCtx, cancel := context.WithTimeout(ctx, time.Second) @@ -372,7 +379,9 @@ func TestUnencodableSelector(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -393,7 +402,9 @@ func TestFailedRequest(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -422,7 +433,9 @@ func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -458,7 +471,9 @@ func TestLocallyFulfilledFirstRequestSucceedsLater(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -493,7 +508,9 @@ func TestRequestReturnsMissingBlocks(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -526,7 +543,9 @@ func TestEncodingExtensions(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -553,13 +572,21 @@ func TestEncodingExtensions(t *testing.T) { expectedError := make(chan error, 2) receivedExtensionData := make(chan []byte, 2) - hook := func(p peer.ID, responseData graphsync.ResponseData) error { + expectedUpdateChan := make(chan []graphsync.ExtensionData, 2) + hook := func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { data, has := responseData.Extension(extensionName1) require.True(t, has, "did not receive extension data in response") receivedExtensionData <- data - return <-expectedError + err := <-expectedError + if err != nil { + hookActions.TerminateWithError(err) + } + update := <-expectedUpdateChan + if len(update) > 0 { + hookActions.UpdateRequestWithExtensions(update...) + } } - requestManager.RegisterResponseHook(hook) + responseHooks.Register(hook) returnedResponseChan, returnedErrorChan := requestManager.SendRequest(requestCtx, peers[0], blockChain.TipLink, blockChain.Selector(), extension1, extension2) rr := readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0] @@ -575,6 +602,7 @@ func TestEncodingExtensions(t *testing.T) { t.Run("responding to extensions", func(t *testing.T) { expectedData := testutil.RandomBytes(100) + expectedUpdate := testutil.RandomBytes(100) firstResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(gsr.ID(), graphsync.PartialResponse, graphsync.ExtensionData{ @@ -588,11 +616,25 @@ func TestEncodingExtensions(t *testing.T) { ), } expectedError <- nil + expectedUpdateChan <- []graphsync.ExtensionData{ + { + Name: extensionName1, + Data: expectedUpdate, + }, + } requestManager.ProcessResponses(peers[0], firstResponses, nil) var received []byte testutil.AssertReceive(ctx, t, receivedExtensionData, &received, "did not receive extension data") require.Equal(t, expectedData, received, "did not receive correct extension data from resposne") + + rr = readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0] + receivedUpdateData, has := rr.gsr.Extension(extensionName1) + require.True(t, has) + require.Equal(t, expectedUpdate, receivedUpdateData, "should have updated with correct extension") + nextExpectedData := testutil.RandomBytes(100) + nextExpectedUpdate1 := testutil.RandomBytes(100) + nextExpectedUpdate2 := testutil.RandomBytes(100) secondResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(gsr.ID(), @@ -607,9 +649,28 @@ func TestEncodingExtensions(t *testing.T) { ), } expectedError <- errors.New("a terrible thing happened") + expectedUpdateChan <- []graphsync.ExtensionData{ + { + Name: extensionName1, + Data: nextExpectedUpdate1, + }, + { + Name: extensionName2, + Data: nextExpectedUpdate2, + }, + } requestManager.ProcessResponses(peers[0], secondResponses, nil) testutil.AssertReceive(ctx, t, receivedExtensionData, &received, "did not receive extension data") require.Equal(t, nextExpectedData, received, "did not receive correct extension data from resposne") + + rr = readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0] + receivedUpdateData, has = rr.gsr.Extension(extensionName1) + require.True(t, has) + require.Equal(t, nextExpectedUpdate1, receivedUpdateData, "should have updated with correct extension") + receivedUpdateData, has = rr.gsr.Extension(extensionName2) + require.True(t, has) + require.Equal(t, nextExpectedUpdate2, receivedUpdateData, "should have updated with correct extension") + testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) }) @@ -620,7 +681,9 @@ func TestOutgoingRequestHooks(t *testing.T) { fph := &fakePeerHandler{requestRecordChan} ctx := context.Background() fal := newFakeAsyncLoader() - requestManager := New(ctx, fal) + requestHooks := hooks.NewRequestHooks() + responseHooks := hooks.NewResponseHooks() + requestManager := New(ctx, fal, requestHooks, responseHooks) requestManager.SetDelegate(fph) requestManager.Startup() @@ -645,7 +708,7 @@ func TestOutgoingRequestHooks(t *testing.T) { ha.UsePersistenceOption("chainstore") } } - requestManager.RegisterRequestHook(hook) + requestHooks.Register(hook) returnedResponseChan1, returnedErrorChan1 := requestManager.SendRequest(requestCtx, peers[0], blockChain.TipLink, blockChain.Selector(), extension1) returnedResponseChan2, returnedErrorChan2 := requestManager.SendRequest(requestCtx, peers[0], blockChain.TipLink, blockChain.Selector()) diff --git a/responsemanager/blockhooks/blookhooks_test.go b/responsemanager/blockhooks/blookhooks_test.go deleted file mode 100644 index f9dcfddc..00000000 --- a/responsemanager/blockhooks/blookhooks_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package blockhooks_test - -import ( - "errors" - "math/rand" - "testing" - - "github.com/ipfs/go-graphsync" - gsmsg "github.com/ipfs/go-graphsync/message" - "github.com/ipfs/go-graphsync/responsemanager/blockhooks" - "github.com/ipfs/go-graphsync/testutil" - "github.com/ipld/go-ipld-prime" - ipldfree "github.com/ipld/go-ipld-prime/impl/free" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - peer "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/require" -) - -type fakeBlkData struct { - link ipld.Link - size uint64 -} - -func (fbd fakeBlkData) Link() ipld.Link { - return fbd.link -} - -func (fbd fakeBlkData) BlockSize() uint64 { - return fbd.size -} - -func (fbd fakeBlkData) BlockSizeOnWire() uint64 { - return fbd.size -} - -func TestBlockHookProcessing(t *testing.T) { - extensionData := testutil.RandomBytes(100) - extensionName := graphsync.ExtensionName("AppleSauce/McGee") - extension := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionData, - } - extensionResponseData := testutil.RandomBytes(100) - extensionResponse := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionResponseData, - } - - root := testutil.GenerateCids(1)[0] - requestID := graphsync.RequestID(rand.Int31()) - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) - p := testutil.GeneratePeers(1)[0] - blockData := &fakeBlkData{ - link: cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, - size: rand.Uint64(), - } - testCases := map[string]struct { - configure func(t *testing.T, blockHooks *blockhooks.OutgoingBlockHooks) - assert func(t *testing.T, result blockhooks.Result) - }{ - "no hooks": { - assert: func(t *testing.T, result blockhooks.Result) { - require.Empty(t, result.Extensions) - require.NoError(t, result.Err) - }, - }, - "send extension data": { - configure: func(t *testing.T, blockHooks *blockhooks.OutgoingBlockHooks) { - blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - hookActions.SendExtensionData(extensionResponse) - }) - }, - assert: func(t *testing.T, result blockhooks.Result) { - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.NoError(t, result.Err) - }, - }, - "terminate with error": { - configure: func(t *testing.T, blockHooks *blockhooks.OutgoingBlockHooks) { - blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - hookActions.TerminateWithError(errors.New("failed")) - }) - }, - assert: func(t *testing.T, result blockhooks.Result) { - require.Empty(t, result.Extensions) - require.EqualError(t, result.Err, "failed") - }, - }, - "pause response": { - configure: func(t *testing.T, blockHooks *blockhooks.OutgoingBlockHooks) { - blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - hookActions.PauseResponse() - }) - }, - assert: func(t *testing.T, result blockhooks.Result) { - require.Empty(t, result.Extensions) - require.EqualError(t, result.Err, blockhooks.ErrPaused.Error()) - }, - }, - } - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - blockHooks := blockhooks.New() - if data.configure != nil { - data.configure(t, blockHooks) - } - result := blockHooks.ProcessBlockHooks(p, request, blockData) - if data.assert != nil { - data.assert(t, result) - } - }) - } -} - -/* - - t.Run("test block hook processing", func(t *testing.T) { - t.Run("can send extension data", func(t *testing.T) { - td := newTestData(t) - defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) - responseManager.Startup() - td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - }) - td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - hookActions.SendExtensionData(td.extensionResponse) - }) - responseManager.ProcessRequests(td.ctx, td.p, td.requests) - var lastRequest completedRequest - testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") - require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") - for i := 0; i < td.blockChainLength; i++ { - var receivedExtension sentExtension - testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response") - require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent") - } - }) - - t.Run("can send errors", func(t *testing.T) { - td := newTestData(t) - defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) - responseManager.Startup() - td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - }) - td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - hookActions.TerminateWithError(errors.New("failed")) - }) - responseManager.ProcessRequests(td.ctx, td.p, td.requests) - var lastRequest completedRequest - testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") - require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should succeed") - }) - - t.Run("can pause/unpause", func(t *testing.T) { - td := newTestData(t) - defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) - responseManager.Startup() - td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - }) - blkIndex := 1 - blockCount := 3 - var hasPaused bool - td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - if blkIndex >= blockCount && !hasPaused { - hookActions.PauseResponse() - hasPaused = true - } - blkIndex++ - }) - responseManager.ProcessRequests(td.ctx, td.p, td.requests) - timer := time.NewTimer(500 * time.Millisecond) - testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) - var sentResponses []sentResponse - nomoreresponses: - for { - select { - case sentResponse := <-td.sentResponses: - sentResponses = append(sentResponses, sentResponse) - default: - break nomoreresponses - } - } - require.LessOrEqual(t, len(sentResponses), blockCount) - err := responseManager.UnpauseResponse(td.p, td.requestID) - require.NoError(t, err) - var lastRequest completedRequest - testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") - require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") - }) - - }) -*/ diff --git a/responsemanager/blockhooks/blockhooks.go b/responsemanager/hooks/blockhooks.go similarity index 63% rename from responsemanager/blockhooks/blockhooks.go rename to responsemanager/hooks/blockhooks.go index 713fbe49..13b15060 100644 --- a/responsemanager/blockhooks/blockhooks.go +++ b/responsemanager/hooks/blockhooks.go @@ -1,4 +1,4 @@ -package blockhooks +package hooks import ( "errors" @@ -18,47 +18,47 @@ type blockHook struct { // OutgoingBlockHooks is a set of outgoing block hooks that can be processed type OutgoingBlockHooks struct { - blockHooksLk sync.RWMutex - blockHooksNextKey uint64 - blockHooks []blockHook + hooksLk sync.RWMutex + nextKey uint64 + hooks []blockHook } -// New returns a new list of outgoing block hooks -func New() *OutgoingBlockHooks { +// NewBlockHooks returns a new list of outgoing block hooks +func NewBlockHooks() *OutgoingBlockHooks { return &OutgoingBlockHooks{} } // Register registers an hook to process outgoing blocks in a response func (obh *OutgoingBlockHooks) Register(hook graphsync.OnOutgoingBlockHook) graphsync.UnregisterHookFunc { - obh.blockHooksLk.Lock() - bh := blockHook{obh.blockHooksNextKey, hook} - obh.blockHooksNextKey++ - obh.blockHooks = append(obh.blockHooks, bh) - obh.blockHooksLk.Unlock() + obh.hooksLk.Lock() + bh := blockHook{obh.nextKey, hook} + obh.nextKey++ + obh.hooks = append(obh.hooks, bh) + obh.hooksLk.Unlock() return func() { - obh.blockHooksLk.Lock() - defer obh.blockHooksLk.Unlock() - for i, matchHook := range obh.blockHooks { + obh.hooksLk.Lock() + defer obh.hooksLk.Unlock() + for i, matchHook := range obh.hooks { if bh.key == matchHook.key { - obh.blockHooks = append(obh.blockHooks[:i], obh.blockHooks[i+1:]...) + obh.hooks = append(obh.hooks[:i], obh.hooks[i+1:]...) return } } } } -// Result is the result of processing block hooks -type Result struct { +// BlockResult is the result of processing block hooks +type BlockResult struct { Err error Extensions []graphsync.ExtensionData } // ProcessBlockHooks runs block hooks against a request and block data -func (obh *OutgoingBlockHooks) ProcessBlockHooks(p peer.ID, request graphsync.RequestData, blockData graphsync.BlockData) Result { - obh.blockHooksLk.RLock() - defer obh.blockHooksLk.RUnlock() +func (obh *OutgoingBlockHooks) ProcessBlockHooks(p peer.ID, request graphsync.RequestData, blockData graphsync.BlockData) BlockResult { + obh.hooksLk.RLock() + defer obh.hooksLk.RUnlock() bha := &blockHookActions{} - for _, bh := range obh.blockHooks { + for _, bh := range obh.hooks { bh.hook(p, request, blockData, bha) if bha.hasError() { break @@ -76,8 +76,8 @@ func (bha *blockHookActions) hasError() bool { return bha.err != nil } -func (bha *blockHookActions) result() Result { - return Result{bha.err, bha.extensions} +func (bha *blockHookActions) result() BlockResult { + return BlockResult{bha.err, bha.extensions} } func (bha *blockHookActions) SendExtensionData(data graphsync.ExtensionData) { diff --git a/responsemanager/hooks/hooks_test.go b/responsemanager/hooks/hooks_test.go new file mode 100644 index 00000000..ecb397d7 --- /dev/null +++ b/responsemanager/hooks/hooks_test.go @@ -0,0 +1,395 @@ +package hooks_test + +import ( + "errors" + "io" + "math/rand" + "testing" + + "github.com/ipfs/go-graphsync" + gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/responsemanager/hooks" + "github.com/ipfs/go-graphsync/testutil" + "github.com/ipld/go-ipld-prime" + ipldfree "github.com/ipld/go-ipld-prime/impl/free" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" + peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +type fakePersistenceOptions struct { + po map[string]ipld.Loader +} + +func (fpo *fakePersistenceOptions) GetLoader(name string) (ipld.Loader, bool) { + loader, ok := fpo.po[name] + return loader, ok +} + +func TestRequestHookProcessing(t *testing.T) { + fakeChooser := func(ipld.Link, ipld.LinkContext) (ipld.NodeBuilder, error) { + return ipldfree.NodeBuilder(), nil + } + fakeLoader := func(link ipld.Link, lnkCtx ipld.LinkContext) (io.Reader, error) { + return nil, nil + } + fpo := &fakePersistenceOptions{ + po: map[string]ipld.Loader{ + "chainstore": fakeLoader, + }, + } + extensionData := testutil.RandomBytes(100) + extensionName := graphsync.ExtensionName("AppleSauce/McGee") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionData, + } + extensionResponseData := testutil.RandomBytes(100) + extensionResponse := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionResponseData, + } + + root := testutil.GenerateCids(1)[0] + requestID := graphsync.RequestID(rand.Int31()) + ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) + request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) + p := testutil.GeneratePeers(1)[0] + testCases := map[string]struct { + configure func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) + assert func(t *testing.T, result hooks.RequestResult) + }{ + "no hooks": { + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Empty(t, result.Extensions) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + "sending extension data, no validation": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.SendExtensionData(extensionResponse) + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + "sending extension data, with validation": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + hookActions.SendExtensionData(extensionResponse) + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.True(t, result.IsValidated) + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + "short circuit on error": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.TerminateWithError(errors.New("something went wrong")) + }) + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + hookActions.SendExtensionData(extensionResponse) + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Empty(t, result.Extensions) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.EqualError(t, result.Err, "something went wrong") + }, + }, + "hooks unregistered": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + unregister := requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + hookActions.SendExtensionData(extensionResponse) + }) + unregister() + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Empty(t, result.Extensions) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + "hooks alter the loader": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UsePersistenceOption("chainstore") + hookActions.SendExtensionData(extensionResponse) + } + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.Nil(t, result.CustomChooser) + require.NotNil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + "hooks alter to non-existent loader": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UsePersistenceOption("applesauce") + hookActions.SendExtensionData(extensionResponse) + } + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.Nil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.EqualError(t, result.Err, "unknown loader option") + }, + }, + "hooks alter the node builder chooser": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + if _, found := requestData.Extension(extensionName); found { + hookActions.UseNodeBuilderChooser(fakeChooser) + hookActions.SendExtensionData(extensionResponse) + } + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.False(t, result.IsValidated) + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.NotNil(t, result.CustomChooser) + require.Nil(t, result.CustomLoader) + require.NoError(t, result.Err) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + requestHooks := hooks.NewRequestHooks(fpo) + if data.configure != nil { + data.configure(t, requestHooks) + } + result := requestHooks.ProcessRequestHooks(p, request) + if data.assert != nil { + data.assert(t, result) + } + }) + } +} + +type fakeBlkData struct { + link ipld.Link + size uint64 +} + +func (fbd fakeBlkData) Link() ipld.Link { + return fbd.link +} + +func (fbd fakeBlkData) BlockSize() uint64 { + return fbd.size +} + +func (fbd fakeBlkData) BlockSizeOnWire() uint64 { + return fbd.size +} + +func TestBlockHookProcessing(t *testing.T) { + extensionData := testutil.RandomBytes(100) + extensionName := graphsync.ExtensionName("AppleSauce/McGee") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionData, + } + extensionResponseData := testutil.RandomBytes(100) + extensionResponse := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionResponseData, + } + + root := testutil.GenerateCids(1)[0] + requestID := graphsync.RequestID(rand.Int31()) + ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) + request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) + p := testutil.GeneratePeers(1)[0] + blockData := &fakeBlkData{ + link: cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, + size: rand.Uint64(), + } + testCases := map[string]struct { + configure func(t *testing.T, blockHooks *hooks.OutgoingBlockHooks) + assert func(t *testing.T, result hooks.BlockResult) + }{ + "no hooks": { + assert: func(t *testing.T, result hooks.BlockResult) { + require.Empty(t, result.Extensions) + require.NoError(t, result.Err) + }, + }, + "send extension data": { + configure: func(t *testing.T, blockHooks *hooks.OutgoingBlockHooks) { + blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + hookActions.SendExtensionData(extensionResponse) + }) + }, + assert: func(t *testing.T, result hooks.BlockResult) { + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.NoError(t, result.Err) + }, + }, + "terminate with error": { + configure: func(t *testing.T, blockHooks *hooks.OutgoingBlockHooks) { + blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + hookActions.TerminateWithError(errors.New("failed")) + }) + }, + assert: func(t *testing.T, result hooks.BlockResult) { + require.Empty(t, result.Extensions) + require.EqualError(t, result.Err, "failed") + }, + }, + "pause response": { + configure: func(t *testing.T, blockHooks *hooks.OutgoingBlockHooks) { + blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + hookActions.PauseResponse() + }) + }, + assert: func(t *testing.T, result hooks.BlockResult) { + require.Empty(t, result.Extensions) + require.EqualError(t, result.Err, hooks.ErrPaused.Error()) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + blockHooks := hooks.NewBlockHooks() + if data.configure != nil { + data.configure(t, blockHooks) + } + result := blockHooks.ProcessBlockHooks(p, request, blockData) + if data.assert != nil { + data.assert(t, result) + } + }) + } +} + +func TestUpdateHookProcessing(t *testing.T) { + extensionData := testutil.RandomBytes(100) + extensionName := graphsync.ExtensionName("AppleSauce/McGee") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionData, + } + extensionUpdateData := testutil.RandomBytes(100) + extensionUpdate := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionUpdateData, + } + extensionResponseData := testutil.RandomBytes(100) + extensionResponse := graphsync.ExtensionData{ + Name: extensionName, + Data: extensionResponseData, + } + + root := testutil.GenerateCids(1)[0] + requestID := graphsync.RequestID(rand.Int31()) + ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) + request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) + update := gsmsg.UpdateRequest(requestID, extensionUpdate) + p := testutil.GeneratePeers(1)[0] + testCases := map[string]struct { + configure func(t *testing.T, updateHooks *hooks.RequestUpdatedHooks) + assert func(t *testing.T, result hooks.UpdateResult) + }{ + "no hooks": { + assert: func(t *testing.T, result hooks.UpdateResult) { + require.Empty(t, result.Extensions) + require.NoError(t, result.Err) + require.False(t, result.Unpause) + }, + }, + "send extension data": { + configure: func(t *testing.T, updateHooks *hooks.RequestUpdatedHooks) { + updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + _, found := requestData.Extension(extensionName) + _, updateFound := updateData.Extension(extensionName) + if found && updateFound { + hookActions.SendExtensionData(extensionResponse) + } + }) + }, + assert: func(t *testing.T, result hooks.UpdateResult) { + require.Len(t, result.Extensions, 1) + require.Contains(t, result.Extensions, extensionResponse) + require.NoError(t, result.Err) + require.False(t, result.Unpause) + + }, + }, + "terminate with error": { + configure: func(t *testing.T, updateHooks *hooks.RequestUpdatedHooks) { + updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + hookActions.TerminateWithError(errors.New("failed")) + }) + }, + assert: func(t *testing.T, result hooks.UpdateResult) { + require.Empty(t, result.Extensions) + require.EqualError(t, result.Err, "failed") + require.False(t, result.Unpause) + + }, + }, + "unpause response": { + configure: func(t *testing.T, updateHooks *hooks.RequestUpdatedHooks) { + updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + hookActions.UnpauseResponse() + }) + }, + assert: func(t *testing.T, result hooks.UpdateResult) { + require.Empty(t, result.Extensions) + require.NoError(t, result.Err) + require.True(t, result.Unpause) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + updateHooks := hooks.NewUpdateHooks() + if data.configure != nil { + data.configure(t, updateHooks) + } + result := updateHooks.ProcessUpdateHooks(p, request, update) + if data.assert != nil { + data.assert(t, result) + } + }) + } +} diff --git a/responsemanager/requesthooks/requesthooks.go b/responsemanager/hooks/requesthook.go similarity index 60% rename from responsemanager/requesthooks/requesthooks.go rename to responsemanager/hooks/requesthook.go index f45ed99a..27669b30 100644 --- a/responsemanager/requesthooks/requesthooks.go +++ b/responsemanager/hooks/requesthook.go @@ -1,4 +1,4 @@ -package requesthooks +package hooks import ( "errors" @@ -23,13 +23,13 @@ type PersistenceOptions interface { // IncomingRequestHooks is a set of incoming request hooks that can be processed type IncomingRequestHooks struct { persistenceOptions PersistenceOptions - requestHooksLk sync.RWMutex - requestHookNextKey uint64 - requestHooks []requestHook + hooksLk sync.RWMutex + nextKey uint64 + hooks []requestHook } -// New returns a new list of incoming request hooks -func New(persistenceOptions PersistenceOptions) *IncomingRequestHooks { +// NewRequestHooks returns a new list of incoming request hooks +func NewRequestHooks(persistenceOptions PersistenceOptions) *IncomingRequestHooks { return &IncomingRequestHooks{ persistenceOptions: persistenceOptions, } @@ -37,25 +37,25 @@ func New(persistenceOptions PersistenceOptions) *IncomingRequestHooks { // Register registers an extension to process new incoming requests func (irh *IncomingRequestHooks) Register(hook graphsync.OnIncomingRequestHook) graphsync.UnregisterHookFunc { - irh.requestHooksLk.Lock() - rh := requestHook{irh.requestHookNextKey, hook} - irh.requestHookNextKey++ - irh.requestHooks = append(irh.requestHooks, rh) - irh.requestHooksLk.Unlock() + irh.hooksLk.Lock() + rh := requestHook{irh.nextKey, hook} + irh.nextKey++ + irh.hooks = append(irh.hooks, rh) + irh.hooksLk.Unlock() return func() { - irh.requestHooksLk.Lock() - defer irh.requestHooksLk.Unlock() - for i, matchHook := range irh.requestHooks { + irh.hooksLk.Lock() + defer irh.hooksLk.Unlock() + for i, matchHook := range irh.hooks { if rh.key == matchHook.key { - irh.requestHooks = append(irh.requestHooks[:i], irh.requestHooks[i+1:]...) + irh.hooks = append(irh.hooks[:i], irh.hooks[i+1:]...) return } } } } -// Result is the outcome of running requesthooks -type Result struct { +// RequestResult is the outcome of running requesthooks +type RequestResult struct { IsValidated bool CustomLoader ipld.Loader CustomChooser traversal.NodeBuilderChooser @@ -64,13 +64,13 @@ type Result struct { } // ProcessRequestHooks runs request hooks against an incoming request -func (irh *IncomingRequestHooks) ProcessRequestHooks(p peer.ID, request graphsync.RequestData) Result { - irh.requestHooksLk.RLock() - defer irh.requestHooksLk.RUnlock() - ha := &hookActions{ +func (irh *IncomingRequestHooks) ProcessRequestHooks(p peer.ID, request graphsync.RequestData) RequestResult { + irh.hooksLk.RLock() + defer irh.hooksLk.RUnlock() + ha := &requestHookActions{ persistenceOptions: irh.persistenceOptions, } - for _, requestHook := range irh.requestHooks { + for _, requestHook := range irh.hooks { requestHook.hook(p, request, ha) if ha.hasError() { break @@ -79,7 +79,7 @@ func (irh *IncomingRequestHooks) ProcessRequestHooks(p peer.ID, request graphsyn return ha.result() } -type hookActions struct { +type requestHookActions struct { persistenceOptions PersistenceOptions isValidated bool err error @@ -88,12 +88,12 @@ type hookActions struct { extensions []graphsync.ExtensionData } -func (ha *hookActions) hasError() bool { +func (ha *requestHookActions) hasError() bool { return ha.err != nil } -func (ha *hookActions) result() Result { - return Result{ +func (ha *requestHookActions) result() RequestResult { + return RequestResult{ IsValidated: ha.isValidated, CustomLoader: ha.loader, CustomChooser: ha.chooser, @@ -102,19 +102,19 @@ func (ha *hookActions) result() Result { } } -func (ha *hookActions) SendExtensionData(ext graphsync.ExtensionData) { +func (ha *requestHookActions) SendExtensionData(ext graphsync.ExtensionData) { ha.extensions = append(ha.extensions, ext) } -func (ha *hookActions) TerminateWithError(err error) { +func (ha *requestHookActions) TerminateWithError(err error) { ha.err = err } -func (ha *hookActions) ValidateRequest() { +func (ha *requestHookActions) ValidateRequest() { ha.isValidated = true } -func (ha *hookActions) UsePersistenceOption(name string) { +func (ha *requestHookActions) UsePersistenceOption(name string) { loader, ok := ha.persistenceOptions.GetLoader(name) if !ok { ha.TerminateWithError(errors.New("unknown loader option")) @@ -123,6 +123,6 @@ func (ha *hookActions) UsePersistenceOption(name string) { ha.loader = loader } -func (ha *hookActions) UseNodeBuilderChooser(chooser traversal.NodeBuilderChooser) { +func (ha *requestHookActions) UseNodeBuilderChooser(chooser traversal.NodeBuilderChooser) { ha.chooser = chooser } diff --git a/responsemanager/hooks/requestupdatehooks.go b/responsemanager/hooks/requestupdatehooks.go new file mode 100644 index 00000000..0999a09f --- /dev/null +++ b/responsemanager/hooks/requestupdatehooks.go @@ -0,0 +1,91 @@ +package hooks + +import ( + "sync" + + "github.com/ipfs/go-graphsync" + peer "github.com/libp2p/go-libp2p-core/peer" +) + +type requestUpdatedHook struct { + key uint64 + hook graphsync.OnRequestUpdatedHook +} + +// RequestUpdatedHooks manages and runs hooks for request updates +type RequestUpdatedHooks struct { + nextKey uint64 + hooksLk sync.RWMutex + hooks []requestUpdatedHook +} + +// NewUpdateHooks returns a new list of request updated hooks +func NewUpdateHooks() *RequestUpdatedHooks { + return &RequestUpdatedHooks{} +} + +// Register registers an hook to process updates to requests +func (ruh *RequestUpdatedHooks) Register(hook graphsync.OnRequestUpdatedHook) graphsync.UnregisterHookFunc { + ruh.hooksLk.Lock() + rh := requestUpdatedHook{ruh.nextKey, hook} + ruh.nextKey++ + ruh.hooks = append(ruh.hooks, rh) + ruh.hooksLk.Unlock() + return func() { + ruh.hooksLk.Lock() + defer ruh.hooksLk.Unlock() + for i, matchHook := range ruh.hooks { + if rh.key == matchHook.key { + ruh.hooks = append(ruh.hooks[:i], ruh.hooks[i+1:]...) + return + } + } + } +} + +// UpdateResult is the result of running update hooks +type UpdateResult struct { + Err error + Unpause bool + Extensions []graphsync.ExtensionData +} + +// ProcessUpdateHooks runs request hooks against an incoming request +func (ruh *RequestUpdatedHooks) ProcessUpdateHooks(p peer.ID, request graphsync.RequestData, update graphsync.RequestData) UpdateResult { + ruh.hooksLk.RLock() + defer ruh.hooksLk.RUnlock() + ha := &updateHookActions{} + for _, updateHook := range ruh.hooks { + updateHook.hook(p, request, update, ha) + if ha.hasError() { + break + } + } + return ha.result() +} + +type updateHookActions struct { + err error + unpause bool + extensions []graphsync.ExtensionData +} + +func (uha *updateHookActions) hasError() bool { + return uha.err != nil +} + +func (uha *updateHookActions) result() UpdateResult { + return UpdateResult{uha.err, uha.unpause, uha.extensions} +} + +func (uha *updateHookActions) SendExtensionData(data graphsync.ExtensionData) { + uha.extensions = append(uha.extensions, data) +} + +func (uha *updateHookActions) TerminateWithError(err error) { + uha.err = err +} + +func (uha *updateHookActions) UnpauseResponse() { + uha.unpause = true +} diff --git a/responsemanager/peerresponsemanager/peerresponsesender.go b/responsemanager/peerresponsemanager/peerresponsesender.go index be8d406c..f4fb0a68 100644 --- a/responsemanager/peerresponsemanager/peerresponsesender.go +++ b/responsemanager/peerresponsemanager/peerresponsesender.go @@ -57,6 +57,7 @@ type PeerResponseSender interface { SendExtensionData(graphsync.RequestID, graphsync.ExtensionData) FinishRequest(requestID graphsync.RequestID) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) + PauseRequest(requestID graphsync.RequestID) } // NewResponseSender generates a new PeerResponseSender for the given context, peer ID, @@ -168,9 +169,13 @@ func (prm *peerResponseSender) FinishWithError(requestID graphsync.RequestID, st prm.finish(requestID, status) } +func (prm *peerResponseSender) PauseRequest(requestID graphsync.RequestID) { + prm.finish(requestID, graphsync.RequestPaused) +} + func (prm *peerResponseSender) finish(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) { if prm.buildResponse(0, func(responseBuilder *responsebuilder.ResponseBuilder) { - responseBuilder.AddCompletedRequest(requestID, status) + responseBuilder.AddResponseCode(requestID, status) }) { prm.signalWork() } diff --git a/responsemanager/requesthooks/requesthooks_test.go b/responsemanager/requesthooks/requesthooks_test.go deleted file mode 100644 index 79f1647c..00000000 --- a/responsemanager/requesthooks/requesthooks_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package requesthooks_test - -import ( - "errors" - "io" - "math/rand" - "testing" - - "github.com/ipfs/go-graphsync" - gsmsg "github.com/ipfs/go-graphsync/message" - "github.com/ipfs/go-graphsync/responsemanager/requesthooks" - "github.com/ipfs/go-graphsync/testutil" - "github.com/ipld/go-ipld-prime" - ipldfree "github.com/ipld/go-ipld-prime/impl/free" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - peer "github.com/libp2p/go-libp2p-core/peer" - "github.com/stretchr/testify/require" -) - -type fakePersistenceOptions struct { - po map[string]ipld.Loader -} - -func (fpo *fakePersistenceOptions) GetLoader(name string) (ipld.Loader, bool) { - loader, ok := fpo.po[name] - return loader, ok -} - -func TestRequestHookProcessing(t *testing.T) { - fakeChooser := func(ipld.Link, ipld.LinkContext) (ipld.NodeBuilder, error) { - return ipldfree.NodeBuilder(), nil - } - fakeLoader := func(link ipld.Link, lnkCtx ipld.LinkContext) (io.Reader, error) { - return nil, nil - } - fpo := &fakePersistenceOptions{ - po: map[string]ipld.Loader{ - "chainstore": fakeLoader, - }, - } - extensionData := testutil.RandomBytes(100) - extensionName := graphsync.ExtensionName("AppleSauce/McGee") - extension := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionData, - } - extensionResponseData := testutil.RandomBytes(100) - extensionResponse := graphsync.ExtensionData{ - Name: extensionName, - Data: extensionResponseData, - } - - root := testutil.GenerateCids(1)[0] - requestID := graphsync.RequestID(rand.Int31()) - ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder()) - request := gsmsg.NewRequest(requestID, root, ssb.Matcher().Node(), graphsync.Priority(0), extension) - p := testutil.GeneratePeers(1)[0] - testCases := map[string]struct { - configure func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) - assert func(t *testing.T, result requesthooks.Result) - }{ - "no hooks": { - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Empty(t, result.Extensions) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - "sending extension data, no validation": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.SendExtensionData(extensionResponse) - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - "sending extension data, with validation": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - hookActions.SendExtensionData(extensionResponse) - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.True(t, result.IsValidated) - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - "short circuit on error": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.TerminateWithError(errors.New("something went wrong")) - }) - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - hookActions.SendExtensionData(extensionResponse) - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Empty(t, result.Extensions) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.EqualError(t, result.Err, "something went wrong") - }, - }, - "hooks unregistered": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - unregister := requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - hookActions.ValidateRequest() - hookActions.SendExtensionData(extensionResponse) - }) - unregister() - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Empty(t, result.Extensions) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - "hooks alter the loader": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - if _, found := requestData.Extension(extensionName); found { - hookActions.UsePersistenceOption("chainstore") - hookActions.SendExtensionData(extensionResponse) - } - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.Nil(t, result.CustomChooser) - require.NotNil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - "hooks alter to non-existent loader": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - if _, found := requestData.Extension(extensionName); found { - hookActions.UsePersistenceOption("applesauce") - hookActions.SendExtensionData(extensionResponse) - } - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.Nil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.EqualError(t, result.Err, "unknown loader option") - }, - }, - "hooks alter the node builder chooser": { - configure: func(t *testing.T, requestHooks *requesthooks.IncomingRequestHooks) { - requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - if _, found := requestData.Extension(extensionName); found { - hookActions.UseNodeBuilderChooser(fakeChooser) - hookActions.SendExtensionData(extensionResponse) - } - }) - }, - assert: func(t *testing.T, result requesthooks.Result) { - require.False(t, result.IsValidated) - require.Len(t, result.Extensions, 1) - require.Contains(t, result.Extensions, extensionResponse) - require.NotNil(t, result.CustomChooser) - require.Nil(t, result.CustomLoader) - require.NoError(t, result.Err) - }, - }, - } - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - requestHooks := requesthooks.New(fpo) - if data.configure != nil { - data.configure(t, requestHooks) - } - result := requestHooks.ProcessRequestHooks(p, request) - if data.assert != nil { - data.assert(t, result) - } - }) - } -} diff --git a/responsemanager/responsebuilder/responsebuilder.go b/responsemanager/responsebuilder/responsebuilder.go index 86865ea5..7ca378f6 100644 --- a/responsemanager/responsebuilder/responsebuilder.go +++ b/responsemanager/responsebuilder/responsebuilder.go @@ -50,10 +50,10 @@ func (rb *ResponseBuilder) AddLink(requestID graphsync.RequestID, link ipld.Link rb.outgoingResponses[requestID] = append(rb.outgoingResponses[requestID], metadata.Item{Link: link, BlockPresent: blockPresent}) } -// AddCompletedRequest marks the given request as completed in the response, +// AddResponseCode marks the given request as completed in the response, // as well as whether the graphsync request responded with complete or partial // data. -func (rb *ResponseBuilder) AddCompletedRequest(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) { +func (rb *ResponseBuilder) AddResponseCode(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) { rb.completedResponses[requestID] = status // make sure this completion goes out in next response even if no links are sent _, ok := rb.outgoingResponses[requestID] diff --git a/responsemanager/responsebuilder/responsebuilder_test.go b/responsemanager/responsebuilder/responsebuilder_test.go index f537ba5e..6c880576 100644 --- a/responsemanager/responsebuilder/responsebuilder_test.go +++ b/responsemanager/responsebuilder/responsebuilder_test.go @@ -30,18 +30,18 @@ func TestMessageBuilding(t *testing.T) { rb.AddLink(requestID1, links[1], false) rb.AddLink(requestID1, links[2], true) - rb.AddCompletedRequest(requestID1, graphsync.RequestCompletedPartial) + rb.AddResponseCode(requestID1, graphsync.RequestCompletedPartial) rb.AddLink(requestID2, links[1], true) rb.AddLink(requestID2, links[2], true) rb.AddLink(requestID2, links[1], true) - rb.AddCompletedRequest(requestID2, graphsync.RequestCompletedFull) + rb.AddResponseCode(requestID2, graphsync.RequestCompletedFull) rb.AddLink(requestID3, links[0], true) rb.AddLink(requestID3, links[1], true) - rb.AddCompletedRequest(requestID4, graphsync.RequestCompletedFull) + rb.AddResponseCode(requestID4, graphsync.RequestCompletedFull) for _, block := range blocks { rb.AddBlock(block) diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index d97ea8e3..bddd4110 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -6,13 +6,12 @@ import ( "math" "time" - "github.com/ipfs/go-graphsync/responsemanager/blockhooks" + "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" - "github.com/ipfs/go-graphsync/responsemanager/requesthooks" "github.com/ipfs/go-graphsync/responsemanager/runtraversal" logging "github.com/ipfs/go-log" "github.com/ipfs/go-peertaskqueue/peertask" @@ -29,12 +28,14 @@ const ( ) type inProgressResponseStatus struct { - ctx context.Context - cancelFn func() - request gsmsg.GraphSyncRequest - loader ipld.Loader - traverser ipldutil.Traverser - isPaused bool + ctx context.Context + cancelFn func() + request gsmsg.GraphSyncRequest + loader ipld.Loader + traverser ipldutil.Traverser + updateSignal chan struct{} + updates []gsmsg.GraphSyncRequest + isPaused bool } type responseKey struct { @@ -43,10 +44,11 @@ type responseKey struct { } type responseTaskData struct { - ctx context.Context - request gsmsg.GraphSyncRequest - loader ipld.Loader - traverser ipldutil.Traverser + ctx context.Context + request gsmsg.GraphSyncRequest + loader ipld.Loader + traverser ipldutil.Traverser + updateSignal chan struct{} } // QueryQueue is an interface that can receive new selector query tasks @@ -61,12 +63,17 @@ type QueryQueue interface { // RequestHooks is an interface for processing request hooks type RequestHooks interface { - ProcessRequestHooks(p peer.ID, request graphsync.RequestData) requesthooks.Result + ProcessRequestHooks(p peer.ID, request graphsync.RequestData) hooks.RequestResult } // BlockHooks is an interface for processing block hooks type BlockHooks interface { - ProcessBlockHooks(p peer.ID, request graphsync.RequestData, blockData graphsync.BlockData) blockhooks.Result + ProcessBlockHooks(p peer.ID, request graphsync.RequestData, blockData graphsync.BlockData) hooks.BlockResult +} + +// UpdateHooks is an interface for processing update hooks +type UpdateHooks interface { + ProcessUpdateHooks(p peer.ID, request graphsync.RequestData, update graphsync.RequestData) hooks.UpdateResult } // PeerManager is an interface that returns sender interfaces for peer responses. @@ -81,14 +88,14 @@ type responseManagerMessage interface { // ResponseManager handles incoming requests from the network, initiates selector // traversals, and transmits responses type ResponseManager struct { - ctx context.Context - cancelFn context.CancelFunc - loader ipld.Loader - peerManager PeerManager - queryQueue QueryQueue - requestHooks RequestHooks - blockHooks BlockHooks - + ctx context.Context + cancelFn context.CancelFunc + loader ipld.Loader + peerManager PeerManager + queryQueue QueryQueue + requestHooks RequestHooks + blockHooks BlockHooks + updateHooks UpdateHooks messages chan responseManagerMessage workSignal chan struct{} ticker *time.Ticker @@ -102,7 +109,8 @@ func New(ctx context.Context, peerManager PeerManager, queryQueue QueryQueue, requestHooks RequestHooks, - blockHooks BlockHooks) *ResponseManager { + blockHooks BlockHooks, + updateHooks UpdateHooks) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) return &ResponseManager{ ctx: ctx, @@ -112,6 +120,7 @@ func New(ctx context.Context, queryQueue: queryQueue, requestHooks: requestHooks, blockHooks: blockHooks, + updateHooks: updateHooks, messages: make(chan responseManagerMessage, 16), workSignal: make(chan struct{}, 1), ticker: time.NewTicker(thawSpeed), @@ -188,6 +197,11 @@ type setResponseDataRequest struct { traverser ipldutil.Traverser } +type responseUpdateRequest struct { + key responseKey + updateChan chan []gsmsg.GraphSyncRequest +} + func (rm *ResponseManager) processQueriesWorker() { const targetWork = 1 taskDataChan := make(chan *responseTaskData) @@ -244,7 +258,7 @@ func (rm *ResponseManager) executeTask(key responseKey, taskData *responseTaskDa case rm.messages <- &setResponseDataRequest{key, loader, traverser}: } } - return rm.executeQuery(key.p, taskData.request, loader, traverser) + return rm.executeQuery(key.p, taskData.request, loader, traverser, taskData.updateSignal) } func (rm *ResponseManager) prepareQuery(ctx context.Context, @@ -272,12 +286,19 @@ func (rm *ResponseManager) prepareQuery(ctx context.Context, return loader, traverser, nil } -func (rm *ResponseManager) executeQuery(p peer.ID, +func (rm *ResponseManager) executeQuery( + p peer.ID, request gsmsg.GraphSyncRequest, loader ipld.Loader, - traverser ipldutil.Traverser) error { + traverser ipldutil.Traverser, + updateSignal chan struct{}) error { + updateChan := make(chan []gsmsg.GraphSyncRequest) peerResponseSender := rm.peerManager.SenderForPeer(p) err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { + err := rm.checkForUpdates(p, request, updateSignal, updateChan, peerResponseSender) + if err != nil { + return err + } blockData := peerResponseSender.SendResponse(request.ID(), link, data) if blockData.BlockSize() > 0 { result := rm.blockHooks.ProcessBlockHooks(p, request, blockData) @@ -291,8 +312,10 @@ func (rm *ResponseManager) executeQuery(p peer.ID, return nil }) if err != nil { - if err != blockhooks.ErrPaused { + if err != hooks.ErrPaused { peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + } else { + peerResponseSender.PauseRequest(request.ID()) } return err } @@ -300,6 +323,36 @@ func (rm *ResponseManager) executeQuery(p peer.ID, return nil } +func (rm *ResponseManager) checkForUpdates( + p peer.ID, + request gsmsg.GraphSyncRequest, + updateSignal chan struct{}, + updateChan chan []gsmsg.GraphSyncRequest, + peerResponseSender peerresponsemanager.PeerResponseSender) error { + select { + case <-updateSignal: + select { + case rm.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: + case <-rm.ctx.Done(): + } + select { + case updates := <-updateChan: + for _, update := range updates { + result := rm.updateHooks.ProcessUpdateHooks(p, request, update) + for _, extension := range result.Extensions { + peerResponseSender.SendExtensionData(request.ID(), extension) + } + if result.Err != nil { + return result.Err + } + } + case <-rm.ctx.Done(): + } + default: + } + return nil +} + // Startup starts processing for the WantManager. func (rm *ResponseManager) Startup() { go rm.run() @@ -335,35 +388,92 @@ func (rm *ResponseManager) run() { func (prm *processRequestMessage) handle(rm *ResponseManager) { for _, request := range prm.requests { key := responseKey{p: prm.p, requestID: request.ID()} - if !request.IsCancel() { - ctx, cancelFn := context.WithCancel(rm.ctx) - rm.inProgressResponses[key] = - &inProgressResponseStatus{ - ctx: ctx, - cancelFn: cancelFn, - request: request, - } - // TODO: Use a better work estimation metric. - rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) - select { - case rm.workSignal <- struct{}{}: - default: - } - } else { + if request.IsCancel() { rm.queryQueue.Remove(key, key.p) response, ok := rm.inProgressResponses[key] if ok { response.cancelFn() + delete(rm.inProgressResponses, key) + } + continue + } + if request.IsUpdate() { + rm.processUpdate(key, request) + continue + } + ctx, cancelFn := context.WithCancel(rm.ctx) + rm.inProgressResponses[key] = + &inProgressResponseStatus{ + ctx: ctx, + cancelFn: cancelFn, + request: request, + updateSignal: make(chan struct{}, 1), } + // TODO: Use a better work estimation metric. + rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) + select { + case rm.workSignal <- struct{}{}: + default: + } + } +} + +func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSyncRequest) { + response, ok := rm.inProgressResponses[key] + if !ok { + log.Warnf("received update for non existent request, peer %s, request ID %d", key.p.Pretty(), key.requestID) + return + } + if !response.isPaused { + response.updates = append(response.updates, update) + select { + case response.updateSignal <- struct{}{}: + default: + } + return + } + result := rm.updateHooks.ProcessUpdateHooks(key.p, response.request, update) + peerResponseSender := rm.peerManager.SenderForPeer(key.p) + for _, extension := range result.Extensions { + peerResponseSender.SendExtensionData(key.requestID, extension) + } + if result.Err != nil { + peerResponseSender.FinishWithError(key.requestID, graphsync.RequestFailedUnknown) + delete(rm.inProgressResponses, key) + response.cancelFn() + return + } + if result.Unpause { + err := rm.unpauseRequest(key.p, key.requestID) + if err != nil { + log.Warnf("error unpausing request: %s", err.Error()) } } } +func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID) error { + key := responseKey{p, requestID} + inProgressResponse, ok := rm.inProgressResponses[key] + if !ok { + return errors.New("could not find request") + } + if !inProgressResponse.isPaused { + return errors.New("request is not paused") + } + inProgressResponse.isPaused = false + rm.queryQueue.PushTasks(p, peertask.Task{Topic: key, Priority: math.MaxInt32, Work: 1}) + select { + case rm.workSignal <- struct{}{}: + default: + } + return nil +} + func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] var taskData *responseTaskData if ok { - taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser} + taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser, response.updateSignal} } else { taskData = nil } @@ -378,7 +488,7 @@ func (ftr *finishTaskRequest) handle(rm *ResponseManager) { if !ok { return } - if ftr.err == blockhooks.ErrPaused { + if ftr.err == hooks.ErrPaused { response.isPaused = true return } @@ -398,33 +508,30 @@ func (srdr *setResponseDataRequest) handle(rm *ResponseManager) { response.traverser = srdr.traverser } -func (sm *synchronizeMessage) handle(rm *ResponseManager) { +func (rur *responseUpdateRequest) handle(rm *ResponseManager) { + response, ok := rm.inProgressResponses[rur.key] + var updates []gsmsg.GraphSyncRequest + if ok { + updates = response.updates + response.updates = nil + } else { + updates = nil + } select { case <-rm.ctx.Done(): - case sm.sync <- struct{}{}: + case rur.updateChan <- updates: } } -func (urm *unpauseRequestMessage) unpauseRequest(rm *ResponseManager) error { - key := responseKey{urm.p, urm.requestID} - inProgressResponse, ok := rm.inProgressResponses[key] - if !ok { - return errors.New("could not find request") - } - if !inProgressResponse.isPaused { - return errors.New("request is not paused") - } - inProgressResponse.isPaused = false - rm.queryQueue.PushTasks(urm.p, peertask.Task{Topic: key, Priority: math.MaxInt32, Work: 1}) +func (sm *synchronizeMessage) handle(rm *ResponseManager) { select { - case rm.workSignal <- struct{}{}: - default: + case <-rm.ctx.Done(): + case sm.sync <- struct{}{}: } - return nil } func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { - err := urm.unpauseRequest(rm) + err := rm.unpauseRequest(urm.p, urm.requestID) select { case <-rm.ctx.Done(): case urm.response <- err: diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index 3dadf223..a88b60cd 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -10,10 +10,9 @@ import ( "github.com/ipfs/go-graphsync" gsmsg "github.com/ipfs/go-graphsync/message" - "github.com/ipfs/go-graphsync/responsemanager/blockhooks" + "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" "github.com/ipfs/go-graphsync/responsemanager/persistenceoptions" - "github.com/ipfs/go-graphsync/responsemanager/requesthooks" "github.com/ipfs/go-graphsync/selectorvalidator" "github.com/ipfs/go-graphsync/testutil" "github.com/ipfs/go-peertaskqueue/peertask" @@ -97,10 +96,15 @@ type completedRequest struct { requestID graphsync.RequestID result graphsync.ResponseStatusCode } +type pausedRequest struct { + requestID graphsync.RequestID +} + type fakePeerResponseSender struct { sentResponses chan sentResponse sentExtensions chan sentExtension lastCompletedRequest chan completedRequest + pausedRequests chan pausedRequest } func (fprs *fakePeerResponseSender) Startup() {} @@ -147,12 +151,16 @@ func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestI fprs.lastCompletedRequest <- completedRequest{requestID, status} } +func (fprs *fakePeerResponseSender) PauseRequest(requestID graphsync.RequestID) { + fprs.pausedRequests <- pausedRequest{requestID} +} + func TestIncomingQuery(t *testing.T) { td := newTestData(t) defer td.cancel() blks := td.blockChain.AllBlocks() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() @@ -173,7 +181,7 @@ func TestCancellationQueryInProgress(t *testing.T) { td := newTestData(t) defer td.cancel() blks := td.blockChain.AllBlocks() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -220,7 +228,7 @@ func TestEarlyCancellation(t *testing.T) { td := newTestData(t) defer td.cancel() td.queryQueue.popWait.Add(1) - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -244,7 +252,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("on its own, should fail validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) var lastRequest completedRequest @@ -255,7 +263,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.SendExtensionData(td.extensionResponse) @@ -272,7 +280,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -290,7 +298,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if any hook fails, should fail", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -311,7 +319,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("hooks can be unregistered", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() unregister := td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -341,7 +349,7 @@ func TestValidationAndExtensions(t *testing.T) { defer td.cancel() obs := make(map[ipld.Link][]byte) oloader, _ := testutil.NewTestStore(obs) - responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() // add validating hook -- so the request SHOULD succeed td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { @@ -375,7 +383,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("hooks can alter the node builder chooser", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() customChooserCallCount := 0 @@ -418,7 +426,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can send extension data", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -440,7 +448,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can send errors", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -451,48 +459,245 @@ func TestValidationAndExtensions(t *testing.T) { responseManager.ProcessRequests(td.ctx, td.p, td.requests) var lastRequest completedRequest testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") - require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should succeed") + require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail") }) t.Run("can pause/unpause", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() }) - blkIndex := 1 + blkIndex := 0 blockCount := 3 - var hasPaused bool td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - if blkIndex >= blockCount && !hasPaused { + blkIndex++ + if blkIndex == blockCount { hookActions.PauseResponse() - hasPaused = true } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + timer := time.NewTimer(500 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") + } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") + err := responseManager.UnpauseResponse(td.p, td.requestID) + require.NoError(t, err) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") + }) + + }) + + t.Run("test update hook processing", func(t *testing.T) { + + t.Run("can pause/unpause", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { blkIndex++ + if blkIndex == blockCount { + hookActions.PauseResponse() + } + }) + td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + if _, found := updateData.Extension(td.extensionName); found { + hookActions.UnpauseResponse() + } }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) timer := time.NewTimer(500 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) var sentResponses []sentResponse - nomoreresponses: - for { - select { - case sentResponse := <-td.sentResponses: - sentResponses = append(sentResponses, sentResponse) - default: - break nomoreresponses - } + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") require.LessOrEqual(t, len(sentResponses), blockCount) - err := responseManager.UnpauseResponse(td.p, td.requestID) - require.NoError(t, err) + responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests) var lastRequest completedRequest testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") }) + t.Run("can send extension data", func(t *testing.T) { + t.Run("when unpaused", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + wait := make(chan struct{}) + sent := make(chan struct{}) + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + close(sent) + <-wait + } + }) + td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + if _, found := updateData.Extension(td.extensionName); found { + hookActions.SendExtensionData(td.extensionResponse) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks") + responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests) + responseManager.synchronize() + close(wait) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") + var receivedExtension sentExtension + testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response") + require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent") + }) + + t.Run("when paused", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + hookActions.PauseResponse() + } + }) + td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + if _, found := updateData.Extension(td.extensionName); found { + hookActions.SendExtensionData(td.extensionResponse) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + var sentResponses []sentResponse + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") + } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") + require.LessOrEqual(t, len(sentResponses), blockCount) + + // send update + responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests) + + // receive data + var receivedExtension sentExtension + testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response") + + // should still be paused + timer := time.NewTimer(500 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) + }) + }) + + t.Run("can send errors", func(t *testing.T) { + t.Run("when unpaused", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + wait := make(chan struct{}) + sent := make(chan struct{}) + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + close(sent) + <-wait + } + }) + td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + if _, found := updateData.Extension(td.extensionName); found { + hookActions.TerminateWithError(errors.New("something went wrong")) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks") + responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests) + responseManager.synchronize() + close(wait) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail") + }) + + t.Run("when paused", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + hookActions.PauseResponse() + } + }) + td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { + if _, found := updateData.Extension(td.extensionName); found { + hookActions.TerminateWithError(errors.New("something went wrong")) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + var sentResponses []sentResponse + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") + } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") + require.LessOrEqual(t, len(sentResponses), blockCount) + + // send update + responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests) + + // should terminate + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail") + + // cannot unpause + err := responseManager.UnpauseResponse(td.p, td.requestID) + require.Error(t, err) + }) + }) + }) } @@ -507,6 +712,7 @@ type testData struct { completedRequestChan chan completedRequest sentResponses chan sentResponse sentExtensions chan sentExtension + pausedRequests chan pausedRequest peerManager *fakePeerManager queryQueue *fakeQueryQueue extensionData []byte @@ -514,12 +720,16 @@ type testData struct { extension graphsync.ExtensionData extensionResponseData []byte extensionResponse graphsync.ExtensionData + extensionUpdateData []byte + extensionUpdate graphsync.ExtensionData requestID graphsync.RequestID requests []gsmsg.GraphSyncRequest + updateRequests []gsmsg.GraphSyncRequest p peer.ID peristenceOptions *persistenceoptions.PersistenceOptions - requestHooks *requesthooks.IncomingRequestHooks - blockHooks *blockhooks.OutgoingBlockHooks + requestHooks *hooks.IncomingRequestHooks + blockHooks *hooks.OutgoingBlockHooks + updateHooks *hooks.RequestUpdatedHooks } func newTestData(t *testing.T) testData { @@ -535,7 +745,8 @@ func newTestData(t *testing.T) testData { td.completedRequestChan = make(chan completedRequest, 1) td.sentResponses = make(chan sentResponse, td.blockChainLength*2) td.sentExtensions = make(chan sentExtension, td.blockChainLength*2) - fprs := &fakePeerResponseSender{lastCompletedRequest: td.completedRequestChan, sentResponses: td.sentResponses, sentExtensions: td.sentExtensions} + td.pausedRequests = make(chan pausedRequest, 1) + fprs := &fakePeerResponseSender{lastCompletedRequest: td.completedRequestChan, sentResponses: td.sentResponses, sentExtensions: td.sentExtensions, pausedRequests: td.pausedRequests} td.peerManager = &fakePeerManager{peerResponseSender: fprs} td.queryQueue = &fakeQueryQueue{} @@ -550,14 +761,22 @@ func newTestData(t *testing.T) testData { Name: td.extensionName, Data: td.extensionResponseData, } - + td.extensionUpdateData = testutil.RandomBytes(100) + td.extensionUpdate = graphsync.ExtensionData{ + Name: td.extensionName, + Data: td.extensionUpdateData, + } td.requestID = graphsync.RequestID(rand.Int31()) td.requests = []gsmsg.GraphSyncRequest{ gsmsg.NewRequest(td.requestID, td.blockChain.TipLink.(cidlink.Link).Cid, td.blockChain.Selector(), graphsync.Priority(0), td.extension), } + td.updateRequests = []gsmsg.GraphSyncRequest{ + gsmsg.UpdateRequest(td.requestID, td.extensionUpdate), + } td.p = testutil.GeneratePeers(1)[0] td.peristenceOptions = persistenceoptions.New() - td.requestHooks = requesthooks.New(td.peristenceOptions) - td.blockHooks = blockhooks.New() + td.requestHooks = hooks.NewRequestHooks(td.peristenceOptions) + td.blockHooks = hooks.NewBlockHooks() + td.updateHooks = hooks.NewUpdateHooks() return td }