diff --git a/conn.go b/conn.go index fa26406..580705a 100644 --- a/conn.go +++ b/conn.go @@ -310,7 +310,18 @@ func (c *Conn) handleMessage(msg Message) error { return ns.events.fireEvent(ns, msg) } - if msg.isWait(c.IsClient()) { + isClient := c.IsClient() + if !isClient { + c.server.waitingMessagesMutex.RLock() + ch, ok := c.server.waitingMessages[msg.wait] + c.server.waitingMessagesMutex.RUnlock() + if ok { + ch <- msg + return nil + } + } + + if msg.isWait(isClient) { c.waitingMessagesMutex.RLock() ch, ok := c.waitingMessages[msg.wait] c.waitingMessagesMutex.RUnlock() @@ -710,10 +721,7 @@ func (c *Conn) write(b []byte, binary bool) bool { return true } -// Write method sends a message to the remote side, -// reports whether the connection is still available -// or when this message is not allowed to be sent to the remote side. -func (c *Conn) Write(msg Message) bool { +func (c *Conn) canWrite(msg Message) bool { if c.IsClosed() { return false } @@ -771,6 +779,17 @@ func (c *Conn) Write(msg Message) bool { return false } + return true +} + +// Write method sends a message to the remote side, +// reports whether the connection is still available +// or when this message is not allowed to be sent to the remote side. +func (c *Conn) Write(msg Message) bool { + if !c.canWrite(msg) { + return false + } + msg.FromExplicit = "" b := serializeMessage(nil, msg) return c.write(b, msg.SetBinary) diff --git a/server.go b/server.go index 09afa5a..1cf5769 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package neffos import ( + "context" "errors" "fmt" "net/http" @@ -62,6 +63,11 @@ type Server struct { disconnect chan *Conn actions chan action broadcaster *broadcaster + // messages that this server must waits + // for a reply from one of its own connections(see `waitMessage`) + // or TODO: from cloud (see `StackExchange.PublishAndWait`). + waitingMessages map[string]chan Message + waitingMessagesMutex sync.RWMutex closed uint32 @@ -88,17 +94,18 @@ func New(upgrader Upgrader, connHandler ConnHandler) *Server { readTimeout, writeTimeout := getTimeouts(connHandler) namespaces := connHandler.GetNamespaces() s := &Server{ - uuid: uuid.Must(uuid.NewV4()).String(), - upgrader: upgrader, - namespaces: namespaces, - readTimeout: readTimeout, - writeTimeout: writeTimeout, - connections: make(map[*Conn]struct{}), - connect: make(chan *Conn, 1), - disconnect: make(chan *Conn), - actions: make(chan action), - broadcaster: newBroadcaster(), - IDGenerator: DefaultIDGenerator, + uuid: uuid.Must(uuid.NewV4()).String(), + upgrader: upgrader, + namespaces: namespaces, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + connections: make(map[*Conn]struct{}), + connect: make(chan *Conn, 1), + disconnect: make(chan *Conn), + actions: make(chan action), + broadcaster: newBroadcaster(), + waitingMessages: make(map[string]chan Message), + IDGenerator: DefaultIDGenerator, } // s.broadcastCond = sync.NewCond(&s.broadcastMu) @@ -301,13 +308,12 @@ func (s *Server) Upgrade( c.ReconnectTries, _ = strconv.Atoi(retriesHeaderValue) } - if !s.usesStackExchange() { - // fire neffos broadcaster when no exchangers are registered. - go func(c *Conn) { - for s.waitMessage(c) { - } - }(c) - } + // TODO: when ask on cloud uncommented: + // if !s.usesStackExchange() { + go func(c *Conn) { + for s.waitMessage(c) { + } + }(c) s.connect <- c @@ -480,6 +486,49 @@ func (s *Server) Broadcast(exceptSender fmt.Stringer, msg Message) { s.broadcaster.broadcast(msg) } +// Ask is like `Broadcast` but it blocks until a response +// from a specific connection if "msg.To" is filled otherwise +// from the first connection which will reply to this "msg". +// +// Accepts a context for deadline as its first input argument. +// The second argument is the request message +// which should be sent to a specific namespace:event +// like the `Conn.Ask`. +// Note: Currently this expects the remote responder +// to be connected inside this server neffos instance - +// StackExchange is not yet implemented to handle this feature, yet -. +func (s *Server) Ask(ctx context.Context, msg Message) (Message, error) { + msg.wait = genWait(false) + + if ctx == nil { + ctx = context.TODO() + } else { + if deadline, has := ctx.Deadline(); has { + if deadline.Before(time.Now().Add(-1 * time.Second)) { + return Message{}, context.DeadlineExceeded + } + } + } + + ch := make(chan Message) + s.waitingMessagesMutex.Lock() + s.waitingMessages[msg.wait] = ch + s.waitingMessagesMutex.Unlock() + + s.Broadcast(nil, msg) + + select { + case <-ctx.Done(): + return Message{}, ctx.Err() + case receive := <-ch: + s.waitingMessagesMutex.Lock() + delete(s.waitingMessages, msg.wait) + s.waitingMessagesMutex.Unlock() + + return receive, receive.Err + } +} + // GetConnectionsByNamespace can be used as an alternative way to retrieve // all connected connections to a specific "namespace" on a specific time point. // Do not use this function frequently, it is not designed to be fast or cheap, use it for debugging or logging every 'x' time. diff --git a/server_test.go b/server_test.go index 90e1ea5..a4ae279 100644 --- a/server_test.go +++ b/server_test.go @@ -112,3 +112,86 @@ func TestServerBroadcastTo(t *testing.T) { wg.Wait() } + +func TestServerAsk(t *testing.T) { + // we fire up two connections, one with the "conn_ID" and other with the default uuid id generator, + // the message which the second client emits should only be sent to the connection with the ID of "conn_ID". + + var ( + wg sync.WaitGroup + namespace = "default" + body = []byte("data") + expectResponse = append(body, []byte("ok")...) + to = "conn_ID" + clientEvents = neffos.Namespaces{ + namespace: neffos.Events{ + "ask": func(c *neffos.NSConn, msg neffos.Message) error { + return neffos.Reply(expectResponse) + }, + }, + } + ) + + teardownServer := runTestServer("localhost:8080", neffos.Namespaces{namespace: neffos.Events{}}, func(wsServer *neffos.Server) { + once := new(uint32) + wsServer.IDGenerator = func(w http.ResponseWriter, r *http.Request) string { + if atomic.CompareAndSwapUint32(once, 0, 1) { + return to // set the "to" only to the first conn for test. + } + + return neffos.DefaultIDGenerator(w, r) + } + + wgWaitToAllConnect := new(sync.WaitGroup) + wgWaitToAllConnect.Add(2) + wsServer.OnConnect = func(c *neffos.Conn) error { + wgWaitToAllConnect.Done() + return nil + } + + go func(wsServer *neffos.Server) { + wgWaitToAllConnect.Wait() + + response, err := wsServer.Ask(nil, neffos.Message{ + Namespace: "default", + Event: "ask", + To: to, + }) + + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(response.Body, expectResponse) { + t.Fatalf("expected response with body: %s but got: %s", string(expectResponse), string(response.Body)) + } + + wg.Done() + }(wsServer) + + }) + defer teardownServer() + + wg.Add(2) // two servers, a gorilla and gobwas. + + teardownClient1 := runTestClient("localhost:8080", clientEvents, + func(dialer string, client *neffos.Client) { + _, err := client.Connect(nil, namespace) + if err != nil { + t.Fatal(err) + } + }) + + defer teardownClient1() + + teardownClient2 := runTestClient("localhost:8080", clientEvents, + func(dialer string, client *neffos.Client) { + _, err := client.Connect(nil, namespace) + if err != nil { + t.Fatal(err) + } + }) + defer teardownClient2() + + wg.Wait() +}