diff --git a/swarm/network/kademlia.go b/swarm/network/kademlia.go index cd94741be329..5fda51e3ebe1 100644 --- a/swarm/network/kademlia.go +++ b/swarm/network/kademlia.go @@ -81,14 +81,15 @@ func NewKadParams() *KadParams { // Kademlia is a table of live peers and a db of known peers (node records) type Kademlia struct { lock sync.RWMutex - *KadParams // Kademlia configuration parameters - base []byte // immutable baseaddress of the table - addrs *pot.Pot // pots container for known peer addresses - conns *pot.Pot // pots container for live peer connections - depth uint8 // stores the last current depth of saturation - nDepth int // stores the last neighbourhood depth - nDepthC chan int // returned by DepthC function to signal neighbourhood depth change - addrCountC chan int // returned by AddrCountC function to signal peer count change + *KadParams // Kademlia configuration parameters + base []byte // immutable baseaddress of the table + addrs *pot.Pot // pots container for known peer addresses + conns *pot.Pot // pots container for live peer connections + depth uint8 // stores the last current depth of saturation + nDepth int // stores the last neighbourhood depth + nDepthC chan int // returned by DepthC function to signal neighbourhood depth change + addrCountC chan int // returned by AddrCountC function to signal peer count change + Pof func(pot.Val, pot.Val, int) (int, bool) // function for calculating kademlia routing distance between two addresses } // NewKademlia creates a Kademlia table for base address addr @@ -103,6 +104,7 @@ func NewKademlia(addr []byte, params *KadParams) *Kademlia { KadParams: params, addrs: pot.NewPot(nil, 0), conns: pot.NewPot(nil, 0), + Pof: pof, } } @@ -289,6 +291,7 @@ func (k *Kademlia) On(p *Peer) (uint8, bool) { // neighbourhood depth on each change. // Not receiving from the returned channel will block On function // when the neighbourhood depth is changed. +// TODO: Why is this exported, and if it should be; why can't we have more subscribers than one? func (k *Kademlia) NeighbourhoodDepthC() <-chan int { k.lock.Lock() defer k.lock.Unlock() @@ -429,7 +432,12 @@ func (k *Kademlia) eachAddr(base []byte, o int, f func(*BzzAddr, int, bool) bool // neighbourhoodDepth returns the proximity order that defines the distance of // the nearest neighbour set with cardinality >= MinProxBinSize // if there is altogether less than MinProxBinSize peers it returns 0 -// caller must hold the lock +func (k *Kademlia) NeighbourhoodDepth() (depth int) { + k.lock.RLock() + defer k.lock.RUnlock() + return k.neighbourhoodDepth() +} + func (k *Kademlia) neighbourhoodDepth() (depth int) { if k.conns.Size() < k.MinProxBinSize { return 0 diff --git a/swarm/pss/api.go b/swarm/pss/api.go index eba7bb722c0c..dd55b2a702a7 100644 --- a/swarm/pss/api.go +++ b/swarm/pss/api.go @@ -51,7 +51,7 @@ func NewAPI(ps *Pss) *API { // // All incoming messages to the node matching this topic will be encapsulated in the APIMsg // struct and sent to the subscriber -func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription, error) { +func (pssapi *API) Receive(ctx context.Context, topic Topic, raw bool, prox bool) (*rpc.Subscription, error) { notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return nil, fmt.Errorf("Subscribe not supported") @@ -59,7 +59,7 @@ func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription, psssub := notifier.CreateSubscription() - handler := func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { + hndlr := NewHandler(func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { apimsg := &APIMsg{ Msg: hexutil.Bytes(msg), Asymmetric: asymmetric, @@ -69,9 +69,15 @@ func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription, log.Warn(fmt.Sprintf("notification on pss sub topic rpc (sub %v) msg %v failed!", psssub.ID, msg)) } return nil + }) + if raw { + hndlr.caps.raw = true + } + if prox { + hndlr.caps.prox = true } - deregf := pssapi.Register(&topic, handler) + deregf := pssapi.Register(&topic, hndlr) go func() { defer deregf() select { diff --git a/swarm/pss/client/client.go b/swarm/pss/client/client.go index d541081d3ed2..5ee387aa7918 100644 --- a/swarm/pss/client/client.go +++ b/swarm/pss/client/client.go @@ -236,7 +236,7 @@ func (c *Client) RunProtocol(ctx context.Context, proto *p2p.Protocol) error { topichex := topicobj.String() msgC := make(chan pss.APIMsg) c.peerPool[topicobj] = make(map[string]*pssRPCRW) - sub, err := c.rpc.Subscribe(ctx, "pss", msgC, "receive", topichex) + sub, err := c.rpc.Subscribe(ctx, "pss", msgC, "receive", topichex, false, false) if err != nil { return fmt.Errorf("pss event subscription failed: %v", err) } diff --git a/swarm/pss/handshake.go b/swarm/pss/handshake.go index e3ead77d0492..5486abafa9fb 100644 --- a/swarm/pss/handshake.go +++ b/swarm/pss/handshake.go @@ -486,7 +486,7 @@ func (api *HandshakeAPI) Handshake(pubkeyid string, topic Topic, sync bool, flus // Activate handshake functionality on a topic func (api *HandshakeAPI) AddHandshake(topic Topic) error { - api.ctrl.deregisterFuncs[topic] = api.ctrl.pss.Register(&topic, api.ctrl.handler) + api.ctrl.deregisterFuncs[topic] = api.ctrl.pss.Register(&topic, NewHandler(api.ctrl.handler)) return nil } diff --git a/swarm/pss/notify/notify.go b/swarm/pss/notify/notify.go index 3731fb9dbbbb..d3c89058b2dc 100644 --- a/swarm/pss/notify/notify.go +++ b/swarm/pss/notify/notify.go @@ -113,7 +113,7 @@ func NewController(ps *pss.Pss) *Controller { notifiers: make(map[string]*notifier), subscriptions: make(map[string]*subscription), } - ctrl.pss.Register(&controlTopic, ctrl.Handler) + ctrl.pss.Register(&controlTopic, pss.NewHandler(ctrl.Handler)) return ctrl } @@ -336,7 +336,7 @@ func (c *Controller) handleNotifyWithKeyMsg(msg *Msg) error { // \TODO keep track of and add actual address updaterAddr := pss.PssAddress([]byte{}) c.pss.SetSymmetricKey(symkey, topic, &updaterAddr, true) - c.pss.Register(&topic, c.Handler) + c.pss.Register(&topic, pss.NewHandler(c.Handler)) return c.subscriptions[msg.namestring].handler(msg.namestring, msg.Payload[:len(msg.Payload)-symKeyLength]) } diff --git a/swarm/pss/notify/notify_test.go b/swarm/pss/notify/notify_test.go index d4d383a6ba2c..6100195b09e4 100644 --- a/swarm/pss/notify/notify_test.go +++ b/swarm/pss/notify/notify_test.go @@ -121,7 +121,7 @@ func TestStart(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() rmsgC := make(chan *pss.APIMsg) - rightSub, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", controlTopic) + rightSub, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", controlTopic, false, false) if err != nil { t.Fatal(err) } @@ -174,7 +174,7 @@ func TestStart(t *testing.T) { t.Fatalf("expected payload length %d, have %d", len(updateMsg)+symKeyLength, len(dMsg.Payload)) } - rightSubUpdate, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", rsrcTopic) + rightSubUpdate, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", rsrcTopic, false, false) if err != nil { t.Fatal(err) } diff --git a/swarm/pss/protocol_test.go b/swarm/pss/protocol_test.go index 4ef3e90a04a1..520c48a2024c 100644 --- a/swarm/pss/protocol_test.go +++ b/swarm/pss/protocol_test.go @@ -92,7 +92,7 @@ func testProtocol(t *testing.T) { lmsgC := make(chan APIMsg) lctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic) + lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic, false, false) if err != nil { t.Fatal(err) } @@ -100,7 +100,7 @@ func testProtocol(t *testing.T) { rmsgC := make(chan APIMsg) rctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic) + rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, false, false) if err != nil { t.Fatal(err) } @@ -130,6 +130,7 @@ func testProtocol(t *testing.T) { log.Debug("lnode ok") case cerr := <-lctx.Done(): t.Fatalf("test message timed out: %v", cerr) + return } select { case <-rmsgC: diff --git a/swarm/pss/pss.go b/swarm/pss/pss.go index e1e24e1f543a..d0986d280b59 100644 --- a/swarm/pss/pss.go +++ b/swarm/pss/pss.go @@ -23,11 +23,13 @@ import ( "crypto/rand" "errors" "fmt" + "hash" "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" @@ -136,10 +138,10 @@ type Pss struct { symKeyDecryptCacheCapacity int // max amount of symkeys to keep. // message handling - handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle() - handlersMu sync.RWMutex - allowRaw bool - hashPool sync.Pool + handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle() + handlersMu sync.RWMutex + hashPool sync.Pool + topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go) // process quitC chan struct{} @@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) { symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity), symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity, - handlers: make(map[Topic]map[*Handler]bool), - allowRaw: params.AllowRaw, + handlers: make(map[Topic]map[*handler]bool), + topicHandlerCaps: make(map[Topic]*handlerCaps), + hashPool: sync.Pool{ New: func() interface{} { - return storage.MakeHashFunc(storage.DefaultHash)() + return sha3.NewKeccak256() }, }, } @@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey { // // Returns a deregister function which needs to be called to // deregister the handler, -func (p *Pss) Register(topic *Topic, handler Handler) func() { +func (p *Pss) Register(topic *Topic, hndlr *handler) func() { p.handlersMu.Lock() defer p.handlersMu.Unlock() handlers := p.handlers[*topic] if handlers == nil { - handlers = make(map[*Handler]bool) + handlers = make(map[*handler]bool) p.handlers[*topic] = handlers + log.Debug("registered handler", "caps", hndlr.caps) + } + if hndlr.caps == nil { + hndlr.caps = &handlerCaps{} + } + handlers[hndlr] = true + if _, ok := p.topicHandlerCaps[*topic]; !ok { + p.topicHandlerCaps[*topic] = &handlerCaps{} } - handlers[&handler] = true - return func() { p.deregister(topic, &handler) } + if hndlr.caps.raw { + p.topicHandlerCaps[*topic].raw = true + } + if hndlr.caps.prox { + p.topicHandlerCaps[*topic].prox = true + } + return func() { p.deregister(topic, hndlr) } } -func (p *Pss) deregister(topic *Topic, h *Handler) { +func (p *Pss) deregister(topic *Topic, hndlr *handler) { p.handlersMu.Lock() defer p.handlersMu.Unlock() handlers := p.handlers[*topic] - if len(handlers) == 1 { + if len(handlers) > 1 { delete(p.handlers, *topic) + // topic caps might have changed now that a handler is gone + caps := &handlerCaps{} + for h := range handlers { + if h.caps.raw { + caps.raw = true + } + if h.caps.prox { + caps.prox = true + } + } + p.topicHandlerCaps[*topic] = caps return } - delete(handlers, h) + delete(handlers, hndlr) } // get all registered handlers for respective topics -func (p *Pss) getHandlers(topic Topic) map[*Handler]bool { +func (p *Pss) getHandlers(topic Topic) map[*handler]bool { p.handlersMu.RLock() defer p.handlersMu.RUnlock() return p.handlers[topic] @@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool { // Only passes error to pss protocol handler if payload is not valid pssmsg func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1) - pssmsg, ok := msg.(*PssMsg) - if !ok { return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg) } + log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:])) if int64(pssmsg.Expire) < time.Now().Unix() { metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1) log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To)) @@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { } p.addFwdCache(pssmsg) - if !p.isSelfPossibleRecipient(pssmsg) { - log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr())) + psstopic := Topic(pssmsg.Payload.Topic) + + // raw is simplest handler contingency to check, so check that first + var isRaw bool + if pssmsg.isRaw() { + if !p.topicHandlerCaps[psstopic].raw { + log.Debug("No handler for raw message", "topic", psstopic) + return nil + } + isRaw = true + } + + // check if we can be recipient: + // - no prox handler on message and partial address matches + // - prox handler on message and we are in prox regardless of partial address match + // store this result so we don't calculate again on every handler + var isProx bool + if _, ok := p.topicHandlerCaps[psstopic]; ok { + isProx = p.topicHandlerCaps[psstopic].prox + } + isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx) + if !isRecipient { + log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx) return p.enqueue(pssmsg) } - log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr())) - if err := p.process(pssmsg); err != nil { + log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:])) + if err := p.process(pssmsg, isRaw, isProx); err != nil { qerr := p.enqueue(pssmsg) if qerr != nil { return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr) @@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { // Entry point to processing a message for which the current node can be the intended recipient. // Attempts symmetric and asymmetric decryption with stored keys. // Dispatches message to all handlers matching the message topic -func (p *Pss) process(pssmsg *PssMsg) error { +func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error { metrics.GetOrRegisterCounter("pss.process", nil).Inc(1) var err error @@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error { envelope := pssmsg.Payload psstopic := Topic(envelope.Topic) - if pssmsg.isRaw() { - if !p.allowRaw { - return errors.New("raw message support disabled") - } + + if raw { payload = pssmsg.Payload.Data } else { if pssmsg.isSym() { @@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error { return err } } - p.executeHandlers(psstopic, payload, from, asymmetric, keyid) + p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid) return nil } -func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) { +func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) { handlers := p.getHandlers(topic) peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{}) - for f := range handlers { - err := (*f)(payload, peer, asymmetric, keyid) + for h := range handlers { + if !h.caps.raw && raw { + log.Warn("norawhandler") + continue + } + if !h.caps.prox && prox { + log.Warn("noproxhandler") + continue + } + err := (h.f)(payload, peer, asymmetric, keyid) if err != nil { - log.Warn("Pss handler %p failed: %v", f, err) + log.Warn("Pss handler failed", "err", err) } } } @@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool { } // test match of leftmost bytes in given message to node's Kademlia address -func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool { +func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool { local := p.Kademlia.BaseAddr() - return bytes.Equal(msg.To, local[:len(msg.To)]) + + // if a partial address matches we are possible recipient regardless of prox + // if not and prox is not set, we are surely not + if bytes.Equal(msg.To, local[:len(msg.To)]) { + + return true + } else if !prox { + return false + } + + depth := p.Kademlia.NeighbourhoodDepth() + po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0) + log.Trace("selfpossible", "po", po, "depth", depth) + + return depth <= po } ///////////////////////////////////////////////////////////////////// @@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error { // // Will fail if raw messages are disallowed func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error { - if !p.allowRaw { - return errors.New("Raw messages not enabled") - } pssMsgParams := &msgParams{ raw: true, } @@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error { pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix()) pssMsg.Payload = payload p.addFwdCache(pssMsg) - return p.enqueue(pssMsg) + err := p.enqueue(pssMsg) + if err != nil { + return err + } + + // if we have a proxhandler on this topic + // also deliver message to ourselves + if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox { + return p.process(pssMsg, true, true) + } + return nil } // Send a message using symmetric encryption @@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by pssMsg.To = to pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix()) pssMsg.Payload = envelope - return p.enqueue(pssMsg) + err = p.enqueue(pssMsg) + if err != nil { + return err + } + if _, ok := p.topicHandlerCaps[topic]; ok { + if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox { + return p.process(pssMsg, true, true) + } + } + return nil } // Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct @@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() { } } +func label(b []byte) string { + return fmt.Sprintf("%04x", b[:2]) +} + // add a message to the cache func (p *Pss) addFwdCache(msg *PssMsg) error { metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1) @@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool { // Digest of message func (p *Pss) digest(msg *PssMsg) pssDigest { - hasher := p.hashPool.Get().(storage.SwarmHash) + return p.digestBytes(msg.serialize()) +} + +func (p *Pss) digestBytes(msg []byte) pssDigest { + hasher := p.hashPool.Get().(hash.Hash) defer p.hashPool.Put(hasher) hasher.Reset() - hasher.Write(msg.serialize()) + hasher.Write(msg) digest := pssDigest{} key := hasher.Sum(nil) copy(digest[:], key[:digestLength]) diff --git a/swarm/pss/pss_test.go b/swarm/pss/pss_test.go index 66a90be6207a..32404aaaf96d 100644 --- a/swarm/pss/pss_test.go +++ b/swarm/pss/pss_test.go @@ -48,20 +48,23 @@ import ( "github.com/ethereum/go-ethereum/p2p/simulations/adapters" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/swarm/network" + "github.com/ethereum/go-ethereum/swarm/pot" "github.com/ethereum/go-ethereum/swarm/state" whisper "github.com/ethereum/go-ethereum/whisper/whisperv5" ) var ( - initOnce = sync.Once{} - debugdebugflag = flag.Bool("vv", false, "veryverbose") - debugflag = flag.Bool("v", false, "verbose") - longrunning = flag.Bool("longrunning", false, "do run long-running tests") - w *whisper.Whisper - wapi *whisper.PublicWhisperAPI - psslogmain log.Logger - pssprotocols map[string]*protoCtrl - useHandshake bool + initOnce = sync.Once{} + loglevel = flag.Int("loglevel", 2, "logging verbosity") + longrunning = flag.Bool("longrunning", false, "do run long-running tests") + w *whisper.Whisper + wapi *whisper.PublicWhisperAPI + psslogmain log.Logger + pssprotocols map[string]*protoCtrl + useHandshake bool + noopHandlerFunc = func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { + return nil + } ) func init() { @@ -75,16 +78,9 @@ func init() { func initTest() { initOnce.Do( func() { - loglevel := log.LvlInfo - if *debugflag { - loglevel = log.LvlDebug - } else if *debugdebugflag { - loglevel = log.LvlTrace - } - psslogmain = log.New("psslog", "*") hs := log.StreamHandler(os.Stderr, log.TerminalFormat(true)) - hf := log.LvlFilterHandler(loglevel, hs) + hf := log.LvlFilterHandler(log.Lvl(*loglevel), hs) h := log.CallerFileHandler(hf) log.Root().SetHandler(h) @@ -280,15 +276,14 @@ func TestAddressMatch(t *testing.T) { } pssmsg := &PssMsg{ - To: remoteaddr, - Payload: &whisper.Envelope{}, + To: remoteaddr, } // differ from first byte if ps.isSelfRecipient(pssmsg) { t.Fatalf("isSelfRecipient true but %x != %x", remoteaddr, localaddr) } - if ps.isSelfPossibleRecipient(pssmsg) { + if ps.isSelfPossibleRecipient(pssmsg, false) { t.Fatalf("isSelfPossibleRecipient true but %x != %x", remoteaddr[:8], localaddr[:8]) } @@ -297,7 +292,7 @@ func TestAddressMatch(t *testing.T) { if ps.isSelfRecipient(pssmsg) { t.Fatalf("isSelfRecipient true but %x != %x", remoteaddr, localaddr) } - if !ps.isSelfPossibleRecipient(pssmsg) { + if !ps.isSelfPossibleRecipient(pssmsg, false) { t.Fatalf("isSelfPossibleRecipient false but %x == %x", remoteaddr[:8], localaddr[:8]) } @@ -306,13 +301,342 @@ func TestAddressMatch(t *testing.T) { if !ps.isSelfRecipient(pssmsg) { t.Fatalf("isSelfRecipient false but %x == %x", remoteaddr, localaddr) } - if !ps.isSelfPossibleRecipient(pssmsg) { + if !ps.isSelfPossibleRecipient(pssmsg, false) { t.Fatalf("isSelfPossibleRecipient false but %x == %x", remoteaddr[:8], localaddr[:8]) } + } -// -func TestHandlerConditions(t *testing.T) { +// test that message is handled by sender if a prox handler exists and sender is in prox of message +func TestProxShortCircuit(t *testing.T) { + + // sender node address + localAddr := network.RandomAddr().Over() + localPotAddr := pot.NewAddressFromBytes(localAddr) + + // set up kademlia + kadParams := network.NewKadParams() + kad := network.NewKademlia(localAddr, kadParams) + peerCount := kad.MinBinSize + 1 + + // set up pss + privKey, err := crypto.GenerateKey() + pssp := NewPssParams().WithPrivateKey(privKey) + ps, err := NewPss(kad, pssp) + if err != nil { + t.Fatal(err.Error()) + } + + // create kademlia peers, so we have peers both inside and outside minproxlimit + var peers []*network.Peer + proxMessageAddress := pot.RandomAddressAt(localPotAddr, peerCount).Bytes() + distantMessageAddress := pot.RandomAddressAt(localPotAddr, 0).Bytes() + + for i := 0; i < peerCount; i++ { + rw := &p2p.MsgPipeRW{} + ptpPeer := p2p.NewPeer(enode.ID{}, "wanna be with me? [ ] yes [ ] no", []p2p.Cap{}) + protoPeer := protocols.NewPeer(ptpPeer, rw, &protocols.Spec{}) + peerAddr := pot.RandomAddressAt(localPotAddr, i) + bzzPeer := &network.BzzPeer{ + Peer: protoPeer, + BzzAddr: &network.BzzAddr{ + OAddr: peerAddr.Bytes(), + UAddr: []byte(fmt.Sprintf("%x", peerAddr[:])), + }, + } + peer := network.NewPeer(bzzPeer, kad) + kad.On(peer) + peers = append(peers, peer) + } + + // register it marking prox capability + delivered := make(chan struct{}) + rawHandlerFunc := func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { + log.Trace("in allowraw handler") + delivered <- struct{}{} + return nil + } + topic := BytesToTopic([]byte{0x2a}) + hndlrProxDereg := ps.Register(&topic, &handler{ + f: rawHandlerFunc, + caps: &handlerCaps{ + raw: true, + prox: true, + }, + }) + defer hndlrProxDereg() + + // send message too far away for sender to be in prox + // reception of this message should time out + errC := make(chan error) + go func() { + err := ps.SendRaw(distantMessageAddress, topic, []byte("foo")) + if err != nil { + errC <- err + } + }() + + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + select { + case <-delivered: + t.Fatal("raw distant message delivered") + case err := <-errC: + t.Fatal(err) + case <-ctx.Done(): + } + + // send message that should be within sender prox + // this message should be delivered + go func() { + err := ps.SendRaw(proxMessageAddress, topic, []byte("bar")) + if err != nil { + errC <- err + } + }() + + ctx, cancel = context.WithTimeout(context.TODO(), time.Second) + defer cancel() + select { + case <-delivered: + case err := <-errC: + t.Fatal(err) + case <-ctx.Done(): + t.Fatal("raw timeout") + } + + // try the same prox message with sym and asym send + proxAddrPss := PssAddress(proxMessageAddress) + symKeyId, err := ps.GenerateSymmetricKey(topic, &proxAddrPss, true) + go func() { + err := ps.SendSym(symKeyId, topic, []byte("baz")) + if err != nil { + errC <- err + } + }() + ctx, cancel = context.WithTimeout(context.TODO(), time.Second) + defer cancel() + select { + case <-delivered: + case err := <-errC: + t.Fatal(err) + case <-ctx.Done(): + t.Fatal("sym timeout") + } + + err = ps.SetPeerPublicKey(&privKey.PublicKey, topic, &proxAddrPss) + if err != nil { + t.Fatal(err) + } + pubKeyId := hexutil.Encode(crypto.FromECDSAPub(&privKey.PublicKey)) + go func() { + err := ps.SendAsym(pubKeyId, topic, []byte("xyzzy")) + if err != nil { + errC <- err + } + }() + ctx, cancel = context.WithTimeout(context.TODO(), time.Second) + defer cancel() + select { + case <-delivered: + case err := <-errC: + t.Fatal(err) + case <-ctx.Done(): + t.Fatal("asym timeout") + } +} + +// verify that node can be set as recipient regardless of explicit message address match if minimum one handler of a topic is explicitly set to allow it +// note that in these tests we use the raw capability on handlers for convenience +func TestAddressMatchProx(t *testing.T) { + + // recipient node address + localAddr := network.RandomAddr().Over() + localPotAddr := pot.NewAddressFromBytes(localAddr) + + // set up kademlia + kadparams := network.NewKadParams() + kad := network.NewKademlia(localAddr, kadparams) + nnPeerCount := kad.MinBinSize + peerCount := nnPeerCount + 2 + + // set up pss + privKey, err := crypto.GenerateKey() + pssp := NewPssParams().WithPrivateKey(privKey) + ps, err := NewPss(kad, pssp) + if err != nil { + t.Fatal(err.Error()) + } + + // create kademlia peers, so we have peers both inside and outside minproxlimit + var peers []*network.Peer + for i := 0; i < peerCount; i++ { + rw := &p2p.MsgPipeRW{} + ptpPeer := p2p.NewPeer(enode.ID{}, "362436 call me anytime", []p2p.Cap{}) + protoPeer := protocols.NewPeer(ptpPeer, rw, &protocols.Spec{}) + peerAddr := pot.RandomAddressAt(localPotAddr, i) + bzzPeer := &network.BzzPeer{ + Peer: protoPeer, + BzzAddr: &network.BzzAddr{ + OAddr: peerAddr.Bytes(), + UAddr: []byte(fmt.Sprintf("%x", peerAddr[:])), + }, + } + peer := network.NewPeer(bzzPeer, kad) + kad.On(peer) + peers = append(peers, peer) + } + + // TODO: create a test in the network package to make a table with n peers where n-m are proxpeers + // meanwhile test regression for kademlia since we are compiling the test parameters from different packages + var proxes int + var conns int + kad.EachConn(nil, peerCount, func(p *network.Peer, po int, prox bool) bool { + conns++ + if prox { + proxes++ + } + log.Trace("kadconn", "po", po, "peer", p, "prox", prox) + return true + }) + if proxes != nnPeerCount { + t.Fatalf("expected %d proxpeers, have %d", nnPeerCount, proxes) + } else if conns != peerCount { + t.Fatalf("expected %d peers total, have %d", peerCount, proxes) + } + + // remote address distances from localAddr to try and the expected outcomes if we use prox handler + remoteDistances := []int{ + 255, + nnPeerCount + 1, + nnPeerCount, + nnPeerCount - 1, + 0, + } + expects := []bool{ + true, + true, + true, + false, + false, + } + + // first the unit test on the method that calculates possible receipient using prox + for i, distance := range remoteDistances { + pssMsg := newPssMsg(&msgParams{}) + pssMsg.To = make([]byte, len(localAddr)) + copy(pssMsg.To, localAddr) + var byteIdx = distance / 8 + pssMsg.To[byteIdx] ^= 1 << uint(7-(distance%8)) + log.Trace(fmt.Sprintf("addrmatch %v", bytes.Equal(pssMsg.To, localAddr))) + if ps.isSelfPossibleRecipient(pssMsg, true) != expects[i] { + t.Fatalf("expected distance %d to be %v", distance, expects[i]) + } + } + + // we move up to higher level and test the actual message handler + // for each distance check if we are possible recipient when prox variant is used is set + + // this handler will increment a counter for every message that gets passed to the handler + var receives int + rawHandlerFunc := func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { + log.Trace("in allowraw handler") + receives++ + return nil + } + + // register it marking prox capability + topic := BytesToTopic([]byte{0x2a}) + hndlrProxDereg := ps.Register(&topic, &handler{ + f: rawHandlerFunc, + caps: &handlerCaps{ + raw: true, + prox: true, + }, + }) + + // test the distances + var prevReceive int + for i, distance := range remoteDistances { + remotePotAddr := pot.RandomAddressAt(localPotAddr, distance) + remoteAddr := remotePotAddr.Bytes() + + var data [32]byte + rand.Read(data[:]) + pssMsg := newPssMsg(&msgParams{raw: true}) + pssMsg.To = remoteAddr + pssMsg.Expire = uint32(time.Now().Unix() + 4200) + pssMsg.Payload = &whisper.Envelope{ + Topic: whisper.TopicType(topic), + Data: data[:], + } + + log.Trace("withprox addrs", "local", localAddr, "remote", remoteAddr) + ps.handlePssMsg(context.TODO(), pssMsg) + if (!expects[i] && prevReceive != receives) || (expects[i] && prevReceive == receives) { + t.Fatalf("expected distance %d recipient %v when prox is set for handler", distance, expects[i]) + } + prevReceive = receives + } + + // now add a non prox-capable handler and test + ps.Register(&topic, &handler{ + f: rawHandlerFunc, + caps: &handlerCaps{ + raw: true, + }, + }) + receives = 0 + prevReceive = 0 + for i, distance := range remoteDistances { + remotePotAddr := pot.RandomAddressAt(localPotAddr, distance) + remoteAddr := remotePotAddr.Bytes() + + var data [32]byte + rand.Read(data[:]) + pssMsg := newPssMsg(&msgParams{raw: true}) + pssMsg.To = remoteAddr + pssMsg.Expire = uint32(time.Now().Unix() + 4200) + pssMsg.Payload = &whisper.Envelope{ + Topic: whisper.TopicType(topic), + Data: data[:], + } + + log.Trace("withprox addrs", "local", localAddr, "remote", remoteAddr) + ps.handlePssMsg(context.TODO(), pssMsg) + if (!expects[i] && prevReceive != receives) || (expects[i] && prevReceive == receives) { + t.Fatalf("expected distance %d recipient %v when prox is set for handler", distance, expects[i]) + } + prevReceive = receives + } + + // now deregister the prox capable handler, now none of the messages will be handled + hndlrProxDereg() + receives = 0 + + for _, distance := range remoteDistances { + remotePotAddr := pot.RandomAddressAt(localPotAddr, distance) + remoteAddr := remotePotAddr.Bytes() + + pssMsg := newPssMsg(&msgParams{raw: true}) + pssMsg.To = remoteAddr + pssMsg.Expire = uint32(time.Now().Unix() + 4200) + pssMsg.Payload = &whisper.Envelope{ + Topic: whisper.TopicType(topic), + Data: []byte(remotePotAddr.String()), + } + + log.Trace("noprox addrs", "local", localAddr, "remote", remoteAddr) + ps.handlePssMsg(context.TODO(), pssMsg) + if receives != 0 { + t.Fatalf("expected distance %d to not be recipient when prox is not set for handler", distance) + } + + } +} + +// verify that message queueing happens when it should, and that expired and corrupt messages are dropped +func TestMessageProcessing(t *testing.T) { t.Skip("Disabled due to probable faulty logic for outbox expectations") // setup @@ -326,13 +650,12 @@ func TestHandlerConditions(t *testing.T) { ps := newTestPss(privkey, network.NewKademlia(addr, network.NewKadParams()), NewPssParams()) // message should pass - msg := &PssMsg{ - To: addr, - Expire: uint32(time.Now().Add(time.Second * 60).Unix()), - Payload: &whisper.Envelope{ - Topic: [4]byte{}, - Data: []byte{0x66, 0x6f, 0x6f}, - }, + msg := newPssMsg(&msgParams{}) + msg.To = addr + msg.Expire = uint32(time.Now().Add(time.Second * 60).Unix()) + msg.Payload = &whisper.Envelope{ + Topic: [4]byte{}, + Data: []byte{0x66, 0x6f, 0x6f}, } if err := ps.handlePssMsg(context.TODO(), msg); err != nil { t.Fatal(err.Error()) @@ -498,6 +821,7 @@ func TestKeys(t *testing.T) { } } +// check that we can retrieve previously added public key entires per topic and peer func TestGetPublickeyEntries(t *testing.T) { privkey, err := crypto.GenerateKey() @@ -557,7 +881,7 @@ OUTER: } // forwarding should skip peers that do not have matching pss capabilities -func TestMismatch(t *testing.T) { +func TestPeerCapabilityMismatch(t *testing.T) { // create privkey for forwarder node privkey, err := crypto.GenerateKey() @@ -615,6 +939,76 @@ func TestMismatch(t *testing.T) { } +// verifies that message handlers for raw messages only are invoked when minimum one handler for the topic exists in which raw messages are explicitly allowed +func TestRawAllow(t *testing.T) { + + // set up pss like so many times before + privKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + baseAddr := network.RandomAddr() + kad := network.NewKademlia((baseAddr).Over(), network.NewKadParams()) + ps := newTestPss(privKey, kad, nil) + topic := BytesToTopic([]byte{0x2a}) + + // create handler innards that increments every time a message hits it + var receives int + rawHandlerFunc := func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { + log.Trace("in allowraw handler") + receives++ + return nil + } + + // wrap this handler function with a handler without raw capability and register it + hndlrNoRaw := &handler{ + f: rawHandlerFunc, + } + ps.Register(&topic, hndlrNoRaw) + + // test it with a raw message, should be poo-poo + pssMsg := newPssMsg(&msgParams{ + raw: true, + }) + pssMsg.To = baseAddr.OAddr + pssMsg.Expire = uint32(time.Now().Unix() + 4200) + pssMsg.Payload = &whisper.Envelope{ + Topic: whisper.TopicType(topic), + } + ps.handlePssMsg(context.TODO(), pssMsg) + if receives > 0 { + t.Fatalf("Expected handler not to be executed with raw cap off") + } + + // now wrap the same handler function with raw capabilities and register it + hndlrRaw := &handler{ + f: rawHandlerFunc, + caps: &handlerCaps{ + raw: true, + }, + } + deregRawHandler := ps.Register(&topic, hndlrRaw) + + // should work now + pssMsg.Payload.Data = []byte("Raw Deal") + ps.handlePssMsg(context.TODO(), pssMsg) + if receives == 0 { + t.Fatalf("Expected handler to be executed with raw cap on") + } + + // now deregister the raw capable handler + prevReceives := receives + deregRawHandler() + + // check that raw messages fail again + pssMsg.Payload.Data = []byte("Raw Trump") + ps.handlePssMsg(context.TODO(), pssMsg) + if receives != prevReceives { + t.Fatalf("Expected handler not to be executed when raw handler is retracted") + } +} + +// verifies that nodes can send and receive raw (verbatim) messages func TestSendRaw(t *testing.T) { t.Run("32", testSendRaw) t.Run("8", testSendRaw) @@ -658,13 +1052,13 @@ func testSendRaw(t *testing.T) { lmsgC := make(chan APIMsg) lctx, lcancel := context.WithTimeout(context.Background(), time.Second*10) defer lcancel() - lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic) + lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic, true, false) log.Trace("lsub", "id", lsub) defer lsub.Unsubscribe() rmsgC := make(chan APIMsg) rctx, rcancel := context.WithTimeout(context.Background(), time.Second*10) defer rcancel() - rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic) + rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, true, false) log.Trace("rsub", "id", rsub) defer rsub.Unsubscribe() @@ -757,13 +1151,13 @@ func testSendSym(t *testing.T) { lmsgC := make(chan APIMsg) lctx, lcancel := context.WithTimeout(context.Background(), time.Second*10) defer lcancel() - lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic) + lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic, false, false) log.Trace("lsub", "id", lsub) defer lsub.Unsubscribe() rmsgC := make(chan APIMsg) rctx, rcancel := context.WithTimeout(context.Background(), time.Second*10) defer rcancel() - rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic) + rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, false, false) log.Trace("rsub", "id", rsub) defer rsub.Unsubscribe() @@ -872,13 +1266,13 @@ func testSendAsym(t *testing.T) { lmsgC := make(chan APIMsg) lctx, lcancel := context.WithTimeout(context.Background(), time.Second*10) defer lcancel() - lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic) + lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic, false, false) log.Trace("lsub", "id", lsub) defer lsub.Unsubscribe() rmsgC := make(chan APIMsg) rctx, rcancel := context.WithTimeout(context.Background(), time.Second*10) defer rcancel() - rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic) + rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, false, false) log.Trace("rsub", "id", rsub) defer rsub.Unsubscribe() @@ -1037,7 +1431,7 @@ func testNetwork(t *testing.T) { msgC := make(chan APIMsg) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - sub, err := rpcclient.Subscribe(ctx, "pss", msgC, "receive", topic) + sub, err := rpcclient.Subscribe(ctx, "pss", msgC, "receive", topic, false, false) if err != nil { t.Fatal(err) } @@ -1209,7 +1603,7 @@ func TestDeduplication(t *testing.T) { rmsgC := make(chan APIMsg) rctx, cancel := context.WithTimeout(context.Background(), time.Second*1) defer cancel() - rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic) + rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, false, false) log.Trace("rsub", "id", rsub) defer rsub.Unsubscribe() @@ -1392,8 +1786,8 @@ func benchmarkSymkeyBruteforceChangeaddr(b *testing.B) { if err != nil { b.Fatalf("could not generate whisper envelope: %v", err) } - ps.Register(&topic, func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { - return nil + ps.Register(&topic, &handler{ + f: noopHandlerFunc, }) pssmsgs = append(pssmsgs, &PssMsg{ To: to, @@ -1402,7 +1796,7 @@ func benchmarkSymkeyBruteforceChangeaddr(b *testing.B) { } b.ResetTimer() for i := 0; i < b.N; i++ { - if err := ps.process(pssmsgs[len(pssmsgs)-(i%len(pssmsgs))-1]); err != nil { + if err := ps.process(pssmsgs[len(pssmsgs)-(i%len(pssmsgs))-1], false, false); err != nil { b.Fatalf("pss processing failed: %v", err) } } @@ -1476,15 +1870,15 @@ func benchmarkSymkeyBruteforceSameaddr(b *testing.B) { if err != nil { b.Fatalf("could not generate whisper envelope: %v", err) } - ps.Register(&topic, func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { - return nil + ps.Register(&topic, &handler{ + f: noopHandlerFunc, }) pssmsg := &PssMsg{ To: addr[len(addr)-1][:], Payload: env, } for i := 0; i < b.N; i++ { - if err := ps.process(pssmsg); err != nil { + if err := ps.process(pssmsg, false, false); err != nil { b.Fatalf("pss processing failed: %v", err) } } @@ -1581,7 +1975,12 @@ func newServices(allowRaw bool) adapters.Services { if useHandshake { SetHandshakeController(ps, NewHandshakeParams()) } - ps.Register(&PingTopic, pp.Handle) + ps.Register(&PingTopic, &handler{ + f: pp.Handle, + caps: &handlerCaps{ + raw: true, + }, + }) ps.addAPI(rpc.API{ Namespace: "psstest", Version: "0.3", diff --git a/swarm/pss/types.go b/swarm/pss/types.go index 56c2c51dc0c0..ba963067cb85 100644 --- a/swarm/pss/types.go +++ b/swarm/pss/types.go @@ -159,9 +159,39 @@ func (msg *PssMsg) String() string { } // Signature for a message handler function for a PssMsg -// // Implementations of this type are passed to Pss.Register together with a topic, -type Handler func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error +type HandlerFunc func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error + +type handlerCaps struct { + raw bool + prox bool +} + +// Handler defines code to be executed upon reception of content. +type handler struct { + f HandlerFunc + caps *handlerCaps +} + +// NewHandler returns a new message handler +func NewHandler(f HandlerFunc) *handler { + return &handler{ + f: f, + caps: &handlerCaps{}, + } +} + +// WithRaw is a chainable method that allows raw messages to be handled. +func (h *handler) WithRaw() *handler { + h.caps.raw = true + return h +} + +// WithProxBin is a chainable method that allows sending messages with full addresses to neighbourhoods using the kademlia depth as reference +func (h *handler) WithProxBin() *handler { + h.caps.prox = true + return h +} // the stateStore handles saving and loading PSS peers and their corresponding keys // it is currently unimplemented