diff --git a/graphsync.go b/graphsync.go index 17b8e934..bae45ed8 100644 --- a/graphsync.go +++ b/graphsync.go @@ -171,9 +171,22 @@ type RequestData interface { Extension(name ExtensionName) (datamodel.Node, bool) // IsCancel returns true if this particular request is being cancelled - IsCancel() bool + Type() RequestType } +type RequestType string + +const ( + // RequestTypeNew means a new request + RequestTypeNew = RequestType("New") + + // RequestTypeCancel means cancel the request referenced by request ID + RequestTypeCancel = RequestType("Cancel") + + // RequestTypeUpdate means the extensions contain an update about this request + RequestTypeUpdate = RequestType("Update") +) + // ResponseData describes a received Graphsync response type ResponseData interface { // RequestID returns the request ID for this response diff --git a/message/ipldbind/message.go b/message/ipldbind/message.go index 849100af..4c273609 100644 --- a/message/ipldbind/message.go +++ b/message/ipldbind/message.go @@ -45,12 +45,11 @@ func (gse GraphSyncExtensions) ToExtensionsList() []graphsync.ExtensionData { type GraphSyncRequest struct { Id []byte - Root *cid.Cid - Selector *datamodel.Node - Extensions GraphSyncExtensions - Priority graphsync.Priority - Cancel bool - Update bool + Root *cid.Cid + Selector *datamodel.Node + Extensions GraphSyncExtensions + Priority graphsync.Priority + RequestType graphsync.RequestType } // GraphSyncResponse is an struct to capture data on a response sent back diff --git a/message/ipldbind/schema.ipldsch b/message/ipldbind/schema.ipldsch index 0cbdaebe..7d49bf81 100644 --- a/message/ipldbind/schema.ipldsch +++ b/message/ipldbind/schema.ipldsch @@ -50,14 +50,25 @@ type GraphSyncResponseStatusCode enum { | RequestCancelled ("35") } representation int +type GraphSyncRequestType enum { + # New means a new request + | New ("n") + # Cancel means cancel the request referenced by request ID + | Cancel ("c") + # Update means the extensions contain an update about this request + | Update ("u") + # Restart means restart this request from the begging, respecting the any DoNotSendCids/DoNotSendBlocks contained + # in the extensions -- essentially a cancel followed by a new + # TODO: | Restart ("r") +} representation string + type GraphSyncRequest struct { - id GraphSyncRequestID (rename "ID") # unique id set on the requester side - root optional Link (rename "Root") # a CID for the root node in the query - selector optional Any (rename "Sel") # see https://github.com/ipld/specs/blob/master/selectors/selectors.md - extensions GraphSyncExtensions (rename "Ext") # side channel information - priority GraphSyncPriority (rename "Pri") # the priority (normalized). default to 1 - cancel Bool (rename "Canc") # whether this cancels a request - update Bool (rename "Updt") # whether this is an update to an in progress request + id GraphSyncRequestID (rename "ID") # unique id set on the requester side + root optional Link (rename "Root") # a CID for the root node in the query + selector optional Any (rename "Sel") # see https://github.com/ipld/specs/blob/master/selectors/selectors.md + extensions GraphSyncExtensions (rename "Ext") # side channel information + priority GraphSyncPriority (rename "Pri") # the priority (normalized). default to 1 + requestType GraphSyncRequestType (rename "Typ") # the request type } representation map type GraphSyncResponse struct { diff --git a/message/message.go b/message/message.go index c28f946c..9fa471e9 100644 --- a/message/message.go +++ b/message/message.go @@ -35,13 +35,12 @@ type MessagePartWithExtensions interface { // GraphSyncRequest is a struct to capture data on a request contained in a // GraphSyncMessage. type GraphSyncRequest struct { - root cid.Cid - selector ipld.Node - priority graphsync.Priority - id graphsync.RequestID - extensions map[string]datamodel.Node - isCancel bool - isUpdate bool + root cid.Cid + selector ipld.Node + priority graphsync.Priority + id graphsync.RequestID + extensions map[string]datamodel.Node + requestType graphsync.RequestType } // String returns a human-readable form of a GraphSyncRequest @@ -57,13 +56,12 @@ func (gsr GraphSyncRequest) String() string { extStr.WriteString(string(name)) extStr.WriteString("|") } - return fmt.Sprintf("GraphSyncRequest", + return fmt.Sprintf("GraphSyncRequest", gsr.root.String(), sel, gsr.priority, gsr.id.String(), - gsr.isCancel, - gsr.isUpdate, + gsr.requestType, extStr.String(), ) } @@ -146,17 +144,17 @@ func NewRequest(id graphsync.RequestID, priority graphsync.Priority, extensions ...graphsync.ExtensionData) GraphSyncRequest { - return newRequest(id, root, selector, priority, false, false, toExtensionsMap(extensions)) + return newRequest(id, root, selector, priority, graphsync.RequestTypeNew, toExtensionsMap(extensions)) } // NewCancelRequest request generates a request to cancel an in progress request func NewCancelRequest(id graphsync.RequestID) GraphSyncRequest { - return newRequest(id, cid.Cid{}, nil, 0, true, false, nil) + return newRequest(id, cid.Cid{}, nil, 0, graphsync.RequestTypeCancel, nil) } // NewUpdateRequest generates a new request to update an in progress request with the given extensions func NewUpdateRequest(id graphsync.RequestID, extensions ...graphsync.ExtensionData) GraphSyncRequest { - return newRequest(id, cid.Cid{}, nil, 0, false, true, toExtensionsMap(extensions)) + return newRequest(id, cid.Cid{}, nil, 0, graphsync.RequestTypeUpdate, toExtensionsMap(extensions)) } // NewLinkMetadata generates a new graphsync.LinkMetadata compatible object, @@ -179,18 +177,16 @@ func newRequest(id graphsync.RequestID, root cid.Cid, selector ipld.Node, priority graphsync.Priority, - isCancel bool, - isUpdate bool, + requestType graphsync.RequestType, extensions map[string]datamodel.Node) GraphSyncRequest { return GraphSyncRequest{ - id: id, - root: root, - selector: selector, - priority: priority, - isCancel: isCancel, - isUpdate: isUpdate, - extensions: extensions, + id: id, + root: root, + selector: selector, + priority: priority, + requestType: requestType, + extensions: extensions, } } @@ -310,11 +306,8 @@ func (gsr GraphSyncRequest) ExtensionNames() []graphsync.ExtensionName { return extNames } -// IsCancel returns true if this particular request is being cancelled -func (gsr GraphSyncRequest) IsCancel() bool { return gsr.isCancel } - -// IsUpdate returns true if this particular request is being updated -func (gsr GraphSyncRequest) IsUpdate() bool { return gsr.isUpdate } +// RequestType returns the type of this request (new, cancel, update, etc.) +func (gsr GraphSyncRequest) Type() graphsync.RequestType { return gsr.requestType } // RequestID returns the request ID for this response func (gsr GraphSyncResponse) RequestID() graphsync.RequestID { return gsr.requestID } @@ -379,7 +372,7 @@ func (gsr GraphSyncRequest) ReplaceExtensions(extensions []graphsync.ExtensionDa // the result func (gsr GraphSyncRequest) MergeExtensions(extensions []graphsync.ExtensionData, mergeFunc func(name graphsync.ExtensionName, oldData datamodel.Node, newData datamodel.Node) (datamodel.Node, error)) (GraphSyncRequest, error) { if gsr.extensions == nil { - return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.isCancel, gsr.isUpdate, toExtensionsMap(extensions)), nil + return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.requestType, toExtensionsMap(extensions)), nil } newExtensionMap := toExtensionsMap(extensions) combinedExtensions := make(map[string]datamodel.Node) @@ -403,5 +396,5 @@ func (gsr GraphSyncRequest) MergeExtensions(extensions []graphsync.ExtensionData } combinedExtensions[name] = oldData } - return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.isCancel, gsr.isUpdate, combinedExtensions), nil + return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.requestType, combinedExtensions), nil } diff --git a/message/v1/message.go b/message/v1/message.go index 713b4580..0554c355 100644 --- a/message/v1/message.go +++ b/message/v1/message.go @@ -119,13 +119,14 @@ func (mh *MessageHandler) ToProto(p peer.ID, gsm message.GraphSyncMessage) (*pb. if err != nil { return nil, err } + pbm.Requests = append(pbm.Requests, &pb.Message_Request{ Id: rid, Root: request.Root().Bytes(), Selector: selector, Priority: int32(request.Priority()), - Cancel: request.IsCancel(), - Update: request.IsUpdate(), + Cancel: request.Type() == graphsync.RequestTypeCancel, + Update: request.Type() == graphsync.RequestTypeUpdate, Extensions: ext, }) } diff --git a/message/v1/message_test.go b/message/v1/message_test.go index e19acf5b..22399b8a 100644 --- a/message/v1/message_test.go +++ b/message/v1/message_test.go @@ -41,7 +41,7 @@ func TestAppendingRequests(t *testing.T) { request := requests[0] extensionData, found := request.Extension(extensionName) require.Equal(t, id, request.ID()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, request.Priority()) require.Equal(t, root.String(), request.Root().String()) require.Equal(t, selector, request.Selector()) @@ -76,8 +76,7 @@ func TestAppendingRequests(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found = deserializedRequest.Extension(extensionName) require.Equal(t, id, deserializedRequest.ID()) - require.False(t, deserializedRequest.IsCancel()) - require.False(t, deserializedRequest.IsUpdate()) + require.Equal(t, deserializedRequest.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, deserializedRequest.Priority()) require.Equal(t, root.String(), deserializedRequest.Root().String()) require.Equal(t, selector, deserializedRequest.Selector()) @@ -179,7 +178,7 @@ func TestRequestCancel(t *testing.T) { require.Len(t, requests, 1, "did not add cancel request") request := requests[0] require.Equal(t, id, request.ID()) - require.True(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeCancel) mh := NewMessageHandler() @@ -192,7 +191,7 @@ func TestRequestCancel(t *testing.T) { require.Len(t, deserializedRequests, 1, "did not add request to deserialized message") deserializedRequest := deserializedRequests[0] require.Equal(t, request.ID(), deserializedRequest.ID()) - require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) + require.Equal(t, request.Type(), deserializedRequest.Type()) } func TestRequestUpdate(t *testing.T) { @@ -213,8 +212,7 @@ func TestRequestUpdate(t *testing.T) { require.Len(t, requests, 1, "did not add cancel request") request := requests[0] require.Equal(t, id, request.ID()) - require.True(t, request.IsUpdate()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeUpdate) extensionData, found := request.Extension(extensionName) require.True(t, found) require.Equal(t, extension.Data, extensionData) @@ -232,8 +230,7 @@ func TestRequestUpdate(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found = deserializedRequest.Extension(extensionName) require.Equal(t, request.ID(), deserializedRequest.ID()) - require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) - require.Equal(t, request.IsUpdate(), deserializedRequest.IsUpdate()) + require.Equal(t, request.Type(), deserializedRequest.Type()) require.Equal(t, request.Priority(), deserializedRequest.Priority()) require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) require.Equal(t, request.Selector(), deserializedRequest.Selector()) @@ -281,8 +278,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found := deserializedRequest.Extension(extensionName) require.Equal(t, request.ID(), deserializedRequest.ID()) - require.False(t, deserializedRequest.IsCancel()) - require.False(t, deserializedRequest.IsUpdate()) + require.Equal(t, deserializedRequest.Type(), graphsync.RequestTypeNew) require.Equal(t, request.Priority(), deserializedRequest.Priority()) require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) require.Equal(t, request.Selector(), deserializedRequest.Selector()) diff --git a/message/v2/message.go b/message/v2/message.go index 29098d24..fa04d755 100644 --- a/message/v2/message.go +++ b/message/v2/message.go @@ -70,13 +70,12 @@ func (mh *MessageHandler) toIPLD(gsm message.GraphSyncMessage) (*ipldbind.GraphS rootPtr = nil } ibm.Requests = append(ibm.Requests, ipldbind.GraphSyncRequest{ - Id: request.ID().Bytes(), - Root: rootPtr, - Selector: selPtr, - Priority: request.Priority(), - Cancel: request.IsCancel(), - Update: request.IsUpdate(), - Extensions: ipldbind.NewGraphSyncExtensions(request), + Id: request.ID().Bytes(), + Root: rootPtr, + Selector: selPtr, + Priority: request.Priority(), + RequestType: request.Type(), + Extensions: ipldbind.NewGraphSyncExtensions(request), }) } @@ -142,12 +141,12 @@ func (mh *MessageHandler) fromIPLD(ibm *ipldbind.GraphSyncMessage) (message.Grap return message.GraphSyncMessage{}, err } - if req.Cancel { + if req.RequestType == graphsync.RequestTypeCancel { requests[id] = message.NewCancelRequest(id) continue } - if req.Update { + if req.RequestType == graphsync.RequestTypeUpdate { requests[id] = message.NewUpdateRequest(id, req.Extensions.ToExtensionsList()...) continue } diff --git a/message/v2/message_test.go b/message/v2/message_test.go index b9c3c740..7e9acad8 100644 --- a/message/v2/message_test.go +++ b/message/v2/message_test.go @@ -40,7 +40,7 @@ func TestAppendingRequests(t *testing.T) { request := requests[0] extensionData, found := request.Extension(extensionName) require.Equal(t, id, request.ID()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, request.Priority()) require.Equal(t, root.String(), request.Root().String()) require.Equal(t, selector, request.Selector()) @@ -55,8 +55,7 @@ func TestAppendingRequests(t *testing.T) { gsrIpld := gsmIpld.Requests[0] require.Equal(t, priority, gsrIpld.Priority) - require.False(t, gsrIpld.Cancel) - require.False(t, gsrIpld.Update) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, root, *gsrIpld.Root) require.Equal(t, selector, *gsrIpld.Selector) require.Equal(t, 1, len(gsrIpld.Extensions.Keys)) @@ -72,8 +71,7 @@ func TestAppendingRequests(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found = deserializedRequest.Extension(extensionName) require.Equal(t, id, deserializedRequest.ID()) - require.False(t, deserializedRequest.IsCancel()) - require.False(t, deserializedRequest.IsUpdate()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, deserializedRequest.Priority()) require.Equal(t, root.String(), deserializedRequest.Root().String()) require.Equal(t, selector, deserializedRequest.Selector()) @@ -174,7 +172,7 @@ func TestRequestCancel(t *testing.T) { require.Len(t, requests, 1, "did not add cancel request") request := requests[0] require.Equal(t, id, request.ID()) - require.True(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeCancel) mh := NewMessageHandler() @@ -187,7 +185,7 @@ func TestRequestCancel(t *testing.T) { require.Len(t, deserializedRequests, 1, "did not add request to deserialized message") deserializedRequest := deserializedRequests[0] require.Equal(t, request.ID(), deserializedRequest.ID()) - require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) + require.Equal(t, request.Type(), deserializedRequest.Type()) } func TestRequestUpdate(t *testing.T) { @@ -208,8 +206,7 @@ func TestRequestUpdate(t *testing.T) { require.Len(t, requests, 1, "did not add cancel request") request := requests[0] require.Equal(t, id, request.ID()) - require.True(t, request.IsUpdate()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeUpdate) extensionData, found := request.Extension(extensionName) require.True(t, found) require.Equal(t, extension.Data, extensionData) @@ -227,8 +224,7 @@ func TestRequestUpdate(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found = deserializedRequest.Extension(extensionName) require.Equal(t, request.ID(), deserializedRequest.ID()) - require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) - require.Equal(t, request.IsUpdate(), deserializedRequest.IsUpdate()) + require.Equal(t, request.Type(), deserializedRequest.Type()) require.Equal(t, request.Priority(), deserializedRequest.Priority()) require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) require.Equal(t, request.Selector(), deserializedRequest.Selector()) @@ -276,8 +272,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { deserializedRequest := deserializedRequests[0] extensionData, found := deserializedRequest.Extension(extensionName) require.Equal(t, request.ID(), deserializedRequest.ID()) - require.False(t, deserializedRequest.IsCancel()) - require.False(t, deserializedRequest.IsUpdate()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, request.Priority(), deserializedRequest.Priority()) require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) require.Equal(t, request.Selector(), deserializedRequest.Selector()) diff --git a/messagequeue/messagequeue_test.go b/messagequeue/messagequeue_test.go index f182f594..bf5b8ec5 100644 --- a/messagequeue/messagequeue_test.go +++ b/messagequeue/messagequeue_test.go @@ -231,7 +231,7 @@ func TestDedupingMessages(t *testing.T) { require.Len(t, requests, 1, "number of requests in first message was not 1") request := requests[0] require.Equal(t, id, request.ID()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, request.Priority()) require.Equal(t, selector, request.Selector()) @@ -241,11 +241,11 @@ func TestDedupingMessages(t *testing.T) { require.Len(t, requests, 2, "number of requests in second message was not 2") for _, request := range requests { if request.ID() == id2 { - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority2, request.Priority()) require.Equal(t, selector2, request.Selector()) } else if request.ID() == id3 { - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority3, request.Priority()) require.Equal(t, selector3, request.Selector()) } else { diff --git a/network/libp2p_impl_test.go b/network/libp2p_impl_test.go index a7cdebc3..7a67dafe 100644 --- a/network/libp2p_impl_test.go +++ b/network/libp2p_impl_test.go @@ -107,8 +107,7 @@ func TestMessageSendAndReceive(t *testing.T) { require.Len(t, receivedRequests, 1, "did not add request to received message") receivedRequest := receivedRequests[0] require.Equal(t, sentRequest.ID(), receivedRequest.ID()) - require.Equal(t, sentRequest.IsCancel(), receivedRequest.IsCancel()) - require.Equal(t, sentRequest.Priority(), receivedRequest.Priority()) + require.Equal(t, sentRequest.Type(), receivedRequest.Type()) require.Equal(t, sentRequest.Root().String(), receivedRequest.Root().String()) require.Equal(t, sentRequest.Selector(), receivedRequest.Selector()) diff --git a/peermanager/peermessagemanager_test.go b/peermanager/peermessagemanager_test.go index d08144c1..9a13c877 100644 --- a/peermanager/peermessagemanager_test.go +++ b/peermanager/peermessagemanager_test.go @@ -92,7 +92,7 @@ func TestSendingMessagesToPeers(t *testing.T) { require.Equal(t, tp[0], firstMessage.p, "first message sent to incorrect peer") request = firstMessage.message.Requests()[0] require.Equal(t, id, request.ID()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, request.Priority()) require.Equal(t, selector, request.Selector()) @@ -101,7 +101,7 @@ func TestSendingMessagesToPeers(t *testing.T) { require.Equal(t, tp[1], secondMessage.p, "second message sent to incorrect peer") request = secondMessage.message.Requests()[0] require.Equal(t, id, request.ID()) - require.False(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeNew) require.Equal(t, priority, request.Priority()) require.Equal(t, selector, request.Selector()) @@ -111,7 +111,7 @@ func TestSendingMessagesToPeers(t *testing.T) { require.Equal(t, tp[0], thirdMessage.p, "third message sent to incorrect peer") request = thirdMessage.message.Requests()[0] require.Equal(t, id, request.ID()) - require.True(t, request.IsCancel()) + require.Equal(t, request.Type(), graphsync.RequestTypeCancel) connectedPeers := peerManager.ConnectedPeers() require.Len(t, connectedPeers, 2) diff --git a/requestmanager/executor/executor_test.go b/requestmanager/executor/executor_test.go index fc0c57f3..b25adb98 100644 --- a/requestmanager/executor/executor_test.go +++ b/requestmanager/executor/executor_test.go @@ -71,7 +71,7 @@ func TestRequestExecutionBlockChain(t *testing.T) { require.Regexp(t, "something went wrong", receivedErrors[0].Error()) require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Equal(t, ree.requestsSent[1].request.Type(), graphsync.RequestTypeCancel) require.Len(t, ree.blookHooksCalled, 6) require.EqualError(t, ree.terminalError, "something went wrong") }, @@ -97,7 +97,7 @@ func TestRequestExecutionBlockChain(t *testing.T) { require.Empty(t, receivedErrors) require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Equal(t, ree.requestsSent[1].request.Type(), graphsync.RequestTypeCancel) require.Len(t, ree.blookHooksCalled, 6) require.EqualError(t, ree.terminalError, hooks.ErrPaused{}.Error()) }, @@ -129,7 +129,7 @@ func TestRequestExecutionBlockChain(t *testing.T) { tbc.VerifyResponseRangeSync(responses, 0, 6) require.Empty(t, receivedErrors) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Equal(t, ree.requestsSent[1].request.Type(), graphsync.RequestTypeCancel) require.Len(t, ree.blookHooksCalled, 6) require.EqualError(t, ree.terminalError, hooks.ErrPaused{}.Error()) }, @@ -165,7 +165,7 @@ func TestRequestExecutionBlockChain(t *testing.T) { require.Regexp(t, "something went wrong", receivedErrors[0].Error()) require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsCancel()) + require.Equal(t, ree.requestsSent[1].request.Type(), graphsync.RequestTypeCancel) require.Len(t, ree.blookHooksCalled, 6) require.EqualError(t, ree.terminalError, "something went wrong") }, @@ -179,7 +179,7 @@ func TestRequestExecutionBlockChain(t *testing.T) { require.Empty(t, receivedErrors) require.Len(t, ree.requestsSent, 2) require.Equal(t, ree.request, ree.requestsSent[0].request) - require.True(t, ree.requestsSent[1].request.IsUpdate()) + require.Equal(t, ree.requestsSent[1].request.Type(), graphsync.RequestTypeUpdate) data, has := ree.requestsSent[1].request.Extension("something") require.True(t, has) str, _ := data.AsString() @@ -326,7 +326,7 @@ func (ree *requestExecutionEnv) GetRequestTask(_ peer.ID, _ *peertask.Task, requ func (ree *requestExecutionEnv) SendRequest(p peer.ID, request gsmsg.GraphSyncRequest) { ree.requestsSent = append(ree.requestsSent, requestSent{p, request}) - if !request.IsCancel() && !request.IsUpdate() { + if request.Type() == graphsync.RequestTypeNew { if ree.customRemoteBehavior == nil { ree.fal.SuccessResponseOn(p, request.ID(), ree.tbc.Blocks(ree.loadLocallyUntil, len(ree.tbc.AllBlocks()))) } else { diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index e69f3ed7..6e5fce3c 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -49,8 +49,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag()) require.Equal(t, peers[0], requestRecords[0].p) require.Equal(t, peers[0], requestRecords[1].p) - require.False(t, requestRecords[0].gsr.IsCancel()) - require.False(t, requestRecords[1].gsr.IsCancel()) + require.Equal(t, requestRecords[0].gsr.Type(), graphsync.RequestTypeNew) require.Equal(t, defaultPriority, requestRecords[0].gsr.Priority()) require.Equal(t, defaultPriority, requestRecords[1].gsr.Priority()) @@ -137,7 +136,7 @@ func TestCancelRequestInProgress(t *testing.T) { cancel1() rr := readNNetworkRequests(requestCtx, t, td, 1)[0] - require.True(t, rr.gsr.IsCancel()) + require.Equal(t, rr.gsr.Type(), graphsync.RequestTypeCancel) require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID()) moreBlocks := td.blockChain.RemainderBlocks(3) @@ -204,7 +203,7 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { rr := readNNetworkRequests(requestCtx, t, td, 1)[0] - require.True(t, rr.gsr.IsCancel()) + require.Equal(t, rr.gsr.Type(), graphsync.RequestTypeCancel) require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID()) td.tcm.RefuteProtected(t, peers[0]) @@ -827,7 +826,7 @@ func TestPauseResume(t *testing.T) { // read the outgoing cancel request pauseCancel := readNNetworkRequests(requestCtx, t, td, 1)[0] - require.True(t, pauseCancel.gsr.IsCancel()) + require.Equal(t, pauseCancel.gsr.Type(), graphsync.RequestTypeCancel) // verify no further responses come through time.Sleep(100 * time.Millisecond) @@ -902,7 +901,7 @@ func TestPauseResumeExternal(t *testing.T) { // read the outgoing cancel request pauseCancel := readNNetworkRequests(requestCtx, t, td, 1)[0] - require.True(t, pauseCancel.gsr.IsCancel()) + require.Equal(t, pauseCancel.gsr.Type(), graphsync.RequestTypeCancel) // verify no further responses come through time.Sleep(100 * time.Millisecond) diff --git a/responsemanager/server.go b/responsemanager/server.go index f71808bb..80758d04 100644 --- a/responsemanager/server.go +++ b/responsemanager/server.go @@ -190,13 +190,14 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync for _, request := range requests { key := responseKey{p: p, requestID: request.ID()} - if request.IsCancel() { + switch request.Type() { + case graphsync.RequestTypeCancel: _ = rm.abortRequest(ctx, p, request.ID(), ipldutil.ContextCancelError{}) continue - } - if request.IsUpdate() { + case graphsync.RequestTypeUpdate: rm.processUpdate(ctx, key, request) continue + default: } rm.connManager.Protect(p, request.ID().Tag()) // don't use `ctx` which has the "message" trace, but rm.ctx for a fresh trace which allows