diff --git a/benchmarks/testinstance/testinstance.go b/benchmarks/testinstance/testinstance.go index bfc72747..685cd264 100644 --- a/benchmarks/testinstance/testinstance.go +++ b/benchmarks/testinstance/testinstance.go @@ -164,7 +164,7 @@ func NewInstance(ctx context.Context, net tn.Network, tempDir string, diskBasedD linkSystem := storeutil.LinkSystemForBlockstore(bstore) gs := gsimpl.New(ctx, gsNet, linkSystem, gsimpl.RejectAllRequestsByDefault()) - transport := gstransport.NewTransport(p, gs, dtNet) + transport := gstransport.NewTransport(gs, dtNet) dt, err := dtimpl.NewDataTransfer(namespace.Wrap(dstore, datastore.NewKey("/data-transfers/transfers")), p, transport) if err != nil { return Instance{}, err diff --git a/channels/caches.go b/channels/caches.go deleted file mode 100644 index 2e2d75a9..00000000 --- a/channels/caches.go +++ /dev/null @@ -1,133 +0,0 @@ -package channels - -import ( - "sync" - "sync/atomic" - - datatransfer "github.com/filecoin-project/go-data-transfer/v2" -) - -type readIndexFn func(datatransfer.ChannelID) (int64, error) - -type cacheKey struct { - evt datatransfer.EventCode - chid datatransfer.ChannelID -} - -type blockIndexCache struct { - lk sync.RWMutex - values map[cacheKey]*int64 -} - -func newBlockIndexCache() *blockIndexCache { - return &blockIndexCache{ - values: make(map[cacheKey]*int64), - } -} - -func (bic *blockIndexCache) getValue(evt datatransfer.EventCode, chid datatransfer.ChannelID, readFromOriginal readIndexFn) (*int64, error) { - idxKey := cacheKey{evt, chid} - bic.lk.RLock() - value := bic.values[idxKey] - bic.lk.RUnlock() - if value != nil { - return value, nil - } - bic.lk.Lock() - defer bic.lk.Unlock() - value = bic.values[idxKey] - if value != nil { - return value, nil - } - newValue, err := readFromOriginal(chid) - if err != nil { - return nil, err - } - bic.values[idxKey] = &newValue - return &newValue, nil -} - -func (bic *blockIndexCache) updateIfGreater(evt datatransfer.EventCode, chid datatransfer.ChannelID, newIndex int64, readFromOriginal readIndexFn) (bool, error) { - value, err := bic.getValue(evt, chid, readFromOriginal) - if err != nil { - return false, err - } - for { - currentIndex := atomic.LoadInt64(value) - if newIndex <= currentIndex { - return false, nil - } - if atomic.CompareAndSwapInt64(value, currentIndex, newIndex) { - return true, nil - } - } -} - -type progressState struct { - dataLimit uint64 - progress *uint64 -} - -type readProgressFn func(datatransfer.ChannelID) (dataLimit uint64, progress uint64, err error) - -type progressCache struct { - lk sync.RWMutex - values map[datatransfer.ChannelID]progressState -} - -func newProgressCache() *progressCache { - return &progressCache{ - values: make(map[datatransfer.ChannelID]progressState), - } -} - -func (pc *progressCache) getValue(chid datatransfer.ChannelID, readProgress readProgressFn) (progressState, error) { - pc.lk.RLock() - value, ok := pc.values[chid] - pc.lk.RUnlock() - if ok { - return value, nil - } - pc.lk.Lock() - defer pc.lk.Unlock() - value, ok = pc.values[chid] - if ok { - return value, nil - } - dataLimit, progress, err := readProgress(chid) - if err != nil { - return progressState{}, err - } - newValue := progressState{ - dataLimit: dataLimit, - progress: &progress, - } - pc.values[chid] = newValue - return newValue, nil -} - -func (pc *progressCache) progress(chid datatransfer.ChannelID, additionalData uint64, readFromOriginal readProgressFn) (bool, error) { - state, err := pc.getValue(chid, readFromOriginal) - if err != nil { - return false, err - } - total := atomic.AddUint64(state.progress, additionalData) - return state.dataLimit != 0 && total >= state.dataLimit, nil -} - -func (pc *progressCache) setDataLimit(chid datatransfer.ChannelID, newLimit uint64) { - pc.lk.RLock() - value, ok := pc.values[chid] - pc.lk.RUnlock() - if !ok { - return - } - pc.lk.Lock() - defer pc.lk.Unlock() - value, ok = pc.values[chid] - if !ok { - return - } - value.dataLimit = newLimit - pc.values[chid] = value -} diff --git a/channels/channel_state.go b/channels/channel_state.go index fec337ba..efabd74d 100644 --- a/channels/channel_state.go +++ b/channels/channel_state.go @@ -50,22 +50,22 @@ func (c channelState) Voucher() datatransfer.TypedVoucher { return datatransfer.TypedVoucher{Voucher: ev.Voucher.Node, Type: ev.Type} } -// ReceivedCidsTotal returns the number of (non-unique) cids received so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) ReceivedCidsTotal() int64 { - return c.ic.ReceivedBlocksTotal +// ReceivedIndex returns the index, a transport specific identifier for "where" +// we are in receiving data for a transfer +func (c channelState) ReceivedIndex() datamodel.Node { + return c.ic.ReceivedIndex.Node } -// QueuedCidsTotal returns the number of (non-unique) cids queued so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) QueuedCidsTotal() int64 { - return c.ic.QueuedBlocksTotal +// QueuedIndex returns the index, a transport specific identifier for "where" +// we are in queing data for a transfer +func (c channelState) QueuedIndex() datamodel.Node { + return c.ic.QueuedIndex.Node } -// SentCidsTotal returns the number of (non-unique) cids sent so far -// on the channel - note that a block can exist in more than one place in the DAG -func (c channelState) SentCidsTotal() int64 { - return c.ic.SentBlocksTotal +// SentIndex returns the index, a transport specific identifier for "where" +// we are in sending data for a transfer +func (c channelState) SentIndex() datamodel.Node { + return c.ic.SentIndex.Node } // Sender returns the peer id for the node that is sending data @@ -139,6 +139,25 @@ func (c channelState) RequiresFinalization() bool { return c.ic.RequiresFinalization } +func (c channelState) InitiatorPaused() bool { + return c.ic.InitiatorPaused +} + +func (c channelState) ResponderPaused() bool { + return c.ic.ResponderPaused || c.ic.Status == datatransfer.Finalizing +} + +func (c channelState) BothPaused() bool { + return c.InitiatorPaused() && c.ResponderPaused() +} + +func (c channelState) SelfPaused() bool { + if c.ic.SelfPeer == c.ic.Initiator { + return c.InitiatorPaused() + } + return c.ResponderPaused() +} + // Stages returns the current ChannelStages object, or an empty object. // It is unsafe for the caller to modify the return value, and changes may not // be persisted. It should be treated as immutable. diff --git a/channels/channels.go b/channels/channels.go index 1d289aa2..bc00c420 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -29,8 +29,6 @@ var ErrWrongType = errors.New("Cannot change type of implementation specific dat // Channels is a thread safe list of channels type Channels struct { notifier Notifier - blockIndexCache *blockIndexCache - progressCache *progressCache stateMachines fsm.Group migrateStateMachines func(context.Context) error } @@ -48,8 +46,6 @@ func New(ds datastore.Batching, selfPeer peer.ID) (*Channels, error) { c := &Channels{notifier: notifier} - c.blockIndexCache = newBlockIndexCache() - c.progressCache = newProgressCache() channelMigrations, err := migrations.GetChannelStateMigrations(selfPeer) if err != nil { return nil, err @@ -62,7 +58,7 @@ func New(ds datastore.Batching, StateEntryFuncs: ChannelStateEntryFuncs, Notifier: c.dispatch, FinalityStates: ChannelFinalityStates, - }, channelMigrations, versioning.VersionKey("2")) + }, channelMigrations, versioning.VersionKey("3")) if err != nil { return nil, err } @@ -164,8 +160,8 @@ func (c *Channels) ChannelOpened(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.Opened) } -func (c *Channels) TransferRequestQueued(chid datatransfer.ChannelID) error { - return c.send(chid, datatransfer.TransferRequestQueued) +func (c *Channels) TransferInitiated(chid datatransfer.ChannelID) error { + return c.send(chid, datatransfer.TransferInitiated) } // Restart marks a data transfer as restarted @@ -173,63 +169,29 @@ func (c *Channels) Restart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.Restart) } +// CompleteCleanupOnRestart tells a channel to restart func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.CompleteCleanupOnRestart) } -func (c *Channels) getQueuedIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.QueuedCidsTotal(), nil +// DataSent records data being sent +func (c *Channels) DataSent(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, delta, index) } -func (c *Channels) getReceivedIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.ReceivedCidsTotal(), nil +// DataQueued records data being queued +func (c *Channels) DataQueued(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, delta, index) } -func (c *Channels) getSentIndex(chid datatransfer.ChannelID) (int64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, err - } - return chst.SentCidsTotal(), nil +// DataReceived records data being received +func (c *Channels) DataReceived(chid datatransfer.ChannelID, delta uint64, index datamodel.Node) error { + return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, delta, index) } -func (c *Channels) getQueuedProgress(chid datatransfer.ChannelID) (uint64, uint64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, 0, err - } - dataLimit := chst.DataLimit() - return dataLimit, chst.Queued(), nil -} - -func (c *Channels) getReceivedProgress(chid datatransfer.ChannelID) (uint64, uint64, error) { - chst, err := c.GetByID(context.TODO(), chid) - if err != nil { - return 0, 0, err - } - dataLimit := chst.DataLimit() - return dataLimit, chst.Received(), nil -} - -func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) error { - return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, delta, index, unique, c.getSentIndex, nil) -} - -func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) error { - return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, delta, index, unique, c.getQueuedIndex, c.getQueuedProgress) -} - -// Returns true if this is the first time the block has been received -func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) error { - return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, delta, index, unique, c.getReceivedIndex, c.getReceivedProgress) +// DataLimitExceeded records a data limit exceeded event +func (c *Channels) DataLimitExceeded(chid datatransfer.ChannelID) error { + return c.send(chid, datatransfer.DataLimitExceeded) } // PauseInitiator pauses the initator of this channel @@ -329,9 +291,13 @@ func (c *Channels) ReceiveDataError(chid datatransfer.ChannelID, err error) erro return c.send(chid, datatransfer.ReceiveDataError, err) } +// SendMessageError indicates an error sending a message to the transport layer +func (c *Channels) SendMessageError(chid datatransfer.ChannelID, err error) error { + return c.send(chid, datatransfer.SendMessageError, err) +} + // SetDataLimit means a data limit has been set on this channel func (c *Channels) SetDataLimit(chid datatransfer.ChannelID, dataLimit uint64) error { - c.progressCache.setDataLimit(chid, dataLimit) return c.send(chid, datatransfer.SetDataLimit, dataLimit) } @@ -347,75 +313,19 @@ func (c *Channels) HasChannel(chid datatransfer.ChannelID) (bool, error) { // fireProgressEvent fires // - an event for queuing / sending / receiving blocks -// - a corresponding "progress" event if the block has not been seen before -// - a DataLimitExceeded event if the progress goes past the data limit -// For example, if a block is being sent for the first time, the method will -// fire both DataSent AND DataSentProgress. -// If a block is resent, the method will fire DataSent but not DataSentProgress. -// If a block is sent for the first time, and more data has been sent than the data limit, -// the method will fire DataSent AND DataProgress AND DataLimitExceeded AND it will return -// datatransfer.ErrPause as the error -func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, delta uint64, index int64, unique bool, readFromOriginal readIndexFn, readProgress readProgressFn) error { +// - a corresponding "progress" event +func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, delta uint64, index datamodel.Node) error { if err := c.checkChannelExists(chid, evt); err != nil { return err } - pause, progress, err := c.checkEvents(chid, evt, delta, index, unique, readFromOriginal, readProgress) - - if err != nil { + // Fire the progress event + if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { return err } - // Fire the progress event if there is progress - if progress { - if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { - return err - } - } - // Fire the regular event - if err := c.stateMachines.Send(chid, evt, index); err != nil { - return err - } - - // fire the pause event if we past our data limit - if pause { - // pause. Data limits only exist on the responder, so we always pause the responder - if err := c.stateMachines.Send(chid, datatransfer.DataLimitExceeded); err != nil { - return err - } - // return a pause error so the transfer knows to pause - return datatransfer.ErrPause - } - return nil -} - -func (c *Channels) checkEvents(chid datatransfer.ChannelID, evt datatransfer.EventCode, delta uint64, index int64, unique bool, readFromOriginal readIndexFn, readProgress readProgressFn) (pause bool, progress bool, err error) { - - // if this is not a unique block, no data progress is made, return - if !unique { - return - } - - // check if data progress is made - progress, err = c.blockIndexCache.updateIfGreater(evt, chid, index, readFromOriginal) - if err != nil { - return false, false, err - } - - // if no data progress, return - if !progress { - return - } - - // if we don't check data limits on this function, return - if readProgress == nil { - return - } - - // check if we're past our data limit - pause, err = c.progressCache.progress(chid, delta, readProgress) - return + return c.stateMachines.Send(chid, evt, index) } func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error { diff --git a/channels/channels_fsm.go b/channels/channels_fsm.go index 4a095054..51dfc6d5 100644 --- a/channels/channels_fsm.go +++ b/channels/channels_fsm.go @@ -2,6 +2,7 @@ package channels import ( logging "github.com/ipfs/go-log/v2" + "github.com/ipld/go-ipld-prime/datamodel" "github.com/filecoin-project/go-statemachine/fsm" @@ -11,16 +12,6 @@ import ( var log = logging.Logger("data-transfer") -var transferringStates = []fsm.StateKey{ - datatransfer.Requested, - datatransfer.Ongoing, - datatransfer.InitiatorPaused, - datatransfer.ResponderPaused, - datatransfer.BothPaused, - datatransfer.ResponderCompleted, - datatransfer.ResponderFinalizing, -} - // ChannelEvents describe the events taht can var ChannelEvents = fsm.Events{ // Open a channel @@ -28,23 +19,32 @@ var ChannelEvents = fsm.Events{ chst.AddLog("") return nil }), + // Remote peer has accepted the Open channel request - fsm.Event(datatransfer.Accept).From(datatransfer.Requested).To(datatransfer.Ongoing).Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + fsm.Event(datatransfer.Accept). + From(datatransfer.Requested).To(datatransfer.Queued). + From(datatransfer.AwaitingAcceptance).To(datatransfer.Ongoing). + Action(func(chst *internal.ChannelState) error { + chst.AddLog("") + return nil + }), - fsm.Event(datatransfer.TransferRequestQueued).FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.Message = "" - chst.AddLog("") - return nil - }), + // The transport has indicated it's begun sending/receiving data + fsm.Event(datatransfer.TransferInitiated). + From(datatransfer.Requested).To(datatransfer.AwaitingAcceptance). + From(datatransfer.Queued).To(datatransfer.Ongoing). + From(datatransfer.Ongoing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.Restart).FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { chst.Message = "" chst.AddLog("") return nil }), + fsm.Event(datatransfer.Cancel).FromAny().To(datatransfer.Cancelling).Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil @@ -59,87 +59,90 @@ var ChannelEvents = fsm.Events{ return nil }), - fsm.Event(datatransfer.DataReceived).FromAny().ToNoChange(). - Action(func(chst *internal.ChannelState, rcvdBlocksTotal int64) error { - if rcvdBlocksTotal > chst.ReceivedBlocksTotal { - chst.ReceivedBlocksTotal = rcvdBlocksTotal - } + fsm.Event(datatransfer.DataReceived).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, receivedIndex datamodel.Node) error { + chst.ReceivedIndex = internal.CborGenCompatibleNode{Node: receivedIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataReceivedProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataReceivedProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Received += delta chst.AddLog("received data") return nil }), - fsm.Event(datatransfer.DataSent). - FromMany(transferringStates...).ToNoChange(). - From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState, sentBlocksTotal int64) error { - if sentBlocksTotal > chst.SentBlocksTotal { - chst.SentBlocksTotal = sentBlocksTotal - } + fsm.Event(datatransfer.DataSent).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, sentIndex datamodel.Node) error { + chst.SentIndex = internal.CborGenCompatibleNode{Node: sentIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataSentProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataSentProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Sent += delta chst.AddLog("sending data") return nil }), - fsm.Event(datatransfer.DataQueued). - FromMany(transferringStates...).ToNoChange(). - From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState, queuedBlocksTotal int64) error { - if queuedBlocksTotal > chst.QueuedBlocksTotal { - chst.QueuedBlocksTotal = queuedBlocksTotal - } + fsm.Event(datatransfer.DataQueued).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). + Action(func(chst *internal.ChannelState, queuedIndex datamodel.Node) error { + chst.QueuedIndex = internal.CborGenCompatibleNode{Node: queuedIndex} chst.AddLog("") return nil }), - fsm.Event(datatransfer.DataQueuedProgress).FromMany(transferringStates...).ToNoChange(). + fsm.Event(datatransfer.DataQueuedProgress).FromMany(datatransfer.TransferringStates.AsFSMStates()...).ToNoChange(). Action(func(chst *internal.ChannelState, delta uint64) error { chst.Queued += delta chst.AddLog("") return nil }), + fsm.Event(datatransfer.SetDataLimit).FromAny().ToJustRecord(). Action(func(chst *internal.ChannelState, dataLimit uint64) error { chst.DataLimit = dataLimit chst.AddLog("") return nil }), + fsm.Event(datatransfer.SetRequiresFinalization).FromAny().ToJustRecord(). Action(func(chst *internal.ChannelState, RequiresFinalization bool) error { chst.RequiresFinalization = RequiresFinalization chst.AddLog("") return nil }), + fsm.Event(datatransfer.Disconnected).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer disconnected: %s", chst.Message) return nil }), + fsm.Event(datatransfer.SendDataError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer send error: %s", chst.Message) return nil }), + fsm.Event(datatransfer.ReceiveDataError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer receive error: %s", chst.Message) return nil }), + + fsm.Event(datatransfer.SendMessageError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { + chst.Message = err.Error() + chst.AddLog("data transfer errored sending message: %s", chst.Message) + return nil + }), + fsm.Event(datatransfer.RequestCancelled).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer request cancelled: %s", chst.Message) return nil }), + fsm.Event(datatransfer.Error).FromAny().To(datatransfer.Failing).Action(func(chst *internal.ChannelState, err error) error { chst.Message = err.Error() chst.AddLog("data transfer erred: %s", chst.Message) @@ -152,6 +155,7 @@ var ChannelEvents = fsm.Events{ chst.AddLog("got new voucher") return nil }), + fsm.Event(datatransfer.NewVoucherResult).FromAny().ToNoChange(). Action(func(chst *internal.ChannelState, voucherResult datatransfer.TypedVoucher) error { chst.VoucherResults = append(chst.VoucherResults, @@ -160,46 +164,54 @@ var ChannelEvents = fsm.Events{ return nil }), + // TODO: There are four states from which the request can be "paused": request, queued, awaiting acceptance + // and ongoing. There four states of being + // paused (no pause, initiator pause, responder pause, both paused). Until the state machine software + // supports orthogonal regions (https://en.wikipedia.org/wiki/UML_state_machine#Orthogonal_regions) + // we end up with a cartesian product of states and as you can see, fairly complicated state transfers. + // Previously, we had dealt with this by moving directly to the Ongoing state upon return from pause but this + // seems less than ideal. We need some kind of support for pausing being an independent aspect of state + // Possibly we should just remove whether a state is paused from the state entirely. fsm.Event(datatransfer.PauseInitiator). - FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.InitiatorPaused). - From(datatransfer.ResponderPaused).To(datatransfer.BothPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.InitiatorPaused = true + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.PauseResponder). - FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.ResponderPaused). - From(datatransfer.InitiatorPaused).To(datatransfer.BothPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.TransferFinished).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = true + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.DataLimitExceeded). - FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.ResponderPaused). - From(datatransfer.InitiatorPaused).To(datatransfer.BothPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.ResponderCompleted, datatransfer.ResponderFinalizing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = true + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.ResumeInitiator). - From(datatransfer.InitiatorPaused).To(datatransfer.Ongoing). - From(datatransfer.BothPaused).To(datatransfer.ResponderPaused). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.ResponderCompleted, datatransfer.ResponderFinalizing).ToJustRecord(). + Action(func(chst *internal.ChannelState) error { + chst.InitiatorPaused = false + chst.AddLog("") + return nil + }), fsm.Event(datatransfer.ResumeResponder). - From(datatransfer.ResponderPaused).To(datatransfer.Ongoing). - From(datatransfer.BothPaused).To(datatransfer.InitiatorPaused). + FromMany(datatransfer.Ongoing, datatransfer.Requested, datatransfer.Queued, datatransfer.AwaitingAcceptance, datatransfer.TransferFinished).ToJustRecord(). From(datatransfer.Finalizing).To(datatransfer.Completing). - FromAny().ToJustRecord().Action(func(chst *internal.ChannelState) error { - chst.AddLog("") - return nil - }), + Action(func(chst *internal.ChannelState) error { + chst.ResponderPaused = false + chst.AddLog("") + return nil + }), // The transfer has finished on the local node - all data was sent / received fsm.Event(datatransfer.FinishTransfer). @@ -207,10 +219,10 @@ var ChannelEvents = fsm.Events{ FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). From(datatransfer.ResponderCompleted).To(datatransfer.Completing). From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderFinalizingTransferFinished). - // If we are in the requested state, it means the other party simply never responded to our + // If we are in the AwaitingAcceptance state, it means the other party simply never responded to our // our data transfer, or we never actually contacted them. In any case, it's safe to skip // the finalization process and complete the transfer - From(datatransfer.Requested).To(datatransfer.Completing). + From(datatransfer.AwaitingAcceptance).To(datatransfer.Completing). Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil @@ -229,9 +241,7 @@ var ChannelEvents = fsm.Events{ fsm.Event(datatransfer.ResponderCompletes). FromAny().To(datatransfer.ResponderCompleted). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). - From(datatransfer.ResponderPaused).To(datatransfer.ResponderFinalizing). From(datatransfer.TransferFinished).To(datatransfer.Completing). - From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderCompleted). From(datatransfer.ResponderFinalizingTransferFinished).To(datatransfer.Completing).Action(func(chst *internal.ChannelState) error { chst.AddLog("") return nil diff --git a/channels/channels_test.go b/channels/channels_test.go index 23c16ef8..3e26af1d 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -1,11 +1,14 @@ package channels_test import ( + "bytes" "context" "errors" + "math/rand" "testing" "time" + "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" basicnode "github.com/ipld/go-ipld-prime/node/basic" @@ -14,8 +17,13 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" + versioning "github.com/filecoin-project/go-ds-versioning/pkg" + versionedds "github.com/filecoin-project/go-ds-versioning/pkg/datastore" + datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/channels" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal/migrations" "github.com/filecoin-project/go-data-transfer/v2/testutil" ) @@ -104,55 +112,23 @@ func TestChannels(t *testing.T) { err = channelList.Accept(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.Accept) - require.Equal(t, state.Status(), datatransfer.Ongoing) + require.Equal(t, state.Status(), datatransfer.Queued) err = channelList.Accept(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) }) - t.Run("transfer queued", func(t *testing.T) { + t.Run("transfer initiated", func(t *testing.T) { state, err := channelList.GetByID(ctx, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - require.Equal(t, state.Status(), datatransfer.Ongoing) + require.Equal(t, state.Status(), datatransfer.Queued) - err = channelList.TransferRequestQueued(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) + err = channelList.TransferInitiated(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.TransferRequestQueued) + state = checkEvent(ctx, t, received, datatransfer.TransferInitiated) require.Equal(t, state.Status(), datatransfer.Ongoing) }) - t.Run("datasent/queued when transfer is already finished", func(t *testing.T) { - ds := dss.MutexWrap(datastore.NewMapDatastore()) - - channelList, err := channels.New(ds, notifier, &fakeEnv{}, peers[0]) - require.NoError(t, err) - err = channelList.Start(ctx) - require.NoError(t, err) - - chid, _, err := channelList.CreateNew(peers[0], tid1, cids[0], selector, fv1, peers[0], peers[0], peers[1]) - require.NoError(t, err) - checkEvent(ctx, t, received, datatransfer.Open) - require.NoError(t, channelList.Accept(chid)) - checkEvent(ctx, t, received, datatransfer.Accept) - - // move the channel to `TransferFinished` state. - require.NoError(t, channelList.FinishTransfer(chid)) - state := checkEvent(ctx, t, received, datatransfer.FinishTransfer) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - - // send a data-sent event and ensure it's a no-op - err = channelList.DataSent(chid, cids[1], 1, 1, true) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - - // send a data-queued event and ensure it's a no-op. - err = channelList.DataQueued(chid, cids[1], 1, 1, true) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataQueued) - require.Equal(t, datatransfer.TransferFinished, state.Status()) - }) - t.Run("updating send/receive values", func(t *testing.T) { ds := dss.MutexWrap(datastore.NewMapDatastore()) @@ -168,51 +144,36 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(0), state.Received()) require.Equal(t, uint64(0), state.Sent()) - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 1, true) + err = channelList.TransferInitiated(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) + require.NoError(t, err) + _ = checkEvent(ctx, t, received, datatransfer.TransferInitiated) + + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 50, basicnode.NewInt(1)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(0), state.Sent()) - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) + err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 100, basicnode.NewInt(1)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataSentProgress) state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(100), state.Sent()) - // send block again has no effect - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, uint64(50), state.Received()) - require.Equal(t, uint64(100), state.Sent()) - // errors if channel does not exist - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 200, basicnode.NewInt(2)) require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) + err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 200, basicnode.NewInt(2)) require.True(t, errors.Is(err, datatransfer.ErrChannelNotFound)) - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50, 2, true) + err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, 50, basicnode.NewInt(2)) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) - - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25, 2, false) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataSent) - require.Equal(t, uint64(100), state.Received()) - require.Equal(t, uint64(100), state.Sent()) - - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 3, false) - require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataReceived) - require.Equal(t, uint64(100), state.Received()) - require.Equal(t, uint64(100), state.Sent()) }) t.Run("data limit", func(t *testing.T) { @@ -227,47 +188,34 @@ func TestChannels(t *testing.T) { require.NoError(t, err) state := checkEvent(ctx, t, received, datatransfer.Open) - err = channelList.DataQueued(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[0], 300, 1, true) - require.NoError(t, err) - _ = checkEvent(ctx, t, received, datatransfer.DataQueuedProgress) - state = checkEvent(ctx, t, received, datatransfer.DataQueued) - require.Equal(t, uint64(300), state.Queued()) - err = channelList.SetDataLimit(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 400) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.SetDataLimit) require.Equal(t, state.DataLimit(), uint64(400)) - // send block again has no effect - err = channelList.DataQueued(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[0], 300, 1, true) + err = channelList.DataLimitExceeded(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) require.NoError(t, err) - state = checkEvent(ctx, t, received, datatransfer.DataQueued) - require.Equal(t, uint64(300), state.Queued()) - - err = channelList.DataQueued(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - _ = checkEvent(ctx, t, received, datatransfer.DataQueuedProgress) - _ = checkEvent(ctx, t, received, datatransfer.DataQueued) state = checkEvent(ctx, t, received, datatransfer.DataLimitExceeded) - require.Equal(t, uint64(500), state.Queued()) + require.True(t, state.ResponderPaused()) err = channelList.SetDataLimit(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, 700) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.SetDataLimit) require.Equal(t, state.DataLimit(), uint64(700)) - err = channelList.DataQueued(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[2], 150, 3, true) - require.NoError(t, err) - _ = checkEvent(ctx, t, received, datatransfer.DataQueuedProgress) - state = checkEvent(ctx, t, received, datatransfer.DataQueued) - require.Equal(t, uint64(650), state.Queued()) + err = channelList.ResumeResponder(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + state = checkEvent(ctx, t, received, datatransfer.ResumeResponder) + require.False(t, state.ResponderPaused()) + + err = channelList.PauseInitiator(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + state = checkEvent(ctx, t, received, datatransfer.PauseInitiator) + require.True(t, state.InitiatorPaused()) - err = channelList.DataQueued(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[3], 200, 4, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - _ = checkEvent(ctx, t, received, datatransfer.DataQueuedProgress) - _ = checkEvent(ctx, t, received, datatransfer.DataQueued) + err = channelList.DataLimitExceeded(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}) + require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.DataLimitExceeded) - require.Equal(t, uint64(850), state.Queued()) + require.True(t, state.BothPaused()) + }) t.Run("pause/resume", func(t *testing.T) { @@ -278,17 +226,19 @@ func TestChannels(t *testing.T) { err = channelList.PauseInitiator(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.PauseInitiator) - require.Equal(t, datatransfer.InitiatorPaused, state.Status()) + require.True(t, state.InitiatorPaused()) + require.False(t, state.BothPaused()) err = channelList.PauseResponder(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.PauseResponder) - require.Equal(t, datatransfer.BothPaused, state.Status()) + require.True(t, state.BothPaused()) err = channelList.ResumeInitiator(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.ResumeInitiator) - require.Equal(t, datatransfer.ResponderPaused, state.Status()) + require.True(t, state.ResponderPaused()) + require.False(t, state.BothPaused()) err = channelList.ResumeResponder(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) require.NoError(t, err) @@ -450,6 +400,153 @@ func TestIsChannelCleaningUp(t *testing.T) { require.False(t, channels.IsChannelCleaningUp(datatransfer.Cancelled)) } +func TestMigrations(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + ds := dss.MutexWrap(datastore.NewMapDatastore()) + received := make(chan event) + notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { + received <- event{evt, chst} + } + numChannels := 5 + transferIDs := make([]datatransfer.TransferID, numChannels) + initiators := make([]peer.ID, numChannels) + responders := make([]peer.ID, numChannels) + baseCids := make([]cid.Cid, numChannels) + + totalSizes := make([]uint64, numChannels) + sents := make([]uint64, numChannels) + receiveds := make([]uint64, numChannels) + + messages := make([]string, numChannels) + vouchers := make([]datatransfer.TypedVoucher, numChannels) + voucherResults := make([]datatransfer.TypedVoucher, numChannels) + sentIndex := make([]int64, numChannels) + receivedIndex := make([]int64, numChannels) + queuedIndex := make([]int64, numChannels) + allSelector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + selfPeer := testutil.GeneratePeers(1)[0] + + list, err := migrations.GetChannelStateMigrations(selfPeer) + require.NoError(t, err) + vds, up := versionedds.NewVersionedDatastore(ds, list, versioning.VersionKey("2")) + require.NoError(t, up(ctx)) + + initialStatuses := []datatransfer.Status{ + datatransfer.Requested, + datatransfer.InitiatorPaused, + datatransfer.ResponderPaused, + datatransfer.BothPaused, + datatransfer.Ongoing, + } + for i := 0; i < numChannels; i++ { + transferIDs[i] = datatransfer.TransferID(rand.Uint64()) + initiators[i] = testutil.GeneratePeers(1)[0] + responders[i] = testutil.GeneratePeers(1)[0] + baseCids[i] = testutil.GenerateCids(1)[0] + totalSizes[i] = rand.Uint64() + sents[i] = rand.Uint64() + receiveds[i] = rand.Uint64() + messages[i] = string(testutil.RandomBytes(20)) + vouchers[i] = testutil.NewTestTypedVoucher() + voucherResults[i] = testutil.NewTestTypedVoucher() + sentIndex[i] = rand.Int63() + receivedIndex[i] = rand.Int63() + queuedIndex[i] = rand.Int63() + channel := migrations.ChannelStateV2{ + TransferID: transferIDs[i], + Initiator: initiators[i], + Responder: responders[i], + BaseCid: baseCids[i], + Selector: internal.CborGenCompatibleNode{ + Node: allSelector, + }, + Sender: initiators[i], + Recipient: responders[i], + TotalSize: totalSizes[i], + Status: initialStatuses[i], + Sent: sents[i], + Received: receiveds[i], + Message: messages[i], + Vouchers: []internal.EncodedVoucher{ + { + Type: vouchers[i].Type, + Voucher: internal.CborGenCompatibleNode{ + Node: vouchers[i].Voucher, + }, + }, + }, + VoucherResults: []internal.EncodedVoucherResult{ + { + Type: voucherResults[i].Type, + VoucherResult: internal.CborGenCompatibleNode{ + Node: voucherResults[i].Voucher, + }, + }, + }, + SentBlocksTotal: sentIndex[i], + ReceivedBlocksTotal: receivedIndex[i], + QueuedBlocksTotal: queuedIndex[i], + SelfPeer: selfPeer, + } + buf := new(bytes.Buffer) + err = channel.MarshalCBOR(buf) + require.NoError(t, err) + err = vds.Put(ctx, datastore.NewKey(datatransfer.ChannelID{ + Initiator: initiators[i], + Responder: responders[i], + ID: transferIDs[i], + }.String()), buf.Bytes()) + require.NoError(t, err) + } + + channelList, err := channels.New(ds, notifier, &fakeEnv{}, selfPeer) + require.NoError(t, err) + err = channelList.Start(ctx) + require.NoError(t, err) + + expectedStatuses := []datatransfer.Status{ + datatransfer.Requested, + datatransfer.Ongoing, + datatransfer.Ongoing, + datatransfer.Ongoing, + datatransfer.Ongoing, + } + + expectedInitiatorPaused := []bool{false, true, false, true, false} + expectedResponderPaused := []bool{false, false, true, true, false} + for i := 0; i < numChannels; i++ { + + channel, err := channelList.GetByID(ctx, datatransfer.ChannelID{ + Initiator: initiators[i], + Responder: responders[i], + ID: transferIDs[i], + }) + require.NoError(t, err) + require.Equal(t, selfPeer, channel.SelfPeer()) + require.Equal(t, transferIDs[i], channel.TransferID()) + require.Equal(t, baseCids[i], channel.BaseCID()) + require.Equal(t, allSelector, channel.Selector()) + require.Equal(t, initiators[i], channel.Sender()) + require.Equal(t, responders[i], channel.Recipient()) + require.Equal(t, totalSizes[i], channel.TotalSize()) + require.Equal(t, sents[i], channel.Sent()) + require.Equal(t, receiveds[i], channel.Received()) + require.Equal(t, messages[i], channel.Message()) + require.Equal(t, vouchers[i], channel.LastVoucher()) + require.Equal(t, voucherResults[i], channel.LastVoucherResult()) + require.Equal(t, expectedStatuses[i], channel.Status()) + require.Equal(t, expectedInitiatorPaused[i], channel.InitiatorPaused()) + require.Equal(t, expectedResponderPaused[i], channel.ResponderPaused()) + require.Equal(t, basicnode.NewInt(sentIndex[i]), channel.SentIndex()) + require.Equal(t, basicnode.NewInt(receivedIndex[i]), channel.ReceivedIndex()) + require.Equal(t, basicnode.NewInt(queuedIndex[i]), channel.QueuedIndex()) + + } +} + type event struct { event datatransfer.Event state datatransfer.ChannelState diff --git a/channels/internal/internalchannel.go b/channels/internal/internalchannel.go index 4a9fabd8..4209edaa 100644 --- a/channels/internal/internalchannel.go +++ b/channels/internal/internalchannel.go @@ -104,19 +104,23 @@ type ChannelState struct { VoucherResults []EncodedVoucherResult // Number of blocks that have been received, including blocks that are // present in more than one place in the DAG - ReceivedBlocksTotal int64 + ReceivedIndex CborGenCompatibleNode // Number of blocks that have been queued, including blocks that are // present in more than one place in the DAG - QueuedBlocksTotal int64 + QueuedIndex CborGenCompatibleNode // Number of blocks that have been sent, including blocks that are // present in more than one place in the DAG - SentBlocksTotal int64 + SentIndex CborGenCompatibleNode // DataLimit is the maximum data that can be transferred on this channel before // revalidation. 0 indicates no limit. DataLimit uint64 // RequiresFinalization indicates at the end of the transfer, the channel should // be left open for a final settlement RequiresFinalization bool + // ResponderPaused indicates whether the responder is in a paused state + ResponderPaused bool + // InitiatorPaused indicates whether the initiator is in a paused state + InitiatorPaused bool // Stages traces the execution fo a data transfer. // // EXPERIMENTAL; subject to change. diff --git a/channels/internal/internalchannel_cbor_gen.go b/channels/internal/internalchannel_cbor_gen.go index 58f43c1b..c7b1e958 100644 --- a/channels/internal/internalchannel_cbor_gen.go +++ b/channels/internal/internalchannel_cbor_gen.go @@ -23,7 +23,7 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{182}); err != nil { + if _, err := w.Write([]byte{184, 24}); err != nil { return err } @@ -345,70 +345,52 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { } } - // t.ReceivedBlocksTotal (int64) (int64) - if len("ReceivedBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"ReceivedBlocksTotal\" was too long") + // t.ReceivedIndex (internal.CborGenCompatibleNode) (struct) + if len("ReceivedIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ReceivedIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("ReceivedBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("ReceivedIndex")); err != nil { return err } - if t.ReceivedBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.ReceivedBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.ReceivedBlocksTotal-1)); err != nil { - return err - } + if err := t.ReceivedIndex.MarshalCBOR(w); err != nil { + return err } - // t.QueuedBlocksTotal (int64) (int64) - if len("QueuedBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"QueuedBlocksTotal\" was too long") + // t.QueuedIndex (internal.CborGenCompatibleNode) (struct) + if len("QueuedIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"QueuedIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("QueuedBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("QueuedIndex")); err != nil { return err } - if t.QueuedBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.QueuedBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.QueuedBlocksTotal-1)); err != nil { - return err - } + if err := t.QueuedIndex.MarshalCBOR(w); err != nil { + return err } - // t.SentBlocksTotal (int64) (int64) - if len("SentBlocksTotal") > cbg.MaxLength { - return xerrors.Errorf("Value in field \"SentBlocksTotal\" was too long") + // t.SentIndex (internal.CborGenCompatibleNode) (struct) + if len("SentIndex") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SentIndex\" was too long") } - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentBlocksTotal"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentIndex"))); err != nil { return err } - if _, err := io.WriteString(w, string("SentBlocksTotal")); err != nil { + if _, err := io.WriteString(w, string("SentIndex")); err != nil { return err } - if t.SentBlocksTotal >= 0 { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SentBlocksTotal)); err != nil { - return err - } - } else { - if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SentBlocksTotal-1)); err != nil { - return err - } + if err := t.SentIndex.MarshalCBOR(w); err != nil { + return err } // t.DataLimit (uint64) (uint64) @@ -443,6 +425,38 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { return err } + // t.ResponderPaused (bool) (bool) + if len("ResponderPaused") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ResponderPaused\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ResponderPaused"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ResponderPaused")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.ResponderPaused); err != nil { + return err + } + + // t.InitiatorPaused (bool) (bool) + if len("InitiatorPaused") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"InitiatorPaused\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("InitiatorPaused"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("InitiatorPaused")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.InitiatorPaused); err != nil { + return err + } + // t.Stages (datatransfer.ChannelStages) (struct) if len("Stages") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Stages\" was too long") @@ -733,83 +747,35 @@ func (t *ChannelState) UnmarshalCBOR(r io.Reader) error { t.VoucherResults[i] = v } - // t.ReceivedBlocksTotal (int64) (int64) - case "ReceivedBlocksTotal": + // t.ReceivedIndex (internal.CborGenCompatibleNode) (struct) + case "ReceivedIndex": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 - if err != nil { - return err - } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + + if err := t.ReceivedIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.ReceivedIndex: %w", err) } - t.ReceivedBlocksTotal = int64(extraI) } - // t.QueuedBlocksTotal (int64) (int64) - case "QueuedBlocksTotal": + // t.QueuedIndex (internal.CborGenCompatibleNode) (struct) + case "QueuedIndex": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 - if err != nil { - return err - } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + + if err := t.QueuedIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.QueuedIndex: %w", err) } - t.QueuedBlocksTotal = int64(extraI) } - // t.SentBlocksTotal (int64) (int64) - case "SentBlocksTotal": + // t.SentIndex (internal.CborGenCompatibleNode) (struct) + case "SentIndex": + { - maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) - var extraI int64 - if err != nil { - return err - } - switch maj { - case cbg.MajUnsignedInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 positive overflow") - } - case cbg.MajNegativeInt: - extraI = int64(extra) - if extraI < 0 { - return fmt.Errorf("int64 negative oveflow") - } - extraI = -1 - extraI - default: - return fmt.Errorf("wrong type for int64 field: %d", maj) + + if err := t.SentIndex.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.SentIndex: %w", err) } - t.SentBlocksTotal = int64(extraI) } // t.DataLimit (uint64) (uint64) case "DataLimit": @@ -844,6 +810,42 @@ func (t *ChannelState) UnmarshalCBOR(r io.Reader) error { default: return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) } + // t.ResponderPaused (bool) (bool) + case "ResponderPaused": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.ResponderPaused = false + case 21: + t.ResponderPaused = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.InitiatorPaused (bool) (bool) + case "InitiatorPaused": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.InitiatorPaused = false + case 21: + t.InitiatorPaused = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } // t.Stages (datatransfer.ChannelStages) (struct) case "Stages": diff --git a/channels/internal/migrations/migrations.go b/channels/internal/migrations/migrations.go index b6a1ed6a..210dfc00 100644 --- a/channels/internal/migrations/migrations.go +++ b/channels/internal/migrations/migrations.go @@ -1,13 +1,119 @@ package migrations import ( + "github.com/ipfs/go-cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" peer "github.com/libp2p/go-libp2p-core/peer" versioning "github.com/filecoin-project/go-ds-versioning/pkg" "github.com/filecoin-project/go-ds-versioning/pkg/versioned" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/channels/internal" ) +//go:generate cbor-gen-for --map-encoding ChannelStateV2 + +// ChannelStateV2 is the internal representation on disk for the channel fsm, version 2 +type ChannelStateV2 struct { + // PeerId of the manager peer + SelfPeer peer.ID + // an identifier for this channel shared by request and responder, set by requester through protocol + TransferID datatransfer.TransferID + // Initiator is the person who intiated this datatransfer request + Initiator peer.ID + // Responder is the person who is responding to this datatransfer request + Responder peer.ID + // base CID for the piece being transferred + BaseCid cid.Cid + // portion of Piece to return, specified by an IPLD selector + Selector internal.CborGenCompatibleNode + // the party that is sending the data (not who initiated the request) + Sender peer.ID + // the party that is receiving the data (not who initiated the request) + Recipient peer.ID + // expected amount of data to be transferred + TotalSize uint64 + // current status of this deal + Status datatransfer.Status + // total bytes read from this node and queued for sending (0 if receiver) + Queued uint64 + // total bytes sent from this node (0 if receiver) + Sent uint64 + // total bytes received by this node (0 if sender) + Received uint64 + // more informative status on a channel + Message string + Vouchers []internal.EncodedVoucher + VoucherResults []internal.EncodedVoucherResult + // Number of blocks that have been received, including blocks that are + // present in more than one place in the DAG + ReceivedBlocksTotal int64 + // Number of blocks that have been queued, including blocks that are + // present in more than one place in the DAG + QueuedBlocksTotal int64 + // Number of blocks that have been sent, including blocks that are + // present in more than one place in the DAG + SentBlocksTotal int64 + // DataLimit is the maximum data that can be transferred on this channel before + // revalidation. 0 indicates no limit. + DataLimit uint64 + // RequiresFinalization indicates at the end of the transfer, the channel should + // be left open for a final settlement + RequiresFinalization bool + // Stages traces the execution fo a data transfer. + // + // EXPERIMENTAL; subject to change. + Stages *datatransfer.ChannelStages +} + +func NoOpChannelState0To2(oldChannelState *ChannelStateV2) (*ChannelStateV2, error) { + return oldChannelState, nil +} + +func MigrateChannelState2To3(oldChannelState *ChannelStateV2) (*internal.ChannelState, error) { + receivedIndex := basicnode.NewInt(oldChannelState.ReceivedBlocksTotal) + sentIndex := basicnode.NewInt(oldChannelState.SentBlocksTotal) + queuedIndex := basicnode.NewInt(oldChannelState.QueuedBlocksTotal) + + responderPaused := oldChannelState.Status == datatransfer.ResponderPaused || oldChannelState.Status == datatransfer.BothPaused + initiatorPaused := oldChannelState.Status == datatransfer.InitiatorPaused || oldChannelState.Status == datatransfer.BothPaused + newStatus := oldChannelState.Status + if newStatus == datatransfer.ResponderPaused || newStatus == datatransfer.InitiatorPaused || newStatus == datatransfer.BothPaused { + newStatus = datatransfer.Ongoing + } + return &internal.ChannelState{ + SelfPeer: oldChannelState.SelfPeer, + TransferID: oldChannelState.TransferID, + Initiator: oldChannelState.Initiator, + Responder: oldChannelState.Responder, + BaseCid: oldChannelState.BaseCid, + Selector: oldChannelState.Selector, + Sender: oldChannelState.Sender, + Recipient: oldChannelState.Recipient, + TotalSize: oldChannelState.TotalSize, + Status: newStatus, + Queued: oldChannelState.Queued, + Sent: oldChannelState.Sent, + Received: oldChannelState.Received, + Message: oldChannelState.Message, + Vouchers: oldChannelState.Vouchers, + VoucherResults: oldChannelState.VoucherResults, + ReceivedIndex: internal.CborGenCompatibleNode{Node: receivedIndex}, + SentIndex: internal.CborGenCompatibleNode{Node: sentIndex}, + QueuedIndex: internal.CborGenCompatibleNode{Node: queuedIndex}, + DataLimit: oldChannelState.DataLimit, + RequiresFinalization: oldChannelState.RequiresFinalization, + InitiatorPaused: initiatorPaused, + ResponderPaused: responderPaused, + Stages: oldChannelState.Stages, + }, nil +} + // GetChannelStateMigrations returns a migration list for the channel states func GetChannelStateMigrations(selfPeer peer.ID) (versioning.VersionedMigrationList, error) { - return versioned.BuilderList{}.Build() + return versioned.BuilderList{ + versioned.NewVersionedBuilder(NoOpChannelState0To2, "2"), + versioned.NewVersionedBuilder(MigrateChannelState2To3, "3").OldVersion("2"), + }.Build() } diff --git a/channels/internal/migrations/migrations_cbor_gen.go b/channels/internal/migrations/migrations_cbor_gen.go new file mode 100644 index 00000000..c4ca74fd --- /dev/null +++ b/channels/internal/migrations/migrations_cbor_gen.go @@ -0,0 +1,876 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package migrations + +import ( + "fmt" + "io" + "sort" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + internal "github.com/filecoin-project/go-data-transfer/v2/channels/internal" + cid "github.com/ipfs/go-cid" + peer "github.com/libp2p/go-libp2p-core/peer" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf +var _ = cid.Undef +var _ = sort.Sort + +func (t *ChannelStateV2) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{182}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.SelfPeer (peer.ID) (string) + if len("SelfPeer") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SelfPeer\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SelfPeer"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("SelfPeer")); err != nil { + return err + } + + if len(t.SelfPeer) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.SelfPeer was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.SelfPeer))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.SelfPeer)); err != nil { + return err + } + + // t.TransferID (datatransfer.TransferID) (uint64) + if len("TransferID") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TransferID\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("TransferID"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("TransferID")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.TransferID)); err != nil { + return err + } + + // t.Initiator (peer.ID) (string) + if len("Initiator") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Initiator\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Initiator"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Initiator")); err != nil { + return err + } + + if len(t.Initiator) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Initiator was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Initiator))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Initiator)); err != nil { + return err + } + + // t.Responder (peer.ID) (string) + if len("Responder") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Responder\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Responder"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Responder")); err != nil { + return err + } + + if len(t.Responder) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Responder was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Responder))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Responder)); err != nil { + return err + } + + // t.BaseCid (cid.Cid) (struct) + if len("BaseCid") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"BaseCid\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("BaseCid"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("BaseCid")); err != nil { + return err + } + + if err := cbg.WriteCidBuf(scratch, w, t.BaseCid); err != nil { + return xerrors.Errorf("failed to write cid field t.BaseCid: %w", err) + } + + // t.Selector (internal.CborGenCompatibleNode) (struct) + if len("Selector") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Selector\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Selector"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Selector")); err != nil { + return err + } + + if err := t.Selector.MarshalCBOR(w); err != nil { + return err + } + + // t.Sender (peer.ID) (string) + if len("Sender") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Sender\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Sender"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Sender")); err != nil { + return err + } + + if len(t.Sender) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Sender was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Sender))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Sender)); err != nil { + return err + } + + // t.Recipient (peer.ID) (string) + if len("Recipient") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Recipient\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Recipient"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Recipient")); err != nil { + return err + } + + if len(t.Recipient) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Recipient was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Recipient))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Recipient)); err != nil { + return err + } + + // t.TotalSize (uint64) (uint64) + if len("TotalSize") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"TotalSize\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("TotalSize"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("TotalSize")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.TotalSize)); err != nil { + return err + } + + // t.Status (datatransfer.Status) (uint64) + if len("Status") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Status\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Status"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Status")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Status)); err != nil { + return err + } + + // t.Queued (uint64) (uint64) + if len("Queued") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Queued\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Queued"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Queued")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Queued)); err != nil { + return err + } + + // t.Sent (uint64) (uint64) + if len("Sent") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Sent\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Sent"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Sent")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Sent)); err != nil { + return err + } + + // t.Received (uint64) (uint64) + if len("Received") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Received\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Received"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Received")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Received)); err != nil { + return err + } + + // t.Message (string) (string) + if len("Message") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Message\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Message"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Message")); err != nil { + return err + } + + if len(t.Message) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Message was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Message))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.Message)); err != nil { + return err + } + + // t.Vouchers ([]internal.EncodedVoucher) (slice) + if len("Vouchers") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Vouchers\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Vouchers"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Vouchers")); err != nil { + return err + } + + if len(t.Vouchers) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Vouchers was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Vouchers))); err != nil { + return err + } + for _, v := range t.Vouchers { + if err := v.MarshalCBOR(w); err != nil { + return err + } + } + + // t.VoucherResults ([]internal.EncodedVoucherResult) (slice) + if len("VoucherResults") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"VoucherResults\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("VoucherResults"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("VoucherResults")); err != nil { + return err + } + + if len(t.VoucherResults) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.VoucherResults was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.VoucherResults))); err != nil { + return err + } + for _, v := range t.VoucherResults { + if err := v.MarshalCBOR(w); err != nil { + return err + } + } + + // t.ReceivedBlocksTotal (int64) (int64) + if len("ReceivedBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"ReceivedBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("ReceivedBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("ReceivedBlocksTotal")); err != nil { + return err + } + + if t.ReceivedBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.ReceivedBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.ReceivedBlocksTotal-1)); err != nil { + return err + } + } + + // t.QueuedBlocksTotal (int64) (int64) + if len("QueuedBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"QueuedBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("QueuedBlocksTotal")); err != nil { + return err + } + + if t.QueuedBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.QueuedBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.QueuedBlocksTotal-1)); err != nil { + return err + } + } + + // t.SentBlocksTotal (int64) (int64) + if len("SentBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SentBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("SentBlocksTotal")); err != nil { + return err + } + + if t.SentBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SentBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SentBlocksTotal-1)); err != nil { + return err + } + } + + // t.DataLimit (uint64) (uint64) + if len("DataLimit") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"DataLimit\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("DataLimit"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("DataLimit")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.DataLimit)); err != nil { + return err + } + + // t.RequiresFinalization (bool) (bool) + if len("RequiresFinalization") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"RequiresFinalization\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("RequiresFinalization"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("RequiresFinalization")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.RequiresFinalization); err != nil { + return err + } + + // t.Stages (datatransfer.ChannelStages) (struct) + if len("Stages") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Stages\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Stages"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Stages")); err != nil { + return err + } + + if err := t.Stages.MarshalCBOR(w); err != nil { + return err + } + return nil +} + +func (t *ChannelStateV2) UnmarshalCBOR(r io.Reader) error { + *t = ChannelStateV2{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("ChannelStateV2: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.SelfPeer (peer.ID) (string) + case "SelfPeer": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.SelfPeer = peer.ID(sval) + } + // t.TransferID (datatransfer.TransferID) (uint64) + case "TransferID": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.TransferID = datatransfer.TransferID(extra) + + } + // t.Initiator (peer.ID) (string) + case "Initiator": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Initiator = peer.ID(sval) + } + // t.Responder (peer.ID) (string) + case "Responder": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Responder = peer.ID(sval) + } + // t.BaseCid (cid.Cid) (struct) + case "BaseCid": + + { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.BaseCid: %w", err) + } + + t.BaseCid = c + + } + // t.Selector (internal.CborGenCompatibleNode) (struct) + case "Selector": + + { + + if err := t.Selector.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Selector: %w", err) + } + + } + // t.Sender (peer.ID) (string) + case "Sender": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Sender = peer.ID(sval) + } + // t.Recipient (peer.ID) (string) + case "Recipient": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Recipient = peer.ID(sval) + } + // t.TotalSize (uint64) (uint64) + case "TotalSize": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.TotalSize = uint64(extra) + + } + // t.Status (datatransfer.Status) (uint64) + case "Status": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Status = datatransfer.Status(extra) + + } + // t.Queued (uint64) (uint64) + case "Queued": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Queued = uint64(extra) + + } + // t.Sent (uint64) (uint64) + case "Sent": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Sent = uint64(extra) + + } + // t.Received (uint64) (uint64) + case "Received": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Received = uint64(extra) + + } + // t.Message (string) (string) + case "Message": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.Message = string(sval) + } + // t.Vouchers ([]internal.EncodedVoucher) (slice) + case "Vouchers": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.Vouchers: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.Vouchers = make([]internal.EncodedVoucher, extra) + } + + for i := 0; i < int(extra); i++ { + + var v internal.EncodedVoucher + if err := v.UnmarshalCBOR(br); err != nil { + return err + } + + t.Vouchers[i] = v + } + + // t.VoucherResults ([]internal.EncodedVoucherResult) (slice) + case "VoucherResults": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.VoucherResults: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.VoucherResults = make([]internal.EncodedVoucherResult, extra) + } + + for i := 0; i < int(extra); i++ { + + var v internal.EncodedVoucherResult + if err := v.UnmarshalCBOR(br); err != nil { + return err + } + + t.VoucherResults[i] = v + } + + // t.ReceivedBlocksTotal (int64) (int64) + case "ReceivedBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.ReceivedBlocksTotal = int64(extraI) + } + // t.QueuedBlocksTotal (int64) (int64) + case "QueuedBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.QueuedBlocksTotal = int64(extraI) + } + // t.SentBlocksTotal (int64) (int64) + case "SentBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.SentBlocksTotal = int64(extraI) + } + // t.DataLimit (uint64) (uint64) + case "DataLimit": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.DataLimit = uint64(extra) + + } + // t.RequiresFinalization (bool) (bool) + case "RequiresFinalization": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.RequiresFinalization = false + case 21: + t.RequiresFinalization = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Stages (datatransfer.ChannelStages) (struct) + case "Stages": + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + t.Stages = new(datatransfer.ChannelStages) + if err := t.Stages.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Stages pointer: %w", err) + } + } + + } + + default: + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) + } + } + + return nil +} diff --git a/errors.go b/errors.go index 0e9903f6..592444c0 100644 --- a/errors.go +++ b/errors.go @@ -17,14 +17,6 @@ const ErrHandlerNotSet = errorType("event handler has not been set") // ErrChannelNotFound means the channel this command was issued for does not exist const ErrChannelNotFound = errorType("channel not found") -// ErrPause is a special error that the DataReceived / DataSent hooks can -// use to pause the channel -const ErrPause = errorType("pause channel") - -// ErrResume is a special error that the RequestReceived / ResponseReceived hooks can -// use to resume the channel -const ErrResume = errorType("resume channel") - // ErrRejected indicates a request was not accepted const ErrRejected = errorType("response rejected") diff --git a/events.go b/events.go index 9ccd47c9..1cc42be8 100644 --- a/events.go +++ b/events.go @@ -61,7 +61,7 @@ const ( // initiator BeginFinalizing - // Disconnected emits when we are not able to connect to the other party + // DEPRECATED in favor of SendMessageError Disconnected // Complete is emitted when a data transfer is complete @@ -91,7 +91,7 @@ const ( // data has been received. DataReceivedProgress - // Deprecated in favour of RequestCancelled + // DEPRECATED in favour of RequestCancelled RequestTimedOut // SendDataError indicates that the transport layer had an error trying @@ -102,7 +102,7 @@ const ( // receiving data from the remote peer ReceiveDataError - // TransferRequestQueued indicates that a new data transfer request has been queued in the transport layer + // DEPRECATED in favor of TransferInitiated TransferRequestQueued // RequestCancelled indicates that a transport layer request was cancelled by the request opener @@ -123,6 +123,12 @@ const ( // pausing the responder, but is distinct from PauseResponder to indicate why the pause // happened DataLimitExceeded + + // TransferInitiated indicates the transport has begun transferring data + TransferInitiated + + // SendMessageError indicates an error sending a data transfer message + SendMessageError ) // Events are human readable names for data transfer events @@ -161,6 +167,8 @@ var Events = map[EventCode]string{ SetDataLimit: "SetDataLimit", SetRequiresFinalization: "SetRequiresFinalization", DataLimitExceeded: "DataLimitExceeded", + TransferInitiated: "TransferInitiated", + SendMessageError: "SendMessageError", } // Event is a struct containing information about a data transfer event diff --git a/go.mod b/go.mod index 972e065e..608fa39f 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/ipfs/go-cid v0.2.0 github.com/ipfs/go-datastore v0.5.1 github.com/ipfs/go-ds-badger v0.3.0 - github.com/ipfs/go-graphsync v0.13.2-0.20220531040852-fa5a9f2d7a86 + github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534 github.com/ipfs/go-ipfs-blockstore v1.1.2 github.com/ipfs/go-ipfs-blocksutil v0.0.1 github.com/ipfs/go-ipfs-chunker v0.0.5 @@ -119,7 +119,6 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/urfave/cli/v2 v2.0.0 // indirect github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect - go.uber.org/goleak v1.1.12 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect diff --git a/go.sum b/go.sum index d34f268a..7da44a75 100644 --- a/go.sum +++ b/go.sum @@ -126,7 +126,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= -github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= @@ -196,7 +195,6 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -391,7 +389,6 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= -github.com/huin/goupnp v1.0.2/go.mod h1:0dxJBVBHqTMjIUMkESDTNgOOx/Mw5wYIfyFmdzSamkM= github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ= github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y= github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= @@ -441,8 +438,8 @@ github.com/ipfs/go-ds-leveldb v0.0.1/go.mod h1:feO8V3kubwsEF22n0YRQCffeb79OOYIyk github.com/ipfs/go-ds-leveldb v0.4.1/go.mod h1:jpbku/YqBSsBc1qgME8BkWS4AxzF2cEu1Ii2r79Hh9s= github.com/ipfs/go-ds-leveldb v0.4.2/go.mod h1:jpbku/YqBSsBc1qgME8BkWS4AxzF2cEu1Ii2r79Hh9s= github.com/ipfs/go-ds-leveldb v0.5.0/go.mod h1:d3XG9RUDzQ6V4SHi8+Xgj9j1XuEk1z82lquxrVbml/Q= -github.com/ipfs/go-graphsync v0.13.2-0.20220531040852-fa5a9f2d7a86 h1:PVLY+D9dz9SwQADbEaxLF5Kc+xOVP+SltDw3GvSdHmk= -github.com/ipfs/go-graphsync v0.13.2-0.20220531040852-fa5a9f2d7a86/go.mod h1:y8e8G6CmZeL9Srvx1l15CtGiRdf3h5JdQuqPz/iYL0A= +github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534 h1:sn7viAPyx3qZVhfRpXhW23mPtzl9rjJKtJ/HM/HsyZg= +github.com/ipfs/go-graphsync v0.13.3-0.20220625074430-a95496cf1534/go.mod h1:RKAui2+/HmlIVnuAXJIn0jltvOAXkl7wz3SYysmYnPI= github.com/ipfs/go-ipfs-blockstore v0.2.1/go.mod h1:jGesd8EtCM3/zPgx+qr0/feTXGUeRai6adgwC+Q+JvE= github.com/ipfs/go-ipfs-blockstore v1.1.2 h1:WCXoZcMYnvOTmlpX+RSSnhVN0uCmbWTeepTGX5lgiXw= github.com/ipfs/go-ipfs-blockstore v1.1.2/go.mod h1:w51tNR9y5+QXB0wkNcHt4O2aSZjTdqaEWaQdSxEyUOY= @@ -616,7 +613,6 @@ github.com/libp2p/go-libp2p v0.7.0/go.mod h1:hZJf8txWeCduQRDC/WSqBGMxaTHCOYHt2xS github.com/libp2p/go-libp2p v0.7.4/go.mod h1:oXsBlTLF1q7pxr+9w6lqzS1ILpyHsaBPniVO7zIHGMw= github.com/libp2p/go-libp2p v0.8.1/go.mod h1:QRNH9pwdbEBpx5DTJYg+qxcVaDMAz3Ee/qDKwXujH5o= github.com/libp2p/go-libp2p v0.14.3/go.mod h1:d12V4PdKbpL0T1/gsUNN8DfgMuRPDX8bS2QxCZlwRH0= -github.com/libp2p/go-libp2p v0.16.0/go.mod h1:ump42BsirwAWxKzsCiFnTtN1Yc+DuPu76fyMX364/O4= github.com/libp2p/go-libp2p v0.19.4 h1:50YL0YwPhWKDd+qbZQDEdnsmVAAkaCQrWUjpdHv4hNA= github.com/libp2p/go-libp2p v0.19.4/go.mod h1:MIt8y481VDhUe4ErWi1a4bvt/CjjFfOq6kZTothWIXY= github.com/libp2p/go-libp2p-asn-util v0.1.0 h1:rABPCO77SjdbJ/eJ/ynIo8vWICy1VEnL5JAxJbQLo1E= @@ -626,7 +622,6 @@ github.com/libp2p/go-libp2p-autonat v0.2.0/go.mod h1:DX+9teU4pEEoZUqR1PiMlqliONQ github.com/libp2p/go-libp2p-autonat v0.2.1/go.mod h1:MWtAhV5Ko1l6QBsHQNSuM6b1sRkXrpk0/LqCr+vCVxI= github.com/libp2p/go-libp2p-autonat v0.2.2/go.mod h1:HsM62HkqZmHR2k1xgX34WuWDzk/nBwNHoeyyT4IWV6A= github.com/libp2p/go-libp2p-autonat v0.4.2/go.mod h1:YxaJlpr81FhdOv3W3BTconZPfhaYivRdf53g+S2wobk= -github.com/libp2p/go-libp2p-autonat v0.6.0/go.mod h1:bFC6kY8jwzNNWoqc8iGE57vsfwyJ/lP4O4DOV1e0B2o= github.com/libp2p/go-libp2p-blankhost v0.1.1/go.mod h1:pf2fvdLJPsC1FsVrNP3DUUvMzUts2dsLLBEpo1vW1ro= github.com/libp2p/go-libp2p-blankhost v0.1.4/go.mod h1:oJF0saYsAXQCSfDq254GMNmLNz6ZTHTOvtF4ZydUvwU= github.com/libp2p/go-libp2p-blankhost v0.2.0/go.mod h1:eduNKXGTioTuQAUcZ5epXi9vMl+t4d8ugUBRQ4SqaNQ= @@ -659,7 +654,6 @@ github.com/libp2p/go-libp2p-core v0.8.1/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJB github.com/libp2p/go-libp2p-core v0.8.2/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.5/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.6/go.mod h1:dgHr0l0hIKfWpGpqAMbpo19pen9wJfdCGv51mTmdpmM= -github.com/libp2p/go-libp2p-core v0.9.0/go.mod h1:ESsbz31oC3C1AvMJoGx26RTuCkNhmkSRCqZ0kQtJ2/8= github.com/libp2p/go-libp2p-core v0.10.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-libp2p-core v0.11.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-libp2p-core v0.12.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= @@ -670,7 +664,6 @@ github.com/libp2p/go-libp2p-crypto v0.1.0/go.mod h1:sPUokVISZiy+nNuTTH/TY+leRSxn github.com/libp2p/go-libp2p-discovery v0.2.0/go.mod h1:s4VGaxYMbw4+4+tsoQTqh7wfxg97AEdo4GYBt6BadWg= github.com/libp2p/go-libp2p-discovery v0.3.0/go.mod h1:o03drFnz9BVAZdzC/QUQ+NeQOu38Fu7LJGEOK2gQltw= github.com/libp2p/go-libp2p-discovery v0.5.0/go.mod h1:+srtPIU9gDaBNu//UHvcdliKBIcr4SfDcm0/PfPJLug= -github.com/libp2p/go-libp2p-discovery v0.6.0/go.mod h1:/u1voHt0tKIe5oIA1RHBKQLVCWPna2dXmPNHc2zR9S8= github.com/libp2p/go-libp2p-loggables v0.1.0 h1:h3w8QFfCt2UJl/0/NW4K829HX/0S4KD31PQ7m8UXXO8= github.com/libp2p/go-libp2p-loggables v0.1.0/go.mod h1:EyumB2Y6PrYjr55Q3/tiJ/o3xoDasoRYM7nOzEpoa90= github.com/libp2p/go-libp2p-mplex v0.2.0/go.mod h1:Ejl9IyjvXJ0T9iqUTE1jpYATQ9NM3g+OtR+EMMODbKo= @@ -687,7 +680,6 @@ github.com/libp2p/go-libp2p-nat v0.1.0/go.mod h1:DQzAG+QbDYjN1/C3B6vXucLtz3u9rEo github.com/libp2p/go-libp2p-netutil v0.1.0 h1:zscYDNVEcGxyUpMd0JReUZTrpMfia8PmLKcKF72EAMQ= github.com/libp2p/go-libp2p-netutil v0.1.0/go.mod h1:3Qv/aDqtMLTUyQeundkKsA+YCThNdbQD54k3TqjpbFU= github.com/libp2p/go-libp2p-noise v0.2.0/go.mod h1:IEbYhBBzGyvdLBoxxULL/SGbJARhUeqlO8lVSREYu2Q= -github.com/libp2p/go-libp2p-noise v0.3.0/go.mod h1:JNjHbociDJKHD64KTkzGnzqJ0FEV5gHJa6AB00kbCNQ= github.com/libp2p/go-libp2p-noise v0.4.0 h1:khcMsGhHNdGqKE5LDLrnHwZvdGVMsrnD4GTkTWkwmLU= github.com/libp2p/go-libp2p-noise v0.4.0/go.mod h1:BzzY5pyzCYSyJbQy9oD8z5oP2idsafjt4/X42h9DjZU= github.com/libp2p/go-libp2p-peer v0.2.0/go.mod h1:RCffaCvUyW2CJmG2gAWVqwePwW7JMgxjsHm7+J5kjWY= @@ -705,7 +697,6 @@ github.com/libp2p/go-libp2p-pnet v0.2.0 h1:J6htxttBipJujEjz1y0a5+eYoiPcFHhSYHH6n github.com/libp2p/go-libp2p-pnet v0.2.0/go.mod h1:Qqvq6JH/oMZGwqs3N1Fqhv8NVhrdYcO0BW4wssv21LA= github.com/libp2p/go-libp2p-quic-transport v0.10.0/go.mod h1:RfJbZ8IqXIhxBRm5hqUEJqjiiY8xmEuq3HUDS993MkA= github.com/libp2p/go-libp2p-quic-transport v0.13.0/go.mod h1:39/ZWJ1TW/jx1iFkKzzUg00W6tDJh73FC0xYudjr7Hc= -github.com/libp2p/go-libp2p-quic-transport v0.15.0/go.mod h1:wv4uGwjcqe8Mhjj7N/Ic0aKjA+/10UnMlSzLO0yRpYQ= github.com/libp2p/go-libp2p-quic-transport v0.16.0/go.mod h1:1BXjVMzr+w7EkPfiHkKnwsWjPjtfaNT0q8RS3tGDvEQ= github.com/libp2p/go-libp2p-quic-transport v0.17.0 h1:yFh4Gf5MlToAYLuw/dRvuzYd1EnE2pX3Lq1N6KDiWRQ= github.com/libp2p/go-libp2p-quic-transport v0.17.0/go.mod h1:x4pw61P3/GRCcSLypcQJE/Q2+E9f4X+5aRcZLXf20LM= @@ -736,7 +727,6 @@ github.com/libp2p/go-libp2p-testing v0.1.1/go.mod h1:xaZWMJrPUM5GlDBxCeGUi7kI4eq github.com/libp2p/go-libp2p-testing v0.1.2-0.20200422005655-8775583591d8/go.mod h1:Qy8sAncLKpwXtS2dSnDOP8ktexIAHKu+J+pnZOFZLTc= github.com/libp2p/go-libp2p-testing v0.3.0/go.mod h1:efZkql4UZ7OVsEfaxNHZPzIehtsBXMrXnCfJIgDti5g= github.com/libp2p/go-libp2p-testing v0.4.0/go.mod h1:Q+PFXYoiYFN5CAEG2w3gLPEzotlKsNSbKQ/lImlOWF0= -github.com/libp2p/go-libp2p-testing v0.4.2/go.mod h1:Q+PFXYoiYFN5CAEG2w3gLPEzotlKsNSbKQ/lImlOWF0= github.com/libp2p/go-libp2p-testing v0.5.0/go.mod h1:QBk8fqIL1XNcno/l3/hhaIEn4aLRijpYOR+zVjjlh+A= github.com/libp2p/go-libp2p-testing v0.7.0/go.mod h1:OLbdn9DbgdMwv00v+tlp1l3oe2Cl+FAjoWIA2pa0X6E= github.com/libp2p/go-libp2p-testing v0.9.0/go.mod h1:Td7kbdkWqYTJYQGTwzlgXwaqldraIanyjuRiAbK/XQU= @@ -744,14 +734,12 @@ github.com/libp2p/go-libp2p-testing v0.9.2 h1:dCpODRtRaDZKF8HXT9qqqgON+OMEB423Kn github.com/libp2p/go-libp2p-testing v0.9.2/go.mod h1:Td7kbdkWqYTJYQGTwzlgXwaqldraIanyjuRiAbK/XQU= github.com/libp2p/go-libp2p-tls v0.1.3/go.mod h1:wZfuewxOndz5RTnCAxFliGjvYSDA40sKitV4c50uI1M= github.com/libp2p/go-libp2p-tls v0.3.0/go.mod h1:fwF5X6PWGxm6IDRwF3V8AVCCj/hOd5oFlg+wo2FxJDY= -github.com/libp2p/go-libp2p-tls v0.3.1/go.mod h1:fwF5X6PWGxm6IDRwF3V8AVCCj/hOd5oFlg+wo2FxJDY= github.com/libp2p/go-libp2p-tls v0.4.1 h1:1ByJUbyoMXvYXDoW6lLsMxqMViQNXmt+CfQqlnCpY+M= github.com/libp2p/go-libp2p-tls v0.4.1/go.mod h1:EKCixHEysLNDlLUoKxv+3f/Lp90O2EXNjTr0UQDnrIw= github.com/libp2p/go-libp2p-transport-upgrader v0.1.1/go.mod h1:IEtA6or8JUbsV07qPW4r01GnTenLW4oi3lOPbUMGJJA= github.com/libp2p/go-libp2p-transport-upgrader v0.2.0/go.mod h1:mQcrHj4asu6ArfSoMuyojOdjx73Q47cYD7s5+gZOlns= github.com/libp2p/go-libp2p-transport-upgrader v0.3.0/go.mod h1:i+SKzbRnvXdVbU3D1dwydnTmKRPXiAR/fyvi1dXuL4o= github.com/libp2p/go-libp2p-transport-upgrader v0.4.2/go.mod h1:NR8ne1VwfreD5VIWIU62Agt/J18ekORFU/j1i2y8zvk= -github.com/libp2p/go-libp2p-transport-upgrader v0.4.3/go.mod h1:bpkldbOWXMrXhpZbSV1mQxTrefOg2Fi+k1ClDSA4ppw= github.com/libp2p/go-libp2p-transport-upgrader v0.5.0/go.mod h1:Rc+XODlB3yce7dvFV4q/RmyJGsFcCZRkeZMu/Zdg0mo= github.com/libp2p/go-libp2p-transport-upgrader v0.7.0/go.mod h1:GIR2aTRp1J5yjVlkUoFqMkdobfob6RnAwYg/RZPhrzg= github.com/libp2p/go-libp2p-transport-upgrader v0.7.1 h1:MSMe+tUfxpC9GArTz7a4G5zQKQgGh00Vio87d3j3xIg= @@ -764,7 +752,6 @@ github.com/libp2p/go-libp2p-yamux v0.2.8/go.mod h1:/t6tDqeuZf0INZMTgd0WxIRbtK2Ez github.com/libp2p/go-libp2p-yamux v0.4.0/go.mod h1:+DWDjtFMzoAwYLVkNZftoucn7PelNoy5nm3tZ3/Zw30= github.com/libp2p/go-libp2p-yamux v0.5.0/go.mod h1:AyR8k5EzyM2QN9Bbdg6X1SkVVuqLwTGf0L4DFq9g6po= github.com/libp2p/go-libp2p-yamux v0.5.4/go.mod h1:tfrXbyaTqqSU654GTvK3ocnSZL3BuHoeTSqhcel1wsE= -github.com/libp2p/go-libp2p-yamux v0.6.0/go.mod h1:MRhd6mAYnFRnSISp4M8i0ClV/j+mWHo2mYLifWGw33k= github.com/libp2p/go-libp2p-yamux v0.8.0/go.mod h1:yTkPgN2ib8FHyU1ZcVD7aelzyAqXXwEPbyx+aSKm9h8= github.com/libp2p/go-libp2p-yamux v0.8.1/go.mod h1:rUozF8Jah2dL9LLGyBaBeTQeARdwhefMCTQVQt6QobE= github.com/libp2p/go-libp2p-yamux v0.9.1 h1:oplewiRix8s45SOrI30rCPZG5mM087YZp+VYhXAh4+c= @@ -782,7 +769,6 @@ github.com/libp2p/go-mplex v0.4.0/go.mod h1:y26Lx+wNVtMYMaPu300Cbot5LkEZ4tJaNYeH github.com/libp2p/go-msgio v0.0.2/go.mod h1:63lBBgOTDKQL6EWazRMCwXsEeEeK9O2Cd+0+6OOuipQ= github.com/libp2p/go-msgio v0.0.4/go.mod h1:63lBBgOTDKQL6EWazRMCwXsEeEeK9O2Cd+0+6OOuipQ= github.com/libp2p/go-msgio v0.0.6/go.mod h1:4ecVB6d9f4BDSL5fqvPiC4A3KivjWn+Venn/1ALLMWA= -github.com/libp2p/go-msgio v0.1.0/go.mod h1:eNlv2vy9V2X/kNldcZ+SShFE++o2Yjxwx6RAYsmgJnE= github.com/libp2p/go-msgio v0.2.0 h1:W6shmB+FeynDrUVl2dgFQvzfBZcXiyqY4VmpQLu9FqU= github.com/libp2p/go-msgio v0.2.0/go.mod h1:dBVM1gW3Jk9XqHkU4eKdGvVHdLa51hoGfll6jMJMSlY= github.com/libp2p/go-nat v0.0.4/go.mod h1:Nmw50VAvKuk38jUBcmNh6p9lUJLoODbJRvYAa/+KSDo= @@ -829,7 +815,6 @@ github.com/libp2p/go-tcp-transport v0.5.1/go.mod h1:UPPL0DIjQqiWRwVAb+CEQlaAG0rp github.com/libp2p/go-ws-transport v0.2.0/go.mod h1:9BHJz/4Q5A9ludYWKoGCFC5gUElzlHoKzu0yY9p/klM= github.com/libp2p/go-ws-transport v0.3.0/go.mod h1:bpgTJmRZAvVHrgHybCVyqoBmyLQ1fiZuEaBYusP5zsk= github.com/libp2p/go-ws-transport v0.4.0/go.mod h1:EcIEKqf/7GDjth6ksuS/6p7R49V4CBY6/E7R/iyhYUA= -github.com/libp2p/go-ws-transport v0.5.0/go.mod h1:I2juo1dNTbl8BKSBYo98XY85kU2xds1iamArLvl8kNg= github.com/libp2p/go-ws-transport v0.6.0 h1:326XBL6Q+5CQ2KtjXz32+eGu02W/Kz2+Fm4SpXdr0q4= github.com/libp2p/go-ws-transport v0.6.0/go.mod h1:dXqtI9e2JV9FtF1NOtWVZSKXh5zXvnuwPXfj8GPBbYU= github.com/libp2p/go-yamux v1.2.2/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= @@ -841,7 +826,6 @@ github.com/libp2p/go-yamux v1.4.0/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/h github.com/libp2p/go-yamux v1.4.1 h1:P1Fe9vF4th5JOxxgQvfbOHkrGqIZniTLf+ddhZp8YTI= github.com/libp2p/go-yamux v1.4.1/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/hkmpooHE= github.com/libp2p/go-yamux/v2 v2.2.0/go.mod h1:3So6P6TV6r75R9jiBpiIKgU/66lOarCZjqROGxzPpPQ= -github.com/libp2p/go-yamux/v2 v2.3.0/go.mod h1:iTU+lOIn/2h0AgKcL49clNTwfEw+WSfDYrXe05EyKIs= github.com/libp2p/go-yamux/v3 v3.0.1/go.mod h1:s2LsDhHbh+RfCsQoICSYt58U2f8ijtPANFD8BmE74Bo= github.com/libp2p/go-yamux/v3 v3.0.2/go.mod h1:s2LsDhHbh+RfCsQoICSYt58U2f8ijtPANFD8BmE74Bo= github.com/libp2p/go-yamux/v3 v3.1.1/go.mod h1:jeLEQgLXqE2YqX1ilAClIfCMDY+0uXQUKmmb/qp0gT4= @@ -852,7 +836,6 @@ github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-b github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lucas-clemente/quic-go v0.19.3/go.mod h1:ADXpNbTQjq1hIzCpB+y/k5iz4n4z4IwqoLb94Kh5Hu8= github.com/lucas-clemente/quic-go v0.23.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= -github.com/lucas-clemente/quic-go v0.24.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= github.com/lucas-clemente/quic-go v0.25.0/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= github.com/lucas-clemente/quic-go v0.27.0/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= github.com/lucas-clemente/quic-go v0.27.1 h1:sOw+4kFSVrdWOYmUjufQ9GBVPqZ+tu+jMtXxXNmRJyk= @@ -1094,7 +1077,6 @@ github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB8 github.com/prometheus/common v0.15.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= github.com/prometheus/common v0.18.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.30.0/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.33.0 h1:rHgav/0a6+uYgGdNt3jwz8FNSesO/Hsang3O0T9A5SE= github.com/prometheus/common v0.33.0/go.mod h1:gB3sOl7P0TvJabZpLY5uQMpUqRCPPCyRLCZYc7JZTNE= @@ -1295,7 +1277,6 @@ go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= -go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= @@ -1331,7 +1312,6 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 h1:kUhD7nTDoI3fVd9G4ORWrbV5NY0liEs/Jg2pv5f+bBA= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -1427,7 +1407,6 @@ golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -1535,7 +1514,6 @@ golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1709,7 +1687,6 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.45.0 h1:NEpgUqV3Z+ZjkqMsxMg11IaDrXY4RY6CQukSGK0uI1M= google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/impl/events.go b/impl/events.go index e81e0fee..af2085f1 100644 --- a/impl/events.go +++ b/impl/events.go @@ -2,13 +2,9 @@ package impl import ( "context" + "errors" "fmt" - "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" @@ -16,130 +12,62 @@ import ( "github.com/filecoin-project/go-data-transfer/v2/message" ) -// OnChannelOpened is called when we send a request for data to the other -// peer on the given channel ID -func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error { - log.Infof("channel %s: opened", chid) - - // Check if the channel is being tracked - has, err := m.channels.HasChannel(chid) +// OnTransportEvent is dispatched when an event occurs on the transport +func (m *manager) OnTransportEvent(chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) + err := m.processTransferEvent(ctx, chid, evt) if err != nil { - return err - } - if !has { - return datatransfer.ErrChannelNotFound + log.Infof("error on channel: %s, closing channel", err) + err := m.closeChannelWithError(ctx, chid, err) + if err != nil { + log.Errorf("error closing channel: %s", err) + } } - - // Fire an event - return m.channels.ChannelOpened(chid) } -// OnDataReceived is called when the transport layer reports that it has -// received some data from the sender. -// It fires an event on the channel, updating the sum of received data and reports -// back a pause to the transport if the data limit is exceeded -func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - _, span := otel.Tracer("data-transfer").Start(ctx, "dataReceived", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("index", index), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size, index, unique) - // if this channel is now paused, send the pause message - if err == datatransfer.ErrPause { - msg := message.UpdateResponse(chid.ID, true) - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - if err := m.transport.SendMessage(ctx, chid, msg); err != nil { +func (m *manager) processTransferEvent(ctx context.Context, chid datatransfer.ChannelID, transportEvent datatransfer.TransportEvent) error { + switch evt := transportEvent.(type) { + case datatransfer.TransportOpenedChannel: + return m.channels.ChannelOpened(chid) + case datatransfer.TransportInitiatedTransfer: + return m.channels.TransferInitiated(chid) + case datatransfer.TransportReceivedData: + return m.channels.DataReceived(chid, evt.Size, evt.Index) + case datatransfer.TransportSentData: + return m.channels.DataSent(chid, evt.Size, evt.Index) + case datatransfer.TransportQueuedData: + return m.channels.DataQueued(chid, evt.Size, evt.Index) + case datatransfer.TransportReachedDataLimit: + if err := m.channels.DataLimitExceeded(chid); err != nil { return err } + msg := message.UpdateResponse(chid.ID, true) + return m.transport.SendMessage(ctx, chid, msg) + case datatransfer.TransportTransferCancelled: + log.Warnf("channel %+v was cancelled: %s", chid, evt.ErrorMessage) + return m.channels.RequestCancelled(chid, errors.New(evt.ErrorMessage)) + + case datatransfer.TransportErrorSendingData: + log.Debugf("channel %+v had transport send error: %s", chid, evt.ErrorMessage) + return m.channels.SendDataError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportErrorReceivingData: + log.Debugf("channel %+v had transport receive error: %s", chid, evt.ErrorMessage) + return m.channels.ReceiveDataError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportCompletedTransfer: + return m.channelCompleted(chid, evt.Success, evt.ErrorMessage) + case datatransfer.TransportReceivedRestartExistingChannelRequest: + return m.restartExistingChannelRequestReceived(chid) + case datatransfer.TransportErrorSendingMessage: + return m.channels.SendMessageError(chid, errors.New(evt.ErrorMessage)) + case datatransfer.TransportPaused: + return m.pause(chid) + case datatransfer.TransportResumed: + return m.resume(chid) } - - return err -} - -// OnDataQueued is called when the transport layer reports that it has queued -// up some data to be sent to the requester. -// It fires an event on the channel, updating the sum of queued data and reports -// back a pause to the transport if the data limit is exceeded -func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { - // The transport layer reports that some data has been queued up to be sent - // to the requester, so fire a DataQueued event on the channels state - // machine. - - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - _, span := otel.Tracer("data-transfer").Start(ctx, "dataQueued", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - var msg datatransfer.Message - err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size, index, unique) - // if this channel is now paused, send the pause message - if err == datatransfer.ErrPause { - msg = message.UpdateResponse(chid.ID, true) - } - - return msg, err -} - -// OnDataSent is called when the transport layer reports that it has finished -// sending data to the requester. -func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - - ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) - _, span := otel.Tracer("data-transfer").Start(ctx, "dataSent", trace.WithAttributes( - attribute.String("channelID", chid.String()), - attribute.String("link", link.String()), - attribute.Int64("size", int64(size)), - )) - defer span.End() - - return m.channels.DataSent(chid, link.(cidlink.Link).Cid, size, index, unique) -} - -// OnRequestReceived is called when a Request message is received from the initiator -// on the responder -func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - - // if request is restart request, process as restart - if request.IsRestart() { - return m.receiveRestartRequest(chid, request) - } - - // if request is new request, process as new - if request.IsNew() { - return m.receiveNewRequest(chid, request) - } - - // if request is cancel request, process as cancel - if request.IsCancel() { - log.Infof("channel %s: received cancel request, cleaning up channel", chid) - - m.transport.CleanupChannel(chid) - return nil, m.channels.Cancel(chid) - } - - // if request contains a new voucher, process updated voucher - if request.IsVoucher() { - return m.processUpdateVoucher(chid, request) - } - - // otherwise process as an "update" message (i.e. a pause or resume) - return m.receiveUpdateRequest(chid, request) -} - -// OnTransferQueued is called when the transport layer receives a request but has not yet processed it -func (m *manager) OnTransferQueued(chid datatransfer.ChannelID) { - m.channels.TransferRequestQueued(chid) + return nil } -// OnRequestReceived is called when a Response message is received from the responder +// OnResponseReceived is called when a Response message is received from the responder // on the initiator func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { @@ -217,34 +145,19 @@ func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datat return m.resumeOther(chid) } -// OnRequestCancelled is called when a transport reports a channel is cancelled -func (m *manager) OnRequestCancelled(chid datatransfer.ChannelID, err error) error { - log.Warnf("channel %+v was cancelled: %s", chid, err) - return m.channels.RequestCancelled(chid, err) -} - -// OnRequestCancelled is called when a transport reports a channel disconnected -func (m *manager) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { - log.Warnf("channel %+v has stalled or disconnected: %s", chid, err) - return m.channels.Disconnected(chid, err) -} - -// OnSendDataError is called when a transport has a network error sending data -func (m *manager) OnSendDataError(chid datatransfer.ChannelID, err error) error { - log.Debugf("channel %+v had transport send error: %s", chid, err) - return m.channels.SendDataError(chid, err) -} - -// OnReceiveDataError is called when a transport has a network error receiving data -func (m *manager) OnReceiveDataError(chid datatransfer.ChannelID, err error) error { - log.Debugf("channel %+v had transport receive error: %s", chid, err) - return m.channels.ReceiveDataError(chid, err) +// OnContextAugment provides an oppurtunity for transports to have data transfer add data to their context (i.e. +// to tie into tracing, etc) +func (m *manager) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { + return func(ctx context.Context) context.Context { + ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) + return ctx + } } -// OnChannelCompleted is called +// channelCompleted is called // - by the requester when all data for a transfer has been received // - by the responder when all data for a transfer has been sent -func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr error) error { +func (m *manager) channelCompleted(chid datatransfer.ChannelID, success bool, errorMessage string) error { // read the channel state chst, err := m.channels.GetByID(context.TODO(), chid) @@ -253,10 +166,10 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er } // If the transferred errored on completion - if completeErr != nil { - // send an error, but only if we haven't already errored for some reason - if chst.Status() != datatransfer.Failing && chst.Status() != datatransfer.Failed { - err := xerrors.Errorf("data transfer channel %s failed to transfer data: %w", chid, completeErr) + if !success { + // send an error, but only if we haven't already errored/finished transfer already for some reason + if !chst.Status().TransferComplete() { + err := fmt.Errorf("data transfer channel %s failed to transfer data: %s", chid, errorMessage) log.Warnf(err.Error()) return m.channels.Error(chid, err) } @@ -273,16 +186,13 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er log.Infow("received OnChannelCompleted, will send completion message to initiator", "chid", chid) // generate and send the final status message - msg, err := message.CompleteResponse(chst.TransferID(), true, chst.RequiresFinalization(), nil) - if err != nil { - return err - } + msg := message.CompleteResponse(chst.TransferID(), true, chst.RequiresFinalization(), nil) log.Infow("sending completion message to initiator", "chid", chid) ctx, _ := m.spansIndex.SpanForChannel(context.Background(), chid) if err := m.transport.SendMessage(ctx, chid, msg); err != nil { err := xerrors.Errorf("channel %s: failed to send completion message to initiator: %w", chid, err) log.Warnw("failed to send completion message to initiator", "chid", chid, "err", err) - return m.OnRequestDisconnected(chid, err) + return m.channels.SendMessageError(chid, err) } log.Infow("successfully sent completion message to initiator", "chid", chid) @@ -293,16 +203,7 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er return m.channels.Complete(chid) } -// OnContextAugment provides an oppurtunity for transports to have data transfer add data to their context (i.e. -// to tie into tracing, etc) -func (m *manager) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { - return func(ctx context.Context) context.Context { - ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) - return ctx - } -} - -func (m *manager) OnRestartExistingChannelRequestReceived(chid datatransfer.ChannelID) error { +func (m *manager) restartExistingChannelRequestReceived(chid datatransfer.ChannelID) error { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) // validate channel exists -> in non-terminal state and that the sender matches channel, err := m.channels.GetByID(context.TODO(), chid) @@ -319,5 +220,6 @@ func (m *manager) OnRestartExistingChannelRequestReceived(chid datatransfer.Chan if err := m.openRestartChannel(ctx, channel); err != nil { return fmt.Errorf("failed to open restart channel %s: %s", chid, err) } + return nil } diff --git a/impl/impl.go b/impl/impl.go index 16a6f462..156869b6 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -252,15 +252,10 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe span.SetStatus(codes.Error, err.Error()) return err } - updateRequest, err := message.VoucherRequest(channelID.ID, &voucher) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - return err - } + updateRequest := message.VoucherRequest(channelID.ID, &voucher) if err := m.transport.SendMessage(ctx, channelID, updateRequest); err != nil { err = fmt.Errorf("Unable to send request: %w", err) - _ = m.OnRequestDisconnected(channelID, err) + _ = m.channels.SendMessageError(channelID, err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err @@ -287,19 +282,14 @@ func (m *manager) SendVoucherResult(ctx context.Context, channelID datatransfer. var updateResponse datatransfer.Response if chst.Status().InFinalization() { - updateResponse, err = message.CompleteResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), &voucherResult) + updateResponse = message.CompleteResponse(channelID.ID, chst.Status().IsAccepted(), chst.ResponderPaused(), &voucherResult) } else { - updateResponse, err = message.VoucherResultResponse(channelID.ID, chst.Status().IsAccepted(), chst.Status().IsResponderPaused(), &voucherResult) + updateResponse = message.VoucherResultResponse(channelID.ID, chst.Status().IsAccepted(), chst.ResponderPaused(), &voucherResult) } - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - return err - } if err := m.transport.SendMessage(ctx, channelID, updateResponse); err != nil { err = fmt.Errorf("Unable to send request: %w", err) - _ = m.OnRequestDisconnected(channelID, err) + _ = m.channels.SendMessageError(channelID, err) span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return err @@ -331,33 +321,16 @@ func (m *manager) updateValidationStatus(ctx context.Context, chid datatransfer. return err } - // dispatch channel events and generate a response message - chst, response, err := m.processValidationUpdate(ctx, chid, result) - - // dispatch transport updates - return m.transport.UpdateChannel(ctx, chid, datatransfer.ChannelUpdate{ - Paused: result.LeaveRequestPaused(chst), - Closed: err != nil || !result.Accepted, - SendMessage: response, - }) -} - -func (m *manager) processValidationUpdate(ctx context.Context, chid datatransfer.ChannelID, result datatransfer.ValidationResult) (datatransfer.ChannelState, datatransfer.Response, error) { - // read the channel state chst, err := m.channels.GetByID(context.TODO(), chid) if err != nil { - return nil, nil, err + return err } - // if the request is now rejected, error the channel - if !result.Accepted { - err = m.recordRejectedValidationEvents(chid, result) - } else { - err = m.recordAcceptedValidationEvents(chst, result) - } + // dispatch channel events and generate a response message + err = m.processValidationUpdate(ctx, chst, result) if err != nil { - return nil, nil, err + return err } // generate a response message @@ -365,100 +338,105 @@ func (m *manager) processValidationUpdate(ctx context.Context, chid datatransfer if chst.Status() == datatransfer.Finalizing { messageType = types.CompleteMessage } - response, msgErr := message.ValidationResultResponse(messageType, chst.TransferID(), result, err, + response := message.ValidationResultResponse(messageType, chid.ID, result, err, result.LeaveRequestPaused(chst)) - if msgErr != nil { - return nil, nil, msgErr + + // dispatch transport updates + return m.transport.ChannelUpdated(ctx, chid, response) +} + +func (m *manager) processValidationUpdate(ctx context.Context, chst datatransfer.ChannelState, result datatransfer.ValidationResult) error { + // if the request is now rejected, error the channel + if !result.Accepted { + return m.recordRejectedValidationEvents(chst.ChannelID(), result) } + return m.recordAcceptedValidationEvents(chst, result) - // return the response message and any errors - return chst, response, nil } // close an open channel (effectively a cancel) func (m *manager) CloseDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { log.Infof("close channel %s", chid) - chst, err := m.channels.GetByID(ctx, chid) - if err != nil { - return err - } ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "closeChannel", trace.WithAttributes( attribute.String("channelID", chid.String()), )) defer span.End() - // Close the channel on the local transport - err = m.transport.UpdateChannel(ctx, chid, datatransfer.ChannelUpdate{ - Paused: chst.Status().IsResponderPaused(), - Closed: true, - }) + + err := m.closeChannel(ctx, chid) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) log.Warnf("unable to close channel %s: %s", chid, err) } + return err +} + +func (m *manager) closeChannel(ctx context.Context, chid datatransfer.ChannelID) error { + // Fire a cancel event + err := m.channels.Cancel(chid) + if err != nil { + return xerrors.Errorf("unable to send cancel to channel FSM: %w", err) + } + + // Close the channel on the local transport + err = m.transport.ChannelUpdated(ctx, chid, nil) // Send a cancel message to the remote peer async go func() { sctx, cancel := context.WithTimeout(context.Background(), cancelSendTimeout) defer cancel() - log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, chst.OtherPeer(), chid) + log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, m.otherPeer(chid), chid) err = m.transport.SendMessage(sctx, chid, m.cancelMessage(chid)) if err != nil { err = fmt.Errorf("unable to send cancel message for channel %s to peer %s: %w", chid, m.peerID, err) - _ = m.OnRequestDisconnected(chid, err) log.Warn(err) } }() - // Fire a cancel event - fsmerr := m.channels.Cancel(chid) - if fsmerr != nil { - return xerrors.Errorf("unable to send cancel to channel FSM: %w", fsmerr) - } - - return nil + return err } // close an open channel and fire an error event func (m *manager) CloseDataTransferChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error { log.Infof("close channel %s with error %s", chid, cherr) - chst, err := m.channels.GetByID(ctx, chid) - if err != nil { - return err - } ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "closeChannel", trace.WithAttributes( attribute.String("channelID", chid.String()), )) defer span.End() + err := m.closeChannelWithError(ctx, chid, cherr) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + log.Warnf("unable to close channel %s: %s", chid, err) + } + return err +} + +func (m *manager) closeChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error { + + // Fire an error event + if err := m.channels.Error(chid, cherr); err != nil { + return xerrors.Errorf("unable to send error %s to channel FSM: %w", cherr, err) + } + // Close transfport and try to send a cancel message to the remote peer. // It's quite likely we aren't able to send the message to the peer because // the channel is already in an error state, which is probably because of // connection issues, so if we cant send the message just log a warning. - log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, chst.OtherPeer(), chid) - err = m.transport.UpdateChannel(ctx, chid, datatransfer.ChannelUpdate{ - Paused: chst.Status().IsResponderPaused(), - Closed: true, - SendMessage: m.cancelMessage(chid), - }) - if err != nil { + log.Infof("%s: sending cancel channel to %s for channel %s", m.peerID, m.otherPeer(chid), chid) + + if err := m.transport.ChannelUpdated(ctx, chid, m.cancelMessage(chid)); err != nil { // Just log a warning here because it's important that we fire the // error event with the original error so that it doesn't get masked // by subsequent errors. log.Warnf("unable to close channel %s: %s", chid, err) } - - // Fire an error event - err = m.channels.Error(chid, cherr) - if err != nil { - return xerrors.Errorf("unable to send error %s to channel FSM: %w", cherr, err) - } - return nil } @@ -472,15 +450,16 @@ func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfe ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) - err := m.transport.UpdateChannel(ctx, chid, datatransfer.ChannelUpdate{ - Paused: true, - SendMessage: m.pauseMessage(chid), - }) - if err != nil { - log.Warnf("Error attempting to pause at transport level: %s", err.Error()) + // fire the pause + if err := m.pause(chid); err != nil { + return err } - return m.pause(chid) + // update transport + if err := m.transport.ChannelUpdated(ctx, chid, m.pauseMessage(chid)); err != nil { + log.Warnf("Error attempting to pause at transport level: %s", err.Error()) + } + return nil } // resume a running data transfer channel @@ -493,15 +472,16 @@ func (m *manager) ResumeDataTransferChannel(ctx context.Context, chid datatransf ctx, _ = m.spansIndex.SpanForChannel(ctx, chid) - err := m.transport.UpdateChannel(ctx, chid, datatransfer.ChannelUpdate{ - Paused: false, - SendMessage: m.resumeMessage(chid), - }) - if err != nil { - log.Warnf("Error attempting to resume at transport level: %s", err.Error()) + // fire the resume + if err := m.resume(chid); err != nil { + return err } - return m.resume(chid) + // update transport + if err := m.transport.ChannelUpdated(ctx, chid, m.resumeMessage(chid)); err != nil { + log.Warnf("Error attempting to resume at transport level: %s", err.Error()) + } + return nil } // get channel state diff --git a/impl/initiating_test.go b/impl/initiating_test.go index f6c4f0ec..a8da316e 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -10,7 +10,7 @@ import ( "github.com/ipfs/go-datastore" dss "github.com/ipfs/go-datastore/sync" "github.com/ipld/go-ipld-prime/datamodel" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/node/basicnode" selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" @@ -139,8 +139,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, @@ -151,8 +150,7 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, &h.voucherResult) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, &h.voucherResult) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, @@ -163,14 +161,13 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) err = h.dt.PauseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.PausedChannels, 1) - require.Equal(t, h.transport.PausedChannels[0], channelID) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) require.Len(t, h.transport.OpenedChannels, 1) require.Len(t, h.transport.MessagesSent, 1) pauseMessage := h.transport.MessagesSent[0].Message @@ -181,8 +178,8 @@ func TestDataTransferInitiating(t *testing.T) { require.Equal(t, pauseMessage.TransferID(), channelID.ID) err = h.dt.ResumeDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - resumedChannel := h.transport.ResumedChannels[0] + require.Len(t, h.transport.ChannelsUpdated, 2) + resumedChannel := h.transport.ChannelsUpdated[1] require.Equal(t, resumedChannel, channelID) require.Len(t, h.transport.MessagesSent, 2) resumeMessage := h.transport.MessagesSent[1].Message @@ -201,8 +198,8 @@ func TestDataTransferInitiating(t *testing.T) { require.NotEmpty(t, channelID) err = h.dt.CloseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ClosedChannels, 1) - require.Equal(t, h.transport.ClosedChannels[0], channelID) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) require.Eventually(t, func() bool { return len(h.transport.MessagesSent) == 1 @@ -221,13 +218,12 @@ func TestDataTransferInitiating(t *testing.T) { channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.NewResponse(channelID.ID, true, false, nil) - require.NoError(t, err) + response := message.NewResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) err = h.dt.PauseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.PausedChannels, 1) + require.Len(t, h.transport.ChannelsUpdated, 1) require.Len(t, h.transport.OpenedChannels, 1) require.Len(t, h.transport.MessagesSent, 1) pauseMessage := h.transport.MessagesSent[0].Message @@ -238,8 +234,8 @@ func TestDataTransferInitiating(t *testing.T) { require.Equal(t, pauseMessage.TransferID(), channelID.ID) err = h.dt.ResumeDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - resumedChannel := h.transport.ResumedChannels[0] + require.Len(t, h.transport.ChannelsUpdated, 2) + resumedChannel := h.transport.ChannelsUpdated[1] require.Equal(t, resumedChannel, channelID) require.Len(t, h.transport.MessagesSent, 2) resumeMessage := h.transport.MessagesSent[1].Message @@ -258,8 +254,8 @@ func TestDataTransferInitiating(t *testing.T) { require.NotEmpty(t, channelID) err = h.dt.CloseDataTransferChannel(h.ctx, channelID) require.NoError(t, err) - require.Len(t, h.transport.ClosedChannels, 1) - require.Equal(t, h.transport.ClosedChannels[0], channelID) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID) require.Eventually(t, func() bool { return len(h.transport.MessagesSent) == 1 @@ -353,7 +349,7 @@ func TestDataTransferRestartInitiating(t *testing.T) { verify func(t *testing.T, h *harness) }{ "RestartDataTransferChannel: Manager Peer Create Pull Restart works": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, + expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, verify: func(t *testing.T, h *harness) { // open a pull channel channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) @@ -362,11 +358,11 @@ func TestDataTransferRestartInitiating(t *testing.T) { require.Len(t, h.transport.OpenedChannels, 1) // some cids should already be received - testCids := testutil.GenerateCids(2) ev, ok := h.dt.(datatransfer.EventsHandler) require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + ev.OnTransportEvent(channelID, datatransfer.TransportInitiatedTransfer{}) + ev.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + ev.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // restart that pull channel err = h.dt.RestartDataTransferChannel(ctx, channelID) diff --git a/impl/receiving_requests.go b/impl/receiving_requests.go index 7665a930..825ed3f2 100644 --- a/impl/receiving_requests.go +++ b/impl/receiving_requests.go @@ -13,7 +13,31 @@ import ( "github.com/filecoin-project/go-data-transfer/v2/message/types" ) -// this file contains methods for processing incoming request messages +func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + + // if request is restart request, process as restart + if request.IsRestart() { + return m.receiveRestartRequest(chid, request) + } + + // if request is new request, process as new + if request.IsNew() { + return m.receiveNewRequest(chid, request) + } + + // if request is cancel request, process as cancel + if request.IsCancel() { + log.Infof("channel %s: received cancel request, cleaning up channel", chid) + + return nil, m.channels.Cancel(chid) + } + // if request contains a new voucher, process updated voucher + if request.IsVoucher() { + return m.processUpdateVoucher(chid, request) + } + // otherwise process as an "update" message (i.e. a pause or resume) + return m.receiveUpdateRequest(chid, request) +} // receiveNewRequest handles an incoming new request message func (m *manager) receiveNewRequest(chid datatransfer.ChannelID, incoming datatransfer.Request) (datatransfer.Response, error) { @@ -23,13 +47,13 @@ func (m *manager) receiveNewRequest(chid datatransfer.ChannelID, incoming datatr result, err := m.acceptRequest(chid, incoming) // generate a response message - msg, msgErr := message.ValidationResultResponse(types.NewMessage, incoming.TransferID(), result, err, result.ForcePause) - if msgErr != nil { - return nil, msgErr - } + msg := message.ValidationResultResponse(types.NewMessage, incoming.TransferID(), result, err, result.ForcePause) - // return the response message and any errors - return msg, m.requestError(result, err, result.ForcePause) + // return the channel update + if err == nil && !result.Accepted { + err = datatransfer.ErrRejected + } + return msg, err } // acceptRequest performs processing (including validation) on a new incoming request @@ -126,13 +150,13 @@ func (m *manager) receiveRestartRequest(chid datatransfer.ChannelID, incoming da stayPaused, result, err := m.restartRequest(chid, incoming) // generate a response message - msg, msgErr := message.ValidationResultResponse(types.RestartMessage, incoming.TransferID(), result, err, stayPaused) - if msgErr != nil { - return nil, msgErr - } + msg := message.ValidationResultResponse(types.RestartMessage, incoming.TransferID(), result, err, stayPaused) // return the response message and any errors - return msg, m.requestError(result, err, result.ForcePause) + if err == nil && !result.Accepted { + err = datatransfer.ErrRejected + } + return msg, err } // restartRequest performs processing (including validation) on a incoming restart request @@ -198,55 +222,6 @@ func (m *manager) restartRequest(chid datatransfer.ChannelID, return stayPaused, result, nil } -// processUpdateVoucher handles an incoming request message with an updated voucher -func (m *manager) processUpdateVoucher(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - // decode the voucher and save it on the channel - voucher, err := request.TypedVoucher() - if err != nil { - return nil, err - } - return nil, m.channels.NewVoucher(chid, voucher) -} - -// receiveUpdateRequest handles an incoming request message change in pause status -func (m *manager) receiveUpdateRequest(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - - if request.IsPaused() { - return nil, m.pauseOther(chid) - } - - err := m.resumeOther(chid) - if err != nil { - return nil, err - } - chst, err := m.channels.GetByID(context.TODO(), chid) - if err != nil { - return nil, err - } - if chst.Status() == datatransfer.ResponderPaused || - chst.Status() == datatransfer.ResponderFinalizing { - return nil, datatransfer.ErrPause - } - return nil, nil -} - -// requestError generates an error message for the transport, adding -// ErrPause / ErrResume based off the validation result -// TODO: get away from using ErrPause/ErrResume to indicate pause resume, -// which would remove the need for most of this method -func (m *manager) requestError(result datatransfer.ValidationResult, resultErr error, stayPaused bool) error { - if resultErr != nil { - return resultErr - } - if !result.Accepted { - return datatransfer.ErrRejected - } - if stayPaused { - return datatransfer.ErrPause - } - return nil -} - // recordRejectedValidationEvents sends changes based on an reject validation to the state machine func (m *manager) recordRejectedValidationEvents(chid datatransfer.ChannelID, result datatransfer.ValidationResult) error { if result.VoucherResult != nil { @@ -288,14 +263,14 @@ func (m *manager) recordAcceptedValidationEvents(chst datatransfer.ChannelState, // pause or resume the request as neccesary if result.LeaveRequestPaused(chst) { - if !chst.Status().IsResponderPaused() { + if !chst.ResponderPaused() { err := m.channels.PauseResponder(chid) if err != nil { return err } } } else { - if chst.Status().IsResponderPaused() { + if chst.ResponderPaused() { err := m.channels.ResumeResponder(chid) if err != nil { return err @@ -306,6 +281,24 @@ func (m *manager) recordAcceptedValidationEvents(chst datatransfer.ChannelState, return nil } +// processUpdateVoucher handles an incoming request message with an updated voucher +func (m *manager) processUpdateVoucher(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + // decode the voucher and save it on the channel + voucher, err := request.TypedVoucher() + if err != nil { + return nil, err + } + return nil, m.channels.NewVoucher(chid, voucher) +} + +// receiveUpdateRequest handles an incoming request message change in pause status +func (m *manager) receiveUpdateRequest(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + if request.IsPaused() { + return nil, m.pauseOther(chid) + } + return nil, m.resumeOther(chid) +} + // validateRestart looks up the appropriate validator and validates a restart func (m *manager) validateRestart(chst datatransfer.ChannelState) (datatransfer.ValidationResult, error) { chv := chst.Voucher() diff --git a/impl/responding_test.go b/impl/responding_test.go index 1537a485..92abd7e1 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -12,12 +12,11 @@ import ( dss "github.com/ipfs/go-datastore/sync" "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/node/basicnode" selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" . "github.com/filecoin-project/go-data-transfer/v2/impl" @@ -107,7 +106,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + require.NoError(t, err) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) @@ -185,7 +184,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + require.NoError(t, err) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) require.False(t, response.IsUpdate()) @@ -302,7 +301,7 @@ func TestDataTransferResponding(t *testing.T) { err = h.dt.PauseDataTransferChannel(h.ctx, channelID(h.id, h.peers)) require.NoError(t, err) _, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.resumeUpdate) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + require.NoError(t, err) }, }, "receive cancel": { @@ -322,8 +321,6 @@ func TestDataTransferResponding(t *testing.T) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.cancelUpdate) require.NoError(t, err) - require.Len(t, h.transport.CleanedUpChannels, 1) - require.Equal(t, channelID(h.id, h.peers), h.transport.CleanedUpChannels[0]) }, }, "validate and revalidate successfully, push": { @@ -332,6 +329,7 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.Accept, datatransfer.NewVoucherResult, datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataLimitExceeded, @@ -347,8 +345,9 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) require.Len(t, h.transport.MessagesSent, 1) response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) @@ -357,14 +356,14 @@ func TestDataTransferResponding(t *testing.T) { require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) require.False(t, response.IsValidationResult()) - response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.NoError(t, err, nil) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) + require.NoError(t, err) require.Nil(t, response) vr := testutil.NewTestTypedVoucher() err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - require.Equal(t, h.transport.ResumedChannels[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) require.Len(t, h.transport.MessagesSent, 2) response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) require.True(t, ok) @@ -383,6 +382,7 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.Accept, datatransfer.NewVoucherResult, datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataLimitExceeded, @@ -398,8 +398,9 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { _, _ = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) require.Len(t, h.transport.MessagesSent, 1) response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) @@ -408,14 +409,14 @@ func TestDataTransferResponding(t *testing.T) { require.False(t, response.IsCancel()) require.True(t, response.IsPaused()) require.False(t, response.IsValidationResult()) - response, err = h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) - require.NoError(t, err, nil) + response, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.voucherUpdate) + require.NoError(t, err) require.Nil(t, response) vr := testutil.NewTestTypedVoucher() err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: false, VoucherResult: &vr}) require.NoError(t, err) - require.Len(t, h.transport.ClosedChannels, 1) - require.Equal(t, h.transport.ClosedChannels[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) require.Len(t, h.transport.MessagesSent, 2) response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) require.True(t, ok) @@ -434,6 +435,7 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.Accept, datatransfer.NewVoucherResult, datatransfer.SetDataLimit, + datatransfer.TransferInitiated, datatransfer.DataQueuedProgress, datatransfer.DataQueued, datatransfer.DataLimitExceeded, @@ -450,12 +452,11 @@ func TestDataTransferResponding(t *testing.T) { verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - msg, err := h.transport.EventHandler.OnDataQueued( - channelID(h.id, h.peers), - cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, - 12345, 1, true) - require.EqualError(t, err, datatransfer.ErrPause.Error()) - response, ok := msg.(datatransfer.Response) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportQueuedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportReachedDataLimit{}) + require.Len(t, h.transport.MessagesSent, 1) + response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) require.Equal(t, response.TransferID(), h.id) require.True(t, response.IsUpdate()) @@ -468,10 +469,10 @@ func TestDataTransferResponding(t *testing.T) { vr := testutil.NewTestTypedVoucher() err = h.dt.UpdateValidationStatus(ctx, channelID(h.id, h.peers), datatransfer.ValidationResult{Accepted: true, DataLimit: 50000, VoucherResult: &vr}) require.NoError(t, err) - require.Len(t, h.transport.ResumedChannels, 1) - require.Equal(t, h.transport.ResumedChannels[0], channelID(h.id, h.peers)) - require.Len(t, h.transport.MessagesSent, 1) - response, ok = h.transport.MessagesSent[0].Message.(datatransfer.Response) + require.Len(t, h.transport.ChannelsUpdated, 1) + require.Equal(t, h.transport.ChannelsUpdated[0], channelID(h.id, h.peers)) + require.Len(t, h.transport.MessagesSent, 2) + response, ok = h.transport.MessagesSent[1].Message.(datatransfer.Response) require.True(t, ok) require.True(t, response.Accepted()) require.Equal(t, response.TransferID(), h.id) @@ -488,6 +489,7 @@ func TestDataTransferResponding(t *testing.T) { datatransfer.Accept, datatransfer.NewVoucherResult, datatransfer.SetRequiresFinalization, + datatransfer.TransferInitiated, datatransfer.BeginFinalizing, datatransfer.NewVoucher, datatransfer.NewVoucherResult, @@ -503,8 +505,8 @@ func TestDataTransferResponding(t *testing.T) { verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - err = h.transport.EventHandler.OnChannelCompleted(channelID(h.id, h.peers), nil) - require.NoError(t, err) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportCompletedTransfer{Success: true}) require.Len(t, h.transport.MessagesSent, 1) response, ok := h.transport.MessagesSent[0].Message.(datatransfer.Response) require.True(t, ok) @@ -544,8 +546,7 @@ func TestDataTransferResponding(t *testing.T) { verify: func(t *testing.T, h *receiverHarness) { _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) require.NoError(t, err) - err = h.transport.EventHandler.OnChannelCompleted(channelID(h.id, h.peers), xerrors.Errorf("err")) - require.NoError(t, err) + h.transport.EventHandler.OnTransportEvent(channelID(h.id, h.peers), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: "something went wrong"}) }, }, "new push request, customized transport": { @@ -631,7 +632,7 @@ func TestDataTransferResponding(t *testing.T) { h.resumeUpdate = message.UpdateRequest(h.id, false) require.NoError(t, err) updateVoucher := testutil.NewTestTypedVoucher() - h.voucherUpdate, err = message.VoucherRequest(h.id, &updateVoucher) + h.voucherUpdate = message.VoucherRequest(h.id, &updateVoucher) h.cancelUpdate = message.CancelRequest(h.id) require.NoError(t, err) h.sv = testutil.NewStubbedValidator() @@ -662,8 +663,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, channelID) - response, err := message.RestartResponse(channelID.ID, true, false, nil) - require.NoError(t, err) + response := message.RestartResponse(channelID.ID, true, false, nil) err = h.transport.EventHandler.OnResponseReceived(channelID, response) require.NoError(t, err) }, @@ -673,6 +673,7 @@ func TestDataTransferRestartResponding(t *testing.T) { datatransfer.Open, datatransfer.Accept, datatransfer.NewVoucherResult, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, @@ -695,11 +696,9 @@ func TestDataTransferRestartResponding(t *testing.T) { // some cids are received chid := datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: h.pushRequest.TransferID()} - testCids := testutil.GenerateCids(2) - ev, ok := h.dt.(datatransfer.EventsHandler) - require.True(t, ok) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // receive restart push request req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, &h.voucher, h.baseCid, h.stor) @@ -894,6 +893,7 @@ func TestDataTransferRestartResponding(t *testing.T) { "ReceiveRestartExistingChannelRequest: Reopen Pull Channel": { expectedEvents: []datatransfer.EventCode{ datatransfer.Open, + datatransfer.TransferInitiated, datatransfer.DataReceivedProgress, datatransfer.DataReceived, datatransfer.DataReceivedProgress, @@ -908,15 +908,13 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NotEmpty(t, channelID) // some cids should already be received - testCids := testutil.GenerateCids(2) - ev, ok := h.dt.(datatransfer.EventsHandler) - require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) + // some cids are received + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportInitiatedTransfer{}) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(1)}) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedData{Size: 12345, Index: basicnode.NewInt(2)}) // send a request to restart the same pull channel - err = h.transport.EventHandler.OnRestartExistingChannelRequestReceived(channelID) - require.NoError(t, err) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedRestartExistingChannelRequest{}) require.Len(t, h.transport.OpenedChannels, 1) require.Len(t, h.transport.RestartedChannels, 1) @@ -927,7 +925,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.Equal(t, restartedChannel.Channel.Sender(), h.peers[1]) require.Equal(t, restartedChannel.Channel.BaseCID(), h.baseCid) require.Equal(t, restartedChannel.Channel.Selector(), h.stor) - require.EqualValues(t, len(testCids), restartedChannel.Channel.ReceivedCidsTotal()) + require.EqualValues(t, basicnode.NewInt(2), restartedChannel.Channel.ReceivedIndex()) // assert a restart request is in the channel request := restartedChannel.Message @@ -958,7 +956,7 @@ func TestDataTransferRestartResponding(t *testing.T) { require.NotEmpty(t, channelID) // send a request to restart the same push request - err = h.transport.EventHandler.OnRestartExistingChannelRequestReceived(channelID) + h.transport.EventHandler.OnTransportEvent(channelID, datatransfer.TransportReceivedRestartExistingChannelRequest{}) require.NoError(t, err) require.Len(t, h.transport.OpenedChannels, 1) diff --git a/impl/utils.go b/impl/utils.go index 518b9e0c..4a74d838 100644 --- a/impl/utils.go +++ b/impl/utils.go @@ -35,6 +35,13 @@ func (m *manager) newRequest(ctx context.Context, selector datamodel.Node, isPul return message.NewRequest(tid, false, isPull, &voucher, baseCid, selector) } +func (m *manager) otherPeer(chid datatransfer.ChannelID) peer.ID { + if chid.Initiator == m.peerID { + return chid.Responder + } + return chid.Initiator +} + func (m *manager) resume(chid datatransfer.ChannelID) error { if chid.Initiator == m.peerID { return m.channels.ResumeInitiator(chid) diff --git a/itest/gstestdata.go b/itest/gstestdata.go index 02b2f7dc..2b1b808a 100644 --- a/itest/gstestdata.go +++ b/itest/gstestdata.go @@ -76,6 +76,8 @@ type GraphsyncTestingData struct { GsNet2 gsnet.GraphSyncNetwork DtNet1 network.DataTransferNetwork DtNet2 network.DataTransferNetwork + Gs1 graphsync.GraphExchange + Gs2 graphsync.GraphExchange OrigBytes []byte TempDir1 string TempDir2 string @@ -151,13 +153,17 @@ func NewGraphsyncTestingData(ctx context.Context, t *testing.T, host1Protocols [ // SetupGraphsyncHost1 sets up a new, real graphsync instance on top of the first host func (gsData *GraphsyncTestingData) SetupGraphsyncHost1() graphsync.GraphExchange { + if gsData.Gs1 != nil { + return gsData.Gs1 + } // setup graphsync if gsData.gs1Cancel != nil { gsData.gs1Cancel() } gsCtx, gsCancel := context.WithCancel(gsData.Ctx) gsData.gs1Cancel = gsCancel - return gsimpl.New(gsCtx, gsData.GsNet1, gsData.LinkSystem1) + gsData.Gs1 = gsimpl.New(gsCtx, gsData.GsNet1, gsData.LinkSystem1) + return gsData.Gs1 } // SetupGSTransportHost1 sets up a new grapshync transport over real graphsync on the first host @@ -172,18 +178,22 @@ func (gsData *GraphsyncTestingData) SetupGSTransportHost1(opts ...gstransport.Op opts = append(opts, gstransport.SupportedExtensions(supportedExtensions)) } - return gstransport.NewTransport(gsData.Host1.ID(), gs, gsData.DtNet1, opts...) + return gstransport.NewTransport(gs, gsData.DtNet1, opts...) } // SetupGraphsyncHost2 sets up a new, real graphsync instance on top of the second host func (gsData *GraphsyncTestingData) SetupGraphsyncHost2() graphsync.GraphExchange { + if gsData.Gs2 != nil { + return gsData.Gs2 + } // setup graphsync if gsData.gs2Cancel != nil { gsData.gs2Cancel() } gsCtx, gsCancel := context.WithCancel(gsData.Ctx) gsData.gs2Cancel = gsCancel - return gsimpl.New(gsCtx, gsData.GsNet2, gsData.LinkSystem2) + gsData.Gs2 = gsimpl.New(gsCtx, gsData.GsNet2, gsData.LinkSystem2) + return gsData.Gs2 } // SetupGSTransportHost2 sets up a new grapshync transport over real graphsync on the second host @@ -197,7 +207,7 @@ func (gsData *GraphsyncTestingData) SetupGSTransportHost2(opts ...gstransport.Op } opts = append(opts, gstransport.SupportedExtensions(supportedExtensions)) } - return gstransport.NewTransport(gsData.Host2.ID(), gs, gsData.DtNet2, opts...) + return gstransport.NewTransport(gs, gsData.DtNet2, opts...) } // LoadUnixFSFile loads a fixtures file we can test dag transfer with diff --git a/itest/integration_test.go b/itest/integration_test.go index f13e69e4..a98a67d8 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -440,7 +440,7 @@ func TestManyReceiversAtOnce(t *testing.T) { destDagService := merkledag.NewDAGService(blockservice.New(altBs, offline.Exchange(altBs))) gs := gsimpl.New(gsData.Ctx, gsnet, lsys) - gsTransport := tp.NewTransport(host.ID(), gs, dtnet) + gsTransport := tp.NewTransport(gs, dtnet) dtDs := namespace.Wrap(ds, datastore.NewKey("datatransfer")) @@ -1443,6 +1443,7 @@ func TestPauseAndResume(t *testing.T) { dt2, err := NewDataTransfer(gsData.DtDs2, gsData.Host2.ID(), tp2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt2) + finished := make(chan struct{}, 2) errChan := make(chan struct{}, 2) opened := make(chan struct{}, 2) @@ -1494,8 +1495,27 @@ func TestPauseAndResume(t *testing.T) { sv := testutil.NewStubbedValidator() sv.StubResult(datatransfer.ValidationResult{Accepted: true}) sv.StubRestartResult(datatransfer.ValidationResult{Accepted: true}) - var chid datatransfer.ChannelID + + gsData.Gs1.RegisterOutgoingBlockHook(func(p peer.ID, r graphsync.RequestData, block graphsync.BlockData, ha graphsync.OutgoingBlockHookActions) { + if block.Index() == 5 && block.BlockSizeOnWire() > 0 { + require.NoError(t, dt1.PauseDataTransferChannel(ctx, chid)) + go func() { + time.Sleep(100 * time.Millisecond) + require.NoError(t, dt1.ResumeDataTransferChannel(ctx, chid)) + }() + } + }) + gsData.Gs2.RegisterIncomingBlockHook(func(p peer.ID, r graphsync.ResponseData, block graphsync.BlockData, ha graphsync.IncomingBlockHookActions) { + if block.Index() == 5 { + require.NoError(t, dt2.PauseDataTransferChannel(ctx, chid)) + go func() { + time.Sleep(50 * time.Millisecond) + require.NoError(t, dt2.ResumeDataTransferChannel(ctx, chid)) + }() + } + }) + if isPull { sv.ExpectSuccessPull() require.NoError(t, dt1.RegisterVoucherType(testutil.TestVoucherType, sv)) @@ -1533,18 +1553,8 @@ func TestPauseAndResume(t *testing.T) { resumeResponders++ case sentIncrement := <-sent: sentIncrements = append(sentIncrements, sentIncrement) - if len(sentIncrements) == 5 { - require.NoError(t, dt1.PauseDataTransferChannel(ctx, chid)) - time.Sleep(100 * time.Millisecond) - require.NoError(t, dt1.ResumeDataTransferChannel(ctx, chid)) - } case receivedIncrement := <-received: receivedIncrements = append(receivedIncrements, receivedIncrement) - if len(receivedIncrements) == 10 { - require.NoError(t, dt2.PauseDataTransferChannel(ctx, chid)) - time.Sleep(100 * time.Millisecond) - require.NoError(t, dt2.ResumeDataTransferChannel(ctx, chid)) - } case <-errChan: t.Fatal("received error on data transfer") } @@ -1805,7 +1815,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { } requestReceived := messageReceived.message.(datatransfer.Request) - response, err := message.NewResponse(requestReceived.TransferID(), true, false, &voucherResult) + response := message.NewResponse(requestReceived.TransferID(), true, false, &voucherResult) require.NoError(t, err) nd := response.ToIPLD() request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ @@ -1823,7 +1833,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { }) t.Run("when no request is initiated", func(t *testing.T) { - response, err := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, &voucher) + response := message.NewResponse(datatransfer.TransferID(rand.Uint32()), true, false, &voucher) require.NoError(t, err) nd := response.ToIPLD() request := gsmsg.NewRequest(graphsync.NewRequestID(), link.(cidlink.Link).Cid, selectorparse.CommonSelector_ExploreAllRecursively, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ @@ -1865,7 +1875,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { gsData.GsNet2.SetDelegate(gsr) gs1 := gsData.SetupGraphsyncHost1() - tp1 := tp.NewTransport(host1.ID(), gs1, gsData.DtNet1) + tp1 := tp.NewTransport(gs1, gsData.DtNet1) dt1, err := NewDataTransfer(gsData.DtDs1, gsData.Host1.ID(), tp1) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, dt1) diff --git a/manager.go b/manager.go index 822f76e5..b7a6f217 100644 --- a/manager.go +++ b/manager.go @@ -40,6 +40,9 @@ func (vr ValidationResult) Equals(vr2 ValidationResult) bool { // LeaveRequestPaused indicates whether all conditions are met to resume a request func (vr ValidationResult) LeaveRequestPaused(chst ChannelState) bool { + if chst == nil { + return false + } if vr.ForcePause { return true } diff --git a/message/message1_1prime/message.go b/message/message1_1prime/message.go index 641686bf..a6fccd71 100644 --- a/message/message1_1prime/message.go +++ b/message/message1_1prime/message.go @@ -76,7 +76,7 @@ func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Reque } // VoucherRequest generates a new request for the data transfer protocol -func VoucherRequest(id datatransfer.TransferID, voucher *datatransfer.TypedVoucher) (datatransfer.Request, error) { +func VoucherRequest(id datatransfer.TransferID, voucher *datatransfer.TypedVoucher) datatransfer.Request { if voucher == nil { voucher = &emptyTypedVoucher } @@ -85,11 +85,11 @@ func VoucherRequest(id datatransfer.TransferID, voucher *datatransfer.TypedVouch VoucherPtr: voucher.Voucher, VoucherTypeIdentifier: voucher.Type, TransferId: uint64(id), - }, nil + } } // RestartResponse builds a new Data Transfer response -func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { +func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { if voucherResult == nil { voucherResult = &emptyTypedVoucher } @@ -100,7 +100,7 @@ func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, v TransferId: uint64(id), VoucherResultPtr: voucherResult.Voucher, VoucherTypeIdentifier: voucherResult.Type, - }, nil + } } // ValidationResultResponse response generates a response based on a validation result @@ -110,7 +110,7 @@ func ValidationResultResponse( id datatransfer.TransferID, validationResult datatransfer.ValidationResult, validationErr error, - paused bool) (datatransfer.Response, error) { + paused bool) datatransfer.Response { voucherResult := &emptyTypedVoucher if validationResult.VoucherResult != nil { @@ -125,11 +125,11 @@ func ValidationResultResponse( TransferId: uint64(id), VoucherTypeIdentifier: voucherResult.Type, VoucherResultPtr: voucherResult.Voucher, - }, nil + } } // NewResponse builds a new Data Transfer response -func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { +func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { if voucherResult == nil { voucherResult = &emptyTypedVoucher } @@ -140,11 +140,11 @@ func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, vouch TransferId: uint64(id), VoucherTypeIdentifier: voucherResult.Type, VoucherResultPtr: voucherResult.Voucher, - }, nil + } } // VoucherResultResponse builds a new response for a voucher result -func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { +func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { if voucherResult == nil { voucherResult = &emptyTypedVoucher } @@ -155,7 +155,7 @@ func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused b TransferId: uint64(id), VoucherTypeIdentifier: voucherResult.Type, VoucherResultPtr: voucherResult.Voucher, - }, nil + } } // UpdateResponse returns a new update response @@ -176,7 +176,7 @@ func CancelResponse(id datatransfer.TransferID) datatransfer.Response { } // CompleteResponse returns a new complete response message -func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) (datatransfer.Response, error) { +func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResult *datatransfer.TypedVoucher) datatransfer.Response { if voucherResult == nil { voucherResult = &emptyTypedVoucher } @@ -187,7 +187,7 @@ func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool VoucherTypeIdentifier: voucherResult.Type, VoucherResultPtr: voucherResult.Voucher, TransferId: uint64(id), - }, nil + } } // FromNet can read a network stream to deserialize a GraphSyncMessage diff --git a/message/message1_1prime/message_test.go b/message/message1_1prime/message_test.go index 33050deb..b84b39c0 100644 --- a/message/message1_1prime/message_test.go +++ b/message/message1_1prime/message_test.go @@ -159,8 +159,7 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { func TestResponses(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewTestTypedVoucher() - response, err := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted - require.NoError(t, err) + response := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted assert.Equal(t, response.TransferID(), id) assert.False(t, response.Accepted()) assert.True(t, response.IsNew()) @@ -182,8 +181,7 @@ func TestResponses(t *testing.T) { func TestTransferResponse_MarshalCBOR(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewTestTypedVoucher() - response, err := message1_1.NewResponse(id, true, false, &voucherResult) // accepted - require.NoError(t, err) + response := message1_1.NewResponse(id, true, false, &voucherResult) // accepted // sanity check that we can marshal data wbuf := new(bytes.Buffer) @@ -195,8 +193,7 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { t.Run("round-trip", func(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewTestTypedVoucher() - response, err := message1_1.NewResponse(id, true, false, &voucherResult) // accepted - require.NoError(t, err) + response := message1_1.NewResponse(id, true, false, &voucherResult) // accepted wbuf := new(bytes.Buffer) require.NoError(t, response.ToNet(wbuf)) @@ -366,8 +363,7 @@ func TestCancelResponse(t *testing.T) { func TestCompleteResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response, err := message1_1.CompleteResponse(id, true, true, nil) - require.NoError(t, err) + response := message1_1.CompleteResponse(id, true, true, nil) assert.Equal(t, response.TransferID(), id) assert.False(t, response.IsNew()) assert.False(t, response.IsUpdate()) @@ -414,8 +410,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) - require.NoError(t, err) + response := message1_1.NewResponse(id, accepted, false, &voucherResult) err = response.ToNet(buf) require.NoError(t, err) deserialized, err = message1_1.FromNet(buf) @@ -476,8 +471,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) - require.NoError(t, err) + response := message1_1.NewResponse(id, accepted, false, &voucherResult) err = response.ToNet(buf) require.NoError(t, err) msg, _ = hex.DecodeString("a36449735271f46752657175657374f668526573706f6e7365a66454797065006441637074f46450617573f4665866657249441a4d65822164565265738178644204cb9a1e34c5f08e9b20aa76090e70020bb56c0ca3d3af7296cd1058a5112890fed218488f084d8df9e4835fb54ad045ffd936e3bf7261b0426c51352a097816ed74482bb9084b4a7ed8adc517f3371e0e0434b511625cd1a41792243dccdcfe88094b64565479706b54657374566f7563686572") @@ -541,8 +535,7 @@ func TestToNetFromNetEquivalency(t *testing.T) { testutil.AssertEqualTestVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := message1_1.NewResponse(id, accepted, false, &voucherResult) - require.NoError(t, err) + response := message1_1.NewResponse(id, accepted, false, &voucherResult) wresponse := response.WrappedForTransport(transportID, transportVersion) err = wresponse.ToNet(buf) require.NoError(t, err) diff --git a/message/message1_1prime/transfer_response_test.go b/message/message1_1prime/transfer_response_test.go index 71fcf703..85741882 100644 --- a/message/message1_1prime/transfer_response_test.go +++ b/message/message1_1prime/transfer_response_test.go @@ -14,8 +14,7 @@ import ( func TestResponseMessageForVersion(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewTestTypedVoucher() - response, err := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted - require.NoError(t, err) + response := message1_1.NewResponse(id, false, true, &voucherResult) // not accepted // v1.2 new protocol out, err := response.MessageForVersion(datatransfer.DataTransfer1_2) diff --git a/statuses.go b/statuses.go index 32c277fb..3b263de0 100644 --- a/statuses.go +++ b/statuses.go @@ -1,5 +1,7 @@ package datatransfer +import "github.com/filecoin-project/go-statemachine/fsm" + // Status is the status of transfer for a given channel type Status uint64 @@ -40,13 +42,13 @@ const ( // Cancelled means the data transfer ended prematurely Cancelled - // InitiatorPaused means the data sender has paused the channel (only the sender can unpause this) + // DEPRECATED: Use InitiatorPaused() method on ChannelState InitiatorPaused - // ResponderPaused means the data receiver has paused the channel (only the receiver can unpause this) + // DEPRECATED: Use ResponderPaused() method on ChannelState ResponderPaused - // BothPaused means both sender and receiver have paused the channel seperately (both must unpause) + // DEPRECATED: Use BothPaused() method on ChannelState BothPaused // ResponderFinalizing is a unique state where the responder is awaiting a final voucher @@ -58,18 +60,81 @@ const ( // ChannelNotFoundError means the searched for data transfer does not exist ChannelNotFoundError + + // Queued indicates a data transfer request has been accepted, but is not actively transfering yet + Queued + + // AwaitingAcceptance indicates a transfer request is actively being processed by the transport + // even if the remote has not yet responded that it's accepted the transfer. Such a state can + // occur, for example, in a requestor-initiated transfer that starts processing prior to receiving + // acceptance from the server. + AwaitingAcceptance ) -func (s Status) IsAccepted() bool { - return s != Requested && s != Cancelled && s != Cancelling && s != Failed && s != Failing && s != ChannelNotFoundError +type statusList []Status + +func (sl statusList) Contains(s Status) bool { + for _, ts := range sl { + if ts == s { + return true + } + } + return false } -func (s Status) IsResponderPaused() bool { - return s == ResponderPaused || s == BothPaused || s == Finalizing +func (sl statusList) AsFSMStates() []fsm.StateKey { + sk := make([]fsm.StateKey, 0, len(sl)) + for _, s := range sl { + sk = append(sk, s) + } + return sk +} + +var NotAcceptedStates = statusList{ + Requested, + AwaitingAcceptance, + Cancelled, + Cancelling, + Failed, + Failing, + ChannelNotFoundError} + +func (s Status) IsAccepted() bool { + return !NotAcceptedStates.Contains(s) } +var FinalizationStatuses = statusList{Finalizing, Completed, Completing} + func (s Status) InFinalization() bool { - return s == Finalizing || s == Completed || s == Completing + return FinalizationStatuses.Contains(s) +} + +var TransferCompleteStates = statusList{ + TransferFinished, + ResponderFinalizingTransferFinished, + Finalizing, + Completed, + Completing, + Failing, + Failed, + Cancelling, + Cancelled, + ChannelNotFoundError, +} + +func (s Status) TransferComplete() bool { + return TransferCompleteStates.Contains(s) +} + +var TransferringStates = statusList{ + Ongoing, + ResponderCompleted, + ResponderFinalizing, + AwaitingAcceptance, +} + +func (s Status) Transferring() bool { + return TransferringStates.Contains(s) } // Statuses are human readable names for data transfer states @@ -92,4 +157,6 @@ var Statuses = map[Status]string{ ResponderFinalizing: "ResponderFinalizing", ResponderFinalizingTransferFinished: "ResponderFinalizingTransferFinished", ChannelNotFoundError: "ChannelNotFoundError", + Queued: "Queued", + AwaitingAcceptance: "AwaitingAcceptance", } diff --git a/testutil/faketransport.go b/testutil/faketransport.go index 7e0d20c8..58726448 100644 --- a/testutil/faketransport.go +++ b/testutil/faketransport.go @@ -18,7 +18,7 @@ type RestartedChannel struct { Message datatransfer.Request } -// Records a message sent +// MessageSent records a message sent type MessageSent struct { ChannelID datatransfer.ChannelID Message datatransfer.Message @@ -38,11 +38,9 @@ type FakeTransport struct { OpenChannelErr error RestartedChannels []RestartedChannel RestartChannelErr error - ClosedChannels []datatransfer.ChannelID - PausedChannels []datatransfer.ChannelID - ResumedChannels []datatransfer.ChannelID MessagesSent []MessageSent UpdateError error + ChannelsUpdated []datatransfer.ChannelID CleanedUpChannels []datatransfer.ChannelID CustomizedTransfers []CustomizedTransfer EventHandler datatransfer.EventsHandler @@ -89,24 +87,13 @@ func (ft *FakeTransport) RestartChannel(ctx context.Context, channelState datatr } // WithChannel takes actions on a channel -func (ft *FakeTransport) UpdateChannel(ctx context.Context, chid datatransfer.ChannelID, update datatransfer.ChannelUpdate) error { +func (ft *FakeTransport) ChannelUpdated(ctx context.Context, chid datatransfer.ChannelID, msg datatransfer.Message) error { - if update.SendMessage != nil { - ft.MessagesSent = append(ft.MessagesSent, MessageSent{chid, update.SendMessage}) + if msg != nil { + ft.MessagesSent = append(ft.MessagesSent, MessageSent{chid, msg}) } - - if update.Closed { - ft.ClosedChannels = append(ft.ClosedChannels, chid) - return ft.UpdateError - } - - if !update.Paused { - ft.ResumedChannels = append(ft.ResumedChannels, chid) - } else { - ft.PausedChannels = append(ft.PausedChannels, chid) - } - - return ft.UpdateError + ft.ChannelsUpdated = append(ft.ChannelsUpdated, chid) + return nil } // SendMessage sends a data transfer message over the channel to the other peer diff --git a/testutil/message.go b/testutil/message.go deleted file mode 100644 index 14319cc3..00000000 --- a/testutil/message.go +++ /dev/null @@ -1,30 +0,0 @@ -package testutil - -import ( - "testing" - - basicnode "github.com/ipld/go-ipld-prime/node/basic" - "github.com/ipld/go-ipld-prime/traversal/selector/builder" - "github.com/stretchr/testify/require" - - datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/message" -) - -// NewDTRequest makes a new DT Request message -func NewDTRequest(t *testing.T, transferID datatransfer.TransferID) datatransfer.Request { - voucher := NewTestTypedVoucher() - baseCid := GenerateCids(1)[0] - selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() - r, err := message.NewRequest(transferID, false, false, &voucher, baseCid, selector) - require.NoError(t, err) - return r -} - -// NewDTResponse makes a new DT Request message -func NewDTResponse(t *testing.T, transferID datatransfer.TransferID) datatransfer.Response { - vresult := NewTestTypedVoucher() - r, err := message.NewResponse(transferID, false, false, &vresult) - require.NoError(t, err) - return r -} diff --git a/testutil/mockchannelstate.go b/testutil/mockchannelstate.go index 2aa75e30..92fc5f06 100644 --- a/testutil/mockchannelstate.go +++ b/testutil/mockchannelstate.go @@ -2,6 +2,7 @@ package testutil import ( cid "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" "github.com/libp2p/go-libp2p-core/peer" @@ -9,32 +10,60 @@ import ( ) type MockChannelStateParams struct { - ReceivedCids []cid.Cid - ChannelID datatransfer.ChannelID - Queued uint64 - Sent uint64 - Received uint64 - Complete bool + ReceivedIndex datamodel.Node + SentIndex datamodel.Node + QueuedIndex datamodel.Node + ChannelID datatransfer.ChannelID + Queued uint64 + Sent uint64 + Received uint64 + Complete bool + BaseCID cid.Cid + Selector ipld.Node + Voucher datatransfer.TypedVoucher + IsPull bool + Self peer.ID + DataLimit uint64 + InitiatorPaused bool + ResponderPaused bool } func NewMockChannelState(params MockChannelStateParams) *MockChannelState { return &MockChannelState{ - receivedCids: params.ReceivedCids, - chid: params.ChannelID, - queued: params.Queued, - sent: params.Sent, - received: params.Received, - complete: params.Complete, + receivedIndex: params.ReceivedIndex, + sentIndex: params.SentIndex, + queuedIndex: params.QueuedIndex, + dataLimit: params.DataLimit, + chid: params.ChannelID, + queued: params.Queued, + sent: params.Sent, + received: params.Received, + complete: params.Complete, + isPull: params.IsPull, + self: params.Self, + baseCID: params.BaseCID, + initiatorPaused: params.InitiatorPaused, + responderPaused: params.ResponderPaused, } } type MockChannelState struct { - receivedCids []cid.Cid - chid datatransfer.ChannelID - queued uint64 - sent uint64 - received uint64 - complete bool + receivedIndex datamodel.Node + sentIndex datamodel.Node + queuedIndex datamodel.Node + dataLimit uint64 + chid datatransfer.ChannelID + queued uint64 + sent uint64 + received uint64 + complete bool + isPull bool + baseCID cid.Cid + selector ipld.Node + voucher datatransfer.TypedVoucher + self peer.ID + initiatorPaused bool + responderPaused bool } var _ datatransfer.ChannelState = (*MockChannelState)(nil) @@ -77,48 +106,67 @@ func (m *MockChannelState) Status() datatransfer.Status { return datatransfer.Ongoing } -func (m *MockChannelState) ReceivedCids() []cid.Cid { - return m.receivedCids +func (m *MockChannelState) SetReceivedIndex(receivedIndex datamodel.Node) { + m.receivedIndex = receivedIndex } -func (m *MockChannelState) ReceivedCidsLen() int { - return len(m.receivedCids) +func (m *MockChannelState) ReceivedIndex() datamodel.Node { + if m.receivedIndex == nil { + return datamodel.Null + } + return m.receivedIndex } -func (m *MockChannelState) ReceivedCidsTotal() int64 { - return (int64)(len(m.receivedCids)) +func (m *MockChannelState) QueuedIndex() datamodel.Node { + if m.queuedIndex == nil { + return datamodel.Null + } + return m.queuedIndex } -func (m *MockChannelState) QueuedCidsTotal() int64 { - panic("implement me") +func (m *MockChannelState) SetQueuedIndex(queuedIndex datamodel.Node) { + m.queuedIndex = queuedIndex } -func (m *MockChannelState) SentCidsTotal() int64 { - panic("implement me") +func (m *MockChannelState) SentIndex() datamodel.Node { + if m.sentIndex == nil { + return datamodel.Null + } + return m.sentIndex +} + +func (m *MockChannelState) SetSentIndex(sentIndex datamodel.Node) { + m.sentIndex = sentIndex } func (m *MockChannelState) TransferID() datatransfer.TransferID { - panic("implement me") + return m.chid.ID } func (m *MockChannelState) BaseCID() cid.Cid { - panic("implement me") + return m.baseCID } func (m *MockChannelState) Selector() datamodel.Node { - panic("implement me") + return m.selector } func (m *MockChannelState) Voucher() datatransfer.TypedVoucher { - panic("implement me") + return m.voucher } func (m *MockChannelState) Sender() peer.ID { - panic("implement me") + if m.isPull { + return m.chid.Responder + } + return m.chid.Initiator } func (m *MockChannelState) Recipient() peer.ID { - panic("implement me") + if m.isPull { + return m.chid.Initiator + } + return m.chid.Responder } func (m *MockChannelState) TotalSize() uint64 { @@ -126,15 +174,18 @@ func (m *MockChannelState) TotalSize() uint64 { } func (m *MockChannelState) IsPull() bool { - panic("implement me") + return m.isPull } func (m *MockChannelState) OtherPeer() peer.ID { - panic("implement me") + if m.self == m.chid.Initiator { + return m.chid.Responder + } + return m.chid.Initiator } func (m *MockChannelState) SelfPeer() peer.ID { - panic("implement me") + return m.self } func (m *MockChannelState) Message() string { @@ -161,10 +212,41 @@ func (m *MockChannelState) Stages() *datatransfer.ChannelStages { panic("implement me") } +func (m *MockChannelState) SetDataLimit(dataLimit uint64) { + m.dataLimit = dataLimit +} + func (m *MockChannelState) DataLimit() uint64 { - panic("implement me") + return m.dataLimit } func (m *MockChannelState) RequiresFinalization() bool { panic("implement me") } + +func (m *MockChannelState) SetResponderPaused(responderPaused bool) { + m.responderPaused = responderPaused +} + +func (m *MockChannelState) ResponderPaused() bool { + return m.responderPaused +} + +func (m *MockChannelState) SetInitiatorPaused(initiatorPaused bool) { + m.initiatorPaused = initiatorPaused +} + +func (m *MockChannelState) InitiatorPaused() bool { + return m.initiatorPaused +} + +func (m *MockChannelState) BothPaused() bool { + return m.initiatorPaused && m.responderPaused +} + +func (m *MockChannelState) SelfPaused() bool { + if m.self == m.chid.Initiator { + return m.initiatorPaused + } + return m.responderPaused +} diff --git a/transport.go b/transport.go index 546c2574..91aebdfc 100644 --- a/transport.go +++ b/transport.go @@ -3,7 +3,7 @@ package datatransfer import ( "context" - ipld "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" ) // TransportID identifies a unique transport @@ -17,75 +17,103 @@ const LegacyTransportID TransportID = "graphsync" // i.e. graphsync 1.0.0 var LegacyTransportVersion Version = Version{1, 0, 0} +type TransportEvent interface { + transportEvent() +} + +// TransportOpenedChannel occurs when the transport begins processing the +// request (prior to that it may simply be queued) -- only applies to initiator +type TransportOpenedChannel struct{} + +// TransportInitiatedTransfer occurs when the transport actually begins sending/receiving data +type TransportInitiatedTransfer struct{} + +// TransportReceivedData occurs when we receive data for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportReceivedData struct { + Size uint64 + Index datamodel.Node +} + +// TransportSentData occurs when we send data for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportSentData struct { + Size uint64 + Index datamodel.Node +} + +// TransportQueuedData occurs when data is queued for sending for the given channel ID +// index is a transport dependent of serializing "here's where I am in this transport" +type TransportQueuedData struct { + Size uint64 + Index datamodel.Node +} + +// TransportReachedDataLimit occurs when a channel hits a previously set data limit +type TransportReachedDataLimit struct{} + +// TransportTransferCancelled occurs when a request we opened (with the given channel Id) to +// receive data is cancelled by us. +type TransportTransferCancelled struct { + ErrorMessage string +} + +// TransportErrorSendingData occurs when a network error occurs trying to send a request +type TransportErrorSendingData struct { + ErrorMessage string +} + +// TransportErrorReceivingData occurs when a network error occurs receiving data +// at the transport layer +type TransportErrorReceivingData struct { + ErrorMessage string +} + +// TransportCompletedTransfer occurs when we finish transferring data for the given channel ID +type TransportCompletedTransfer struct { + Success bool + ErrorMessage string +} + +type TransportReceivedRestartExistingChannelRequest struct{} + +// TransportErrorSendingMessage occurs when a network error occurs trying to send a request +type TransportErrorSendingMessage struct { + ErrorMessage string +} + +type TransportPaused struct{} + +type TransportResumed struct{} + // EventsHandler are semantic data transfer events that happen as a result of transport events type EventsHandler interface { // ChannelState queries for the current channel state ChannelState(ctx context.Context, chid ChannelID) (ChannelState, error) - // OnChannelOpened is called when we send a request for data to the other - // peer on the given channel ID - // return values are: - // - error = ignore incoming data for this channel - OnChannelOpened(chid ChannelID) error - // OnResponseReceived is called when we receive a response to a request - // - nil = continue receiving data - // - error = cancel this request - OnResponseReceived(chid ChannelID, msg Response) error - // OnDataReceive is called when we receive data for the given channel ID - // return values are: - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request - OnDataReceived(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error - - // OnDataQueued is called when data is queued for sending for the given channel ID - // return values are: - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request - OnDataQueued(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) (Message, error) - - // OnDataSent is called when we send data for the given channel ID - OnDataSent(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error - - // OnTransferQueued is called when a new data transfer request is queued in the transport layer. - OnTransferQueued(chid ChannelID) - - // OnRequestReceived is called when we receive a new request to send data - // for the given channel ID - // return values are: - // message = data transfer message along with reply - // err = error - // - nil = proceed with sending data - // - error = cancel this request - // - err == ErrPause - pause this request (only for new requests) - // - err == ErrResume - resume this request (only for update requests) + // OnTransportEvent is dispatched when an event occurs on the transport + // It MAY be dispatched asynchronously by the transport to the time the + // event occurs + // However, the other handler functions may ONLY be called on the same channel + // after all events are dispatched. In other words, the transport MUST allow + // the handler to process all events before calling the other functions which + // have a synchronous return + OnTransportEvent(chid ChannelID, event TransportEvent) + + // OnRequestReceived occurs when we receive a request for the given channel ID + // return values are a message to send an error if the transport should be closed + // TODO: in a future improvement, a received request should become a + // just TransportEvent, and should be handled asynchronously OnRequestReceived(chid ChannelID, msg Request) (Response, error) - // OnChannelCompleted is called when we finish transferring data for the given channel ID - // Error returns are logged but otherwise have no effect - OnChannelCompleted(chid ChannelID, err error) error - - // OnRequestCancelled is called when a request we opened (with the given channel Id) to - // receive data is cancelled by us. - // Error returns are logged but otherwise have no effect - OnRequestCancelled(chid ChannelID, err error) error - - // OnRequestDisconnected is called when a network error occurs trying to send a request - OnRequestDisconnected(chid ChannelID, err error) error - - // OnSendDataError is called when a network error occurs sending data - // at the transport layer - OnSendDataError(chid ChannelID, err error) error - // OnReceiveDataError is called when a network error occurs receiving data - // at the transport layer - OnReceiveDataError(chid ChannelID, err error) error + // OnRequestReceived occurs when we receive a response to a request + // TODO: in a future improvement, a received response should become a + // just TransportEvent, and should be handled asynchronously + OnResponseReceived(chid ChannelID, msg Response) error // OnContextAugment allows the transport to attach data transfer tracing information // to its local context, in order to create a hierarchical trace OnContextAugment(chid ChannelID) func(context.Context) context.Context - - OnRestartExistingChannelRequestReceived(chid ChannelID) error } /* @@ -121,12 +149,10 @@ type Transport interface { req Request, ) error - // UpdateChannel sends one or more updates the transport channel at once, - // such as pausing/resuming, closing the transfer, or sending additional - // messages over the channel. Grouping the commands allows the transport - // the ability to plan how to execute these updates based on the capabilities - // and API of the underlying transport protocol and library - UpdateChannel(ctx context.Context, chid ChannelID, update ChannelUpdate) error + // ChannelUpdated notifies the transport that state of the channel has been updated, + // along with an optional message to send over the transport to tell + // the other peer about the update + ChannelUpdated(ctx context.Context, chid ChannelID, message Message) error // SetEventHandler sets the handler for events on channels SetEventHandler(events EventsHandler) error // CleanupChannel removes any associated data on a closed channel @@ -151,13 +177,17 @@ type TransportCapabilities struct { Pausable bool } -// ChannelUpdate describes updates to a channel - changing it's paused status, closing the transfer, -// and additional messages to send -type ChannelUpdate struct { - // Paused sets the paused status of the channel. If pause/resumes are not supported, this is a no op - Paused bool - // Closed sets whether the channel is closed - Closed bool - // SendMessage sends an additional message - SendMessage Message -} +func (TransportOpenedChannel) transportEvent() {} +func (TransportInitiatedTransfer) transportEvent() {} +func (TransportReceivedData) transportEvent() {} +func (TransportSentData) transportEvent() {} +func (TransportQueuedData) transportEvent() {} +func (TransportReachedDataLimit) transportEvent() {} +func (TransportTransferCancelled) transportEvent() {} +func (TransportErrorSendingData) transportEvent() {} +func (TransportErrorReceivingData) transportEvent() {} +func (TransportCompletedTransfer) transportEvent() {} +func (TransportReceivedRestartExistingChannelRequest) transportEvent() {} +func (TransportErrorSendingMessage) transportEvent() {} +func (TransportPaused) transportEvent() {} +func (TransportResumed) transportEvent() {} diff --git a/transport/graphsync/dtchannel/dtchannel.go b/transport/graphsync/dtchannel/dtchannel.go index e0a003be..81a9944f 100644 --- a/transport/graphsync/dtchannel/dtchannel.go +++ b/transport/graphsync/dtchannel/dtchannel.go @@ -33,6 +33,7 @@ const ( // Info needed to keep track of a data transfer channel type Channel struct { + isSender bool channelID datatransfer.ChannelID gs graphsync.GraphExchange @@ -46,7 +47,11 @@ type Channel struct { storeLk sync.RWMutex storeRegistered bool - receivedCidsTotal int64 + receivedIndex int64 + sentIndex int64 + queuedIndex int64 + dataLimit uint64 + progress uint64 } func NewChannel(channelID datatransfer.ChannelID, gs graphsync.GraphExchange) *Channel { @@ -94,8 +99,8 @@ func (c *Channel) Open( } // add do not send cids ext as needed - if c.receivedCidsTotal > 0 { - data := donotsendfirstblocks.EncodeDoNotSendFirstBlocks(c.receivedCidsTotal) + if c.receivedIndex > 0 { + data := donotsendfirstblocks.EncodeDoNotSendFirstBlocks(c.receivedIndex) exts = append(exts, graphsync.ExtensionData{ Name: graphsync.ExtensionsDoNotSendFirstBlocks, Data: data, @@ -118,8 +123,8 @@ func (c *Channel) Open( // Open a new graphsync request msg := fmt.Sprintf("Opening graphsync request to %s for root %s", dataSender, root) - if c.receivedCidsTotal > 0 { - msg += fmt.Sprintf(" with %d Blocks already received", c.receivedCidsTotal) + if c.receivedIndex > 0 { + msg += fmt.Sprintf(" with %d Blocks already received", c.receivedIndex) } log.Info(msg) c.requestID = &requestID @@ -158,7 +163,7 @@ func (c *Channel) GsReqOpened(sender peer.ID, requestID graphsync.RequestID, hoo // gsDataRequestRcvd is called when the transport receives an incoming request // for data. -func (c *Channel) GsDataRequestRcvd(sender peer.ID, requestID graphsync.RequestID, pauseRequest bool, hookActions graphsync.IncomingRequestHookActions) { +func (c *Channel) GsDataRequestRcvd(sender peer.ID, requestID graphsync.RequestID, chst datatransfer.ChannelState, hookActions graphsync.IncomingRequestHookActions) { c.lk.Lock() defer c.lk.Unlock() log.Debugf("%s: received request for data, req_id=%d", c.channelID, requestID) @@ -184,11 +189,26 @@ func (c *Channel) GsDataRequestRcvd(sender peer.ID, requestID graphsync.RequestI c.requestID = &requestID log.Infow("incoming graphsync request", "peer", sender, "graphsync request id", requestID, "data transfer channel id", c.channelID) - if pauseRequest { + c.state = channelOpen + + err := c.updateFromChannelState(chst) + if err != nil { + hookActions.TerminateWithError(err) + return + } + + action := c.actionFromChannelState(chst) + switch action { + case Pause: c.state = channelPaused + hookActions.PauseResponse() + case Close: + c.state = channelClosed + hookActions.TerminateWithError(datatransfer.ErrRejected) return + default: } - c.state = channelOpen + hookActions.ValidateRequest() } func (c *Channel) MarkPaused() { @@ -262,6 +282,113 @@ func (c *Channel) Resume(ctx context.Context, extensions []graphsync.ExtensionDa return c.gs.Unpause(ctx, *c.requestID, extensions...) } +type Action string + +const ( + NoAction Action = "" + Close Action = "close" + Pause Action = "pause" + Resume Action = "resume" +) + +// UpdateFromChannelState updates internal graphsync channel state form a datatransfer +// channel state +func (c *Channel) UpdateFromChannelState(chst datatransfer.ChannelState) error { + c.lk.Lock() + defer c.lk.Unlock() + return c.updateFromChannelState(chst) +} + +func (c *Channel) updateFromChannelState(chst datatransfer.ChannelState) error { + // read the sent value + sentNode := chst.SentIndex() + if !sentNode.IsNull() { + sentIndex, err := sentNode.AsInt() + if err != nil { + return err + } + if sentIndex > c.sentIndex { + c.sentIndex = sentIndex + } + } + + // read the received + receivedNode := chst.ReceivedIndex() + if !receivedNode.IsNull() { + receivedIndex, err := receivedNode.AsInt() + if err != nil { + return err + } + if receivedIndex > c.receivedIndex { + c.receivedIndex = receivedIndex + } + } + + // read the queued + queuedNode := chst.QueuedIndex() + if !queuedNode.IsNull() { + queuedIndex, err := queuedNode.AsInt() + if err != nil { + return err + } + if queuedIndex > c.queuedIndex { + c.queuedIndex = queuedIndex + } + } + + // set progress + var progress uint64 + if chst.Sender() == chst.SelfPeer() { + progress = chst.Queued() + } else { + progress = chst.Received() + } + if progress > c.progress { + c.progress = progress + } + + // set data limit + c.dataLimit = chst.DataLimit() + return nil +} + +// ActionFromChannelState comparse internal graphsync channel state with the data transfer +// state and determines what if any action should be taken on graphsync +func (c *Channel) ActionFromChannelState(chst datatransfer.ChannelState) Action { + c.lk.Lock() + defer c.lk.Unlock() + return c.actionFromChannelState(chst) +} + +func (c *Channel) actionFromChannelState(chst datatransfer.ChannelState) Action { + // if the state is closed, and we haven't closed, we need to close + if !c.requesterCancelled && c.state != channelClosed && chst.Status().TransferComplete() { + return Close + } + + // if the state is running, and we're paused, we need to pause + if c.requestID != nil && c.state == channelPaused && !chst.SelfPaused() { + return Resume + } + + // if the state is paused, and the transfer is running, we need to resume + if c.requestID != nil && c.state == channelOpen && chst.SelfPaused() { + return Pause + } + + return NoAction +} + +func (c *Channel) ReconcileChannelState(chst datatransfer.ChannelState) (Action, error) { + c.lk.Lock() + defer c.lk.Unlock() + err := c.updateFromChannelState(chst) + if err != nil { + return NoAction, err + } + return c.actionFromChannelState(chst), nil +} + func (c *Channel) MarkTransferComplete() { c.lk.Lock() defer c.lk.Unlock() @@ -300,12 +427,45 @@ func (c *Channel) UseStore(lsys ipld.LinkSystem) error { return nil } -func (c *Channel) UpdateReceivedCidsIfGreater(nextIdx int64) { +func (c *Channel) UpdateReceivedIndexIfGreater(nextIdx int64) bool { c.lk.Lock() defer c.lk.Unlock() - if c.receivedCidsTotal < nextIdx { - c.receivedCidsTotal = nextIdx + if c.receivedIndex < nextIdx { + c.receivedIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateQueuedIndexIfGreater(nextIdx int64) bool { + c.lk.Lock() + defer c.lk.Unlock() + if c.queuedIndex < nextIdx { + c.queuedIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateSentIndexIfGreater(nextIdx int64) bool { + c.lk.Lock() + defer c.lk.Unlock() + if c.sentIndex < nextIdx { + c.sentIndex = nextIdx + return true + } + return false +} + +func (c *Channel) UpdateProgress(additionalData uint64) bool { + c.lk.Lock() + defer c.lk.Unlock() + c.progress += additionalData + reachedLimit := c.dataLimit != 0 && c.progress >= c.dataLimit + if reachedLimit { + c.state = channelPaused } + return reachedLimit } func (c *Channel) Cleanup() { diff --git a/transport/graphsync/exceptions_test.go b/transport/graphsync/exceptions_test.go new file mode 100644 index 00000000..c4845721 --- /dev/null +++ b/transport/graphsync/exceptions_test.go @@ -0,0 +1,287 @@ +package graphsync_test + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestTransferExceptions(t *testing.T) { + ctx := context.Background() + testCases := []struct { + name string + parameters []testharness.Option + test func(t *testing.T, th *testharness.GsTestHarness) + }{ + { + name: "error executing pull graphsync request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + receivedRequest := th.Fgs.ReceivedRequests[0] + close(receivedRequest.ResponseChan) + receivedRequest.ResponseErrChan <- errors.New("something went wrong") + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: fmt.Sprintf("channel %s: graphsync request failed to complete: something went wrong", th.Channel.ChannelID())}) + }, + }, + { + name: "unrecognized outgoing pull request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + // configure a store + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + // setup a seperate request with different request ID but with contained message + otherRequest := testharness.NewFakeRequest(graphsync.NewRequestID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionDataTransfer1_1: th.NewRequest(t).ToIPLD(), + }, graphsync.RequestTypeNew) + // run outgoing request hook on this request + th.OutgoingRequestHook(otherRequest) + // no channel opened + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + // no store configuration + require.Empty(t, th.OutgoingRequestHookActions.PersistenceOption) + // run outgoing request processing listener + th.OutgoingRequestProcessingListener(otherRequest) + // no transfer initiated event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + dtResponse := th.Response() + // create a response with the wrong request ID + otherResponse := testharness.NewFakeResponse(otherRequest.ID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionIncomingRequest1_1: dtResponse.ToIPLD(), + }, graphsync.PartialResponse) + // run incoming response hook + th.IncomingResponseHook(otherResponse) + // no response received + require.Nil(t, th.Events.ReceivedResponse) + // run blook hook + block := testharness.NewFakeBlockData(12345, 1, true) + th.IncomingBlockHook(otherResponse, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }, + }, + { + name: "error cancelling on restart request", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + th.Fgs.ReturnedCancelError = errors.New("something went wrong") + err := th.Transport.RestartChannel(th.Ctx, th.Channel, th.RestartRequest(t)) + require.EqualError(t, err, fmt.Sprintf("%s: restarting graphsync request: cancelling graphsync request for channel %s: %s", th.Channel.ChannelID(), th.Channel.ChannelID(), "something went wrong")) + }, + }, + { + name: "error reconnecting during restart", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + expectedErr := errors.New("something went wrong") + th.DtNet.ReturnedConnectWithRetryError = expectedErr + err := th.Transport.RestartChannel(th.Ctx, th.Channel, th.RestartRequest(t)) + require.ErrorIs(t, err, expectedErr) + }, + }, + { + name: "unrecognized incoming graphsync request dt response", + test: func(t *testing.T, th *testharness.GsTestHarness) { + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + th.IncomingRequestHook(request) + require.False(t, th.IncomingRequestHookActions.Validated) + require.Error(t, th.IncomingRequestHookActions.TerminationError) + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }, + }, + { + name: "incoming graphsync request w/ dt response gets OnResponseReceived error", + test: func(t *testing.T, th *testharness.GsTestHarness) { + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + th.Events.ReturnedResponseReceivedError = errors.New("something went wrong") + th.IncomingRequestHook(request) + require.False(t, th.IncomingRequestHookActions.Validated) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, "something went wrong") + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }, + }, + { + name: "pull request cancelled", + parameters: []testharness.Option{testharness.PullRequest()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + _ = th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest := th.Fgs.ReceivedRequests[0] + close(receivedRequest.ResponseChan) + receivedRequest.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(receivedRequest.ResponseErrChan) + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ + ErrorMessage: "graphsync request cancelled", + }) + }, + }, + { + name: "error opening sending push message", + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.DtNet.ReturnedSendMessageError = errors.New("something went wrong") + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.EqualError(t, err, "something went wrong") + }, + }, + { + name: "unrecognized incoming graphsync push request", + test: func(t *testing.T, th *testharness.GsTestHarness) { + // open a channel + th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + // configure a store + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + voucherResult := testutil.NewTestTypedVoucher() + otherRequest := testharness.NewFakeRequest(graphsync.NewRequestID(), map[graphsync.ExtensionName]datamodel.Node{ + extension.ExtensionDataTransfer1_1: message.NewResponse(datatransfer.TransferID(rand.Uint64()), true, false, &voucherResult).ToIPLD(), + }, graphsync.RequestTypeNew) + // run incoming request hook on new request + th.IncomingRequestHook(otherRequest) + // should error + require.Error(t, th.IncomingRequestHookActions.TerminationError) + // run incoming request processing listener + th.IncomingRequestProcessingListener(otherRequest) + // no transfer initiated event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + // run block queued hook + block := testharness.NewFakeBlockData(12345, 1, true) + th.OutgoingBlockHook(otherRequest, block) + // no block queued event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // run block sent hook + th.BlockSentListener(otherRequest, block) + // no block sent event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // run complete listener + th.ResponseCompletedListener(otherRequest, graphsync.RequestCompletedFull) + // no complete event + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }, + }, + { + name: "channel update on unrecognized channel", + test: func(t *testing.T, th *testharness.GsTestHarness) { + err := th.Transport.ChannelUpdated(th.Ctx, th.Channel.ChannelID(), th.NewRequest(t)) + require.Error(t, err) + }, + }, + { + name: "incoming request errors in OnRequestReceived", + parameters: []testharness.Option{testharness.PullRequest(), testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + voucherResult := testutil.NewTestTypedVoucher() + dtResponse := message.NewResponse(th.Channel.TransferID(), false, false, &voucherResult) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Events.ReturnedRequestReceivedError = errors.New("something went wrong") + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, dtRequest, th.Events.ReceivedRequest) + require.Empty(t, th.DtNet.ProtectedPeers) + require.Empty(t, th.IncomingRequestHookActions.PersistenceOption) + require.False(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingRequestHookActions.Paused) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, "something went wrong") + sentResponse := th.IncomingRequestHookActions.DTMessage(t) + require.Equal(t, dtResponse, sentResponse) + th.IncomingRequestHookActions.RefuteAugmentedContextKey(t, ctxKey{}) + }, + }, + { + name: "incoming gs request with contained push request errors", + parameters: []testharness.Option{testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.IncomingRequestHook(request) + require.EqualError(t, th.IncomingRequestHookActions.TerminationError, datatransfer.ErrUnsupported.Error()) + }, + }, + { + name: "incoming requests completes with error code for graphsync", + parameters: []testharness.Option{testharness.PullRequest(), testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.IncomingRequestHook(request) + + th.ResponseCompletedListener(request, graphsync.RequestFailedUnknown) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: fmt.Sprintf("graphsync response to peer %s did not complete: response status code %s", th.Channel.Recipient(), graphsync.RequestFailedUnknown.String())}) + + }, + }, + { + name: "incoming push request message errors in OnRequestReceived", + parameters: []testharness.Option{testharness.Responder()}, + test: func(t *testing.T, th *testharness.GsTestHarness) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + voucherResult := testutil.NewTestTypedVoucher() + dtResponse := message.NewResponse(th.Channel.TransferID(), false, false, &voucherResult) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Events.ReturnedRequestReceivedError = errors.New("something went wrong") + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Equal(t, th.NewRequest(t), th.Events.ReceivedRequest) + require.Empty(t, th.DtNet.ProtectedPeers) + require.Empty(t, th.Fgs.ReceivedRequests) + require.Len(t, th.DtNet.SentMessages, 1) + require.Equal(t, testharness.FakeSentMessage{Message: dtResponse, TransportID: "graphsync", PeerID: th.Channel.OtherPeer()}, th.DtNet.SentMessages[0]) + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testCase.parameters...) + testCase.test(t, th) + }) + } +} diff --git a/transport/graphsync/executor/executor.go b/transport/graphsync/executor/executor.go index 0bc23e4f..b0a8b4c0 100644 --- a/transport/graphsync/executor/executor.go +++ b/transport/graphsync/executor/executor.go @@ -13,8 +13,7 @@ var log = logging.Logger("dt_graphsync") // EventsHandler are the data transfer events that can be dispatched by the execetor type EventsHandler interface { - OnRequestCancelled(datatransfer.ChannelID, error) error - OnChannelCompleted(datatransfer.ChannelID, error) error + OnTransportEvent(datatransfer.ChannelID, datatransfer.TransportEvent) } // Executor handles consuming channels on an outgoing GraphSync request @@ -74,9 +73,7 @@ func (e *Executor) executeRequest( if _, ok := lastError.(graphsync.RequestClientCancelledErr); ok { terr := fmt.Errorf("graphsync request cancelled") log.Warnf("channel %s: %s", e.channelID, terr) - if err := events.OnRequestCancelled(e.channelID, terr); err != nil { - log.Error(err) - } + events.OnTransportEvent(e.channelID, datatransfer.TransportTransferCancelled{ErrorMessage: terr.Error()}) return } @@ -102,8 +99,11 @@ func (e *Executor) executeRequest( if completedRequestListener != nil { completedRequestListener(e.channelID) } - err := events.OnChannelCompleted(e.channelID, completeErr) - if err != nil { - log.Errorf("channel %s: processing OnChannelCompleted: %s", e.channelID, err) + + if completeErr == nil { + events.OnTransportEvent(e.channelID, datatransfer.TransportCompletedTransfer{Success: true}) + } else { + events.OnTransportEvent(e.channelID, datatransfer.TransportCompletedTransfer{Success: false, ErrorMessage: completeErr.Error()}) } + } diff --git a/transport/graphsync/executor/executor_test.go b/transport/graphsync/executor/executor_test.go index 9e5523c1..089a151d 100644 --- a/transport/graphsync/executor/executor_test.go +++ b/transport/graphsync/executor/executor_test.go @@ -34,7 +34,7 @@ func TestExecutor(t *testing.T) { responseErrors: []error{errors.New("something went wrong")}, expectedEventRecord: fakeEvents{ completedChannel: chid, - completedError: fmt.Errorf("channel %s: graphsync request failed to complete: %w", chid, errors.New("something went wrong")), + completedError: fmt.Errorf("channel %s: graphsync request failed to complete: %s", chid, errors.New("something went wrong")), }, }, "client cancelled request error, no listener": { @@ -119,16 +119,17 @@ type fakeEvents struct { cancelledErr error } -func (fe *fakeEvents) OnChannelCompleted(chid datatransfer.ChannelID, err error) error { - fe.completedChannel = chid - fe.completedError = err - return nil -} - -func (fe *fakeEvents) OnRequestCancelled(chid datatransfer.ChannelID, err error) error { - fe.cancelledChannel = chid - fe.cancelledErr = err - return nil +func (fe *fakeEvents) OnTransportEvent(chid datatransfer.ChannelID, transportEvent datatransfer.TransportEvent) { + switch evt := transportEvent.(type) { + case datatransfer.TransportCompletedTransfer: + fe.completedChannel = chid + if !evt.Success { + fe.completedError = errors.New(evt.ErrorMessage) + } + case datatransfer.TransportTransferCancelled: + fe.cancelledChannel = chid + fe.cancelledErr = errors.New(evt.ErrorMessage) + } } type fakeCompletedRequestListener struct { diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index 07fe307e..41d1e085 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -36,12 +36,10 @@ var defaultSupportedExtensions = []graphsync.ExtensionName{ var incomingReqExtensions = []graphsync.ExtensionName{ extension.ExtensionIncomingRequest1_1, - extension.ExtensionDataTransfer1_1, } var outgoingBlkExtensions = []graphsync.ExtensionName{ extension.ExtensionOutgoingBlock1_1, - extension.ExtensionDataTransfer1_1, } // Option is an option for setting up the graphsync transport @@ -91,11 +89,11 @@ type Transport struct { } // NewTransport makes a new hooks manager with the given hook events interface -func NewTransport(peerID peer.ID, gs graphsync.GraphExchange, dtNet network.DataTransferNetwork, options ...Option) *Transport { +func NewTransport(gs graphsync.GraphExchange, dtNet network.DataTransferNetwork, options ...Option) *Transport { t := &Transport{ gs: gs, dtNet: dtNet, - peerID: peerID, + peerID: dtNet.ID(), supportedExtensions: defaultSupportedExtensions, dtChannels: make(map[datatransfer.ChannelID]*dtchannel.Channel), requestIDToChannelID: newRequestIDToChannelIDMap(), @@ -128,6 +126,7 @@ func (t *Transport) OpenChannel( channel datatransfer.Channel, req datatransfer.Request) error { t.dtNet.Protect(channel.OtherPeer(), channel.ChannelID().String()) + t.trackDTChannel(channel.ChannelID()) if channel.IsPull() { return t.openRequest(ctx, channel.Sender(), @@ -157,8 +156,13 @@ func (t *Transport) RestartChannel( t.dtNet.Protect(channelState.OtherPeer(), channelState.ChannelID().String()) ch := t.trackDTChannel(channelState.ChannelID()) - ch.UpdateReceivedCidsIfGreater(channelState.ReceivedCidsTotal()) + err = ch.UpdateFromChannelState(channelState) + if err != nil { + return err + } + if channelState.IsPull() { + return t.openRequest(ctx, channelState.Sender(), channelState.ChannelID(), @@ -203,49 +207,61 @@ func (t *Transport) openRequest( return nil } -// UpdateChannel sends one or more updates the transport channel at once, -// such as pausing/resuming, closing the transfer, or sending additional -// messages over the channel. Grouping the commands allows the transport -// the ability to plan how to execute these updates -func (t *Transport) UpdateChannel(ctx context.Context, chid datatransfer.ChannelID, update datatransfer.ChannelUpdate) error { - +func (t *Transport) reconcileChannelStates(ctx context.Context, chid datatransfer.ChannelID) (*dtchannel.Channel, dtchannel.Action, error) { + chst, err := t.events.ChannelState(ctx, chid) + if err != nil { + return nil, dtchannel.NoAction, err + } ch, err := t.getDTChannel(chid) if err != nil { - if update.SendMessage != nil && !update.Closed { - return t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, update.SendMessage) + return nil, dtchannel.NoAction, err + } + action, err := ch.ReconcileChannelState(chst) + return ch, action, err +} + +// ChannelUpdated notifies the transport that state of the channel has been updated, +// along with an optional message to send over the transport to tell +// the other peer about the update +func (t *Transport) ChannelUpdated(ctx context.Context, chid datatransfer.ChannelID, message datatransfer.Message) error { + ch, action, err := t.reconcileChannelStates(ctx, chid) + if err != nil { + if message != nil { + if sendErr := t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, message); sendErr != nil { + return sendErr + } } return err } + return t.processAction(ctx, chid, ch, action, message) +} - if !update.Paused && ch.Paused() { - +func (t *Transport) processAction(ctx context.Context, chid datatransfer.ChannelID, ch *dtchannel.Channel, action dtchannel.Action, message datatransfer.Message) error { + if action == dtchannel.Resume { var extensions []graphsync.ExtensionData - if update.SendMessage != nil { + if message != nil { var err error - extensions, err = extension.ToExtensionData(update.SendMessage, t.supportedExtensions) + extensions, err = extension.ToExtensionData(message, t.supportedExtensions) if err != nil { return err } } - return ch.Resume(ctx, extensions) } - if update.SendMessage != nil { - if err := t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, update.SendMessage); err != nil { + if message != nil { + if err := t.dtNet.SendMessage(ctx, t.otherPeer(chid), transportID, message); err != nil { return err } } - - if update.Closed { + switch action { + case dtchannel.Close: return ch.Close(ctx) - } - - if update.Paused && !ch.Paused() { + case dtchannel.Pause: return ch.Pause(ctx) + default: + return nil } - - return nil } // SendMessage sends a data transfer message over the channel to the other peer @@ -284,15 +300,16 @@ func (t *Transport) SetEventHandler(events datatransfer.EventsHandler) error { } t.events = events - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestQueuedHook(t.gsReqQueuedHook)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestProcessingListener(t.gsRequestProcessingListener)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingRequestProcessingListener(t.gsRequestProcessingListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingRequestHook(t.gsReqRecdHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterCompletedResponseListener(t.gsCompletedResponseListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingBlockHook(t.gsIncomingBlockHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingBlockHook(t.gsOutgoingBlockHook)) - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterBlockSentListener(t.gsBlockSentHook)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterBlockSentListener(t.gsBlockSentListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterOutgoingRequestHook(t.gsOutgoingRequestHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingResponseHook(t.gsIncomingResponseHook)) - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestUpdatedHook(t.gsRequestUpdatedHook)) + //t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestUpdatedHook(t.gsRequestUpdatedHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestorCancelledListener(t.gsRequestorCancelledListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterNetworkErrorListener(t.gsNetworkSendErrorListener)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterReceiverNetworkErrorListener(t.gsNetworkReceiveErrorListener)) diff --git a/transport/graphsync/graphsync_test.go b/transport/graphsync/graphsync_test.go deleted file mode 100644 index 5eccc205..00000000 --- a/transport/graphsync/graphsync_test.go +++ /dev/null @@ -1,1362 +0,0 @@ -package graphsync_test - -/* -import ( - "context" - "errors" - "io" - "math/rand" - "testing" - "time" - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/donotsendfirstblocks" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/datamodel" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/ipld/go-ipld-prime/node/basicnode" - peer "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" - "github.com/stretchr/testify/require" - datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/message" - "github.com/filecoin-project/go-data-transfer/v2/testutil" - . "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" - "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" -) - -func TestManager(t *testing.T) { - testCases := map[string]struct { - requestConfig gsRequestConfig - responseConfig gsResponseConfig - updatedConfig gsRequestConfig - events fakeEvents - action func(gsData *harness) - check func(t *testing.T, events *fakeEvents, gsData *harness) - protocol protocol.ID - }{ - "gs outgoing request with recognized dt pull channel will record incoming blocks": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.True(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs outgoing request with recognized dt push channel will record incoming blocks": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.True(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "non-data-transfer gs request will not record incoming blocks and send updates": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{}) - require.False(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs request unrecognized opened channel will not record incoming blocks": { - events: fakeEvents{ - OnChannelOpenedError: errors.New("Not recognized"), - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.False(t, events.OnDataReceivedCalled) - require.NoError(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "gs incoming block with data receive error will halt request": { - events: fakeEvents{ - OnDataReceivedError: errors.New("something went wrong"), - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.True(t, events.OnDataReceivedCalled) - require.Error(t, gsData.incomingBlockHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request can receive gs response": { - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request cannot receive gs response with dt request": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response can receive gs response": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response cannot receive gs response with dt response": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request will error with malformed update": { - responseConfig: gsResponseConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt request will ignore non-data-transfer update": { - responseConfig: gsResponseConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "outgoing gs request with recognized dt response can send message on update": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.incomingResponseHookActions.TerminationError) - assertHasOutgoingMessage(t, gsData.incomingResponseHookActions.SentExtensions, - events.RequestReceivedResponse) - }, - }, - "outgoing gs request with recognized dt response err will error": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - events: fakeEvents{ - OnRequestReceivedErrors: []error{errors.New("something went wrong")}, - }, - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.incomingResponseHOok() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Error(t, gsData.incomingResponseHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will validate gs request & send dt response": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) - require.True(t, gsData.incomingRequestHookActions.Validated) - assertHasExtensionMessage(t, extension.ExtensionDataTransfer1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) - require.NoError(t, gsData.incomingRequestHookActions.TerminationError) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.RequestReceivedChannelID: { - Current: gsData.request.ID(), - }, - }, - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "incoming gs request with recognized dt response will validate gs request": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.Equal(t, events.ResponseReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - dtResponseData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtResponseData, events.ResponseReceivedResponse) - require.True(t, gsData.incomingRequestHookActions.Validated) - require.NoError(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "malformed data transfer extension on incoming request will terminate": { - requestConfig: gsRequestConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, gsData.incomingRequestHookActions.Validated) - require.Error(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "unrecognized incoming dt request will terminate but send response": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - OnRequestReceivedErrors: []error{errors.New("something went wrong")}, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) - assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) - require.False(t, gsData.incomingRequestHookActions.Validated) - assertHasExtensionMessage(t, extension.ExtensionIncomingRequest1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) - require.Error(t, gsData.incomingRequestHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will record outgoing blocks": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - - "incoming gs request with recognized dt response will record outgoing blocks": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "non-data-transfer request will not record outgoing blocks": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, events.OnDataQueuedCalled) - }, - }, - "outgoing data queued error will terminate request": { - events: fakeEvents{ - OnDataQueuedError: errors.New("something went wrong"), - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.Error(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "outgoing data queued error == pause will pause request": { - events: fakeEvents{ - OnDataQueuedError: datatransfer.ErrPause, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.True(t, gsData.outgoingBlockHookActions.Paused) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will send updates": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.outgoingBlockHook() - }, - events: fakeEvents{ - OnDataQueuedMessage: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnDataQueuedCalled) - require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) - assertHasExtensionMessage(t, extension.ExtensionOutgoingBlock1_1, gsData.outgoingBlockHookActions.SentExtensions, - events.OnDataQueuedMessage) - }, - }, - "incoming gs request with recognized dt request can receive update": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request cannot receive update with dt response": { - updatedConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Equal(t, 0, events.OnResponseReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt response can receive update": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - updatedConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnResponseReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt response cannot receive update with dt request": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnResponseReceivedCallCount) - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will error with malformed update": { - updatedConfig: gsRequestConfig{ - dtExtensionMalformed: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.Error(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request will ignore non-data-transfer update": { - updatedConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - }, - }, - "incoming gs request with recognized dt request can send message on update": { - events: fakeEvents{ - RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestUpdatedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 2, events.OnRequestReceivedCallCount) - require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) - assertHasOutgoingMessage(t, gsData.requestUpdatedHookActions.SentExtensions, - events.RequestReceivedResponse) - }, - }, - "recognized incoming request will record successful request completion": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedFull, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnChannelCompletedCalled) - require.True(t, events.ChannelCompletedSuccess) - }, - }, - - "recognized incoming request will record unsuccessful request completion": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedPartial, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnChannelCompletedCalled) - require.False(t, events.ChannelCompletedSuccess) - }, - }, - "recognized incoming request will not record request cancellation": { - responseConfig: gsResponseConfig{ - status: graphsync.RequestCancelled, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.False(t, events.OnChannelCompletedCalled) - }, - }, - "non-data-transfer request will not record request completed": { - requestConfig: gsRequestConfig{ - dtExtensionMissing: true, - }, - responseConfig: gsResponseConfig{ - status: graphsync.RequestCompletedPartial, - }, - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.responseCompletedListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 0, events.OnRequestReceivedCallCount) - require.False(t, events.OnChannelCompletedCalled) - }, - }, - "recognized incoming request can be closed": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertCancelReceived(gsData.ctx, t) - }, - }, - "unrecognized request cannot be closed": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not close via graphsync": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoCancelReceived(t) - }, - }, - "recognized incoming request can be paused": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertPauseReceived(gsData.ctx, t) - }, - }, - "unrecognized request cannot be paused": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not pause via graphsync": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoPauseReceived(t) - }, - }, - - "incoming request can be queued": { - action: func(gsData *harness) { - gsData.incomingRequestQueuedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.TransferQueuedCalled) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - events.TransferQueuedChannelID) - }, - }, - - "incoming request with dtResponse can be queued": { - requestConfig: gsRequestConfig{ - dtIsResponse: true, - }, - responseConfig: gsResponseConfig{ - dtIsResponse: true, - }, - action: func(gsData *harness) { - gsData.incomingRequestQueuedHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.TransferQueuedCalled) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - events.TransferQueuedChannelID) - }, - }, - - "recognized incoming request can be resumed": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertResumeReceived(gsData.ctx, t) - }, - }, - - "unrecognized request cannot be resumed": { - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.Error(t, err) - }, - }, - "recognized incoming request that requestor cancelled will not resume via graphsync but will resume otherwise": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.requestorCancelledListener() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - err := gsData.transport.ResumeChannel(gsData.ctx, - gsData.incoming, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, - ) - require.NoError(t, err) - require.Equal(t, 1, events.OnRequestReceivedCallCount) - gsData.fgs.AssertNoResumeReceived(t) - gsData.incomingRequestHook() - assertHasOutgoingMessage(t, gsData.incomingRequestHookActions.SentExtensions, gsData.incoming) - }, - }, - "recognized incoming request will record network send error": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.networkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnSendDataErrorCalled) - }, - }, - "recognized outgoing request will record network send error": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.networkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.OnSendDataErrorCalled) - }, - }, - "recognized incoming request will record network receive error": { - action: func(gsData *harness) { - gsData.incomingRequestHook() - gsData.receiverNetworkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnReceiveDataErrorCalled) - }, - }, - "recognized outgoing request will record network receive error": { - action: func(gsData *harness) { - gsData.outgoingRequestHook() - gsData.receiverNetworkErrorListener(errors.New("something went wrong")) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - require.True(t, events.OnReceiveDataErrorCalled) - }, - }, - "open channel adds block count to the DoNotSendFirstBlocks extension for v1.2 protocol": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - ext := requestReceived.Extensions - require.Len(t, ext, 2) - doNotSend := ext[1] - - name := doNotSend.Name - require.Equal(t, graphsync.ExtensionsDoNotSendFirstBlocks, name) - data := doNotSend.Data - blockCount, err := donotsendfirstblocks.DecodeDoNotSendFirstBlocks(data) - require.NoError(t, err) - require.EqualValues(t, blockCount, 2) - }, - }, - "ChannelsForPeer when request is open": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.ChannelOpenedChannelID: { - Current: gsData.request.ID(), - }, - }, - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "open channel cancels an existing request with the same channel ID": { - action: func(gsData *harness) { - cids := testutil.GenerateCids(2) - channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) - stor, _ := gsData.outgoing.Selector() - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - - go gsData.altOutgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - channel, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - - channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) - require.Equal(t, channelsForPeer, ChannelsForPeer{ - ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ - events.ChannelOpenedChannelID: { - Current: gsData.altRequest.ID(), - Previous: []graphsync.RequestID{gsData.request.ID()}, - }, - }, - SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, - }) - }, - }, - "OnChannelCompleted called when outgoing request completes successfully": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnChannelCompletedCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.True(t, events.ChannelCompletedSuccess) - }, - }, - "OnChannelCompleted called when outgoing request completes with error": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - requestReceived.ResponseErrChan <- graphsync.RequestFailedUnknownErr{} - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnChannelCompletedCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.False(t, events.ChannelCompletedSuccess) - }, - }, - "OnChannelComplete when outgoing request cancelled by caller": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - for _, ext := range requestReceived.Extensions { - extensions[ext.Name] = ext.Data - } - request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) - gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) - _ = gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - }, - }, - "request times out if we get request context cancelled error": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - close(requestReceived.ResponseChan) - requestReceived.ResponseErrChan <- graphsync.RequestClientCancelledErr{} - close(requestReceived.ResponseErrChan) - - require.Eventually(t, func() bool { - return events.OnRequestCancelledCalled == true - }, 2*time.Second, 100*time.Millisecond) - require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, events.OnRequestCancelledChannelId) - }, - }, - "request cancelled out if transport shuts down": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - gsData.fgs.AssertRequestReceived(gsData.ctx, t) - - gsData.transport.Shutdown(gsData.ctx) - - ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - gsData.fgs.AssertCancelReceived(ctxt, t) - - require.Nil(t, gsData.fgs.IncomingRequestHook) - require.Nil(t, gsData.fgs.CompletedResponseListener) - require.Nil(t, gsData.fgs.IncomingBlockHook) - require.Nil(t, gsData.fgs.OutgoingBlockHook) - require.Nil(t, gsData.fgs.BlockSentListener) - require.Nil(t, gsData.fgs.OutgoingRequestHook) - require.Nil(t, gsData.fgs.IncomingResponseHook) - require.Nil(t, gsData.fgs.RequestUpdatedHook) - require.Nil(t, gsData.fgs.RequestorCancelledListener) - require.Nil(t, gsData.fgs.NetworkErrorListener) - }, - }, - "request pause works even if called when request is still pending": { - action: func(gsData *harness) { - gsData.fgs.LeaveRequestsOpen() - stor, _ := gsData.outgoing.Selector() - - go gsData.outgoingRequestHook() - _ = gsData.transport.OpenChannel( - gsData.ctx, - gsData.other, - datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, - cidlink.Link{Cid: gsData.outgoing.BaseCid()}, - stor, - nil, - gsData.outgoing) - - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) - assertHasOutgoingMessage(t, requestReceived.Extensions, gsData.outgoing) - completed := make(chan struct{}) - go func() { - err := gsData.transport.PauseChannel(context.Background(), datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - require.NoError(t, err) - close(completed) - }() - time.Sleep(100 * time.Millisecond) - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - for _, ext := range requestReceived.Extensions { - extensions[ext.Name] = ext.Data - } - request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) - gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) - select { - case <-gsData.ctx.Done(): - t.Fatal("never paused channel") - case <-completed: - } - }, - }, - "UseStore can change store used for outgoing requests": { - action: func(gsData *harness) { - lsys := cidlink.DefaultLinkSystem() - lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { - return nil, nil - } - lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { - return nil, nil, nil - } - _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, lsys) - gsData.outgoingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}.String() - gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) - require.Equal(t, expectedChannel, gsData.outgoingRequestHookActions.PersistenceOption) - gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) - gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) - }, - }, - "UseStore can change store used for incoming requests": { - action: func(gsData *harness) { - lsys := cidlink.DefaultLinkSystem() - lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { - return nil, nil - } - lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { - return nil, nil, nil - } - _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, lsys) - gsData.incomingRequestHook() - }, - check: func(t *testing.T, events *fakeEvents, gsData *harness) { - expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}.String() - gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) - require.Equal(t, expectedChannel, gsData.incomingRequestHookActions.PersistenceOption) - gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) - gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) - }, - }, - } - - ctx := context.Background() - for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - peers := testutil.GeneratePeers(2) - transferID := datatransfer.TransferID(rand.Uint32()) - requestID := graphsync.NewRequestID() - request := data.requestConfig.makeRequest(t, transferID, requestID) - altRequest := data.requestConfig.makeRequest(t, transferID, graphsync.NewRequestID()) - response := data.responseConfig.makeResponse(t, transferID, requestID) - updatedRequest := data.updatedConfig.makeRequest(t, transferID, requestID) - block := testutil.NewFakeBlockData() - fgs := testutil.NewFakeGraphSync() - outgoing := testutil.NewDTRequest(t, transferID) - incoming := testutil.NewDTResponse(t, transferID) - transport := NewTransport(peers[0], fgs) - gsData := &harness{ - ctx: ctx, - outgoing: outgoing, - incoming: incoming, - transport: transport, - fgs: fgs, - self: peers[0], - transferID: transferID, - other: peers[1], - altRequest: altRequest, - request: request, - response: response, - updatedRequest: updatedRequest, - block: block, - outgoingRequestHookActions: &testutil.FakeOutgoingRequestHookActions{}, - outgoingBlockHookActions: &testutil.FakeOutgoingBlockHookActions{}, - incomingBlockHookActions: &testutil.FakeIncomingBlockHookActions{}, - incomingRequestHookActions: &testutil.FakeIncomingRequestHookActions{}, - requestUpdatedHookActions: &testutil.FakeRequestUpdatedActions{}, - incomingResponseHookActions: &testutil.FakeIncomingResponseHookActions{}, - requestQueuedHookActions: &testutil.FakeRequestQueuedHookActions{}, - } - require.NoError(t, transport.SetEventHandler(&data.events)) - if data.action != nil { - data.action(gsData) - } - data.check(t, &data.events, gsData) - }) - } -} - -type fakeEvents struct { - ChannelOpenedChannelID datatransfer.ChannelID - RequestReceivedChannelID datatransfer.ChannelID - ResponseReceivedChannelID datatransfer.ChannelID - OnChannelOpenedError error - OnDataReceivedCalled bool - OnDataReceivedError error - OnDataSentCalled bool - OnRequestReceivedCallCount int - OnRequestReceivedErrors []error - OnResponseReceivedCallCount int - OnResponseReceivedErrors []error - OnChannelCompletedCalled bool - OnChannelCompletedErr error - OnDataQueuedCalled bool - OnDataQueuedMessage datatransfer.Message - OnDataQueuedError error - - OnRequestCancelledCalled bool - OnRequestCancelledChannelId datatransfer.ChannelID - OnSendDataErrorCalled bool - OnSendDataErrorChannelID datatransfer.ChannelID - OnReceiveDataErrorCalled bool - OnReceiveDataErrorChannelID datatransfer.ChannelID - OnContextAugmentFunc func(context.Context) context.Context - TransferQueuedCalled bool - TransferQueuedChannelID datatransfer.ChannelID - - ChannelCompletedSuccess bool - RequestReceivedRequest datatransfer.Request - RequestReceivedResponse datatransfer.Response - ResponseReceivedResponse datatransfer.Response -} - -func (fe *fakeEvents) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { - fe.OnDataQueuedCalled = true - - return fe.OnDataQueuedMessage, fe.OnDataQueuedError -} - -func (fe *fakeEvents) OnRequestCancelled(chid datatransfer.ChannelID, err error) error { - fe.OnRequestCancelledCalled = true - fe.OnRequestCancelledChannelId = chid - - return nil -} - -func (fe *fakeEvents) OnTransferQueued(chid datatransfer.ChannelID) { - fe.TransferQueuedCalled = true - fe.TransferQueuedChannelID = chid -} - -func (fe *fakeEvents) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { - return nil -} - -func (fe *fakeEvents) OnSendDataError(chid datatransfer.ChannelID, err error) error { - fe.OnSendDataErrorCalled = true - fe.OnSendDataErrorChannelID = chid - return nil -} - -func (fe *fakeEvents) OnReceiveDataError(chid datatransfer.ChannelID, err error) error { - fe.OnReceiveDataErrorCalled = true - fe.OnReceiveDataErrorChannelID = chid - return nil -} - -func (fe *fakeEvents) OnChannelOpened(chid datatransfer.ChannelID) error { - fe.ChannelOpenedChannelID = chid - return fe.OnChannelOpenedError -} - -func (fe *fakeEvents) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - fe.OnDataReceivedCalled = true - return fe.OnDataReceivedError -} - -func (fe *fakeEvents) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { - fe.OnDataSentCalled = true - return nil -} - -func (fe *fakeEvents) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { - fe.OnRequestReceivedCallCount++ - fe.RequestReceivedChannelID = chid - fe.RequestReceivedRequest = request - var err error - if len(fe.OnRequestReceivedErrors) > 0 { - err, fe.OnRequestReceivedErrors = fe.OnRequestReceivedErrors[0], fe.OnRequestReceivedErrors[1:] - } - return fe.RequestReceivedResponse, err -} - -func (fe *fakeEvents) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { - fe.OnResponseReceivedCallCount++ - fe.ResponseReceivedResponse = response - fe.ResponseReceivedChannelID = chid - var err error - if len(fe.OnResponseReceivedErrors) > 0 { - err, fe.OnResponseReceivedErrors = fe.OnResponseReceivedErrors[0], fe.OnResponseReceivedErrors[1:] - } - return err -} - -func (fe *fakeEvents) OnChannelCompleted(chid datatransfer.ChannelID, completeErr error) error { - fe.OnChannelCompletedCalled = true - fe.ChannelCompletedSuccess = completeErr == nil - return fe.OnChannelCompletedErr -} - -func (fe *fakeEvents) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { - return fe.OnContextAugmentFunc -} - -type harness struct { - outgoing datatransfer.Request - incoming datatransfer.Response - ctx context.Context - transport *Transport - fgs *testutil.FakeGraphSync - transferID datatransfer.TransferID - self peer.ID - other peer.ID - block graphsync.BlockData - request graphsync.RequestData - altRequest graphsync.RequestData - response graphsync.ResponseData - updatedRequest graphsync.RequestData - outgoingRequestHookActions *testutil.FakeOutgoingRequestHookActions - incomingBlockHookActions *testutil.FakeIncomingBlockHookActions - outgoingBlockHookActions *testutil.FakeOutgoingBlockHookActions - incomingRequestHookActions *testutil.FakeIncomingRequestHookActions - requestUpdatedHookActions *testutil.FakeRequestUpdatedActions - incomingResponseHookActions *testutil.FakeIncomingResponseHookActions - requestQueuedHookActions *testutil.FakeRequestQueuedHookActions -} - -func (ha *harness) outgoingRequestHook() { - ha.fgs.OutgoingRequestHook(ha.other, ha.request, ha.outgoingRequestHookActions) -} - -func (ha *harness) altOutgoingRequestHook() { - ha.fgs.OutgoingRequestHook(ha.other, ha.altRequest, ha.outgoingRequestHookActions) -} - -func (ha *harness) incomingBlockHook() { - ha.fgs.IncomingBlockHook(ha.other, ha.response, ha.block, ha.incomingBlockHookActions) -} -func (ha *harness) outgoingBlockHook() { - ha.fgs.OutgoingBlockHook(ha.other, ha.request, ha.block, ha.outgoingBlockHookActions) -} - -func (ha *harness) incomingRequestHook() { - ha.fgs.IncomingRequestHook(ha.other, ha.request, ha.incomingRequestHookActions) -} - -func (ha *harness) incomingRequestQueuedHook() { - ha.fgs.IncomingRequestQueuedHook(ha.other, ha.request, ha.requestQueuedHookActions) -} - -func (ha *harness) requestUpdatedHook() { - ha.fgs.RequestUpdatedHook(ha.other, ha.request, ha.updatedRequest, ha.requestUpdatedHookActions) -} -func (ha *harness) incomingResponseHOok() { - ha.fgs.IncomingResponseHook(ha.other, ha.response, ha.incomingResponseHookActions) -} -func (ha *harness) responseCompletedListener() { - ha.fgs.CompletedResponseListener(ha.other, ha.request, ha.response.Status()) -} -func (ha *harness) requestorCancelledListener() { - ha.fgs.RequestorCancelledListener(ha.other, ha.request) -} -func (ha *harness) networkErrorListener(err error) { - ha.fgs.NetworkErrorListener(ha.other, ha.request, err) -} -func (ha *harness) receiverNetworkErrorListener(err error) { - ha.fgs.ReceiverNetworkErrorListener(ha.other, err) -} - -type dtConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool -} - -func (dtc *dtConfig) extensions(t *testing.T, transferID datatransfer.TransferID, extName graphsync.ExtensionName) map[graphsync.ExtensionName]datamodel.Node { - extensions := make(map[graphsync.ExtensionName]datamodel.Node) - if !dtc.dtExtensionMissing { - if dtc.dtExtensionMalformed { - extensions[extName] = basicnode.NewInt(10) - } else { - var msg datatransfer.Message - if dtc.dtIsResponse { - msg = testutil.NewDTResponse(t, transferID) - } else { - msg = testutil.NewDTRequest(t, transferID) - } - nd := msg.ToIPLD() - extensions[extName] = nd - } - } - return extensions -} - -type gsRequestConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool -} - -func (grc *gsRequestConfig) makeRequest(t *testing.T, transferID datatransfer.TransferID, requestID graphsync.RequestID) graphsync.RequestData { - dtConfig := dtConfig{ - dtExtensionMissing: grc.dtExtensionMissing, - dtIsResponse: grc.dtIsResponse, - dtExtensionMalformed: grc.dtExtensionMalformed, - } - extensions := dtConfig.extensions(t, transferID, extension.ExtensionDataTransfer1_1) - return testutil.NewFakeRequest(requestID, extensions) -} - -type gsResponseConfig struct { - dtExtensionMissing bool - dtIsResponse bool - dtExtensionMalformed bool - status graphsync.ResponseStatusCode -} - -func (grc *gsResponseConfig) makeResponse(t *testing.T, transferID datatransfer.TransferID, requestID graphsync.RequestID) graphsync.ResponseData { - dtConfig := dtConfig{ - dtExtensionMissing: grc.dtExtensionMissing, - dtIsResponse: grc.dtIsResponse, - dtExtensionMalformed: grc.dtExtensionMalformed, - } - extensions := dtConfig.extensions(t, transferID, extension.ExtensionDataTransfer1_1) - return testutil.NewFakeResponse(requestID, extensions, grc.status) -} - -func assertDecodesToMessage(t *testing.T, data datamodel.Node, expected datatransfer.Message) { - actual, err := message.FromIPLD(data) - require.NoError(t, err) - require.Equal(t, expected, actual) -} - -func assertHasOutgoingMessage(t *testing.T, extensions []graphsync.ExtensionData, expected datatransfer.Message) { - nd := expected.ToIPLD() - found := false - for _, e := range extensions { - if e.Name == extension.ExtensionDataTransfer1_1 { - require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") - found = true - } - } - if !found { - require.Fail(t, "extension not found") - } -} - -func assertHasExtensionMessage(t *testing.T, name graphsync.ExtensionName, extensions []graphsync.ExtensionData, expected datatransfer.Message) { - nd := expected.ToIPLD() - found := false - for _, e := range extensions { - if e.Name == name { - require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") - found = true - } - } - if !found { - require.Fail(t, "extension not found") - } -} -*/ diff --git a/transport/graphsync/hooks.go b/transport/graphsync/hooks.go index 021fea4e..1e90301c 100644 --- a/transport/graphsync/hooks.go +++ b/transport/graphsync/hooks.go @@ -1,47 +1,31 @@ package graphsync import ( + "context" "errors" + "fmt" "github.com/ipfs/go-graphsync" + basicnode "github.com/ipld/go-ipld-prime/node/basic" peer "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/dtchannel" "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" ) // gsOutgoingRequestHook is called when a graphsync request is made func (t *Transport) gsOutgoingRequestHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.OutgoingRequestHookActions) { - message, _ := extension.GetTransferData(request, t.supportedExtensions) - // extension not found; probably not our request. - if message == nil { + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { return } - // A graphsync request is made when either - // - The local node opens a data-transfer pull channel, so the local node - // sends a graphsync request to ask the remote peer for the data - // - The remote peer opened a data-transfer push channel, and in response - // the local node sends a graphsync request to ask for the data - var initiator peer.ID - var responder peer.ID - if message.IsRequest() { - // This is a pull request so the data-transfer initiator is the local node - initiator = t.peerID - responder = p - } else { - // This is a push response so the data-transfer initiator is the remote - // peer: They opened the push channel, we respond by sending a - // graphsync request for the data - initiator = p - responder = t.peerID - } - chid := datatransfer.ChannelID{Initiator: initiator, Responder: responder, ID: message.TransferID()} + // Start tracking the channel if we're not already + ch, err := t.getDTChannel(chid) - // A data transfer channel was opened - err := t.events.OnChannelOpened(chid) if err != nil { // There was an error opening the channel, bail out log.Errorf("processing OnChannelOpened for %s: %s", chid, err) @@ -49,8 +33,8 @@ func (t *Transport) gsOutgoingRequestHook(p peer.ID, request graphsync.RequestDa return } - // Start tracking the channel if we're not already - ch := t.trackDTChannel(chid) + // A data transfer channel was opened + t.events.OnTransportEvent(chid, datatransfer.TransportOpenedChannel{}) // Signal that the channel has been opened ch.GsReqOpened(p, request.ID(), hookActions) @@ -69,49 +53,36 @@ func (t *Transport) gsIncomingBlockHook(p peer.ID, response graphsync.ResponseDa return } - ch.UpdateReceivedCidsIfGreater(block.Index()) + if ch.UpdateReceivedIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { - err = t.events.OnDataReceived(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - return - } + t.events.OnTransportEvent(chid, datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) - if err == datatransfer.ErrPause { - ch.MarkPaused() - hookActions.PauseRequest() + if ch.UpdateProgress(block.BlockSizeOnWire()) { + t.events.OnTransportEvent(chid, datatransfer.TransportReachedDataLimit{}) + hookActions.PauseRequest() + } } + } -func (t *Transport) gsBlockSentHook(p peer.ID, request graphsync.RequestData, block graphsync.BlockData) { - // When a data transfer is restarted, the requester sends a list of CIDs - // that it already has. Graphsync calls the sent hook for all blocks even - // if they are in the list (meaning, they aren't actually sent over the - // wire). So here we check if the block was actually sent - // over the wire before firing the data sent event. - if block.BlockSizeOnWire() == 0 { +func (t *Transport) gsBlockSentListener(p peer.ID, request graphsync.RequestData, block graphsync.BlockData) { + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { return } - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { + ch, err := t.getDTChannel(chid) + if err != nil { + log.Errorf("sent hook error: %s, for channel %s", err, chid) return } - if err := t.events.OnDataSent(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0); err != nil { - log.Errorf("failed to process data sent: %+v", err) + if ch.UpdateSentIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { + t.events.OnTransportEvent(chid, datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) } } func (t *Transport) gsOutgoingBlockHook(p peer.ID, request graphsync.RequestData, block graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { - // When a data transfer is restarted, the requester sends a list of CIDs - // that it already has. Graphsync calls the outgoing block hook for all - // blocks even if they are in the list (meaning, they aren't actually going - // to be sent over the wire). So here we check if the block is actually - // going to be sent over the wire before firing the data queued event. - if block.BlockSizeOnWire() == 0 { - return - } chid, ok := t.requestIDToChannelID.load(request.ID()) if !ok { @@ -124,75 +95,25 @@ func (t *Transport) gsOutgoingBlockHook(p peer.ID, request graphsync.RequestData return } - // OnDataQueued is called when a block is queued to be sent to the remote - // peer. It can return ErrPause to pause the response (eg if payment is - // required) and it can return a message that will be sent with the block - // (eg to ask for payment). - msg, err := t.events.OnDataQueued(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - return - } + if ch.UpdateQueuedIndexIfGreater(block.Index()) && block.BlockSizeOnWire() != 0 { + t.events.OnTransportEvent(chid, datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) - if err == datatransfer.ErrPause { - ch.MarkPaused() - hookActions.PauseResponse() - } - - if msg != nil { - // gsOutgoingBlockHook uses a unique extension name so it can be attached with data from a different hook - // outgoingBlkExtensions also includes the default extension name so it remains compatible with all data-transfer protocol versions out there - extensions, err := extension.ToExtensionData(msg, outgoingBlkExtensions) - if err != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.SendExtensionData(extension) + if ch.UpdateProgress(block.BlockSizeOnWire()) { + t.events.OnTransportEvent(chid, datatransfer.TransportReachedDataLimit{}) + hookActions.PauseResponse() } } } // gsReqQueuedHook is called when graphsync enqueues an incoming request for data -func (t *Transport) gsReqQueuedHook(p peer.ID, request graphsync.RequestData, hookActions graphsync.RequestQueuedHookActions) { - msg, err := extension.GetTransferData(request, t.supportedExtensions) - if err != nil { - log.Errorf("failed GetTransferData, req=%+v, err=%s", request, err) - } - // extension not found; probably not our request. - if msg == nil { +func (t *Transport) gsRequestProcessingListener(p peer.ID, request graphsync.RequestData, requestCount int) { + + chid, ok := t.requestIDToChannelID.load(request.ID()) + if !ok { return } - var chid datatransfer.ChannelID - if msg.IsRequest() { - // when a data transfer request comes in on graphsync, the remote peer - // initiated a pull - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID} - dtRequest := msg.(datatransfer.Request) - if dtRequest.IsNew() { - log.Infof("%s, pull request queued, req_id=%d", chid, request.ID()) - t.events.OnTransferQueued(chid) - } else { - log.Infof("%s, pull restart request queued, req_id=%d", chid, request.ID()) - } - } else { - // when a data transfer response comes in on graphsync, this node - // initiated a push, and the remote peer responded with a request - // for data - chid = datatransfer.ChannelID{ID: msg.TransferID(), Initiator: t.peerID, Responder: p} - response := msg.(datatransfer.Response) - if response.IsNew() { - log.Infof("%s, GS pull request queued in response to our push, req_id=%d", chid, request.ID()) - t.events.OnTransferQueued(chid) - } else { - log.Infof("%s, GS pull request queued in response to our restart push, req_id=%d", chid, request.ID()) - } - } - augmentContext := t.events.OnContextAugment(chid) - if augmentContext != nil { - hookActions.AugmentContext(augmentContext) - } + t.events.OnTransportEvent(chid, datatransfer.TransportInitiatedTransfer{}) } // gsReqRecdHook is called when graphsync receives an incoming request for data @@ -225,7 +146,21 @@ func (t *Transport) gsReqRecdHook(p peer.ID, request graphsync.RequestData, hook log.Debugf("%s: received request for data (pull), req_id=%d", chid, request.ID()) request := msg.(datatransfer.Request) + + // graphsync never receives dt push requests as new graphsync requests -- is so, we should error + isNewOrRestart := (request.IsNew() || request.IsRestart()) + if isNewOrRestart && !request.IsPull() { + hookActions.TerminateWithError(datatransfer.ErrUnsupported) + return + } + responseMessage, err = t.events.OnRequestReceived(chid, request) + + // if we're going to accept this new/restart request, protect connection + if isNewOrRestart && err == nil { + t.dtNet.Protect(p, chid.String()) + } + } else { // when a data transfer response comes in on graphsync, this node // initiated a push, and the remote peer responded with a request @@ -253,26 +188,30 @@ func (t *Transport) gsReqRecdHook(p peer.ID, request graphsync.RequestData, hook } } - if err != nil && err != datatransfer.ErrPause { + if err != nil { hookActions.TerminateWithError(err) return } - // Check if the callback indicated that the channel should be paused - // immediately (eg because data is still being unsealed) - paused := false - if err == datatransfer.ErrPause { - log.Debugf("%s: pausing graphsync response", chid) + hookActions.AugmentContext(t.events.OnContextAugment(chid)) - paused = true - hookActions.PauseResponse() + chst, err := t.events.ChannelState(context.TODO(), chid) + if err != nil { + hookActions.TerminateWithError(err) } - ch := t.trackDTChannel(chid) - t.requestIDToChannelID.set(request.ID(), true, chid) - ch.GsDataRequestRcvd(p, request.ID(), paused, hookActions) - - hookActions.ValidateRequest() + var ch *dtchannel.Channel + if msg.IsRequest() { + ch = t.trackDTChannel(chid) + } else { + ch, err = t.getDTChannel(chid) + if err != nil { + hookActions.TerminateWithError(err) + return + } + } + t.requestIDToChannelID.set(request.ID(), true, chid) + ch.GsDataRequestRcvd(p, request.ID(), chst, hookActions) } // gsCompletedResponseListener is a graphsync.OnCompletedResponseListener. We use it learn when the data transfer is complete @@ -293,9 +232,11 @@ func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.Req } ch.MarkTransferComplete() - var completeErr error - if status != graphsync.RequestCompletedFull { - completeErr = xerrors.Errorf("graphsync response to peer %s did not complete: response status code %s", p, status.String()) + var completeEvent datatransfer.TransportCompletedTransfer + if status == graphsync.RequestCompletedFull { + completeEvent.Success = true + } else { + completeEvent.ErrorMessage = fmt.Sprintf("graphsync response to peer %s did not complete: response status code %s", p, status.String()) } // Used by the tests to listen for when a response completes @@ -303,35 +244,7 @@ func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.Req t.completedResponseListener(chid) } - err = t.events.OnChannelCompleted(chid, completeErr) - if err != nil { - log.Error(err) - } - -} - -func (t *Transport) gsRequestUpdatedHook(p peer.ID, request graphsync.RequestData, update graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) { - chid, ok := t.requestIDToChannelID.load(request.ID()) - if !ok { - return - } - - responseMessage, err := t.processExtension(chid, update, p, t.supportedExtensions) - - if responseMessage != nil { - extensions, extensionErr := extension.ToExtensionData(responseMessage, t.supportedExtensions) - if extensionErr != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.SendExtensionData(extension) - } - } - - if err != nil && err != datatransfer.ErrPause { - hookActions.TerminateWithError(err) - } + t.events.OnTransportEvent(chid, completeEvent) } @@ -344,14 +257,7 @@ func (t *Transport) gsIncomingResponseHook(p peer.ID, response graphsync.Respons responseMessage, err := t.processExtension(chid, response, p, incomingReqExtensions) if responseMessage != nil { - extensions, extensionErr := extension.ToExtensionData(responseMessage, t.supportedExtensions) - if extensionErr != nil { - hookActions.TerminateWithError(err) - return - } - for _, extension := range extensions { - hookActions.UpdateRequestWithExtensions(extension) - } + t.dtNet.SendMessage(context.TODO(), p, transportID, responseMessage) } if err != nil { @@ -361,7 +267,11 @@ func (t *Transport) gsIncomingResponseHook(p peer.ID, response graphsync.Respons // In a case where the transfer sends blocks immediately this extension may contain both a // response message and a revalidation request so we trigger OnResponseReceived again for this // specific extension name - _, err = t.processExtension(chid, response, p, []graphsync.ExtensionName{extension.ExtensionOutgoingBlock1_1}) + responseMessage, err = t.processExtension(chid, response, p, outgoingBlkExtensions) + + if responseMessage != nil { + t.dtNet.SendMessage(context.TODO(), p, transportID, responseMessage) + } if err != nil { hookActions.TerminateWithError(err) @@ -375,14 +285,12 @@ func (t *Transport) processExtension(chid datatransfer.ChannelID, gsMsg extensio if err != nil { return nil, err } - // extension not found; probably not our request. if msg == nil { return nil, nil } if msg.IsRequest() { - // only accept request message updates when original message was also request if (chid != datatransfer.ChannelID{ID: msg.TransferID(), Initiator: p, Responder: t.peerID}) { return nil, errors.New("received request on response channel") @@ -397,6 +305,7 @@ func (t *Transport) processExtension(chid datatransfer.ChannelID, gsMsg extensio } dtResponse := msg.(datatransfer.Response) + return nil, t.events.OnResponseReceived(chid, dtResponse) } @@ -426,10 +335,7 @@ func (t *Transport) gsNetworkSendErrorListener(p peer.ID, request graphsync.Requ return } - err := t.events.OnSendDataError(chid, gserr) - if err != nil { - log.Errorf("failed to fire transport send error %s: %s", gserr, err) - } + t.events.OnTransportEvent(chid, datatransfer.TransportErrorSendingData{ErrorMessage: gserr.Error()}) } // Called when there is a graphsync error receiving data @@ -441,9 +347,6 @@ func (t *Transport) gsNetworkReceiveErrorListener(p peer.ID, gserr error) { return } - err := t.events.OnReceiveDataError(chid, gserr) - if err != nil { - log.Errorf("failed to fire transport receive error %s: %s", gserr, err) - } + t.events.OnTransportEvent(chid, datatransfer.TransportErrorReceivingData{ErrorMessage: gserr.Error()}) }) } diff --git a/transport/graphsync/initiating_test.go b/transport/graphsync/initiating_test.go new file mode 100644 index 00000000..8b0d1e53 --- /dev/null +++ b/transport/graphsync/initiating_test.go @@ -0,0 +1,1325 @@ +package graphsync_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestInitiatingPullRequestSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.PullRequest()) + var receivedRequest testharness.ReceivedGraphSyncRequest + var request graphsync.RequestData + t.Run("opens successfully", func(t *testing.T) { + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest = th.Fgs.ReceivedRequests[0] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, th.NewRequest(t), msg) + }) + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + t.Run("receives outgoing request hook", func(t *testing.T) { + th.OutgoingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.OutgoingRequestHookActions.PersistenceOption) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + }) + t.Run("receives outgoing processing listener", func(t *testing.T) { + th.OutgoingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + dtResponse := th.Response() + response := receivedRequest.Response(t, dtResponse, nil, graphsync.PartialResponse) + t.Run("receives response", func(t *testing.T) { + th.IncomingResponseHook(response) + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }) + + t.Run("received block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + dtPauseResponse := th.UpdateResponse(true) + pauseResponse := receivedRequest.Response(t, nil, dtPauseResponse, graphsync.RequestPaused) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtPauseResponse) + }) + + t.Run("send update", func(t *testing.T) { + vRequest := th.VoucherRequest() + th.Transport.SendMessage(ctx, th.Channel.ChannelID(), vRequest) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: vRequest}) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeResponse := th.UpdateResponse(false) + pauseResponse := receivedRequest.Response(t, nil, dtResumeResponse, graphsync.PartialResponse) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtResumeResponse) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetInitiatorPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetInitiatorPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + require.Len(t, th.Fgs.Resumes, 1) + resume := th.Fgs.Resumes[0] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateRequest(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 1) + }) + + t.Run("restart request", func(t *testing.T) { + restartIndex := int64(5) + th.Channel.SetReceivedIndex(basicnode.NewInt(restartIndex)) + err := th.Transport.RestartChannel(ctx, th.Channel, th.RestartRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.DtNet.ConnectWithRetryAttempts, 1) + require.Equal(t, th.DtNet.ConnectWithRetryAttempts[0], testharness.ConnectWithRetryAttempt{th.Channel.OtherPeer(), "graphsync"}) + require.Len(t, th.Fgs.Cancels, 1) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ErrorMessage: "graphsync request cancelled"}) + require.Equal(t, request.ID(), th.Fgs.Cancels[0]) + require.Len(t, th.Fgs.ReceivedRequests, 2) + receivedRequest = th.Fgs.ReceivedRequests[1] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, th.RestartRequest(t), msg) + nd, has := request.Extension(graphsync.ExtensionsDoNotSendFirstBlocks) + require.True(t, has) + val, err := nd.AsInt() + require.NoError(t, err) + require.Equal(t, restartIndex, val) + }) + + t.Run("complete request", func(t *testing.T) { + close(receivedRequest.ResponseChan) + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEventEventually(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +type ctxKey struct{} + +func TestInitiatingPushRequestSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx) + t.Run("opens successfully", func(t *testing.T) { + err := th.Transport.OpenChannel(th.Ctx, th.Channel, th.NewRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.NewRequest(t)}) + }) + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + dtResponse := th.Response() + requestID := graphsync.NewRequestID() + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResponse.ToIPLD()}, graphsync.RequestTypeNew) + //response := receivedRequest.Response(t, dtResponse, nil, graphsync.PartialResponse) + t.Run("receives incoming request hook", func(t *testing.T) { + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedResponse, dtResponse) + }) + + t.Run("receives incoming processing listener", func(t *testing.T) { + th.IncomingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + + t.Run("queued block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("sent block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + th.RequestorCancelledListener(request) + dtPauseResponse := th.UpdateResponse(true) + th.DtNet.Delegates[0].Receiver.ReceiveResponse(ctx, th.Channel.OtherPeer(), dtPauseResponse) + require.Equal(t, th.Events.ReceivedResponse, dtPauseResponse) + }) + + t.Run("send update", func(t *testing.T) { + vRequest := th.VoucherRequest() + th.Transport.SendMessage(ctx, th.Channel.ChannelID(), vRequest) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: vRequest}) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeResponse := th.UpdateResponse(false) + request = testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResumeResponse.ToIPLD()}, graphsync.RequestTypeNew) + // reset hook behavior + th.IncomingRequestHookActions = &testharness.FakeIncomingRequestHookActions{} + th.IncomingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedResponse, dtResumeResponse) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetInitiatorPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetInitiatorPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + require.Len(t, th.Fgs.Resumes, 1) + resume := th.Fgs.Resumes[0] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateRequest(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateRequest(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateRequest(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 1) + }) + + t.Run("restart request", func(t *testing.T) { + err := th.Transport.RestartChannel(ctx, th.Channel, th.RestartRequest(t)) + require.NoError(t, err) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.DtNet.ConnectWithRetryAttempts, 1) + require.Equal(t, th.DtNet.ConnectWithRetryAttempts[0], testharness.ConnectWithRetryAttempt{th.Channel.OtherPeer(), "graphsync"}) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.NewRequest(t)}) + }) + + t.Run("complete request", func(t *testing.T) { + th.ResponseCompletedListener(request, graphsync.RequestCompletedFull) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +/* "gs outgoing request with recognized dt push channel will record incoming blocks": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.True(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "non-data-transfer gs request will not record incoming blocks and send updates": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{}) + require.False(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "gs request unrecognized opened channel will not record incoming blocks": { + events: fakeEvents{ + OnChannelOpenedError: errors.New("Not recognized"), + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.False(t, events.OnDataReceivedCalled) + require.NoError(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "gs incoming block with data receive error will halt request": { + events: fakeEvents{ + OnDataReceivedError: errors.New("something went wrong"), + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.True(t, events.OnDataReceivedCalled) + require.Error(t, gsData.incomingBlockHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request can receive gs response": { + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request cannot receive gs response with dt request": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response can receive gs response": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response cannot receive gs response with dt response": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request will error with malformed update": { + responseConfig: gsResponseConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt request will ignore non-data-transfer update": { + responseConfig: gsResponseConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "outgoing gs request with recognized dt response can send message on update": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, events.ChannelOpenedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.incomingResponseHookActions.TerminationError) + assertHasOutgoingMessage(t, gsData.incomingResponseHookActions.SentExtensions, + events.RequestReceivedResponse) + }, + }, + "outgoing gs request with recognized dt response err will error": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + events: fakeEvents{ + OnRequestReceivedErrors: []error{errors.New("something went wrong")}, + }, + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.incomingResponseHOok() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Error(t, gsData.incomingResponseHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will validate gs request & send dt response": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) + require.True(t, gsData.incomingRequestHookActions.Validated) + assertHasExtensionMessage(t, extension.ExtensionDataTransfer1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) + require.NoError(t, gsData.incomingRequestHookActions.TerminationError) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.RequestReceivedChannelID: { + Current: gsData.request.ID(), + }, + }, + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "incoming gs request with recognized dt response will validate gs request": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.Equal(t, events.ResponseReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + dtResponseData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtResponseData, events.ResponseReceivedResponse) + require.True(t, gsData.incomingRequestHookActions.Validated) + require.NoError(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "malformed data transfer extension on incoming request will terminate": { + requestConfig: gsRequestConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, gsData.incomingRequestHookActions.Validated) + require.Error(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "unrecognized incoming dt request will terminate but send response": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + OnRequestReceivedErrors: []error{errors.New("something went wrong")}, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Equal(t, events.RequestReceivedChannelID, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + dtRequestData, _ := gsData.request.Extension(extension.ExtensionDataTransfer1_1) + assertDecodesToMessage(t, dtRequestData, events.RequestReceivedRequest) + require.False(t, gsData.incomingRequestHookActions.Validated) + assertHasExtensionMessage(t, extension.ExtensionIncomingRequest1_1, gsData.incomingRequestHookActions.SentExtensions, events.RequestReceivedResponse) + require.Error(t, gsData.incomingRequestHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will record outgoing blocks": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + + "incoming gs request with recognized dt response will record outgoing blocks": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "non-data-transfer request will not record outgoing blocks": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, events.OnDataQueuedCalled) + }, + }, + "outgoing data queued error will terminate request": { + events: fakeEvents{ + OnDataQueuedError: errors.New("something went wrong"), + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.Error(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "outgoing data queued error == pause will pause request": { + events: fakeEvents{ + OnDataQueuedError: datatransfer.ErrPause, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.True(t, gsData.outgoingBlockHookActions.Paused) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will send updates": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.outgoingBlockHook() + }, + events: fakeEvents{ + OnDataQueuedMessage: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnDataQueuedCalled) + require.NoError(t, gsData.outgoingBlockHookActions.TerminationError) + assertHasExtensionMessage(t, extension.ExtensionOutgoingBlock1_1, gsData.outgoingBlockHookActions.SentExtensions, + events.OnDataQueuedMessage) + }, + }, + "incoming gs request with recognized dt request can receive update": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request cannot receive update with dt response": { + updatedConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Equal(t, 0, events.OnResponseReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt response can receive update": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + updatedConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnResponseReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt response cannot receive update with dt request": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnResponseReceivedCallCount) + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will error with malformed update": { + updatedConfig: gsRequestConfig{ + dtExtensionMalformed: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.Error(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request will ignore non-data-transfer update": { + updatedConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + }, + }, + "incoming gs request with recognized dt request can send message on update": { + events: fakeEvents{ + RequestReceivedResponse: testutil.NewDTResponse(t, datatransfer.TransferID(rand.Uint32())), + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestUpdatedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 2, events.OnRequestReceivedCallCount) + require.NoError(t, gsData.requestUpdatedHookActions.TerminationError) + assertHasOutgoingMessage(t, gsData.requestUpdatedHookActions.SentExtensions, + events.RequestReceivedResponse) + }, + }, + "recognized incoming request will record successful request completion": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedFull, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnChannelCompletedCalled) + require.True(t, events.ChannelCompletedSuccess) + }, + }, + + "recognized incoming request will record unsuccessful request completion": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedPartial, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnChannelCompletedCalled) + require.False(t, events.ChannelCompletedSuccess) + }, + }, + "recognized incoming request will not record request cancellation": { + responseConfig: gsResponseConfig{ + status: graphsync.RequestCancelled, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.False(t, events.OnChannelCompletedCalled) + }, + }, + "non-data-transfer request will not record request completed": { + requestConfig: gsRequestConfig{ + dtExtensionMissing: true, + }, + responseConfig: gsResponseConfig{ + status: graphsync.RequestCompletedPartial, + }, + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.responseCompletedListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 0, events.OnRequestReceivedCallCount) + require.False(t, events.OnChannelCompletedCalled) + }, + }, + "recognized incoming request can be closed": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertCancelReceived(gsData.ctx, t) + }, + }, + "unrecognized request cannot be closed": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not close via graphsync": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoCancelReceived(t) + }, + }, + "recognized incoming request can be paused": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertPauseReceived(gsData.ctx, t) + }, + }, + "unrecognized request cannot be paused": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not pause via graphsync": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.PauseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoPauseReceived(t) + }, + }, + + "incoming request can be queued": { + action: func(gsData *harness) { + gsData.incomingRequestQueuedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.TransferQueuedCalled) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + events.TransferQueuedChannelID) + }, + }, + + "incoming request with dtResponse can be queued": { + requestConfig: gsRequestConfig{ + dtIsResponse: true, + }, + responseConfig: gsResponseConfig{ + dtIsResponse: true, + }, + action: func(gsData *harness) { + gsData.incomingRequestQueuedHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.TransferQueuedCalled) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + events.TransferQueuedChannelID) + }, + }, + + "recognized incoming request can be resumed": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertResumeReceived(gsData.ctx, t) + }, + }, + + "unrecognized request cannot be resumed": { + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.Error(t, err) + }, + }, + "recognized incoming request that requestor cancelled will not resume via graphsync but will resume otherwise": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.requestorCancelledListener() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + err := gsData.transport.ResumeChannel(gsData.ctx, + gsData.incoming, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, + ) + require.NoError(t, err) + require.Equal(t, 1, events.OnRequestReceivedCallCount) + gsData.fgs.AssertNoResumeReceived(t) + gsData.incomingRequestHook() + assertHasOutgoingMessage(t, gsData.incomingRequestHookActions.SentExtensions, gsData.incoming) + }, + }, + "recognized incoming request will record network send error": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.networkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnSendDataErrorCalled) + }, + }, + "recognized outgoing request will record network send error": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.networkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.OnSendDataErrorCalled) + }, + }, + "recognized incoming request will record network receive error": { + action: func(gsData *harness) { + gsData.incomingRequestHook() + gsData.receiverNetworkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.Equal(t, 1, events.OnRequestReceivedCallCount) + require.True(t, events.OnReceiveDataErrorCalled) + }, + }, + "recognized outgoing request will record network receive error": { + action: func(gsData *harness) { + gsData.outgoingRequestHook() + gsData.receiverNetworkErrorListener(errors.New("something went wrong")) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + require.True(t, events.OnReceiveDataErrorCalled) + }, + }, + "open channel adds block count to the DoNotSendFirstBlocks extension for v1.2 protocol": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + ext := requestReceived.Extensions + require.Len(t, ext, 2) + doNotSend := ext[1] + + name := doNotSend.Name + require.Equal(t, graphsync.ExtensionsDoNotSendFirstBlocks, name) + data := doNotSend.Data + blockCount, err := donotsendfirstblocks.DecodeDoNotSendFirstBlocks(data) + require.NoError(t, err) + require.EqualValues(t, blockCount, 2) + }, + }, + "ChannelsForPeer when request is open": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.ChannelOpenedChannelID: { + Current: gsData.request.ID(), + }, + }, + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "open channel cancels an existing request with the same channel ID": { + action: func(gsData *harness) { + cids := testutil.GenerateCids(2) + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ReceivedCids: cids}) + stor, _ := gsData.outgoing.Selector() + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + + go gsData.altOutgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + channel, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + + channelsForPeer := gsData.transport.ChannelsForPeer(gsData.other) + require.Equal(t, channelsForPeer, ChannelsForPeer{ + ReceivingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{ + events.ChannelOpenedChannelID: { + Current: gsData.altRequest.ID(), + Previous: []graphsync.RequestID{gsData.request.ID()}, + }, + }, + SendingChannels: map[datatransfer.ChannelID]ChannelGraphsyncRequests{}, + }) + }, + }, + "OnChannelCompleted called when outgoing request completes successfully": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnChannelCompletedCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.True(t, events.ChannelCompletedSuccess) + }, + }, + "OnChannelCompleted called when outgoing request completes with error": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + requestReceived.ResponseErrChan <- graphsync.RequestFailedUnknownErr{} + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnChannelCompletedCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.False(t, events.ChannelCompletedSuccess) + }, + }, + "OnChannelComplete when outgoing request cancelled by caller": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, ext := range requestReceived.Extensions { + extensions[ext.Name] = ext.Data + } + request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) + gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) + _ = gsData.transport.CloseChannel(gsData.ctx, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + }, + }, + "request times out if we get request context cancelled error": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + close(requestReceived.ResponseChan) + requestReceived.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(requestReceived.ResponseErrChan) + + require.Eventually(t, func() bool { + return events.OnRequestCancelledCalled == true + }, 2*time.Second, 100*time.Millisecond) + require.Equal(t, datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, events.OnRequestCancelledChannelId) + }, + }, + "request cancelled out if transport shuts down": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + gsData.fgs.AssertRequestReceived(gsData.ctx, t) + + gsData.transport.Shutdown(gsData.ctx) + + ctxt, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + gsData.fgs.AssertCancelReceived(ctxt, t) + + require.Nil(t, gsData.fgs.IncomingRequestHook) + require.Nil(t, gsData.fgs.CompletedResponseListener) + require.Nil(t, gsData.fgs.IncomingBlockHook) + require.Nil(t, gsData.fgs.OutgoingBlockHook) + require.Nil(t, gsData.fgs.BlockSentListener) + require.Nil(t, gsData.fgs.OutgoingRequestHook) + require.Nil(t, gsData.fgs.IncomingResponseHook) + require.Nil(t, gsData.fgs.RequestUpdatedHook) + require.Nil(t, gsData.fgs.RequestorCancelledListener) + require.Nil(t, gsData.fgs.NetworkErrorListener) + }, + }, + "request pause works even if called when request is still pending": { + action: func(gsData *harness) { + gsData.fgs.LeaveRequestsOpen() + stor, _ := gsData.outgoing.Selector() + + go gsData.outgoingRequestHook() + _ = gsData.transport.OpenChannel( + gsData.ctx, + gsData.other, + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, + cidlink.Link{Cid: gsData.outgoing.BaseCid()}, + stor, + nil, + gsData.outgoing) + + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + requestReceived := gsData.fgs.AssertRequestReceived(gsData.ctx, t) + assertHasOutgoingMessage(t, requestReceived.Extensions, gsData.outgoing) + completed := make(chan struct{}) + go func() { + err := gsData.transport.PauseChannel(context.Background(), datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + require.NoError(t, err) + close(completed) + }() + time.Sleep(100 * time.Millisecond) + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, ext := range requestReceived.Extensions { + extensions[ext.Name] = ext.Data + } + request := testutil.NewFakeRequest(graphsync.NewRequestID(), extensions) + gsData.fgs.OutgoingRequestHook(gsData.other, request, gsData.outgoingRequestHookActions) + select { + case <-gsData.ctx.Done(): + t.Fatal("never paused channel") + case <-completed: + } + }, + }, + "UseStore can change store used for outgoing requests": { + action: func(gsData *harness) { + lsys := cidlink.DefaultLinkSystem() + lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { + return nil, nil + } + lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { + return nil, nil, nil + } + _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}, lsys) + gsData.outgoingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}.String() + gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) + require.Equal(t, expectedChannel, gsData.outgoingRequestHookActions.PersistenceOption) + gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.other, Initiator: gsData.self}) + gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) + }, + }, + "UseStore can change store used for incoming requests": { + action: func(gsData *harness) { + lsys := cidlink.DefaultLinkSystem() + lsys.StorageReadOpener = func(ipld.LinkContext, ipld.Link) (io.Reader, error) { + return nil, nil + } + lsys.StorageWriteOpener = func(ipld.LinkContext) (io.Writer, ipld.BlockWriteCommitter, error) { + return nil, nil, nil + } + _ = gsData.transport.UseStore(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}, lsys) + gsData.incomingRequestHook() + }, + check: func(t *testing.T, events *fakeEvents, gsData *harness) { + expectedChannel := "data-transfer-" + datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}.String() + gsData.fgs.AssertHasPersistenceOption(t, expectedChannel) + require.Equal(t, expectedChannel, gsData.incomingRequestHookActions.PersistenceOption) + gsData.transport.CleanupChannel(datatransfer.ChannelID{ID: gsData.transferID, Responder: gsData.self, Initiator: gsData.other}) + gsData.fgs.AssertDoesNotHavePersistenceOption(t, expectedChannel) + }, + },*/ diff --git a/transport/graphsync/receiver.go b/transport/graphsync/receiver.go index ba9b1639..2391fe1f 100644 --- a/transport/graphsync/receiver.go +++ b/transport/graphsync/receiver.go @@ -3,7 +3,6 @@ package graphsync import ( "context" - "github.com/ipfs/go-graphsync" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" "go.opentelemetry.io/otel" @@ -11,7 +10,6 @@ import ( "go.opentelemetry.io/otel/trace" datatransfer "github.com/filecoin-project/go-data-transfer/v2" - "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" ) type receiver struct { @@ -32,7 +30,10 @@ func (r *receiver) ReceiveRequest( func (r *receiver) receiveRequest(ctx context.Context, initiator peer.ID, incoming datatransfer.Request) error { chid := datatransfer.ChannelID{Initiator: initiator, Responder: r.transport.peerID, ID: incoming.TransferID()} - ctx = r.transport.events.OnContextAugment(chid)(ctx) + ctxAugment := r.transport.events.OnContextAugment(chid) + if ctxAugment != nil { + ctx = ctxAugment(ctx) + } ctx, span := otel.Tracer("gs-data-transfer").Start(ctx, "receiveRequest", trace.WithAttributes( attribute.String("channelID", chid.String()), attribute.String("baseCid", incoming.BaseCid().String()), @@ -43,63 +44,60 @@ func (r *receiver) receiveRequest(ctx context.Context, initiator peer.ID, incomi attribute.Bool("isPaused", incoming.IsPaused()), )) defer span.End() + isNewOrRestart := incoming.IsNew() || incoming.IsRestart() + // a graphsync pull request MUST come in via graphsync + if isNewOrRestart && incoming.IsPull() { + return datatransfer.ErrUnsupported + } response, receiveErr := r.transport.events.OnRequestReceived(chid, incoming) + initiateGraphsyncRequest := isNewOrRestart && response != nil && receiveErr == nil ch, err := r.transport.getDTChannel(chid) - initiateGraphsyncRequest := (response != nil) && (response.IsNew() || response.IsRestart()) && response.Accepted() && !incoming.IsPull() if err != nil { - if !initiateGraphsyncRequest { + if !initiateGraphsyncRequest || receiveErr != nil { if response != nil { - return r.transport.dtNet.SendMessage(ctx, initiator, transportID, response) + if sendErr := r.transport.dtNet.SendMessage(ctx, initiator, transportID, response); sendErr != nil { + return sendErr + } + return receiveErr } return receiveErr } ch = r.transport.trackDTChannel(chid) } - if receiveErr == datatransfer.ErrResume && ch.Paused() { - - var extensions []graphsync.ExtensionData + if receiveErr != nil { if response != nil { - var err error - extensions, err = extension.ToExtensionData(response, r.transport.supportedExtensions) - if err != nil { + if err := r.transport.dtNet.SendMessage(ctx, initiator, transportID, response); err != nil { return err } + _ = ch.Close(ctx) + return receiveErr } - - return ch.Resume(ctx, extensions) } - if response != nil { - if initiateGraphsyncRequest { - stor, _ := incoming.Selector() - if response.IsRestart() { - channel, err := r.transport.events.ChannelState(ctx, chid) - if err != nil { - return err - } - ch.UpdateReceivedCidsIfGreater(channel.ReceivedCidsTotal()) - } - if err := r.transport.openRequest(ctx, initiator, chid, cidlink.Link{Cid: incoming.BaseCid()}, stor, response); err != nil { - return err - } - } else { - if err := r.transport.dtNet.SendMessage(ctx, initiator, transportID, response); err != nil { - return err - } - } + if isNewOrRestart { + r.transport.dtNet.Protect(initiator, chid.String()) + } + chst, err := r.transport.events.ChannelState(ctx, chid) + if err != nil { + return err } - if receiveErr == datatransfer.ErrPause { - return ch.Pause(ctx) + err = ch.UpdateFromChannelState(chst) + if err != nil { + return err } - if receiveErr != nil { - _ = ch.Close(ctx) - return receiveErr + if initiateGraphsyncRequest { + stor, _ := incoming.Selector() + if err := r.transport.openRequest(ctx, initiator, chid, cidlink.Link{Cid: incoming.BaseCid()}, stor, response); err != nil { + return err + } + response = nil } - return nil + action := ch.ActionFromChannelState(chst) + return r.transport.processAction(ctx, chid, ch, action, response) } // ReceiveResponse handles responses to our Push or Pull data transfer request. @@ -135,9 +133,6 @@ func (r *receiver) receiveResponse( if err != nil { return err } - if receiveErr == datatransfer.ErrPause { - return ch.Pause(ctx) - } if receiveErr != nil { log.Warnf("closing channel %s after getting error processing response from %s: %s", chid, sender, err) @@ -180,9 +175,6 @@ func (r *receiver) ReceiveRestartExistingChannelRequest(ctx context.Context, return } - err = r.transport.events.OnRestartExistingChannelRequestReceived(ch) - if err != nil { - log.Errorf(err.Error()) - } + r.transport.events.OnTransportEvent(ch, datatransfer.TransportReceivedRestartExistingChannelRequest{}) return } diff --git a/transport/graphsync/responding_test.go b/transport/graphsync/responding_test.go new file mode 100644 index 00000000..93343574 --- /dev/null +++ b/transport/graphsync/responding_test.go @@ -0,0 +1,413 @@ +package graphsync_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/testharness" +) + +func TestRespondingPullSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.PullRequest(), testharness.Responder()) + + // this actually happens in the request received event handler itself in a real life case, but here we just run it before + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + requestID := graphsync.NewRequestID() + dtRequest := th.NewRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRequest.ToIPLD()}, graphsync.RequestTypeNew) + + // this the actual start of request processing + t.Run("received and responds successfully", func(t *testing.T) { + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Channel.SetResponderPaused(true) + th.Channel.SetDataLimit(10000) + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKey{}, "applesauce") + } + th.IncomingRequestHook(request) + require.Equal(t, dtRequest, th.Events.ReceivedRequest) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.True(t, th.IncomingRequestHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + sentResponse := th.IncomingRequestHookActions.DTMessage(t) + require.Equal(t, dtResponse, sentResponse) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + }) + + t.Run("receives incoming processing listener", func(t *testing.T) { + th.IncomingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + + t.Run("unpause request", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 1) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[0].DTMessage(t)) + }) + + t.Run("queued block / data limits", func(t *testing.T) { + // consume first block + block := testharness.NewFakeBlockData(8000, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume second block -- should hit data limit + block = testharness.NewFakeBlockData(3000, 2, true) + th.OutgoingBlockHook(request, block) + require.True(t, th.OutgoingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + // reset data limit + th.Channel.SetResponderPaused(false) + th.Channel.SetDataLimit(20000) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 2) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[1].DTMessage(t)) + + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume third block + block = testharness.NewFakeBlockData(5000, 4, true) + th.OutgoingBlockHook(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume fourth block should hit data limit again + block = testharness.NewFakeBlockData(5000, 5, true) + th.OutgoingBlockHook(request, block) + require.True(t, th.OutgoingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportQueuedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + }) + + t.Run("sent block", func(t *testing.T) { + block := testharness.NewFakeBlockData(12345, 1, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + block = testharness.NewFakeBlockData(12345, 2, true) + th.BlockSentListener(request, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.BlockSentListener(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportSentData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + }) + + t.Run("receive pause", func(t *testing.T) { + th.RequestorCancelledListener(request) + dtPauseRequest := th.UpdateRequest(true) + th.Events.ReturnedRequestReceivedResponse = nil + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), dtPauseRequest) + require.Equal(t, th.Events.ReceivedRequest, dtPauseRequest) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeRequest := th.UpdateRequest(false) + request = testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtResumeRequest.ToIPLD()}, graphsync.RequestTypeNew) + // reset hook behavior + th.IncomingRequestHookActions = &testharness.FakeIncomingRequestHookActions{} + th.IncomingRequestHook(request) + // only protect on new and restart requests + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingBlockHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + require.Equal(t, th.Events.ReceivedRequest, dtResumeRequest) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetResponderPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + + t.Run("resume", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + require.Len(t, th.Fgs.Resumes, 3) + resume := th.Fgs.Resumes[2] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateResponse(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 3) + }) + + t.Run("restart request", func(t *testing.T) { + dtRestartRequest := th.RestartRequest(t) + request := testharness.NewFakeRequest(requestID, map[graphsync.ExtensionName]datamodel.Node{extension.ExtensionDataTransfer1_1: dtRestartRequest.ToIPLD()}, graphsync.RequestTypeNew) + th.IncomingRequestHook(request) + // protect again for a restart + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.IncomingRequestHookActions.PersistenceOption) + require.True(t, th.IncomingRequestHookActions.Validated) + require.False(t, th.IncomingRequestHookActions.Paused) + require.NoError(t, th.IncomingRequestHookActions.TerminationError) + th.IncomingRequestHookActions.AssertAugmentedContextKey(t, ctxKey{}, "applesauce") + }) + + t.Run("complete request", func(t *testing.T) { + th.ResponseCompletedListener(request, graphsync.RequestCompletedFull) + select { + case <-th.CompletedResponses: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} + +func TestRespondingPushSuccessFlow(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + th := testharness.SetupHarness(ctx, testharness.Responder()) + var receivedRequest testharness.ReceivedGraphSyncRequest + var request graphsync.RequestData + + contextAugmentedCalls := []struct{}{} + th.Events.ReturnedOnContextAugmentFunc = func(ctx context.Context) context.Context { + contextAugmentedCalls = append(contextAugmentedCalls, struct{}{}) + return ctx + } + t.Run("configures persistence", func(t *testing.T) { + th.Transport.UseStore(th.Channel.ChannelID(), cidlink.DefaultLinkSystem()) + th.Fgs.AssertHasPersistenceOption(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String())) + }) + t.Run("receive new request", func(t *testing.T) { + dtResponse := th.Response() + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.Channel.SetResponderPaused(true) + th.Channel.SetDataLimit(10000) + + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Equal(t, th.NewRequest(t), th.Events.ReceivedRequest) + require.Len(t, th.DtNet.ProtectedPeers, 1) + require.Equal(t, th.DtNet.ProtectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.ReceivedRequests, 1) + receivedRequest = th.Fgs.ReceivedRequests[0] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, dtResponse, msg) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, request.ID(), th.Fgs.Pauses[0]) + require.Len(t, contextAugmentedCalls, 1) + }) + + t.Run("unpause request", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 1) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[0].DTMessage(t)) + }) + + t.Run("receives outgoing request hook", func(t *testing.T) { + th.OutgoingRequestHook(request) + require.Equal(t, fmt.Sprintf("data-transfer-%s", th.Channel.ChannelID().String()), th.OutgoingRequestHookActions.PersistenceOption) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportOpenedChannel{}) + }) + + t.Run("receives outgoing processing listener", func(t *testing.T) { + th.OutgoingRequestProcessingListener(request) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportInitiatedTransfer{}) + }) + response := receivedRequest.Response(t, nil, nil, graphsync.PartialResponse) + t.Run("received block / data limits", func(t *testing.T) { + th.IncomingResponseHook(response) + // consume first block + block := testharness.NewFakeBlockData(8000, 1, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume second block -- should hit data limit + block = testharness.NewFakeBlockData(3000, 2, true) + th.IncomingBlockHook(response, block) + require.True(t, th.IncomingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + // reset data limit + th.Channel.SetResponderPaused(false) + th.Channel.SetDataLimit(20000) + dtValidationResponse := th.ValidationResultResponse(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), dtValidationResponse) + require.Len(t, th.Fgs.Resumes, 2) + require.Equal(t, dtValidationResponse, th.Fgs.Resumes[1].DTMessage(t)) + + // block not on wire has no effect + block = testharness.NewFakeBlockData(12345, 3, false) + th.IncomingBlockHook(response, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + // block with lower index has no effect + block = testharness.NewFakeBlockData(67890, 1, true) + th.OutgoingBlockHook(request, block) + th.Events.RefuteTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume third block + block = testharness.NewFakeBlockData(5000, 4, true) + th.IncomingBlockHook(response, block) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + + // consume fourth block should hit data limit again + block = testharness.NewFakeBlockData(5000, 5, true) + th.IncomingBlockHook(response, block) + require.True(t, th.IncomingBlockHookActions.Paused) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReceivedData{Size: block.BlockSize(), Index: basicnode.NewInt(block.Index())}) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportReachedDataLimit{}) + + }) + + t.Run("receive pause", func(t *testing.T) { + dtPauseRequest := th.UpdateRequest(true) + pauseResponse := receivedRequest.Response(t, nil, dtPauseRequest, graphsync.RequestPaused) + th.IncomingResponseHook(pauseResponse) + th.Events.ReturnedRequestReceivedResponse = nil + require.Equal(t, th.Events.ReceivedRequest, dtPauseRequest) + }) + + t.Run("receive resume", func(t *testing.T) { + dtResumeRequest := th.UpdateRequest(false) + pauseResponse := receivedRequest.Response(t, nil, dtResumeRequest, graphsync.PartialResponse) + th.IncomingResponseHook(pauseResponse) + require.Equal(t, th.Events.ReceivedRequest, dtResumeRequest) + }) + + t.Run("pause", func(t *testing.T) { + th.Channel.SetResponderPaused(true) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + require.Len(t, th.Fgs.Pauses, 1) + require.Equal(t, th.Fgs.Pauses[0], request.ID()) + }) + t.Run("pause again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(true)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(true)}) + // should not pause again + require.Len(t, th.Fgs.Pauses, 1) + }) + t.Run("resume", func(t *testing.T) { + th.Channel.SetResponderPaused(false) + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + require.Len(t, th.Fgs.Resumes, 3) + resume := th.Fgs.Resumes[2] + require.Equal(t, request.ID(), resume.RequestID) + msg := resume.DTMessage(t) + require.Equal(t, msg, th.UpdateResponse(false)) + }) + t.Run("resume again", func(t *testing.T) { + th.Transport.ChannelUpdated(ctx, th.Channel.ChannelID(), th.UpdateResponse(false)) + // should send message again + th.DtNet.AssertSentMessage(t, testharness.FakeSentMessage{PeerID: th.Channel.OtherPeer(), TransportID: "graphsync", Message: th.UpdateResponse(false)}) + // should not resume again + require.Len(t, th.Fgs.Resumes, 3) + }) + + t.Run("restart request", func(t *testing.T) { + restartIndex := int64(5) + th.Channel.SetReceivedIndex(basicnode.NewInt(restartIndex)) + dtResponse := th.RestartResponse(false) + th.Events.ReturnedRequestReceivedResponse = dtResponse + th.DtNet.Delegates[0].Receiver.ReceiveRequest(ctx, th.Channel.OtherPeer(), th.NewRequest(t)) + require.Len(t, th.DtNet.ProtectedPeers, 2) + require.Equal(t, th.DtNet.ProtectedPeers[1], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + require.Len(t, th.Fgs.Cancels, 1) + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportTransferCancelled{ErrorMessage: "graphsync request cancelled"}) + require.Equal(t, request.ID(), th.Fgs.Cancels[0]) + require.Len(t, th.Fgs.ReceivedRequests, 2) + receivedRequest = th.Fgs.ReceivedRequests[1] + request = receivedRequest.ToRequestData(t) + msg, err := extension.GetTransferData(request, []graphsync.ExtensionName{ + extension.ExtensionDataTransfer1_1, + }) + require.NoError(t, err) + require.Equal(t, dtResponse, msg) + nd, has := request.Extension(graphsync.ExtensionsDoNotSendFirstBlocks) + require.True(t, has) + val, err := nd.AsInt() + require.NoError(t, err) + require.Equal(t, restartIndex, val) + require.Len(t, contextAugmentedCalls, 2) + }) + + t.Run("complete request", func(t *testing.T) { + close(receivedRequest.ResponseChan) + close(receivedRequest.ResponseErrChan) + select { + case <-th.CompletedRequests: + case <-ctx.Done(): + t.Fatalf("did not complete request") + } + th.Events.AssertTransportEvent(t, th.Channel.ChannelID(), datatransfer.TransportCompletedTransfer{Success: true}) + }) + + t.Run("cleanup request", func(t *testing.T) { + th.Transport.CleanupChannel(th.Channel.ChannelID()) + require.Len(t, th.DtNet.UnprotectedPeers, 1) + require.Equal(t, th.DtNet.UnprotectedPeers[0], testharness.TaggedPeer{th.Channel.OtherPeer(), th.Channel.ChannelID().String()}) + }) +} diff --git a/transport/graphsync/testharness/events.go b/transport/graphsync/testharness/events.go new file mode 100644 index 00000000..398a3ba8 --- /dev/null +++ b/transport/graphsync/testharness/events.go @@ -0,0 +1,74 @@ +package testharness + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" +) + +type ReceivedTransportEvent struct { + ChannelID datatransfer.ChannelID + TransportEvent datatransfer.TransportEvent +} + +type FakeEvents struct { + // function return value parameters + ReturnedRequestReceivedResponse datatransfer.Response + ReturnedRequestReceivedError error + ReturnedResponseReceivedError error + ReturnedChannelState datatransfer.ChannelState + ReturnedOnContextAugmentFunc func(context.Context) context.Context + + // recording of actions + OnRequestReceivedCalled bool + ReceivedRequest datatransfer.Request + OnResponseReceivedCalled bool + ReceivedResponse datatransfer.Response + ReceivedTransportEvents []ReceivedTransportEvent +} + +func (fe *FakeEvents) OnTransportEvent(chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + fe.ReceivedTransportEvents = append(fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} + +func (fe *FakeEvents) AssertTransportEvent(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.Contains(t, fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} + +func (fe *FakeEvents) AssertTransportEventEventually(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.Eventually(t, func() bool { + for _, receivedEvent := range fe.ReceivedTransportEvents { + if (receivedEvent == ReceivedTransportEvent{chid, evt}) { + return true + } + } + return false + }, time.Second, time.Millisecond) +} + +func (fe *FakeEvents) RefuteTransportEvent(t *testing.T, chid datatransfer.ChannelID, evt datatransfer.TransportEvent) { + require.NotContains(t, fe.ReceivedTransportEvents, ReceivedTransportEvent{chid, evt}) +} +func (fe *FakeEvents) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { + fe.OnRequestReceivedCalled = true + fe.ReceivedRequest = request + return fe.ReturnedRequestReceivedResponse, fe.ReturnedRequestReceivedError +} + +func (fe *FakeEvents) OnResponseReceived(chid datatransfer.ChannelID, response datatransfer.Response) error { + fe.OnResponseReceivedCalled = true + fe.ReceivedResponse = response + return fe.ReturnedResponseReceivedError +} + +func (fe *FakeEvents) OnContextAugment(chid datatransfer.ChannelID) func(context.Context) context.Context { + return fe.ReturnedOnContextAugmentFunc +} + +func (fe *FakeEvents) ChannelState(ctx context.Context, chid datatransfer.ChannelID) (datatransfer.ChannelState, error) { + return fe.ReturnedChannelState, nil +} diff --git a/testutil/fakegraphsync.go b/transport/graphsync/testharness/fakegraphsync.go similarity index 71% rename from testutil/fakegraphsync.go rename to transport/graphsync/testharness/fakegraphsync.go index 758a3f40..5f31b454 100644 --- a/testutil/fakegraphsync.go +++ b/transport/graphsync/testharness/fakegraphsync.go @@ -1,4 +1,4 @@ -package testutil +package testharness import ( "context" @@ -19,13 +19,14 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer/v2" "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/testutil" "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" ) -func matchDtMessage(t *testing.T, extensions []graphsync.ExtensionData) datatransfer.Message { +func matchDtMessage(t *testing.T, extensions []graphsync.ExtensionData, extName graphsync.ExtensionName) datatransfer.Message { var matchedExtension *graphsync.ExtensionData for _, ext := range extensions { - if ext.Name == extension.ExtensionDataTransfer1_1 { + if ext.Name == extName { matchedExtension = &ext break } @@ -47,9 +48,37 @@ type ReceivedGraphSyncRequest struct { ResponseErrChan chan error } +func (gsRequest ReceivedGraphSyncRequest) ToRequestData(t *testing.T) graphsync.RequestData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + for _, extension := range gsRequest.Extensions { + extensions[extension.Name] = extension.Data + } + requestID, ok := gsRequest.requestID() + require.True(t, ok) + return NewFakeRequest(requestID, extensions, graphsync.RequestTypeNew) +} + +func (gsRequest ReceivedGraphSyncRequest) Response(t *testing.T, incomingRequestMsg datatransfer.Message, blockMessage datatransfer.Message, code graphsync.ResponseStatusCode) graphsync.ResponseData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if incomingRequestMsg != nil { + extensions[extension.ExtensionIncomingRequest1_1] = incomingRequestMsg.ToIPLD() + } + if blockMessage != nil { + extensions[extension.ExtensionOutgoingBlock1_1] = blockMessage.ToIPLD() + } + requestID, ok := gsRequest.requestID() + require.True(t, ok) + return NewFakeResponse(requestID, extensions, code) +} + +func (gsRequest ReceivedGraphSyncRequest) requestID() (graphsync.RequestID, bool) { + request, ok := gsRequest.Ctx.Value(graphsync.RequestIDContextKey{}).(graphsync.RequestID) + return request, ok +} + // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (gsRequest ReceivedGraphSyncRequest) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, gsRequest.Extensions) + return matchDtMessage(t, gsRequest.Extensions, extension.ExtensionDataTransfer1_1) } type Resume struct { @@ -59,7 +88,7 @@ type Resume struct { // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (resume Resume) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, resume.Extensions) + return matchDtMessage(t, resume.Extensions, extension.ExtensionDataTransfer1_1) } type Update struct { @@ -69,41 +98,41 @@ type Update struct { // DTMessage returns the data transfer message among the graphsync extensions sent with this request func (update Update) DTMessage(t *testing.T) datatransfer.Message { - return matchDtMessage(t, update.Extensions) + return matchDtMessage(t, update.Extensions, extension.ExtensionDataTransfer1_1) } // FakeGraphSync implements a GraphExchange but does nothing type FakeGraphSync struct { - requests chan ReceivedGraphSyncRequest // records calls to fakeGraphSync.Request - pauses chan graphsync.RequestID - resumes chan Resume - cancels chan graphsync.RequestID - updates chan Update - persistenceOptionsLk sync.RWMutex - persistenceOptions map[string]ipld.LinkSystem - leaveRequestsOpen bool - OutgoingRequestHook graphsync.OnOutgoingRequestHook - IncomingBlockHook graphsync.OnIncomingBlockHook - OutgoingBlockHook graphsync.OnOutgoingBlockHook - IncomingRequestQueuedHook graphsync.OnIncomingRequestQueuedHook - IncomingRequestHook graphsync.OnIncomingRequestHook - CompletedResponseListener graphsync.OnResponseCompletedListener - RequestUpdatedHook graphsync.OnRequestUpdatedHook - IncomingResponseHook graphsync.OnIncomingResponseHook - RequestorCancelledListener graphsync.OnRequestorCancelledListener - BlockSentListener graphsync.OnBlockSentListener - NetworkErrorListener graphsync.OnNetworkErrorListener - ReceiverNetworkErrorListener graphsync.OnReceiverNetworkErrorListener + ReceivedRequests []ReceivedGraphSyncRequest // records calls to fakeGraphSync.Request + Pauses []graphsync.RequestID + Resumes []Resume + Cancels []graphsync.RequestID + Updates []Update + persistenceOptionsLk sync.RWMutex + persistenceOptions map[string]ipld.LinkSystem + leaveRequestsOpen bool + OutgoingRequestHook graphsync.OnOutgoingRequestHook + IncomingBlockHook graphsync.OnIncomingBlockHook + OutgoingBlockHook graphsync.OnOutgoingBlockHook + IncomingRequestProcessingListener graphsync.OnRequestProcessingListener + OutgoingRequestProcessingListener graphsync.OnRequestProcessingListener + IncomingRequestHook graphsync.OnIncomingRequestHook + CompletedResponseListener graphsync.OnResponseCompletedListener + RequestUpdatedHook graphsync.OnRequestUpdatedHook + IncomingResponseHook graphsync.OnIncomingResponseHook + RequestorCancelledListener graphsync.OnRequestorCancelledListener + BlockSentListener graphsync.OnBlockSentListener + NetworkErrorListener graphsync.OnNetworkErrorListener + ReceiverNetworkErrorListener graphsync.OnReceiverNetworkErrorListener + ReturnedCancelError error + ReturnedPauseError error + ReturnedResumeError error + ReturnedSendUpdateError error } // NewFakeGraphSync returns a new fake graphsync implementation func NewFakeGraphSync() *FakeGraphSync { return &FakeGraphSync{ - requests: make(chan ReceivedGraphSyncRequest, 2), - pauses: make(chan graphsync.RequestID, 1), - resumes: make(chan Resume, 1), - cancels: make(chan graphsync.RequestID, 1), - updates: make(chan Update, 1), persistenceOptions: make(map[string]ipld.LinkSystem), } } @@ -112,70 +141,6 @@ func (fgs *FakeGraphSync) LeaveRequestsOpen() { fgs.leaveRequestsOpen = true } -// AssertNoRequestReceived asserts that no requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoRequestReceived(t *testing.T) { - require.Empty(t, fgs.requests, "should not receive request") -} - -// AssertRequestReceived asserts a request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertRequestReceived(ctx context.Context, t *testing.T) ReceivedGraphSyncRequest { - var requestReceived ReceivedGraphSyncRequest - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case requestReceived = <-fgs.requests: - } - return requestReceived -} - -// AssertNoPauseReceived asserts that no pause requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoPauseReceived(t *testing.T) { - require.Empty(t, fgs.pauses, "should not receive pause request") -} - -// AssertPauseReceived asserts a pause request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertPauseReceived(ctx context.Context, t *testing.T) graphsync.RequestID { - var pauseReceived graphsync.RequestID - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case pauseReceived = <-fgs.pauses: - } - return pauseReceived -} - -// AssertNoResumeReceived asserts that no resume requests should ahve been received by this graphsync implementation -func (fgs *FakeGraphSync) AssertNoResumeReceived(t *testing.T) { - require.Empty(t, fgs.resumes, "should not receive resume request") -} - -// AssertResumeReceived asserts a resume request should be received before the context closes (and returns said request) -func (fgs *FakeGraphSync) AssertResumeReceived(ctx context.Context, t *testing.T) Resume { - var resumeReceived Resume - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case resumeReceived = <-fgs.resumes: - } - return resumeReceived -} - -// AssertNoCancelReceived asserts that no requests were cancelled by thiss graphsync implementation -func (fgs *FakeGraphSync) AssertNoCancelReceived(t *testing.T) { - require.Empty(t, fgs.cancels, "should not cancel request") -} - -// AssertCancelReceived asserts a requests was cancelled before the context closes (and returns said request id) -func (fgs *FakeGraphSync) AssertCancelReceived(ctx context.Context, t *testing.T) graphsync.RequestID { - var cancelReceived graphsync.RequestID - select { - case <-ctx.Done(): - t.Fatal("did not receive message sent") - case cancelReceived = <-fgs.cancels: - } - return cancelReceived -} - // AssertHasPersistenceOption verifies that a persistence option was registered func (fgs *FakeGraphSync) AssertHasPersistenceOption(t *testing.T, name string) ipld.LinkSystem { fgs.persistenceOptionsLk.RLock() @@ -195,9 +160,9 @@ func (fgs *FakeGraphSync) AssertDoesNotHavePersistenceOption(t *testing.T, name // Request initiates a new GraphSync request to the given peer using the given selector spec. func (fgs *FakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector datamodel.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { - errors := make(chan error) - responses := make(chan graphsync.ResponseProgress) - fgs.requests <- ReceivedGraphSyncRequest{ctx, p, root, selector, extensions, responses, errors} + errors := make(chan error, 1) + responses := make(chan graphsync.ResponseProgress, 1) + fgs.ReceivedRequests = append(fgs.ReceivedRequests, ReceivedGraphSyncRequest{ctx, p, root, selector, extensions, responses, errors}) if !fgs.leaveRequestsOpen { close(responses) close(errors) @@ -233,11 +198,11 @@ func (fgs *FakeGraphSync) RegisterIncomingRequestHook(hook graphsync.OnIncomingR } } -// RegisterIncomingRequestQueuedHook adds a hook that runs when an incoming GS request is queued. -func (fgs *FakeGraphSync) RegisterIncomingRequestQueuedHook(hook graphsync.OnIncomingRequestQueuedHook) graphsync.UnregisterHookFunc { - fgs.IncomingRequestQueuedHook = hook +// RegisterIncomingRequestProcessingListener adds a hook that runs when an incoming GS request begins processing +func (fgs *FakeGraphSync) RegisterIncomingRequestProcessingListener(hook graphsync.OnRequestProcessingListener) graphsync.UnregisterHookFunc { + fgs.IncomingRequestProcessingListener = hook return func() { - fgs.IncomingRequestQueuedHook = nil + fgs.IncomingRequestProcessingListener = nil } } @@ -291,19 +256,29 @@ func (fgs *FakeGraphSync) RegisterCompletedResponseListener(listener graphsync.O // Unpause unpauses a request that was paused in a block hook based on request ID func (fgs *FakeGraphSync) Unpause(ctx context.Context, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - fgs.resumes <- Resume{requestID, extensions} - return nil + fgs.Resumes = append(fgs.Resumes, Resume{requestID, extensions}) + return fgs.ReturnedResumeError } // Pause pauses a request based on request ID func (fgs *FakeGraphSync) Pause(ctx context.Context, requestID graphsync.RequestID) error { - fgs.pauses <- requestID - return nil + fgs.Pauses = append(fgs.Pauses, requestID) + return fgs.ReturnedPauseError } func (fgs *FakeGraphSync) Cancel(ctx context.Context, requestID graphsync.RequestID) error { - fgs.cancels <- requestID - return nil + if fgs.leaveRequestsOpen { + for _, rr := range fgs.ReceivedRequests { + existingRequestID, has := rr.requestID() + if has && requestID.String() == existingRequestID.String() { + close(rr.ResponseChan) + rr.ResponseErrChan <- graphsync.RequestClientCancelledErr{} + close(rr.ResponseErrChan) + } + } + } + fgs.Cancels = append(fgs.Cancels, requestID) + return fgs.ReturnedCancelError } // RegisterRequestorCancelledListener adds a listener on the responder for requests cancelled by the requestor @@ -342,22 +317,25 @@ func (fgs *FakeGraphSync) Stats() graphsync.Stats { return graphsync.Stats{} } -func (fgs *FakeGraphSync) RegisterOutgoingRequestProcessingListener(graphsync.OnOutgoingRequestProcessingListener) graphsync.UnregisterHookFunc { - // TODO: just a stub for now, hopefully nobody needs this - return func() {} +func (fgs *FakeGraphSync) RegisterOutgoingRequestProcessingListener(listener graphsync.OnRequestProcessingListener) graphsync.UnregisterHookFunc { + fgs.OutgoingRequestProcessingListener = listener + return func() { + fgs.OutgoingRequestProcessingListener = nil + } } func (fgs *FakeGraphSync) SendUpdate(ctx context.Context, id graphsync.RequestID, extensions ...graphsync.ExtensionData) error { - fgs.updates <- Update{RequestID: id, Extensions: extensions} - return nil + fgs.Updates = append(fgs.Updates, Update{RequestID: id, Extensions: extensions}) + return fgs.ReturnedSendUpdateError } var _ graphsync.GraphExchange = &FakeGraphSync{} type fakeBlkData struct { - link ipld.Link - size uint64 - index int64 + link ipld.Link + size uint64 + onWire bool + index int64 } func (fbd fakeBlkData) Link() ipld.Link { @@ -369,7 +347,10 @@ func (fbd fakeBlkData) BlockSize() uint64 { } func (fbd fakeBlkData) BlockSizeOnWire() uint64 { - return fbd.size + if fbd.onWire { + return fbd.size + } + return 0 } func (fbd fakeBlkData) Index() int64 { @@ -377,11 +358,12 @@ func (fbd fakeBlkData) Index() int64 { } // NewFakeBlockData returns a fake block that matches the block data interface -func NewFakeBlockData() graphsync.BlockData { +func NewFakeBlockData(size uint64, index int64, onWire bool) graphsync.BlockData { return &fakeBlkData{ - link: cidlink.Link{Cid: GenerateCids(1)[0]}, - size: rand.Uint64(), - index: int64(rand.Uint32()), + link: cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, + size: size, + index: index, + onWire: onWire, } } @@ -427,14 +409,14 @@ func (fr *fakeRequest) Type() graphsync.RequestType { } // NewFakeRequest returns a fake request that matches the request data interface -func NewFakeRequest(id graphsync.RequestID, extensions map[graphsync.ExtensionName]datamodel.Node) graphsync.RequestData { +func NewFakeRequest(id graphsync.RequestID, extensions map[graphsync.ExtensionName]datamodel.Node, requestType graphsync.RequestType) graphsync.RequestData { return &fakeRequest{ id: id, - root: GenerateCids(1)[0], + root: testutil.GenerateCids(1)[0], selector: selectorparse.CommonSelector_ExploreAllRecursively, priority: graphsync.Priority(rand.Int()), extensions: extensions, - requestType: graphsync.RequestTypeNew, + requestType: requestType, } } @@ -533,6 +515,7 @@ type FakeIncomingRequestHookActions struct { Validated bool SentExtensions []graphsync.ExtensionData Paused bool + CtxAugFuncs []func(context.Context) context.Context } func (fa *FakeIncomingRequestHookActions) SendExtensionData(ext graphsync.ExtensionData) { @@ -558,6 +541,30 @@ func (fa *FakeIncomingRequestHookActions) PauseResponse() { fa.Paused = true } +func (fa *FakeIncomingRequestHookActions) AugmentContext(ctxAugFunc func(reqCtx context.Context) context.Context) { + fa.CtxAugFuncs = append(fa.CtxAugFuncs, ctxAugFunc) +} + +func (fa *FakeIncomingRequestHookActions) AssertAugmentedContextKey(t *testing.T, key interface{}, value interface{}) { + ctx := context.Background() + for _, f := range fa.CtxAugFuncs { + ctx = f(ctx) + } + require.Equal(t, value, ctx.Value(key)) +} + +func (fa *FakeIncomingRequestHookActions) RefuteAugmentedContextKey(t *testing.T, key interface{}) { + ctx := context.Background() + for _, f := range fa.CtxAugFuncs { + ctx = f(ctx) + } + require.Nil(t, ctx.Value(key)) +} + +func (fa *FakeIncomingRequestHookActions) DTMessage(t *testing.T) datatransfer.Message { + return matchDtMessage(t, fa.SentExtensions, extension.ExtensionIncomingRequest1_1) +} + var _ graphsync.IncomingRequestHookActions = &FakeIncomingRequestHookActions{} type FakeRequestUpdatedActions struct { @@ -594,13 +601,3 @@ func (fa *FakeIncomingResponseHookActions) UpdateRequestWithExtensions(extension } var _ graphsync.IncomingResponseHookActions = &FakeIncomingResponseHookActions{} - -type FakeRequestQueuedHookActions struct { - ctxAugFuncs []func(context.Context) context.Context -} - -func (fa *FakeRequestQueuedHookActions) AugmentContext(ctxAugFunc func(reqCtx context.Context) context.Context) { - fa.ctxAugFuncs = append(fa.ctxAugFuncs, ctxAugFunc) -} - -var _ graphsync.RequestQueuedHookActions = &FakeRequestQueuedHookActions{} diff --git a/transport/graphsync/testharness/harness.go b/transport/graphsync/testharness/harness.go new file mode 100644 index 00000000..c545ad8b --- /dev/null +++ b/transport/graphsync/testharness/harness.go @@ -0,0 +1,285 @@ +package testharness + +import ( + "context" + "math/rand" + "testing" + + "github.com/ipfs/go-graphsync" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/message" + "github.com/filecoin-project/go-data-transfer/v2/message/types" + "github.com/filecoin-project/go-data-transfer/v2/testutil" + dtgs "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync" + "github.com/filecoin-project/go-data-transfer/v2/transport/graphsync/extension" +) + +type harnessConfig struct { + isPull bool + isResponder bool + makeEvents func(gsData *GsTestHarness) *FakeEvents + makeNetwork func(gsData *GsTestHarness) *FakeNetwork + transportOptions []dtgs.Option +} + +type Option func(*harnessConfig) + +func PullRequest() Option { + return func(hc *harnessConfig) { + hc.isPull = true + } +} + +func Responder() Option { + return func(hc *harnessConfig) { + hc.isResponder = true + } +} + +func Events(makeEvents func(gsData *GsTestHarness) *FakeEvents) Option { + return func(hc *harnessConfig) { + hc.makeEvents = makeEvents + } +} + +func Network(makeNetwork func(gsData *GsTestHarness) *FakeNetwork) Option { + return func(hc *harnessConfig) { + hc.makeNetwork = makeNetwork + } +} + +func TransportOptions(options []dtgs.Option) Option { + return func(hc *harnessConfig) { + hc.transportOptions = options + } +} + +func SetupHarness(ctx context.Context, options ...Option) *GsTestHarness { + hc := &harnessConfig{} + for _, option := range options { + option(hc) + } + peers := testutil.GeneratePeers(2) + transferID := datatransfer.TransferID(rand.Uint32()) + fgs := NewFakeGraphSync() + fgs.LeaveRequestsOpen() + voucher := testutil.NewTestTypedVoucher() + baseCid := testutil.GenerateCids(1)[0] + selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + chid := datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: transferID} + if hc.isResponder { + chid = datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: transferID} + } + channel := testutil.NewMockChannelState(testutil.MockChannelStateParams{ + BaseCID: baseCid, + Voucher: voucher, + Selector: selector, + IsPull: hc.isPull, + Self: peers[0], + ChannelID: chid, + }) + gsData := &GsTestHarness{ + Ctx: ctx, + Fgs: fgs, + Channel: channel, + CompletedRequests: make(chan datatransfer.ChannelID, 16), + CompletedResponses: make(chan datatransfer.ChannelID, 16), + OutgoingRequestHookActions: &FakeOutgoingRequestHookActions{}, + OutgoingBlockHookActions: &FakeOutgoingBlockHookActions{}, + IncomingBlockHookActions: &FakeIncomingBlockHookActions{}, + IncomingRequestHookActions: &FakeIncomingRequestHookActions{}, + RequestUpdateHookActions: &FakeRequestUpdatedActions{}, + IncomingResponseHookActions: &FakeIncomingResponseHookActions{}, + } + if hc.makeEvents != nil { + gsData.Events = hc.makeEvents(gsData) + } else { + gsData.Events = &FakeEvents{ + ReturnedChannelState: channel, + } + } + if hc.makeNetwork != nil { + gsData.DtNet = hc.makeNetwork(gsData) + } else { + gsData.DtNet = NewFakeNetwork(peers[0]) + } + gsData.Transport = dtgs.NewTransport(gsData.Fgs, gsData.DtNet, + append(hc.transportOptions, + dtgs.RegisterCompletedRequestListener(gsData.completedRequestListener), + dtgs.RegisterCompletedResponseListener(gsData.completedResponseListener))...) + gsData.Transport.SetEventHandler(gsData.Events) + return gsData +} + +type GsTestHarness struct { + Ctx context.Context + Fgs *FakeGraphSync + Channel *testutil.MockChannelState + RequestID graphsync.RequestID + AltRequestID graphsync.RequestID + Events *FakeEvents + DtNet *FakeNetwork + OutgoingRequestHookActions *FakeOutgoingRequestHookActions + IncomingBlockHookActions *FakeIncomingBlockHookActions + OutgoingBlockHookActions *FakeOutgoingBlockHookActions + IncomingRequestHookActions *FakeIncomingRequestHookActions + RequestUpdateHookActions *FakeRequestUpdatedActions + IncomingResponseHookActions *FakeIncomingResponseHookActions + Transport *dtgs.Transport + CompletedRequests chan datatransfer.ChannelID + CompletedResponses chan datatransfer.ChannelID +} + +func (th *GsTestHarness) completedRequestListener(chid datatransfer.ChannelID) { + th.CompletedRequests <- chid +} +func (th *GsTestHarness) completedResponseListener(chid datatransfer.ChannelID) { + th.CompletedResponses <- chid +} + +func (th *GsTestHarness) NewRequest(t *testing.T) datatransfer.Request { + vouch := th.Channel.Voucher() + message, err := message.NewRequest(th.Channel.TransferID(), false, th.Channel.IsPull(), &vouch, th.Channel.BaseCID(), th.Channel.Selector()) + require.NoError(t, err) + return message +} + +func (th *GsTestHarness) RestartRequest(t *testing.T) datatransfer.Request { + vouch := th.Channel.Voucher() + message, err := message.NewRequest(th.Channel.TransferID(), true, th.Channel.IsPull(), &vouch, th.Channel.BaseCID(), th.Channel.Selector()) + require.NoError(t, err) + return message +} + +func (th *GsTestHarness) VoucherRequest() datatransfer.Request { + newVouch := testutil.NewTestTypedVoucher() + return message.VoucherRequest(th.Channel.TransferID(), &newVouch) +} + +func (th *GsTestHarness) UpdateRequest(pause bool) datatransfer.Request { + return message.UpdateRequest(th.Channel.TransferID(), pause) +} + +func (th *GsTestHarness) Response() datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.NewResponse(th.Channel.TransferID(), true, false, &voucherResult) +} + +func (th *GsTestHarness) ValidationResultResponse(pause bool) datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.ValidationResultResponse(types.VoucherResultMessage, th.Channel.TransferID(), datatransfer.ValidationResult{VoucherResult: &voucherResult, Accepted: true}, nil, pause) +} + +func (th *GsTestHarness) RestartResponse(pause bool) datatransfer.Response { + voucherResult := testutil.NewTestTypedVoucher() + return message.ValidationResultResponse(types.RestartMessage, th.Channel.TransferID(), datatransfer.ValidationResult{VoucherResult: &voucherResult, Accepted: true}, nil, pause) +} + +func (th *GsTestHarness) UpdateResponse(paused bool) datatransfer.Response { + return message.UpdateResponse(th.Channel.TransferID(), true) +} + +func (th *GsTestHarness) OutgoingRequestHook(request graphsync.RequestData) { + th.Fgs.OutgoingRequestHook(th.Channel.OtherPeer(), request, th.OutgoingRequestHookActions) +} + +func (th *GsTestHarness) OutgoingRequestProcessingListener(request graphsync.RequestData) { + th.Fgs.OutgoingRequestProcessingListener(th.Channel.OtherPeer(), request, 0) +} + +func (th *GsTestHarness) IncomingBlockHook(response graphsync.ResponseData, block graphsync.BlockData) { + th.Fgs.IncomingBlockHook(th.Channel.OtherPeer(), response, block, th.IncomingBlockHookActions) +} + +func (th *GsTestHarness) OutgoingBlockHook(request graphsync.RequestData, block graphsync.BlockData) { + th.Fgs.OutgoingBlockHook(th.Channel.OtherPeer(), request, block, th.OutgoingBlockHookActions) +} + +func (th *GsTestHarness) IncomingRequestHook(request graphsync.RequestData) { + th.Fgs.IncomingRequestHook(th.Channel.OtherPeer(), request, th.IncomingRequestHookActions) +} + +func (th *GsTestHarness) IncomingRequestProcessingListener(request graphsync.RequestData) { + th.Fgs.IncomingRequestProcessingListener(th.Channel.OtherPeer(), request, 1) +} + +func (th *GsTestHarness) IncomingResponseHook(response graphsync.ResponseData) { + th.Fgs.IncomingResponseHook(th.Channel.OtherPeer(), response, th.IncomingResponseHookActions) +} + +func (th *GsTestHarness) ResponseCompletedListener(request graphsync.RequestData, code graphsync.ResponseStatusCode) { + th.Fgs.CompletedResponseListener(th.Channel.OtherPeer(), request, code) +} + +func (th *GsTestHarness) RequestorCancelledListener(request graphsync.RequestData) { + th.Fgs.RequestorCancelledListener(th.Channel.OtherPeer(), request) +} + +/* +func (ha *GsTestHarness) networkErrorListener(err error) { + ha.Fgs.NetworkErrorListener(ha.other, ha.request, err) +} +func (ha *GsTestHarness) receiverNetworkErrorListener(err error) { + ha.Fgs.ReceiverNetworkErrorListener(ha.other, err) +} +*/ + +func (th *GsTestHarness) BlockSentListener(request graphsync.RequestData, block graphsync.BlockData) { + th.Fgs.BlockSentListener(th.Channel.OtherPeer(), request, block) +} + +func (ha *GsTestHarness) makeRequest(requestID graphsync.RequestID, messageNode datamodel.Node, requestType graphsync.RequestType) graphsync.RequestData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if messageNode != nil { + extensions[extension.ExtensionDataTransfer1_1] = messageNode + } + return NewFakeRequest(requestID, extensions, requestType) +} + +func (ha *GsTestHarness) makeResponse(requestID graphsync.RequestID, messageNode datamodel.Node, responseCode graphsync.ResponseStatusCode) graphsync.ResponseData { + extensions := make(map[graphsync.ExtensionName]datamodel.Node) + if messageNode != nil { + extensions[extension.ExtensionDataTransfer1_1] = messageNode + } + return NewFakeResponse(requestID, extensions, responseCode) +} + +func assertDecodesToMessage(t *testing.T, data datamodel.Node, expected datatransfer.Message) { + actual, err := message.FromIPLD(data) + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +func assertHasOutgoingMessage(t *testing.T, extensions []graphsync.ExtensionData, expected datatransfer.Message) { + nd := expected.ToIPLD() + found := false + for _, e := range extensions { + if e.Name == extension.ExtensionDataTransfer1_1 { + require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") + found = true + } + } + if !found { + require.Fail(t, "extension not found") + } +} + +func assertHasExtensionMessage(t *testing.T, name graphsync.ExtensionName, extensions []graphsync.ExtensionData, expected datatransfer.Message) { + nd := expected.ToIPLD() + found := false + for _, e := range extensions { + if e.Name == name { + require.True(t, ipld.DeepEqual(nd, e.Data), "data matches") + found = true + } + } + if !found { + require.Fail(t, "extension not found") + } +} diff --git a/transport/graphsync/testharness/testnet.go b/transport/graphsync/testharness/testnet.go new file mode 100644 index 00000000..14ab118b --- /dev/null +++ b/transport/graphsync/testharness/testnet.go @@ -0,0 +1,106 @@ +package testharness + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer/v2" + "github.com/filecoin-project/go-data-transfer/v2/transport/helpers/network" +) + +// FakeSentMessage is a recording of a message sent on the FakeNetwork +type FakeSentMessage struct { + PeerID peer.ID + TransportID datatransfer.TransportID + Message datatransfer.Message +} + +type FakeDelegates struct { + TransportID datatransfer.TransportID + Versions []datatransfer.Version + Receiver network.Receiver +} + +type ConnectWithRetryAttempt struct { + PeerID peer.ID + TransportID datatransfer.TransportID +} + +type TaggedPeer struct { + PeerID peer.ID + Tag string +} + +// FakeNetwork is a network that satisfies the DataTransferNetwork interface but +// does not actually do anything +type FakeNetwork struct { + SentMessages []FakeSentMessage + Delegates []FakeDelegates + ConnectWithRetryAttempts []ConnectWithRetryAttempt + ProtectedPeers []TaggedPeer + UnprotectedPeers []TaggedPeer + + ReturnedPeerDescription network.ProtocolDescription + ReturnedPeerID peer.ID + ReturnedSendMessageError error + ReturnedConnectWithRetryError error +} + +// NewFakeNetwork returns a new fake data transfer network instance +func NewFakeNetwork(id peer.ID) *FakeNetwork { + return &FakeNetwork{ReturnedPeerID: id} +} + +var _ network.DataTransferNetwork = (*FakeNetwork)(nil) + +// SendMessage sends a GraphSync message to a peer. +func (fn *FakeNetwork) SendMessage(ctx context.Context, + p peer.ID, + t datatransfer.TransportID, + m datatransfer.Message) error { + fn.SentMessages = append(fn.SentMessages, FakeSentMessage{p, t, m}) + return fn.ReturnedSendMessageError +} + +// SetDelegate registers the Reciver to handle messages received from the +// network. +func (fn *FakeNetwork) SetDelegate(t datatransfer.TransportID, v []datatransfer.Version, r network.Receiver) { + fn.Delegates = append(fn.Delegates, FakeDelegates{t, v, r}) +} + +// ConnectTo establishes a connection to the given peer +func (fn *FakeNetwork) ConnectTo(_ context.Context, _ peer.ID) error { + return nil +} + +func (fn *FakeNetwork) ConnectWithRetry(ctx context.Context, p peer.ID, transportID datatransfer.TransportID) error { + fn.ConnectWithRetryAttempts = append(fn.ConnectWithRetryAttempts, ConnectWithRetryAttempt{p, transportID}) + return fn.ReturnedConnectWithRetryError +} + +// ID returns a stubbed id for host of this network +func (fn *FakeNetwork) ID() peer.ID { + return fn.ReturnedPeerID +} + +// Protect does nothing on the fake network +func (fn *FakeNetwork) Protect(id peer.ID, tag string) { + fn.ProtectedPeers = append(fn.ProtectedPeers, TaggedPeer{id, tag}) +} + +// Unprotect does nothing on the fake network +func (fn *FakeNetwork) Unprotect(id peer.ID, tag string) bool { + fn.UnprotectedPeers = append(fn.UnprotectedPeers, TaggedPeer{id, tag}) + return false +} + +func (fn *FakeNetwork) Protocol(ctx context.Context, id peer.ID, transportID datatransfer.TransportID) (network.ProtocolDescription, error) { + return fn.ReturnedPeerDescription, nil +} + +func (fn *FakeNetwork) AssertSentMessage(t *testing.T, sentMessage FakeSentMessage) { + require.Contains(t, fn.SentMessages, sentMessage) +} diff --git a/transport/helpers/network/libp2p_impl_test.go b/transport/helpers/network/libp2p_impl_test.go index 76b59ef6..1b689004 100644 --- a/transport/helpers/network/libp2p_impl_test.go +++ b/transport/helpers/network/libp2p_impl_test.go @@ -132,8 +132,7 @@ func TestMessageSendAndReceive(t *testing.T) { accepted := false id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewTestTypedVoucher() - response, err := message.ValidationResultResponse(types.NewMessage, id, datatransfer.ValidationResult{Accepted: accepted, VoucherResult: &voucherResult}, nil, false) - require.NoError(t, err) + response := message.ValidationResultResponse(types.NewMessage, id, datatransfer.ValidationResult{Accepted: accepted, VoucherResult: &voucherResult}, nil, false) require.NoError(t, dtnet2.SendMessage(ctx, host1.ID(), "graphsync", response)) select { diff --git a/types.go b/types.go index 1181d21e..e97c8038 100644 --- a/types.go +++ b/types.go @@ -123,17 +123,17 @@ type ChannelState interface { // LastVoucherResult returns the last voucher result sent on the channel LastVoucherResult() TypedVoucher - // ReceivedCidsTotal returns the number of (non-unique) cids received so far - // on the channel - note that a block can exist in more than one place in the DAG - ReceivedCidsTotal() int64 + // ReceivedIndex returns the index, a transport specific identifier for "where" + // we are in receiving data for a transfer + ReceivedIndex() datamodel.Node - // QueuedCidsTotal returns the number of (non-unique) cids queued so far - // on the channel - note that a block can exist in more than one place in the DAG - QueuedCidsTotal() int64 + // QueuedIndex returns the index, a transport specific identifier for "where" + // we are in queing data for a transfer + QueuedIndex() datamodel.Node - // SentCidsTotal returns the number of (non-unique) cids sent so far - // on the channel - note that a block can exist in more than one place in the DAG - SentCidsTotal() int64 + // SentIndex returns the index, a transport specific identifier for "where" + // we are in sending data for a transfer + SentIndex() datamodel.Node // Queued returns the number of bytes read from the node and queued for sending Queued() uint64 @@ -146,6 +146,18 @@ type ChannelState interface { // be left open for a final settlement RequiresFinalization() bool + // InitiatorPaused indicates whether the initiator of this channel is in a paused state + InitiatorPaused() bool + + // ResponderPaused indicates whether the responder of this channel is in a paused state + ResponderPaused() bool + + // BothPaused indicates both sides of the transfer have paused the transfer + BothPaused() bool + + // SelfPaused indicates whether the local peer for this channel is in a paused state + SelfPaused() bool + // Stages returns the timeline of events this data transfer has gone through, // for observability purposes. //