From 77370ed40896861752932480a5e4edd1967e58db Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Fri, 19 Jan 2024 10:51:00 -0500 Subject: [PATCH] Decouple MLS messages from messagev1 (#333) * Separate MLS messages and implement service methods * fix: group id and installation id are bytes * fix: idempotent send group/welcome messages via uniquness in db * fix: hex decode group id from mls validation service * fix: s/Cursor/IdCursor * fix: pass message data only in send group message request * refactor: add mls {Group,Welcome}MessageInput types for send requests * refactor: s/installation_id/installation_key in mls/api * fix: clean up mls query page size logic * feat: implement mls subscribe group/welcome messages * Hex encode group ID * fix: remove duplicate import * fix: return grpc invalidargument on invalid group id --------- Co-authored-by: Nicholas Molnar <65710+neekolas@users.noreply.github.com> --- dev/generate | 2 + dev/migrate-mls | 4 + dev/run | 2 + dev/up | 2 +- go.mod | 5 +- go.sum | 6 +- pkg/api/authentication_test.go | 18 +- pkg/api/config.go | 3 +- pkg/api/interceptor.go | 18 +- pkg/api/message/v1/service.go | 67 +-- pkg/api/server.go | 65 ++- pkg/authn/authn.pb.go | 2 +- .../mls/20240109001927_add-messages.down.sql | 6 + .../mls/20240109001927_add-messages.up.sql | 37 ++ pkg/mls/api/v1/mock.gen.go | 257 ++++++++++ pkg/mls/api/v1/service.go | 470 +++++++++++++----- pkg/mls/api/v1/service_test.go | 301 ++++++++--- pkg/mls/store/config.go | 2 + pkg/mls/store/models.go | 24 +- pkg/mls/store/store.go | 222 ++++++++- pkg/mls/store/store_test.go | 424 +++++++++++++++- pkg/mlsvalidate/service.go | 13 +- pkg/mlsvalidate/service_test.go | 2 +- pkg/store/query_test.go | 31 -- pkg/store/store.go | 72 --- pkg/topic/mls.go | 28 ++ pkg/topic/topic.go | 9 - tools.go | 2 +- 28 files changed, 1683 insertions(+), 411 deletions(-) create mode 100755 dev/migrate-mls create mode 100644 pkg/migrations/mls/20240109001927_add-messages.down.sql create mode 100644 pkg/migrations/mls/20240109001927_add-messages.up.sql create mode 100644 pkg/mls/api/v1/mock.gen.go create mode 100644 pkg/topic/mls.go diff --git a/dev/generate b/dev/generate index 2cd08848..871f9e74 100755 --- a/dev/generate +++ b/dev/generate @@ -2,3 +2,5 @@ set -e go generate ./... + +mockgen -package api github.com/xmtp/proto/v3/go/mls/api/v1 MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer > pkg/mls/api/v1/mock.gen.go diff --git a/dev/migrate-mls b/dev/migrate-mls new file mode 100755 index 00000000..1af15f75 --- /dev/null +++ b/dev/migrate-mls @@ -0,0 +1,4 @@ +#!/bin/bash +set -e + +dev/run --create-mls-migration "$@" diff --git a/dev/run b/dev/run index e90f2126..b108c144 100755 --- a/dev/run +++ b/dev/run @@ -2,6 +2,7 @@ set -e MESSAGE_DB_DSN="postgres://postgres:xmtp@localhost:15432/postgres?sslmode=disable" +MLS_DB_DSN="postgres://postgres:xmtp@localhost:15432/postgres?sslmode=disable" AUTHZ_DB_DSN="postgres://postgres:xmtp@localhost:15432/postgres?sslmode=disable" NODE_KEY="8a30dcb604b0b53627a5adc054dbf434b446628d4bd1eccc681d223f0550ce67" @@ -13,6 +14,7 @@ go run cmd/xmtpd/main.go \ --store.db-connection-string "${MESSAGE_DB_DSN}" \ --store.reader-db-connection-string "${MESSAGE_DB_DSN}" \ --store.metrics-period 5s \ + --mls-store.db-connection-string "${MESSAGE_DB_DSN}" \ --authz-db-connection-string "${AUTHZ_DB_DSN}" \ --go-profiling \ "$@" diff --git a/dev/up b/dev/up index 2f09fa23..5a107cc3 100755 --- a/dev/up +++ b/dev/up @@ -8,7 +8,7 @@ if ! which golangci-lint &>/dev/null; then brew install golangci-lint; fi if ! which shellcheck &>/dev/null; then brew install shellcheck; fi if ! which protoc &>/dev/null; then brew install protobuf; fi if ! which protoc-gen-go &>/dev/null; then go install google.golang.org/protobuf/cmd/protoc-gen-go@latest; fi -if ! which mockgen &>/dev/null; then go install github.com/golang/mock/mockgen@latest; fi +if ! which mockgen &>/dev/null; then go install go.uber.org/mock/mockgen@latest; fi if ! which protolint &>/dev/null; then go install github.com/yoheimuta/protolint/cmd/protolint@latest; fi dev/generate diff --git a/go.mod b/go.mod index ba8eb007..3079b0da 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( github.com/ethereum/go-ethereum v1.10.26 - github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 @@ -30,8 +29,9 @@ require ( github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.36.3-0.20240111132545-4e1d2b1b2399 + github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0 github.com/yoheimuta/protolint v0.39.0 + go.uber.org/mock v0.4.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 google.golang.org/grpc v1.53.0 @@ -75,6 +75,7 @@ require ( github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v1.0.0 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index b6aef526..4c90c10e 100644 --- a/go.sum +++ b/go.sum @@ -1146,8 +1146,8 @@ github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0 github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 h1:wzUffJGCTBGXIDyNU+1UBu1fn2Nzo+OQzM1pLrheh58= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3/go.mod h1:bJREWk+NDnZYjgLQdAi8SUWuq/5pkMme4GqiffEhUF4= -github.com/xmtp/proto/v3 v3.36.3-0.20240111132545-4e1d2b1b2399 h1:i5qynxHZRn7mIXQPt8M7c6ac0NBb+MEn2g2qKzvRTyM= -github.com/xmtp/proto/v3 v3.36.3-0.20240111132545-4e1d2b1b2399/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0 h1:eGNiXDTiXcXTf5ne4HACbqbHaQrVlRz2hwcn05E7v8U= +github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw= github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8= @@ -1182,6 +1182,8 @@ go.uber.org/fx v1.20.0 h1:ZMC/pnRvhsthOZh9MZjMq5U8Or3mA9zBSPaLnzs3ihQ= go.uber.org/fx v1.20.0/go.mod h1:qCUj0btiR3/JnanEr1TYEePfSw6o/4qYJscgvzQ5Ub0= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= diff --git a/pkg/api/authentication_test.go b/pkg/api/authentication_test.go index bd86e232..dd3f04c8 100644 --- a/pkg/api/authentication_test.go +++ b/pkg/api/authentication_test.go @@ -19,7 +19,7 @@ func Test_AuthnNoToken(t *testing.T) { }) } -func Test_AuthnNoTokenNonMLS(t *testing.T) { +func Test_AuthnNoTokenNonV0(t *testing.T) { ctx := context.Background() testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, server *Server) { _, err := client.Publish(ctx, &messageV1.PublishRequest{ @@ -36,22 +36,6 @@ func Test_AuthnNoTokenNonMLS(t *testing.T) { }) } -func Test_AuthnNoTokenMLS(t *testing.T) { - ctx := context.Background() - testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, server *Server) { - _, err := client.Publish(ctx, &messageV1.PublishRequest{ - Envelopes: []*messageV1.Envelope{ - { - ContentTopic: "/xmtp/mls/1/m-0x1234/proto", - TimestampNs: 0, - Message: []byte{}, - }, - }, - }) - require.NoError(t, err) - }) -} - func Test_AuthnNoTokenMixedV0MLS(t *testing.T) { ctx := context.Background() testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, server *Server) { diff --git a/pkg/api/config.go b/pkg/api/config.go index daaaa5e2..880b824a 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -50,8 +50,7 @@ type AuthnOptions struct { Authenticated requests will be permitted according to the rules of the request type, (i.e. you can't publish into other wallets' contact and private topics). */ - Enable bool `long:"enable" description:"require client authentication via wallet tokens"` - EnableMLS bool `long:"enable-mls" description:"require client authentication for MLS"` + Enable bool `long:"enable" description:"require client authentication via wallet tokens"` /* Ratelimits enables request rate limiting. diff --git a/pkg/api/interceptor.go b/pkg/api/interceptor.go index c4860b5b..7b6c32c6 100644 --- a/pkg/api/interceptor.go +++ b/pkg/api/interceptor.go @@ -87,23 +87,9 @@ func (wa *WalletAuthorizer) Stream() grpc.StreamServerInterceptor { } } -func (wa *WalletAuthorizer) isProtocolMLS(request *messagev1.PublishRequest) bool { - envelopes := request.Envelopes - if len(envelopes) == 0 { - return false - } - // If any of the envelopes are not for a v3 topic, then we treat the request as non-v3 - for _, envelope := range envelopes { - if !strings.HasPrefix(envelope.ContentTopic, "/xmtp/mls/") { - return false - } - } - return true -} - func (wa *WalletAuthorizer) requiresAuthorization(req interface{}) bool { - publishRequest, isPublish := req.(*messagev1.PublishRequest) - return isPublish && (!wa.isProtocolMLS(publishRequest) || wa.AuthnConfig.EnableMLS) + _, isPublish := req.(*messagev1.PublishRequest) + return isPublish } func (wa *WalletAuthorizer) getWallet(ctx context.Context) (types.WalletAddr, error) { diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index 28cccdad..9963903f 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -13,16 +13,13 @@ import ( "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" "github.com/pkg/errors" - wakunode "github.com/waku-org/go-waku/waku/v2/node" wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" - wakurelay "github.com/waku-org/go-waku/waku/v2/protocol/relay" proto "github.com/xmtp/proto/v3/go/message_api/v1" apicontext "github.com/xmtp/xmtp-node-go/pkg/api/message/v1/context" "github.com/xmtp/xmtp-node-go/pkg/logging" "github.com/xmtp/xmtp-node-go/pkg/metrics" "github.com/xmtp/xmtp-node-go/pkg/store" "github.com/xmtp/xmtp-node-go/pkg/topic" - "github.com/xmtp/xmtp-node-go/pkg/tracing" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -45,24 +42,24 @@ type Service struct { // Configured as constructor options. log *zap.Logger - waku *wakunode.WakuNode store *store.Store + publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error + // Configured internally. ctx context.Context ctxCancel func() wg sync.WaitGroup - relaySub *wakurelay.Subscription ns *server.Server nc *nats.Conn } -func NewService(node *wakunode.WakuNode, logger *zap.Logger, store *store.Store) (s *Service, err error) { +func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { s = &Service{ - waku: node, - log: logger.Named("message/v1"), - store: store, + log: log.Named("message/v1"), + store: store, + publishToWakuRelay: publishToWakuRelay, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) @@ -82,44 +79,11 @@ func NewService(node *wakunode.WakuNode, logger *zap.Logger, store *store.Store) return nil, err } - // Initialize waku relay subscription. - s.relaySub, err = s.waku.Relay().Subscribe(s.ctx) - if err != nil { - return nil, errors.Wrap(err, "subscribing to relay") - } - tracing.GoPanicWrap(s.ctx, &s.wg, "broadcast", func(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case wakuEnv := <-s.relaySub.Ch: - if wakuEnv == nil { - continue - } - env := buildEnvelope(wakuEnv.Message()) - - envB, err := pb.Marshal(env) - if err != nil { - s.log.Error("marshalling envelope", zap.Error(err)) - continue - } - err = s.nc.Publish(buildNatsSubject(env.ContentTopic), envB) - if err != nil { - s.log.Error("publishing envelope to local nats", zap.Error(err)) - continue - } - } - } - }) - return s, nil } func (s *Service) Close() { s.log.Info("closing") - if s.relaySub != nil { - s.relaySub.Unsubscribe() - } if s.ctxCancel != nil { s.ctxCancel() @@ -136,6 +100,22 @@ func (s *Service) Close() { s.log.Info("closed") } +func (s *Service) HandleIncomingWakuRelayMessage(msg *wakupb.WakuMessage) error { + env := buildEnvelope(msg) + + envB, err := pb.Marshal(env) + if err != nil { + return err + } + + err = s.nc.Publish(buildNatsSubject(env.ContentTopic), envB) + if err != nil { + return err + } + + return nil +} + func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*proto.PublishResponse, error) { for _, env := range req.Envelopes { log := s.log.Named("publish").With(zap.String("content_topic", env.ContentTopic)) @@ -156,7 +136,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot } } - _, err := s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ + err := s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ ContentTopic: env.ContentTopic, Timestamp: int64(env.TimestampNs), Payload: env.Message, @@ -164,6 +144,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } + metrics.EmitPublishedEnvelope(ctx, log, env) } return &proto.PublishResponse{}, nil diff --git a/pkg/api/server.go b/pkg/api/server.go index 460a5e41..ea9fabf7 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -14,10 +14,13 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" swgui "github.com/swaggest/swgui/v3" + wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" + wakurelay "github.com/waku-org/go-waku/waku/v2/protocol/relay" proto "github.com/xmtp/proto/v3/go/message_api/v1" mlsv1pb "github.com/xmtp/proto/v3/go/mls/api/v1" messagev1openapi "github.com/xmtp/proto/v3/openapi/message_api/v1" "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" + "github.com/xmtp/xmtp-node-go/pkg/topic" "github.com/xmtp/xmtp-node-go/pkg/tracing" "google.golang.org/grpc/health" healthgrpc "google.golang.org/grpc/health/grpc_health_v1" @@ -48,6 +51,8 @@ type Server struct { mlsv1 *mlsv1.Service wg sync.WaitGroup ctx context.Context + ctxCancel func() + wakuRelaySub *wakurelay.Subscription authorizer *WalletAuthorizer } @@ -61,7 +66,7 @@ func New(config *Config) (*Server, error) { Config: config, } - s.ctx = context.Background() + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) // Start gRPC services. err := s.startGRPC() @@ -123,7 +128,12 @@ func (s *Server) startGRPC() error { healthcheck := health.NewServer() healthgrpc.RegisterHealthServer(grpcServer, healthcheck) - s.messagev1, err = messagev1.NewService(s.Waku, s.Log, s.Store) + publishToWakuRelay := func(ctx context.Context, msg *wakupb.WakuMessage) error { + _, err := s.Waku.Relay().Publish(ctx, msg) + return err + } + + s.messagev1, err = messagev1.NewService(s.Log, s.Store, publishToWakuRelay) if err != nil { return errors.Wrap(err, "creating message service") } @@ -131,12 +141,49 @@ func (s *Server) startGRPC() error { // Enable the MLS server if a store is provided if s.Config.MLSStore != nil && s.Config.MLSValidator != nil && s.Config.EnableMls { - s.mlsv1, err = mlsv1.NewService(s.Waku, s.Log, s.Store, s.Config.MLSStore, s.Config.MLSValidator) + s.mlsv1, err = mlsv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, publishToWakuRelay) if err != nil { return errors.Wrap(err, "creating mls service") } mlsv1pb.RegisterMlsApiServer(grpcServer, s.mlsv1) } + + // Initialize waku relay subscription. + s.wakuRelaySub, err = s.Waku.Relay().Subscribe(s.ctx) + if err != nil { + return errors.Wrap(err, "subscribing to relay") + } + tracing.GoPanicWrap(s.ctx, &s.wg, "broadcast", func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case wakuEnv := <-s.wakuRelaySub.Ch: + if wakuEnv == nil || wakuEnv.Message() == nil { + continue + } + wakuMsg := wakuEnv.Message() + + if topic.IsMLSV1(wakuMsg.ContentTopic) { + if s.mlsv1 != nil { + err := s.mlsv1.HandleIncomingWakuRelayMessage(wakuEnv.Message()) + if err != nil { + s.Log.Error("error handling waku relay message by mlsv1 service", zap.Error(err)) + } + } + } else { + if s.messagev1 != nil { + err := s.messagev1.HandleIncomingWakuRelayMessage(wakuEnv.Message()) + if err != nil { + s.Log.Error("error handling waku relay message by messagev1 service", zap.Error(err)) + } + } + } + + } + } + }) + prometheus.Register(grpcServer) tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) { @@ -215,9 +262,21 @@ func (s *Server) startHTTP() error { func (s *Server) Close() { s.Log.Info("closing") + + if s.ctxCancel != nil { + s.ctxCancel() + } + + if s.wakuRelaySub != nil { + s.wakuRelaySub.Unsubscribe() + } + if s.messagev1 != nil { s.messagev1.Close() } + if s.mlsv1 != nil { + s.mlsv1.Close() + } if s.httpListener != nil { err := s.httpListener.Close() diff --git a/pkg/authn/authn.pb.go b/pkg/authn/authn.pb.go index e7e88e3f..dc370405 100644 --- a/pkg/authn/authn.pb.go +++ b/pkg/authn/authn.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.1 -// protoc v3.21.8 +// protoc v4.25.1 // source: authn.proto package authn diff --git a/pkg/migrations/mls/20240109001927_add-messages.down.sql b/pkg/migrations/mls/20240109001927_add-messages.down.sql new file mode 100644 index 00000000..a64da5e3 --- /dev/null +++ b/pkg/migrations/mls/20240109001927_add-messages.down.sql @@ -0,0 +1,6 @@ +SET statement_timeout = 0; + +--bun:split + +DROP TABLE IF EXISTS messages; + diff --git a/pkg/migrations/mls/20240109001927_add-messages.up.sql b/pkg/migrations/mls/20240109001927_add-messages.up.sql new file mode 100644 index 00000000..6e71eb8b --- /dev/null +++ b/pkg/migrations/mls/20240109001927_add-messages.up.sql @@ -0,0 +1,37 @@ +SET statement_timeout = 0; + +--bun:split + +CREATE TABLE group_messages ( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + group_id BYTEA NOT NULL, + data BYTEA NOT NULL, + group_id_data_hash BYTEA NOT NULL +); + +--bun:split + +CREATE INDEX idx_group_messages_group_id_created_at ON group_messages(group_id, created_at); + +--bun:split + +CREATE UNIQUE INDEX idx_group_messages_group_id_data_hash ON group_messages (group_id_data_hash); + +--bun:split + +CREATE TABLE welcome_messages ( + id BIGSERIAL PRIMARY KEY, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + installation_key BYTEA NOT NULL, + data BYTEA NOT NULL, + installation_key_data_hash BYTEA NOT NULL +); + +--bun:split + +CREATE INDEX idx_welcome_messages_installation_key_created_at ON welcome_messages(installation_key, created_at); + +--bun:split + +CREATE UNIQUE INDEX idx_welcome_messages_group_key_data_hash ON welcome_messages (installation_key_data_hash); diff --git a/pkg/mls/api/v1/mock.gen.go b/pkg/mls/api/v1/mock.gen.go new file mode 100644 index 00000000..846dbf10 --- /dev/null +++ b/pkg/mls/api/v1/mock.gen.go @@ -0,0 +1,257 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/xmtp/proto/v3/go/mls/api/v1 (interfaces: MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer) +// +// Generated by this command: +// +// mockgen -package api github.com/xmtp/proto/v3/go/mls/api/v1 MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer +// + +// Package api is a generated GoMock package. +package api + +import ( + context "context" + reflect "reflect" + + v1 "github.com/xmtp/proto/v3/go/mls/api/v1" + gomock "go.uber.org/mock/gomock" + metadata "google.golang.org/grpc/metadata" +) + +// MockMlsApi_SubscribeGroupMessagesServer is a mock of MlsApi_SubscribeGroupMessagesServer interface. +type MockMlsApi_SubscribeGroupMessagesServer struct { + ctrl *gomock.Controller + recorder *MockMlsApi_SubscribeGroupMessagesServerMockRecorder +} + +// MockMlsApi_SubscribeGroupMessagesServerMockRecorder is the mock recorder for MockMlsApi_SubscribeGroupMessagesServer. +type MockMlsApi_SubscribeGroupMessagesServerMockRecorder struct { + mock *MockMlsApi_SubscribeGroupMessagesServer +} + +// NewMockMlsApi_SubscribeGroupMessagesServer creates a new mock instance. +func NewMockMlsApi_SubscribeGroupMessagesServer(ctrl *gomock.Controller) *MockMlsApi_SubscribeGroupMessagesServer { + mock := &MockMlsApi_SubscribeGroupMessagesServer{ctrl: ctrl} + mock.recorder = &MockMlsApi_SubscribeGroupMessagesServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMlsApi_SubscribeGroupMessagesServer) EXPECT() *MockMlsApi_SubscribeGroupMessagesServerMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).Context)) +} + +// RecvMsg mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) RecvMsg(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) RecvMsg(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).RecvMsg), arg0) +} + +// Send mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) Send(arg0 *v1.GroupMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) Send(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).Send), arg0) +} + +// SendHeader mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) SendHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendHeader indicates an expected call of SendHeader. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SendHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SendHeader), arg0) +} + +// SendMsg mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) SendMsg(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SendMsg(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SendMsg), arg0) +} + +// SetHeader mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) SetHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetHeader indicates an expected call of SetHeader. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SetHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SetHeader), arg0) +} + +// SetTrailer mocks base method. +func (m *MockMlsApi_SubscribeGroupMessagesServer) SetTrailer(arg0 metadata.MD) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTrailer", arg0) +} + +// SetTrailer indicates an expected call of SetTrailer. +func (mr *MockMlsApi_SubscribeGroupMessagesServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockMlsApi_SubscribeGroupMessagesServer)(nil).SetTrailer), arg0) +} + +// MockMlsApi_SubscribeWelcomeMessagesServer is a mock of MlsApi_SubscribeWelcomeMessagesServer interface. +type MockMlsApi_SubscribeWelcomeMessagesServer struct { + ctrl *gomock.Controller + recorder *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder +} + +// MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder is the mock recorder for MockMlsApi_SubscribeWelcomeMessagesServer. +type MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder struct { + mock *MockMlsApi_SubscribeWelcomeMessagesServer +} + +// NewMockMlsApi_SubscribeWelcomeMessagesServer creates a new mock instance. +func NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl *gomock.Controller) *MockMlsApi_SubscribeWelcomeMessagesServer { + mock := &MockMlsApi_SubscribeWelcomeMessagesServer{ctrl: ctrl} + mock.recorder = &MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) EXPECT() *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).Context)) +} + +// RecvMsg mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) RecvMsg(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) RecvMsg(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).RecvMsg), arg0) +} + +// Send mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) Send(arg0 *v1.WelcomeMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) Send(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).Send), arg0) +} + +// SendHeader mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SendHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendHeader indicates an expected call of SendHeader. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SendHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHeader", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SendHeader), arg0) +} + +// SendMsg mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SendMsg(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SendMsg(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SendMsg), arg0) +} + +// SetHeader mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SetHeader(arg0 metadata.MD) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetHeader", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetHeader indicates an expected call of SetHeader. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SetHeader(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeader", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SetHeader), arg0) +} + +// SetTrailer mocks base method. +func (m *MockMlsApi_SubscribeWelcomeMessagesServer) SetTrailer(arg0 metadata.MD) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetTrailer", arg0) +} + +// SetTrailer indicates an expected call of SetTrailer. +func (mr *MockMlsApi_SubscribeWelcomeMessagesServerMockRecorder) SetTrailer(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTrailer", reflect.TypeOf((*MockMlsApi_SubscribeWelcomeMessagesServer)(nil).SetTrailer), arg0) +} diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 97731ec9..7695cb88 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -2,45 +2,125 @@ package api import ( "context" - - wakunode "github.com/waku-org/go-waku/waku/v2/node" + "encoding/hex" + "errors" + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" - proto "github.com/xmtp/proto/v3/go/mls/api/v1" - "github.com/xmtp/xmtp-node-go/pkg/metrics" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" - "github.com/xmtp/xmtp-node-go/pkg/store" "github.com/xmtp/xmtp-node-go/pkg/topic" "go.uber.org/zap" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + pb "google.golang.org/protobuf/proto" emptypb "google.golang.org/protobuf/types/known/emptypb" ) type Service struct { - proto.UnimplementedMlsApiServer + mlsv1.UnimplementedMlsApiServer log *zap.Logger - waku *wakunode.WakuNode - messageStore *store.Store - mlsStore mlsstore.MlsStore + store mlsstore.MlsStore validationService mlsvalidate.MLSValidationService + + publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error + + ns *server.Server + nc *nats.Conn + + ctx context.Context + ctxCancel func() } -func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store.Store, mlsStore mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) { +func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { s = &Service{ - log: logger.Named("mls/v1"), - waku: node, - messageStore: messageStore, - mlsStore: mlsStore, - validationService: validationService, + log: log.Named("mls/v1"), + store: store, + validationService: validationService, + publishToWakuRelay: publishToWakuRelay, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + + // Initialize nats for subscriptions. + s.ns, err = server.NewServer(&server.Options{ + Port: server.RANDOM_PORT, + }) + if err != nil { + return nil, err + } + go s.ns.Start() + if !s.ns.ReadyForConnections(4 * time.Second) { + return nil, errors.New("nats not ready") + } + s.nc, err = nats.Connect(s.ns.ClientURL()) + if err != nil { + return nil, err } s.log.Info("Starting MLS service") return s, nil } -func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) { +func (s *Service) Close() { + s.log.Info("closing") + + if s.ctxCancel != nil { + s.ctxCancel() + } + + if s.nc != nil { + s.nc.Close() + } + if s.ns != nil { + s.ns.Shutdown() + } + + s.log.Info("closed") +} + +func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) error { + if topic.IsMLSV1Group(wakuMsg.ContentTopic) { + var msg mlsv1.GroupMessage + err := pb.Unmarshal(wakuMsg.Payload, &msg) + if err != nil { + return err + } + if msg.GetV1() == nil { + return nil + } + err = s.nc.Publish(buildNatsSubjectForGroupMessages(msg.GetV1().GroupId), wakuMsg.Payload) + if err != nil { + return err + } + } else if topic.IsMLSV1Welcome(wakuMsg.ContentTopic) { + var msg mlsv1.WelcomeMessage + err := pb.Unmarshal(wakuMsg.Payload, &msg) + if err != nil { + return err + } + if msg.GetV1() == nil { + return nil + } + err = s.nc.Publish(buildNatsSubjectForWelcomeMessages(msg.GetV1().InstallationKey), wakuMsg.Payload) + if err != nil { + return err + } + } else { + s.log.Info("received unknown mls message type from waku relay", zap.String("topic", wakuMsg.ContentTopic)) + } + + return nil +} + +func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) { if err := validateRegisterInstallationRequest(req); err != nil { return nil, err } @@ -53,22 +133,22 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) } - installationId := results[0].InstallationId + installationId := results[0].InstallationKey accountAddress := results[0].AccountAddress credentialIdentity := results[0].CredentialIdentity - if err = s.mlsStore.CreateInstallation(ctx, installationId, accountAddress, credentialIdentity, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { + if err = s.store.CreateInstallation(ctx, installationId, accountAddress, credentialIdentity, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { return nil, err } - return &proto.RegisterInstallationResponse{ - InstallationId: installationId, + return &mlsv1.RegisterInstallationResponse{ + InstallationKey: installationId, }, nil } -func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPackagesRequest) (*proto.FetchKeyPackagesResponse, error) { - ids := req.InstallationIds - installations, err := s.mlsStore.FetchKeyPackages(ctx, ids) +func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPackagesRequest) (*mlsv1.FetchKeyPackagesResponse, error) { + ids := req.InstallationKeys + installations, err := s.store.FetchKeyPackages(ctx, ids) if err != nil { return nil, status.Errorf(codes.Internal, "failed to fetch key packages: %s", err) } @@ -77,7 +157,7 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPacka keyPackageMap[string(id)] = idx } - resPackages := make([]*proto.FetchKeyPackagesResponse_KeyPackage, len(ids)) + resPackages := make([]*mlsv1.FetchKeyPackagesResponse_KeyPackage, len(ids)) for _, installation := range installations { idx, ok := keyPackageMap[string(installation.ID)] @@ -85,158 +165,292 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPacka return nil, status.Errorf(codes.Internal, "could not find key package for installation") } - resPackages[idx] = &proto.FetchKeyPackagesResponse_KeyPackage{ + resPackages[idx] = &mlsv1.FetchKeyPackagesResponse_KeyPackage{ KeyPackageTlsSerialized: installation.KeyPackage, } } - return &proto.FetchKeyPackagesResponse{ + return &mlsv1.FetchKeyPackagesResponse{ KeyPackages: resPackages, }, nil } -func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (res *emptypb.Empty, err error) { - if err = validatePublishToGroupRequest(req); err != nil { +func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPackageRequest) (res *emptypb.Empty, err error) { + if err = validateUploadKeyPackageRequest(req); err != nil { return nil, err } + // Extract the key packages from the request + keyPackageBytes := req.KeyPackage.KeyPackageTlsSerialized + + validationResults, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{keyPackageBytes}) + if err != nil { + // TODO: Differentiate between validation errors and internal errors + return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) + } + installationId := validationResults[0].InstallationKey + expiration := validationResults[0].Expiration - messages := make([][]byte, len(req.Messages)) - for i, message := range req.Messages { - v1 := message.GetV1() - if v1 == nil { - return nil, status.Errorf(codes.InvalidArgument, "message must be v1") - } - messages[i] = v1.MlsMessageTlsSerialized + if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { + return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) + } + + return &emptypb.Empty{}, nil +} + +func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "unimplemented") +} + +func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) { + if err = validateGetIdentityUpdatesRequest(req); err != nil { + return nil, err } - validationResults, err := s.validationService.ValidateGroupMessages(ctx, messages) + accountAddresses := req.AccountAddresses + updates, err := s.store.GetIdentityUpdates(ctx, req.AccountAddresses, int64(req.StartTimeNs)) if err != nil { - // TODO: Separate validation errors from internal errors - return nil, status.Errorf(codes.InvalidArgument, "invalid group message: %s", err) + return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err) } - for i, result := range validationResults { - message := messages[i] + resUpdates := make([]*mlsv1.GetIdentityUpdatesResponse_WalletUpdates, len(accountAddresses)) + for i, accountAddress := range accountAddresses { + walletUpdates := updates[accountAddress] - if err = requireReadyToSend(result.GroupId, message); err != nil { - return nil, err + resUpdates[i] = &mlsv1.GetIdentityUpdatesResponse_WalletUpdates{ + Updates: []*mlsv1.GetIdentityUpdatesResponse_Update{}, } - // TODO: Wrap this in a transaction so publishing is all or nothing - if err = s.publishMessage(ctx, topic.BuildGroupTopic(result.GroupId), message); err != nil { - return nil, status.Errorf(codes.Internal, "failed to publish message: %s", err) + for _, walletUpdate := range walletUpdates { + resUpdates[i].Updates = append(resUpdates[i].Updates, buildIdentityUpdate(walletUpdate)) } } - return &emptypb.Empty{}, nil + return &mlsv1.GetIdentityUpdatesResponse{ + Updates: resUpdates, + }, nil } -func (s *Service) publishMessage(ctx context.Context, contentTopic string, message []byte) error { - log := s.log.Named("publish-mls").With(zap.String("content_topic", contentTopic)) - env, err := s.messageStore.InsertMLSMessage(ctx, contentTopic, message) - if err != nil { - return status.Errorf(codes.Internal, "failed to insert message: %s", err) +func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) { + if err = validateSendGroupMessagesRequest(req); err != nil { + return nil, err } - if _, err = s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ - ContentTopic: contentTopic, - Timestamp: int64(env.TimestampNs), - Payload: message, - }); err != nil { - return status.Errorf(codes.Internal, "failed to publish message: %s", err) + validationResults, err := s.validationService.ValidateGroupMessages(ctx, req.Messages) + if err != nil { + // TODO: Separate validation errors from internal errors + return nil, status.Errorf(codes.InvalidArgument, "invalid group message: %s", err) } - metrics.EmitPublishedEnvelope(ctx, log, env) + for i, result := range validationResults { + input := req.Messages[i] - return nil -} + if err = requireReadyToSend(result.GroupId, input.GetV1().Data); err != nil { + return nil, err + } -func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcomesRequest) (res *emptypb.Empty, err error) { - if err = validatePublishWelcomesRequest(req); err != nil { - return nil, err - } + // TODO: Wrap this in a transaction so publishing is all or nothing + decodedGroupId, err := hex.DecodeString(result.GroupId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid group id") + } + msg, err := s.store.InsertGroupMessage(ctx, decodedGroupId, input.GetV1().Data) + if err != nil { + if mlsstore.IsAlreadyExistsError(err) { + continue + } + return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) + } - // TODO: Wrap this in a transaction so publishing is all or nothing - for _, welcome := range req.WelcomeMessages { - contentTopic := topic.BuildWelcomeTopic(welcome.InstallationId) - if err = s.publishMessage(ctx, contentTopic, welcome.WelcomeMessage.GetV1().WelcomeMessageTlsSerialized); err != nil { - return nil, status.Errorf(codes.Internal, "failed to publish welcome message: %s", err) + msgB, err := pb.Marshal(&mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + GroupId: msg.GroupId, + Data: msg.Data, + }, + }, + }) + if err != nil { + return nil, err + } + + err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1GroupTopic(decodedGroupId), + Timestamp: msg.CreatedAt.UnixNano(), + Payload: msgB, + }) + if err != nil { + return nil, err } } + return &emptypb.Empty{}, nil } -func (s *Service) UploadKeyPackage(ctx context.Context, req *proto.UploadKeyPackageRequest) (res *emptypb.Empty, err error) { - if err = validateUploadKeyPackageRequest(req); err != nil { +func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcomeMessagesRequest) (res *emptypb.Empty, err error) { + if err = validateSendWelcomeMessagesRequest(req); err != nil { return nil, err } - // Extract the key packages from the request - keyPackageBytes := req.KeyPackage.KeyPackageTlsSerialized - validationResults, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{keyPackageBytes}) - if err != nil { - // TODO: Differentiate between validation errors and internal errors - return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) - } - installationId := validationResults[0].InstallationId - expiration := validationResults[0].Expiration + // TODO: Wrap this in a transaction so publishing is all or nothing + for _, input := range req.Messages { + msg, err := s.store.InsertWelcomeMessage(ctx, input.GetV1().InstallationKey, input.GetV1().Data) + if err != nil { + if mlsstore.IsAlreadyExistsError(err) { + continue + } + return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) + } - if err = s.mlsStore.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { - return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) - } + msgB, err := pb.Marshal(&mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + InstallationKey: msg.InstallationKey, + Data: msg.Data, + }, + }, + }) + if err != nil { + return nil, err + } + err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1WelcomeTopic(input.GetV1().InstallationKey), + Timestamp: msg.CreatedAt.UnixNano(), + Payload: msgB, + }) + if err != nil { + return nil, err + } + } return &emptypb.Empty{}, nil } -func (s *Service) RevokeInstallation(ctx context.Context, req *proto.RevokeInstallationRequest) (*emptypb.Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") +func (s *Service) QueryGroupMessages(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { + return s.store.QueryGroupMessagesV1(ctx, req) } -func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (res *proto.GetIdentityUpdatesResponse, err error) { - if err = validateGetIdentityUpdatesRequest(req); err != nil { - return nil, err - } +func (s *Service) QueryWelcomeMessages(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { + return s.store.QueryWelcomeMessagesV1(ctx, req) +} - accountAddresses := req.AccountAddresses - updates, err := s.mlsStore.GetIdentityUpdates(ctx, req.AccountAddresses, int64(req.StartTimeNs)) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err) +func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesRequest, stream mlsv1.MlsApi_SubscribeGroupMessagesServer) error { + log := s.log.Named("subscribe-group-messages").With(zap.Int("filters", len(req.Filters))) + + // Send a header (any header) to fix an issue with Tonic based GRPC clients. + // See: https://github.com/xmtp/libxmtp/pull/58 + _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) + + var streamLock sync.Mutex + for _, filter := range req.Filters { + natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId) + sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { + var msg mlsv1.GroupMessage + err := pb.Unmarshal(natsMsg.Data, &msg) + if err != nil { + log.Error("parsing group message from bytes", zap.Error(err)) + return + } + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(&msg) + if err != nil { + log.Error("sending group message to subscribe", zap.Error(err)) + } + }() + }) + if err != nil { + log.Error("error subscribing to group messages", zap.Error(err)) + return err + } + defer func() { + _ = sub.Unsubscribe() + }() } - resUpdates := make([]*proto.GetIdentityUpdatesResponse_WalletUpdates, len(accountAddresses)) - for i, accountAddress := range accountAddresses { - walletUpdates := updates[accountAddress] + select { + case <-stream.Context().Done(): + return nil + case <-s.ctx.Done(): + return nil + } +} - resUpdates[i] = &proto.GetIdentityUpdatesResponse_WalletUpdates{ - Updates: []*proto.GetIdentityUpdatesResponse_Update{}, +func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRequest, stream mlsv1.MlsApi_SubscribeWelcomeMessagesServer) error { + log := s.log.Named("subscribe-welcome-messages").With(zap.Int("filters", len(req.Filters))) + + // Send a header (any header) to fix an issue with Tonic based GRPC clients. + // See: https://github.com/xmtp/libxmtp/pull/58 + _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) + + var streamLock sync.Mutex + for _, filter := range req.Filters { + natsSubject := buildNatsSubjectForWelcomeMessages(filter.InstallationKey) + sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { + var msg mlsv1.WelcomeMessage + err := pb.Unmarshal(natsMsg.Data, &msg) + if err != nil { + log.Error("parsing welcome message from bytes", zap.Error(err)) + return + } + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(&msg) + if err != nil { + log.Error("sending welcome message to subscribe", zap.Error(err)) + } + }() + }) + if err != nil { + log.Error("error subscribing to welcome messages", zap.Error(err)) + return err } + defer func() { + _ = sub.Unsubscribe() + }() + } - for _, walletUpdate := range walletUpdates { - resUpdates[i].Updates = append(resUpdates[i].Updates, buildIdentityUpdate(walletUpdate)) - } + select { + case <-stream.Context().Done(): + return nil + case <-s.ctx.Done(): + return nil } +} - return &proto.GetIdentityUpdatesResponse{ - Updates: resUpdates, - }, nil +func buildNatsSubjectForGroupMessages(groupId []byte) string { + hasher := fnv.New64a() + hasher.Write(groupId) + return fmt.Sprintf("gm-%x", hasher.Sum64()) } -func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdatesResponse_Update { - base := proto.GetIdentityUpdatesResponse_Update{ +func buildNatsSubjectForWelcomeMessages(installationId []byte) string { + hasher := fnv.New64a() + hasher.Write(installationId) + return fmt.Sprintf("wm-%x", hasher.Sum64()) +} + +func buildIdentityUpdate(update mlsstore.IdentityUpdate) *mlsv1.GetIdentityUpdatesResponse_Update { + base := mlsv1.GetIdentityUpdatesResponse_Update{ TimestampNs: update.TimestampNs, } switch update.Kind { case mlsstore.Create: - base.Kind = &proto.GetIdentityUpdatesResponse_Update_NewInstallation{ - NewInstallation: &proto.GetIdentityUpdatesResponse_NewInstallationUpdate{ - InstallationId: update.InstallationId, + base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_NewInstallation{ + NewInstallation: &mlsv1.GetIdentityUpdatesResponse_NewInstallationUpdate{ + InstallationKey: update.InstallationKey, CredentialIdentity: update.CredentialIdentity, }, } case mlsstore.Revoke: - base.Kind = &proto.GetIdentityUpdatesResponse_Update_RevokedInstallation{ - RevokedInstallation: &proto.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ - InstallationId: update.InstallationId, + base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_RevokedInstallation{ + RevokedInstallation: &mlsv1.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ + InstallationKey: update.InstallationKey, }, } } @@ -244,45 +458,45 @@ func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdat return &base } -func validatePublishToGroupRequest(req *proto.PublishToGroupRequest) error { +func validateSendGroupMessagesRequest(req *mlsv1.SendGroupMessagesRequest) error { if req == nil || len(req.Messages) == 0 { - return status.Errorf(codes.InvalidArgument, "no messages to publish") + return status.Errorf(codes.InvalidArgument, "no group messages to send") + } + for _, input := range req.Messages { + if input == nil || input.GetV1() == nil { + return status.Errorf(codes.InvalidArgument, "invalid group message") + } } return nil } -func validatePublishWelcomesRequest(req *proto.PublishWelcomesRequest) error { - if req == nil || len(req.WelcomeMessages) == 0 { - return status.Errorf(codes.InvalidArgument, "no welcome messages to publish") +func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) error { + if req == nil || len(req.Messages) == 0 { + return status.Errorf(codes.InvalidArgument, "no welcome messages to send") } - for _, welcome := range req.WelcomeMessages { - if welcome == nil || welcome.WelcomeMessage == nil { - return status.Errorf(codes.InvalidArgument, "invalid welcome message") - } - - v1 := welcome.WelcomeMessage.GetV1() - if v1 == nil || len(v1.WelcomeMessageTlsSerialized) == 0 { + for _, input := range req.Messages { + if input == nil || input.GetV1() == nil { return status.Errorf(codes.InvalidArgument, "invalid welcome message") } } return nil } -func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error { +func validateRegisterInstallationRequest(req *mlsv1.RegisterInstallationRequest) error { if req == nil || req.KeyPackage == nil { return status.Errorf(codes.InvalidArgument, "no key package") } return nil } -func validateUploadKeyPackageRequest(req *proto.UploadKeyPackageRequest) error { +func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error { if req == nil || req.KeyPackage == nil { return status.Errorf(codes.InvalidArgument, "no key package") } return nil } -func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) error { +func validateGetIdentityUpdatesRequest(req *mlsv1.GetIdentityUpdatesRequest) error { if req == nil || len(req.AccountAddresses) == 0 { return status.Errorf(codes.InvalidArgument, "no wallet addresses to get updates for") } @@ -290,7 +504,7 @@ func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) err } func requireReadyToSend(groupId string, message []byte) error { - if groupId == "" { + if len(groupId) == 0 { return status.Errorf(codes.InvalidArgument, "group id is empty") } if len(message) == 0 { diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index 8ad15ba3..d725318b 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -5,17 +5,19 @@ import ( "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/uptrace/bun" - v1 "github.com/xmtp/proto/v3/go/message_api/v1" - proto "github.com/xmtp/proto/v3/go/mls/api/v1" - messageContents "github.com/xmtp/proto/v3/go/mls/message_contents" + wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" - "github.com/xmtp/xmtp-node-go/pkg/store" test "github.com/xmtp/xmtp-node-go/pkg/testing" + "github.com/xmtp/xmtp-node-go/pkg/topic" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/proto" ) type mockedMLSValidationService struct { @@ -33,7 +35,7 @@ func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, ke return response.([]mlsvalidate.IdentityValidationResult), args.Error(1) } -func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) { +func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages []*mlsv1.GroupMessageInput) ([]mlsvalidate.GroupMessageValidationResult, error) { args := m.Called(ctx, groupMessages) return args.Get(0).([]mlsvalidate.GroupMessageValidationResult), args.Error(1) @@ -46,7 +48,7 @@ func newMockedValidationService() *mockedMLSValidationService { func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []byte, accountAddress string) *mock.Call { return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{ { - InstallationId: installationId, + InstallationKey: installationId, AccountAddress: accountAddress, CredentialIdentity: []byte("test"), Expiration: 0, @@ -54,41 +56,32 @@ func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []by }, nil) } -func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId string) *mock.Call { +func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId []byte) *mock.Call { return m.On("ValidateGroupMessages", mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{ { - GroupId: groupId, + GroupId: fmt.Sprintf("%x", groupId), }, }, nil) } func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMLSValidationService, func()) { log := test.NewLog(t) - mlsDb, _, mlsDbCleanup := test.NewMLSDB(t) - mlsStore, err := mlsstore.New(ctx, mlsstore.Config{ + db, _, mlsDbCleanup := test.NewMLSDB(t) + store, err := mlsstore.New(ctx, mlsstore.Config{ Log: log, - DB: mlsDb, + DB: db, }) require.NoError(t, err) - messageDb, _, messageDbCleanup := test.NewDB(t) - messageStore, err := store.New(&store.Config{ - Log: log, - DB: messageDb, - ReaderDB: messageDb, - CleanerDB: messageDb, - }) - require.NoError(t, err) - node, nodeCleanup := test.NewNode(t) mlsValidationService := newMockedValidationService() - svc, err := NewService(node, log, messageStore, mlsStore, mlsValidationService) + svc, err := NewService(log, store, mlsValidationService, func(ctx context.Context, wm *wakupb.WakuMessage) error { + return nil + }) require.NoError(t, err) - return svc, mlsDb, mlsValidationService, func() { - messageStore.Close() + return svc, db, mlsValidationService, func() { + svc.Close() mlsDbCleanup() - messageDbCleanup() - nodeCleanup() } } @@ -102,14 +95,14 @@ func TestRegisterInstallation(t *testing.T) { mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) - res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) require.NoError(t, err) - require.Equal(t, installationId, res.InstallationId) + require.Equal(t, installationId, res.InstallationKey) installations := []mlsstore.Installation{} err = mlsDb.NewSelect().Model(&installations).Where("id = ?", installationId).Scan(ctx) @@ -126,8 +119,8 @@ func TestRegisterInstallationError(t *testing.T) { mlsValidationService.On("ValidateKeyPackages", ctx, mock.Anything).Return(nil, errors.New("error validating")) - res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -145,16 +138,16 @@ func TestUploadKeyPackage(t *testing.T) { mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) - res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) require.NoError(t, err) require.NotNil(t, res) - uploadRes, err := svc.UploadKeyPackage(ctx, &proto.UploadKeyPackageRequest{ - KeyPackage: &proto.KeyPackageUpload{ + uploadRes, err := svc.UploadKeyPackage(ctx, &mlsv1.UploadKeyPackageRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test2"), }, }) @@ -176,8 +169,8 @@ func TestFetchKeyPackages(t *testing.T) { mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, accountAddress1) - res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -191,16 +184,16 @@ func TestFetchKeyPackages(t *testing.T) { mockCall.Unset() mlsValidationService.mockValidateKeyPackages(installationId2, accountAddress2) - res, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + res, err = svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test2"), }, }) require.NoError(t, err) require.NotNil(t, res) - consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ - InstallationIds: [][]byte{installationId1, installationId2}, + consumeRes, err := svc.FetchKeyPackages(ctx, &mlsv1.FetchKeyPackagesRequest{ + InstallationKeys: [][]byte{installationId1, installationId2}, }) require.NoError(t, err) require.NotNil(t, consumeRes) @@ -209,8 +202,8 @@ func TestFetchKeyPackages(t *testing.T) { require.Equal(t, []byte("test2"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized) // Now do it with the installationIds reversed - consumeRes, err = svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ - InstallationIds: [][]byte{installationId2, installationId1}, + consumeRes, err = svc.FetchKeyPackages(ctx, &mlsv1.FetchKeyPackagesRequest{ + InstallationKeys: [][]byte{installationId2, installationId1}, }) require.NoError(t, err) @@ -226,40 +219,72 @@ func TestFetchKeyPackagesFail(t *testing.T) { svc, _, _, cleanup := newTestService(t, ctx) defer cleanup() - consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ - InstallationIds: [][]byte{test.RandomBytes(32)}, + consumeRes, err := svc.FetchKeyPackages(ctx, &mlsv1.FetchKeyPackagesRequest{ + InstallationKeys: [][]byte{test.RandomBytes(32)}, }) require.Nil(t, err) - require.Equal(t, []*proto.FetchKeyPackagesResponse_KeyPackage{nil}, consumeRes.KeyPackages) + require.Equal(t, []*mlsv1.FetchKeyPackagesResponse_KeyPackage{nil}, consumeRes.KeyPackages) } -func TestPublishToGroup(t *testing.T) { +func TestSendGroupMessages(t *testing.T) { ctx := context.Background() svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - groupId := test.RandomString(32) + groupId := []byte(test.RandomString(32)) mlsValidationService.mockValidateGroupMessages(groupId) - _, err := svc.PublishToGroup(ctx, &proto.PublishToGroupRequest{ - Messages: []*messageContents.GroupMessage{{ - Version: &messageContents.GroupMessage_V1_{ - V1: &messageContents.GroupMessage_V1{ - MlsMessageTlsSerialized: []byte("test"), + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("test"), + }, }, }, - }}, + }, }) require.NoError(t, err) - results, err := svc.messageStore.Query(&v1.QueryRequest{ - ContentTopics: []string{fmt.Sprintf("/xmtp/mls/1/g-%s/proto", groupId)}, + resp, err := svc.store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: groupId, }) require.NoError(t, err) - require.Len(t, results.Envelopes, 1) - require.Equal(t, results.Envelopes[0].Message, []byte("test")) - require.NotNil(t, results.Envelopes[0].TimestampNs) + require.Len(t, resp.Messages, 1) + require.Equal(t, resp.Messages[0].GetV1().Data, []byte("test")) + require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) +} + +func TestSendWelcomeMessages(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + installationId := []byte(test.RandomString(32)) + + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: []byte(installationId), + Data: []byte("test"), + }, + }, + }, + }, + }) + require.NoError(t, err) + + resp, err := svc.store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: installationId, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, resp.Messages[0].GetV1().Data, []byte("test")) + require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) } func TestGetIdentityUpdates(t *testing.T) { @@ -272,41 +297,185 @@ func TestGetIdentityUpdates(t *testing.T) { mockCall := mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) - _, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + _, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) require.NoError(t, err) - identityUpdates, err := svc.GetIdentityUpdates(ctx, &proto.GetIdentityUpdatesRequest{ + identityUpdates, err := svc.GetIdentityUpdates(ctx, &mlsv1.GetIdentityUpdatesRequest{ AccountAddresses: []string{accountAddress}, }) require.NoError(t, err) require.NotNil(t, identityUpdates) require.Len(t, identityUpdates.Updates, 1) - require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().InstallationId, installationId) + require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().InstallationKey, installationId) require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().CredentialIdentity, []byte("test")) for _, walletUpdate := range identityUpdates.Updates { for _, update := range walletUpdate.Updates { - require.Equal(t, installationId, update.GetNewInstallation().InstallationId) + require.Equal(t, installationId, update.GetNewInstallation().InstallationKey) } } mockCall.Unset() mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), accountAddress) - _, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - KeyPackage: &proto.KeyPackageUpload{ + _, err = svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ + KeyPackage: &mlsv1.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) require.NoError(t, err) - identityUpdates, err = svc.GetIdentityUpdates(ctx, &proto.GetIdentityUpdatesRequest{ + identityUpdates, err = svc.GetIdentityUpdates(ctx, &mlsv1.GetIdentityUpdatesRequest{ AccountAddresses: []string{accountAddress}, }) require.NoError(t, err) require.Len(t, identityUpdates.Updates, 1) require.Len(t, identityUpdates.Updates[0].Updates, 2) } + +func TestSubscribeGroupMessages(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + groupId := []byte(test.RandomString(32)) + + msgs := make([]*mlsv1.GroupMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: uint64(i + 1), + CreatedNs: uint64(i + 1), + GroupId: groupId, + Data: []byte(fmt.Sprintf("data%d", i+1)), + }, + }, + } + } + + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + for _, msg := range msgs { + stream.EXPECT().Send(newGroupMessageEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx) + + go func() { + err := svc.SubscribeGroupMessages(&mlsv1.SubscribeGroupMessagesRequest{ + Filters: []*mlsv1.SubscribeGroupMessagesRequest_Filter{ + { + GroupId: groupId, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + for _, msg := range msgs { + msgB, err := proto.Marshal(msg) + require.NoError(t, err) + + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + } + + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + +func TestSubscribeWelcomeMessages(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + installationKey := []byte(test.RandomString(32)) + + msgs := make([]*mlsv1.WelcomeMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: uint64(i + 1), + CreatedNs: uint64(i + 1), + InstallationKey: installationKey, + Data: []byte(fmt.Sprintf("data%d", i+1)), + }, + }, + } + } + + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + for _, msg := range msgs { + stream.EXPECT().Send(newWelcomeMessageEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx) + + go func() { + err := svc.SubscribeWelcomeMessages(&mlsv1.SubscribeWelcomeMessagesRequest{ + Filters: []*mlsv1.SubscribeWelcomeMessagesRequest_Filter{ + { + InstallationKey: installationKey, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + for _, msg := range msgs { + msgB, err := proto.Marshal(msg) + require.NoError(t, err) + + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + } + + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + +type groupMessageEqualsMatcher struct { + obj *mlsv1.GroupMessage +} + +func newGroupMessageEqualsMatcher(obj *mlsv1.GroupMessage) *groupMessageEqualsMatcher { + return &groupMessageEqualsMatcher{obj} +} + +func (m *groupMessageEqualsMatcher) Matches(obj interface{}) bool { + return proto.Equal(m.obj, obj.(*mlsv1.GroupMessage)) +} + +func (m *groupMessageEqualsMatcher) String() string { + return m.obj.String() +} + +type welcomeMessageEqualsMatcher struct { + obj *mlsv1.WelcomeMessage +} + +func newWelcomeMessageEqualsMatcher(obj *mlsv1.WelcomeMessage) *welcomeMessageEqualsMatcher { + return &welcomeMessageEqualsMatcher{obj} +} + +func (m *welcomeMessageEqualsMatcher) Matches(obj interface{}) bool { + return proto.Equal(m.obj, obj.(*mlsv1.WelcomeMessage)) +} + +func (m *welcomeMessageEqualsMatcher) String() string { + return m.obj.String() +} diff --git a/pkg/mls/store/config.go b/pkg/mls/store/config.go index 6da5323a..064767c9 100644 --- a/pkg/mls/store/config.go +++ b/pkg/mls/store/config.go @@ -17,4 +17,6 @@ type StoreOptions struct { type Config struct { Log *zap.Logger DB *bun.DB + + now func() time.Time } diff --git a/pkg/mls/store/models.go b/pkg/mls/store/models.go index 02ec32cc..d9ceccea 100644 --- a/pkg/mls/store/models.go +++ b/pkg/mls/store/models.go @@ -1,6 +1,10 @@ package store -import "github.com/uptrace/bun" +import ( + "time" + + "github.com/uptrace/bun" +) type Installation struct { bun.BaseModel `bun:"table:installations"` @@ -15,3 +19,21 @@ type Installation struct { KeyPackage []byte `bun:"key_package,notnull,type:bytea"` Expiration uint64 `bun:"expiration,notnull"` } + +type GroupMessage struct { + bun.BaseModel `bun:"table:group_messages"` + + Id uint64 `bun:",pk,notnull"` + CreatedAt time.Time `bun:",notnull"` + GroupId []byte `bun:",notnull,type:bytea"` + Data []byte `bun:",notnull,type:bytea"` +} + +type WelcomeMessage struct { + bun.BaseModel `bun:"table:welcome_messages"` + + Id uint64 `bun:",pk,notnull"` + CreatedAt time.Time `bun:",notnull"` + InstallationKey []byte `bun:",notnull,type:bytea"` + Data []byte `bun:",notnull,type:bytea"` +} diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index c338fdf3..8e64318b 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -2,16 +2,21 @@ package store import ( "context" + "crypto/sha256" "errors" "sort" + "strings" "time" "github.com/uptrace/bun" "github.com/uptrace/bun/migrate" - mlsMigrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" + migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" "go.uber.org/zap" ) +const maxPageSize = 100 + type Store struct { config Config log *zap.Logger @@ -23,9 +28,18 @@ type MlsStore interface { UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]*Installation, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) + InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) + InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte) (*WelcomeMessage, error) + QueryGroupMessagesV1(ctx context.Context, query *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) + QueryWelcomeMessagesV1(ctx context.Context, query *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) } func New(ctx context.Context, config Config) (*Store, error) { + if config.now == nil { + config.now = func() time.Time { + return time.Now().UTC() + } + } s := &Store{ log: config.Log.Named("mlsstore"), db: config.DB, @@ -125,16 +139,16 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string if installation.CreatedAt > startTimeNs { out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ Kind: Create, - InstallationId: installation.ID, + InstallationKey: installation.ID, CredentialIdentity: installation.CredentialIdentity, TimestampNs: uint64(installation.CreatedAt), }) } if installation.RevokedAt != nil && *installation.RevokedAt > startTimeNs { out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ - Kind: Revoke, - InstallationId: installation.ID, - TimestampNs: uint64(*installation.RevokedAt), + Kind: Revoke, + InstallationKey: installation.ID, + TimestampNs: uint64(*installation.RevokedAt), }) } } @@ -157,8 +171,185 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) e return err } +func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) { + message := GroupMessage{ + Data: data, + } + + var id uint64 + err := s.db.QueryRow("INSERT INTO group_messages (group_id, data, group_id_data_hash) VALUES (?, ?, ?) RETURNING id", groupId, data, sha256.Sum256(append(groupId, data...))).Scan(&id) + if err != nil { + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return nil, NewAlreadyExistsError(err) + } + return nil, err + } + + err = s.db.NewSelect().Model(&message).Where("id = ?", id).Scan(ctx) + if err != nil { + return nil, err + } + + return &message, nil +} + +func (s *Store) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte) (*WelcomeMessage, error) { + message := WelcomeMessage{ + Data: data, + } + + var id uint64 + err := s.db.QueryRow("INSERT INTO welcome_messages (installation_key, data, installation_key_data_hash) VALUES (?, ?, ?) RETURNING id", installationId, data, sha256.Sum256(append(installationId, data...))).Scan(&id) + if err != nil { + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return nil, NewAlreadyExistsError(err) + } + return nil, err + } + + err = s.db.NewSelect().Model(&message).Where("id = ?", id).Scan(ctx) + if err != nil { + return nil, err + } + + return &message, nil +} + +func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { + msgs := make([]*GroupMessage, 0) + + if len(req.GroupId) == 0 { + return nil, errors.New("group is required") + } + + q := s.db.NewSelect(). + Model(&msgs). + Where("group_id = ?", req.GroupId) + + direction := mlsv1.SortDirection_SORT_DIRECTION_DESCENDING + if req.PagingInfo != nil && req.PagingInfo.Direction != mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { + direction = req.PagingInfo.Direction + } + switch direction { + case mlsv1.SortDirection_SORT_DIRECTION_DESCENDING: + q = q.Order("id DESC") + case mlsv1.SortDirection_SORT_DIRECTION_ASCENDING: + q = q.Order("id ASC") + } + + pageSize := maxPageSize + if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { + pageSize = int(req.PagingInfo.Limit) + } + q = q.Limit(pageSize) + + if req.PagingInfo != nil && req.PagingInfo.IdCursor != 0 { + if direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { + q = q.Where("id > ?", req.PagingInfo.IdCursor) + } else { + q = q.Where("id < ?", req.PagingInfo.IdCursor) + } + } + + err := q.Scan(ctx) + if err != nil { + return nil, err + } + + messages := make([]*mlsv1.GroupMessage, 0, len(msgs)) + for _, msg := range msgs { + messages = append(messages, &mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + GroupId: msg.GroupId, + Data: msg.Data, + }, + }, + }) + } + + pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} + if len(messages) >= pageSize { + lastMsg := msgs[len(messages)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + return &mlsv1.QueryGroupMessagesResponse{ + Messages: messages, + PagingInfo: pagingInfo, + }, nil +} + +func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { + msgs := make([]*WelcomeMessage, 0) + + if len(req.InstallationKey) == 0 { + return nil, errors.New("installation is required") + } + + q := s.db.NewSelect(). + Model(&msgs). + Where("installation_key = ?", req.InstallationKey) + + direction := mlsv1.SortDirection_SORT_DIRECTION_DESCENDING + if req.PagingInfo != nil && req.PagingInfo.Direction != mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { + direction = req.PagingInfo.Direction + } + switch direction { + case mlsv1.SortDirection_SORT_DIRECTION_DESCENDING: + q = q.Order("id DESC") + case mlsv1.SortDirection_SORT_DIRECTION_ASCENDING: + q = q.Order("id ASC") + } + + pageSize := maxPageSize + if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { + pageSize = int(req.PagingInfo.Limit) + } + q = q.Limit(pageSize) + + if req.PagingInfo != nil && req.PagingInfo.IdCursor != 0 { + if direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { + q = q.Where("id > ?", req.PagingInfo.IdCursor) + } else { + q = q.Where("id < ?", req.PagingInfo.IdCursor) + } + } + + err := q.Scan(ctx) + if err != nil { + return nil, err + } + + messages := make([]*mlsv1.WelcomeMessage, 0, len(msgs)) + for _, msg := range msgs { + messages = append(messages, &mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + Data: msg.Data, + }, + }, + }) + } + + pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} + if len(messages) >= pageSize { + lastMsg := msgs[len(messages)-1] + pagingInfo.IdCursor = lastMsg.Id + } + + return &mlsv1.QueryWelcomeMessagesResponse{ + Messages: messages, + PagingInfo: pagingInfo, + }, nil +} + func (s *Store) migrate(ctx context.Context) error { - migrator := migrate.NewMigrator(s.db, mlsMigrations.Migrations) + migrator := migrate.NewMigrator(s.db, migrations.Migrations) err := migrator.Init(ctx) if err != nil { return err @@ -189,7 +380,7 @@ const ( type IdentityUpdate struct { Kind IdentityUpdateKind - InstallationId []byte + InstallationKey []byte CredentialIdentity []byte TimestampNs uint64 } @@ -208,3 +399,20 @@ func (a IdentityUpdateList) Swap(i, j int) { func (a IdentityUpdateList) Less(i, j int) bool { return a[i].TimestampNs < a[j].TimestampNs } + +type AlreadyExistsError struct { + Err error +} + +func (e *AlreadyExistsError) Error() string { + return e.Err.Error() +} + +func NewAlreadyExistsError(err error) *AlreadyExistsError { + return &AlreadyExistsError{err} +} + +func IsAlreadyExistsError(err error) bool { + _, ok := err.(*AlreadyExistsError) + return ok +} diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 703cf9e8..569b09c0 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -4,8 +4,10 @@ import ( "context" "sort" "testing" + "time" "github.com/stretchr/testify/require" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" ) @@ -104,9 +106,9 @@ func TestGetIdentityUpdates(t *testing.T) { identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) require.NoError(t, err) require.Len(t, identityUpdates[walletAddress], 2) - require.Equal(t, identityUpdates[walletAddress][0].InstallationId, installationId1) + require.Equal(t, identityUpdates[walletAddress][0].InstallationKey, installationId1) require.Equal(t, identityUpdates[walletAddress][0].Kind, Create) - require.Equal(t, identityUpdates[walletAddress][1].InstallationId, installationId2) + require.Equal(t, identityUpdates[walletAddress][1].InstallationKey, installationId2) // Make sure that date filtering works identityUpdates, err = store.GetIdentityUpdates(ctx, []string{walletAddress}, nowNs()+1000000) @@ -171,3 +173,421 @@ func TestIdentityUpdateSort(t *testing.T) { require.Equal(t, updates[1].TimestampNs, uint64(2)) require.Equal(t, updates[2].TimestampNs, uint64(3)) } + +func TestInsertGroupMessage_Single(t *testing.T) { + started := time.Now().UTC().Add(-time.Minute) + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + msg, err := store.InsertGroupMessage(ctx, []byte("group"), []byte("data")) + require.NoError(t, err) + require.NotNil(t, msg) + require.Equal(t, uint64(1), msg.Id) + require.True(t, msg.CreatedAt.Before(time.Now().UTC()) && msg.CreatedAt.After(started)) + require.Equal(t, []byte("group"), msg.GroupId) + require.Equal(t, []byte("data"), msg.Data) + + msgs := make([]*GroupMessage, 0) + err = store.db.NewSelect().Model(&msgs).Scan(ctx) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, msg, msgs[0]) +} + +func TestInsertGroupMessage_Duplicate(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + msg, err := store.InsertGroupMessage(ctx, []byte("group"), []byte("data")) + require.NoError(t, err) + require.NotNil(t, msg) + + msg, err = store.InsertGroupMessage(ctx, []byte("group"), []byte("data")) + require.Nil(t, msg) + require.IsType(t, &AlreadyExistsError{}, err) + require.True(t, IsAlreadyExistsError(err)) +} + +func TestInsertGroupMessage_ManyOrderedByTime(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertGroupMessage(ctx, []byte("group"), []byte("data1")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group"), []byte("data2")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group"), []byte("data3")) + require.NoError(t, err) + + msgs := make([]*GroupMessage, 0) + err = store.db.NewSelect().Model(&msgs).Order("created_at DESC").Scan(ctx) + require.NoError(t, err) + require.Len(t, msgs, 3) + require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data2"), msgs[1].Data) + require.Equal(t, []byte("data1"), msgs[2].Data) +} + +func TestInsertWelcomeMessage_Single(t *testing.T) { + started := time.Now().UTC().Add(-time.Minute) + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + msg, err := store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data")) + require.NoError(t, err) + require.NotNil(t, msg) + require.Equal(t, uint64(1), msg.Id) + require.True(t, msg.CreatedAt.Before(time.Now().UTC()) && msg.CreatedAt.After(started)) + require.Equal(t, []byte("installation"), msg.InstallationKey) + require.Equal(t, []byte("data"), msg.Data) + + msgs := make([]*WelcomeMessage, 0) + err = store.db.NewSelect().Model(&msgs).Scan(ctx) + require.NoError(t, err) + require.Len(t, msgs, 1) + require.Equal(t, msg, msgs[0]) +} + +func TestInsertWelcomeMessage_Duplicate(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + msg, err := store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data")) + require.NoError(t, err) + require.NotNil(t, msg) + + msg, err = store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data")) + require.Nil(t, msg) + require.IsType(t, &AlreadyExistsError{}, err) + require.True(t, IsAlreadyExistsError(err)) +} + +func TestInsertWelcomeMessage_ManyOrderedByTime(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data1")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data2")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data3")) + require.NoError(t, err) + + msgs := make([]*WelcomeMessage, 0) + err = store.db.NewSelect().Model(&msgs).Order("created_at DESC").Scan(ctx) + require.NoError(t, err) + require.Len(t, msgs, 3) + require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data2"), msgs[1].Data) + require.Equal(t, []byte("data1"), msgs[2].Data) +} + +func TestQueryGroupMessagesV1_MissingGroup(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + + resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{}) + require.EqualError(t, err, "group is required") + require.Nil(t, resp) +} + +func TestQueryWelcomeMessagesV1_MissingInstallation(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + + resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{}) + require.EqualError(t, err, "installation is required") + require.Nil(t, resp) +} + +func TestQueryGroupMessagesV1_Filter(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertGroupMessage(ctx, []byte("group1"), []byte("data1")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group2"), []byte("data2")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("data3")) + require.NoError(t, err) + + resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("unknown"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 0) + + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("data3"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("data1"), resp.Messages[1].GetV1().Data) + + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group2"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, []byte("data2"), resp.Messages[0].GetV1().Data) + + // Sort ascending + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("data1"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("data3"), resp.Messages[1].GetV1().Data) +} + +func TestQueryWelcomeMessagesV1_Filter(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("data1")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation2"), []byte("data2")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("data3")) + require.NoError(t, err) + + resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("unknown"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 0) + + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("data3"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("data1"), resp.Messages[1].GetV1().Data) + + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation2"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, []byte("data2"), resp.Messages[0].GetV1().Data) + + // Sort ascending + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: &mlsv1.PagingInfo{ + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("data1"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("data3"), resp.Messages[1].GetV1().Data) +} + +func TestQueryGroupMessagesV1_Paginate(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertGroupMessage(ctx, []byte("group1"), []byte("content1")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group2"), []byte("content2")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content3")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group2"), []byte("content4")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content5")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content6")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content7")) + require.NoError(t, err) + _, err = store.InsertGroupMessage(ctx, []byte("group1"), []byte("content8")) + require.NoError(t, err) + + resp, err := store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 6) + require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + require.Equal(t, []byte("content6"), resp.Messages[2].GetV1().Data) + require.Equal(t, []byte("content5"), resp.Messages[3].GetV1().Data) + require.Equal(t, []byte("content3"), resp.Messages[4].GetV1().Data) + require.Equal(t, []byte("content1"), resp.Messages[5].GetV1().Data) + + thirdMsg := resp.Messages[2] + fifthMsg := resp.Messages[4] + + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + + // Order descending by default + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + IdCursor: thirdMsg.GetV1().Id, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content3"), resp.Messages[1].GetV1().Data) + + // Next page from previous response + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: resp.PagingInfo, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, []byte("content1"), resp.Messages[0].GetV1().Data) + + // Order ascending + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + IdCursor: fifthMsg.GetV1().Id, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content6"), resp.Messages[1].GetV1().Data) + + // Next page from previous response + resp, err = store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: []byte("group1"), + PagingInfo: resp.PagingInfo, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content7"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content8"), resp.Messages[1].GetV1().Data) +} + +func TestQueryWelcomeMessagesV1_Paginate(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + _, err := store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content1")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation2"), []byte("content2")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content3")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation2"), []byte("content4")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content5")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content6")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content7")) + require.NoError(t, err) + _, err = store.InsertWelcomeMessage(ctx, []byte("installation1"), []byte("content8")) + require.NoError(t, err) + + resp, err := store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 6) + require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + require.Equal(t, []byte("content6"), resp.Messages[2].GetV1().Data) + require.Equal(t, []byte("content5"), resp.Messages[3].GetV1().Data) + require.Equal(t, []byte("content3"), resp.Messages[4].GetV1().Data) + require.Equal(t, []byte("content1"), resp.Messages[5].GetV1().Data) + + thirdMsg := resp.Messages[2] + fifthMsg := resp.Messages[4] + + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content8"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content7"), resp.Messages[1].GetV1().Data) + + // Order descending by default + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + IdCursor: thirdMsg.GetV1().Id, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content3"), resp.Messages[1].GetV1().Data) + + // Next page from previous response + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: resp.PagingInfo, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, []byte("content1"), resp.Messages[0].GetV1().Data) + + // Order ascending + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: &mlsv1.PagingInfo{ + Limit: 2, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + IdCursor: fifthMsg.GetV1().Id, + }, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content5"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content6"), resp.Messages[1].GetV1().Data) + + // Next page from previous response + resp, err = store.QueryWelcomeMessagesV1(ctx, &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: []byte("installation1"), + PagingInfo: resp.PagingInfo, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 2) + require.Equal(t, []byte("content7"), resp.Messages[0].GetV1().Data) + require.Equal(t, []byte("content8"), resp.Messages[1].GetV1().Data) +} diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index 0391abb1..ed9f5b11 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" svc "github.com/xmtp/proto/v3/go/mls_validation/v1" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -11,7 +12,7 @@ import ( type IdentityValidationResult struct { AccountAddress string - InstallationId []byte + InstallationKey []byte CredentialIdentity []byte Expiration uint64 } @@ -27,7 +28,7 @@ type IdentityInput struct { type MLSValidationService interface { ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error) - ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]GroupMessageValidationResult, error) + ValidateGroupMessages(ctx context.Context, groupMessages []*mlsv1.GroupMessageInput) ([]GroupMessageValidationResult, error) } type MLSValidationServiceImpl struct { @@ -57,7 +58,7 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages(ctx context.Context, keyP } out[i] = IdentityValidationResult{ AccountAddress: response.AccountAddress, - InstallationId: response.InstallationId, + InstallationKey: response.InstallationId, CredentialIdentity: response.CredentialIdentityBytes, Expiration: response.Expiration, } @@ -77,7 +78,7 @@ func makeValidateKeyPackageRequest(keyPackageBytes [][]byte) *svc.ValidateKeyPac } } -func (s *MLSValidationServiceImpl) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]GroupMessageValidationResult, error) { +func (s *MLSValidationServiceImpl) ValidateGroupMessages(ctx context.Context, groupMessages []*mlsv1.GroupMessageInput) ([]GroupMessageValidationResult, error) { req := makeValidateGroupMessagesRequest(groupMessages) response, err := s.grpcClient.ValidateGroupMessages(ctx, req) @@ -98,11 +99,11 @@ func (s *MLSValidationServiceImpl) ValidateGroupMessages(ctx context.Context, gr return out, nil } -func makeValidateGroupMessagesRequest(groupMessages [][]byte) *svc.ValidateGroupMessagesRequest { +func makeValidateGroupMessagesRequest(groupMessages []*mlsv1.GroupMessageInput) *svc.ValidateGroupMessagesRequest { groupMessageRequests := make([]*svc.ValidateGroupMessagesRequest_GroupMessage, len(groupMessages)) for i, groupMessage := range groupMessages { groupMessageRequests[i] = &svc.ValidateGroupMessagesRequest_GroupMessage{ - GroupMessageBytesTlsSerialized: groupMessage, + GroupMessageBytesTlsSerialized: groupMessage.GetV1().Data, } } return &svc.ValidateGroupMessagesRequest{ diff --git a/pkg/mlsvalidate/service_test.go b/pkg/mlsvalidate/service_test.go index 33b5cb95..6e0a7332 100644 --- a/pkg/mlsvalidate/service_test.go +++ b/pkg/mlsvalidate/service_test.go @@ -56,7 +56,7 @@ func TestValidateKeyPackages(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, "0x123", res[0].AccountAddress) - assert.Equal(t, []byte("123"), res[0].InstallationId) + assert.Equal(t, []byte("123"), res[0].InstallationKey) assert.Equal(t, []byte("456"), res[0].CredentialIdentity) } diff --git a/pkg/store/query_test.go b/pkg/store/query_test.go index d3c2c028..5d084e4a 100644 --- a/pkg/store/query_test.go +++ b/pkg/store/query_test.go @@ -1,7 +1,6 @@ package store import ( - "context" "testing" "time" @@ -285,33 +284,3 @@ func TestPageSizeOne(t *testing.T) { loops++ } } - -func TestMlsMessagePublish(t *testing.T) { - store, cleanup, _ := createAndFillDb(t) - defer cleanup() - - message := []byte{1, 2, 3} - contentTopic := "foo" - ctx := context.Background() - - env, err := store.InsertMLSMessage(ctx, contentTopic, message) - require.NoError(t, err) - - require.Equal(t, env.ContentTopic, contentTopic) - require.Equal(t, env.Message, message) - - response, err := store.Query(&messagev1.QueryRequest{ - ContentTopics: []string{contentTopic}, - }) - require.NoError(t, err) - require.Len(t, response.Envelopes, 1) - require.Equal(t, response.Envelopes[0].Message, message) - require.Equal(t, response.Envelopes[0].ContentTopic, contentTopic) - require.NotNil(t, response.Envelopes[0].TimestampNs) - - parsedTime := time.Unix(0, int64(response.Envelopes[0].TimestampNs)) - // Sanity check to ensure that the timestamps are reasonable - require.True(t, time.Since(parsedTime) < 10*time.Second || time.Since(parsedTime) > -10*time.Second) - - require.Equal(t, env.TimestampNs, response.Envelopes[0].TimestampNs) -} diff --git a/pkg/store/store.go b/pkg/store/store.go index 4a309044..71e17d9a 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -2,7 +2,6 @@ package store import ( "context" - "fmt" "strings" "sync" "time" @@ -21,22 +20,6 @@ import ( const maxPageSize = 100 -const timestampGeneratorSql = `( - ( - EXTRACT( - EPOCH - FROM - clock_timestamp() - ) :: bigint * 1000000000 - ) + ( - EXTRACT( - MICROSECONDS - FROM - clock_timestamp() - ) :: bigint * 1000 - ) -)` - type Store struct { config *Config ctx context.Context @@ -144,61 +127,6 @@ func (s *Store) InsertMessage(env *messagev1.Envelope) (bool, error) { return stored, err } -func (s *Store) InsertMLSMessage(ctx context.Context, contentTopic string, data []byte) (*messagev1.Envelope, error) { - tmpEnvelope := &messagev1.Envelope{ - ContentTopic: contentTopic, - Message: data, - } - digest := computeDigest(tmpEnvelope) - var envelope messagev1.Envelope - - err := tracing.Wrap(s.ctx, s.log, "storing mls message", func(ctx context.Context, log *zap.Logger, span tracing.Span) error { - tracing.SpanResource(span, "store") - tracing.SpanType(span, "db") - - stmnt := fmt.Sprintf(`INSERT INTO - message ( - id, - receiverTimestamp, - senderTimestamp, - contentTopic, - pubsubTopic, - payload, - version, - should_expire - ) - VALUES - ( - $1, - %s, - %s, - $2, - $3, - $4, - $5, - $6 - ) RETURNING senderTimestamp`, timestampGeneratorSql, timestampGeneratorSql) - - var senderTimestamp uint64 - err := s.config.DB.QueryRowContext(ctx, stmnt, digest, contentTopic, "", data, 0, false).Scan(&senderTimestamp) - if err != nil { - return err - } - envelope = messagev1.Envelope{ - ContentTopic: contentTopic, - TimestampNs: senderTimestamp, - Message: data, - } - return err - }) - - if err != nil { - return nil, err - } - - return &envelope, nil -} - func (s *Store) insertMessage(env *messagev1.Envelope, receiverTimestamp int64) error { digest := computeDigest(env) shouldExpire := !isXMTP(env.ContentTopic) diff --git a/pkg/topic/mls.go b/pkg/topic/mls.go new file mode 100644 index 00000000..33f3ecbf --- /dev/null +++ b/pkg/topic/mls.go @@ -0,0 +1,28 @@ +package topic + +import ( + "fmt" + "strings" +) + +const mlsv1Prefix = "/xmtp/mls/1/" + +func IsMLSV1(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix) +} + +func IsMLSV1Group(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix+"g-") +} + +func IsMLSV1Welcome(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix+"w-") +} + +func BuildMLSV1GroupTopic(groupId []byte) string { + return fmt.Sprintf("%sg-%x/proto", mlsv1Prefix, groupId) +} + +func BuildMLSV1WelcomeTopic(installationId []byte) string { + return fmt.Sprintf("%sw-%x/proto", mlsv1Prefix, installationId) +} diff --git a/pkg/topic/topic.go b/pkg/topic/topic.go index 21b4f3f5..f2460045 100644 --- a/pkg/topic/topic.go +++ b/pkg/topic/topic.go @@ -1,7 +1,6 @@ package topic import ( - "fmt" "strings" ) @@ -38,11 +37,3 @@ func Category(contentTopic string) string { } return "invalid" } - -func BuildGroupTopic(groupId string) string { - return fmt.Sprintf("/xmtp/mls/1/g-%s/proto", groupId) -} - -func BuildWelcomeTopic(installationId []byte) string { - return fmt.Sprintf("/xmtp/mls/1/w-%x/proto", installationId) -} diff --git a/tools.go b/tools.go index 4f5979a8..5a194c14 100644 --- a/tools.go +++ b/tools.go @@ -4,7 +4,7 @@ package tools import ( - _ "github.com/golang/mock/mockgen" _ "github.com/yoheimuta/protolint/cmd/protolint" + _ "go.uber.org/mock/mockgen" _ "google.golang.org/protobuf/cmd/protoc-gen-go" )