diff --git a/dev/docker/env b/dev/docker/env index f4b6cead..c3440a8a 100755 --- a/dev/docker/env +++ b/dev/docker/env @@ -2,5 +2,5 @@ set -e function docker_compose() { - docker-compose -f dev/docker/docker-compose.yml -p xmtpd "$@" + docker compose -f dev/docker/docker-compose.yml -p xmtpd "$@" } diff --git a/dev/e2e/docker/env b/dev/e2e/docker/env index bb308135..5064cdbf 100755 --- a/dev/e2e/docker/env +++ b/dev/e2e/docker/env @@ -2,5 +2,5 @@ set -e function docker_compose() { - docker-compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@" + docker compose -f dev/e2e/docker/docker-compose.yml -p xmtpd-e2e "$@" } diff --git a/dev/lint b/dev/lint index 13ae7d66..a01906bf 100755 --- a/dev/lint +++ b/dev/lint @@ -7,6 +7,6 @@ if [[ $(gofmt -l .) ]]; then echo "gofmt errors, run 'gofmt -w .' and commit" fi -golangci-lint --config dev/.golangci.yaml run ./... --deadline=5m +golangci-lint --config dev/.golangci.yaml run ./... protolint . diff --git a/pkg/api/interceptor.go b/pkg/api/interceptor.go index f99c051d..f96012d3 100644 --- a/pkg/api/interceptor.go +++ b/pkg/api/interceptor.go @@ -96,17 +96,17 @@ func (wa *WalletAuthorizer) requiresAuthorization(req interface{}) bool { func (wa *WalletAuthorizer) getWallet(ctx context.Context) (types.WalletAddr, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return "", status.Errorf(codes.Unauthenticated, "metadata is not provided") + return "", status.Error(codes.Unauthenticated, "metadata is not provided") } values := md.Get(authorizationMetadataKey) if len(values) == 0 { - return "", status.Errorf(codes.Unauthenticated, "authorization token is not provided") + return "", status.Error(codes.Unauthenticated, "authorization token is not provided") } words := strings.SplitN(values[0], " ", 2) if len(words) != 2 { - return "", status.Errorf(codes.Unauthenticated, "invalid authorization header") + return "", status.Error(codes.Unauthenticated, "invalid authorization header") } if scheme := strings.TrimSpace(words[0]); scheme != "Bearer" { return "", status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme) @@ -127,14 +127,14 @@ func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}, wall if pub, isPublish := req.(*messagev1.PublishRequest); isPublish { for _, env := range pub.Envelopes { if !wa.privilegedAddresses[wallet] && !allowedToPublish(env.ContentTopic, wallet) { - return status.Errorf(codes.PermissionDenied, "publishing to restricted topic") + return status.Error(codes.PermissionDenied, "publishing to restricted topic") } } } if wa.AllowLists { if wa.AllowLister.IsDenyListed(wallet.String()) { wa.Log.Debug("wallet deny listed", logging.WalletAddress(wallet.String())) - return status.Errorf(codes.PermissionDenied, ErrDenyListed.Error()) + return status.Error(codes.PermissionDenied, ErrDenyListed.Error()) } } return nil @@ -185,7 +185,8 @@ func (wa *WalletAuthorizer) applyLimits(ctx context.Context, fullMethod string, logging.String("method", method), logging.String("limit", string(limitType)), logging.Int("cost", cost)) - return status.Errorf(codes.ResourceExhausted, err.Error()) + + return status.Error(codes.ResourceExhausted, err.Error()) } const ( diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index 0c6eacfc..fff37b8e 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -130,17 +130,17 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot log.Debug("received message") if len(env.ContentTopic) > MaxContentTopicNameSize { - return nil, status.Errorf(codes.InvalidArgument, "topic length too big") + return nil, status.Error(codes.InvalidArgument, "topic length too big") } if len(env.Message) > MaxMessageSize { - return nil, status.Errorf(codes.InvalidArgument, "message too big") + return nil, status.Error(codes.InvalidArgument, "message too big") } if !topic.IsEphemeral(env.ContentTopic) { _, err := s.store.InsertMessage(env) if err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, status.Error(codes.Internal, err.Error()) } } @@ -150,7 +150,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot Payload: env.Message, }) if err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, status.Error(codes.Internal, err.Error()) } metrics.EmitPublishedEnvelope(ctx, log, env) @@ -393,7 +393,7 @@ func (s *Service) BatchQuery(ctx context.Context, req *proto.BatchQueryRequest) // We execute the query using the existing Query API resp, err := s.Query(ctx, query) if err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) + return nil, status.Error(codes.Internal, err.Error()) } responses = append(responses, resp) } diff --git a/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.down.sql b/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.down.sql new file mode 100644 index 00000000..cc72f6bf --- /dev/null +++ b/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.down.sql @@ -0,0 +1,7 @@ +SET statement_timeout = 0; + +--bun:split +ALTER TABLE installations + ADD COLUMN inbox_id BYTEA NOT NULL, + ADD COLUMN expiration BIGINT NOT NULL; + diff --git a/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.up.sql b/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.up.sql new file mode 100644 index 00000000..9b12bee1 --- /dev/null +++ b/pkg/migrations/mls/20240814032323_remove-inbox-id-from-installation.up.sql @@ -0,0 +1,7 @@ +SET statement_timeout = 0; + +--bun:split +ALTER TABLE installations + DROP COLUMN IF EXISTS inbox_id, + DROP COLUMN IF EXISTS expiration; + diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 56c18c67..8dfafb3e 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -113,6 +113,11 @@ func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) er return nil } +/* +* +DEPRECATED: Use UploadKeyPackage instead +* +*/ func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) { if err := validateRegisterInstallationRequest(req); err != nil { return nil, err @@ -126,9 +131,9 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterI if len(results) != 1 { return nil, status.Errorf(codes.Internal, "unexpected number of results: %d", len(results)) } + installationKey := results[0].InstallationKey - credential := results[0].Credential - if err = s.store.CreateInstallation(ctx, installationKey, credential.InboxId, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { + if err = s.store.CreateOrUpdateInstallation(ctx, installationKey, req.KeyPackage.KeyPackageTlsSerialized); err != nil { return nil, err } return &mlsv1.RegisterInstallationResponse{ @@ -152,7 +157,7 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPacka idx, ok := keyPackageMap[string(installation.ID)] if !ok { - return nil, status.Errorf(codes.Internal, "could not find key package for installation") + return nil, status.Error(codes.Internal, "could not find key package for installation") } resPackages[idx] = &mlsv1.FetchKeyPackagesResponse_KeyPackage{ @@ -178,9 +183,8 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack } installationId := validationResults[0].InstallationKey - expiration := validationResults[0].Expiration - if err = s.store.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { + if err = s.store.CreateOrUpdateInstallation(ctx, installationId, keyPackageBytes); err != nil { return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) } @@ -188,11 +192,11 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPack } func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + return nil, status.Error(codes.Unimplemented, "unimplemented") } func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + return nil, status.Error(codes.Unimplemented, "unimplemented") } func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) { @@ -521,11 +525,11 @@ func buildNatsSubjectForWelcomeMessages(installationId []byte) string { func validateSendGroupMessagesRequest(req *mlsv1.SendGroupMessagesRequest) error { if req == nil || len(req.Messages) == 0 { - return status.Errorf(codes.InvalidArgument, "no group messages to send") + return status.Error(codes.InvalidArgument, "no group messages to send") } for _, input := range req.Messages { if input == nil || input.GetV1() == nil { - return status.Errorf(codes.InvalidArgument, "invalid group message") + return status.Error(codes.InvalidArgument, "invalid group message") } } return nil @@ -537,12 +541,12 @@ func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) e } for _, input := range req.Messages { if input == nil || input.GetV1() == nil { - return status.Errorf(codes.InvalidArgument, "invalid welcome message") + return status.Error(codes.InvalidArgument, "invalid welcome message") } v1 := input.GetV1() if len(v1.Data) == 0 || len(v1.InstallationKey) == 0 || len(v1.HpkePublicKey) == 0 { - return status.Errorf(codes.InvalidArgument, "invalid welcome message") + return status.Error(codes.InvalidArgument, "invalid welcome message") } } return nil @@ -550,24 +554,24 @@ func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) e func validateRegisterInstallationRequest(req *mlsv1.RegisterInstallationRequest) error { if req == nil || req.KeyPackage == nil { - return status.Errorf(codes.InvalidArgument, "no key package") + return status.Error(codes.InvalidArgument, "no key package") } return nil } func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error { if req == nil || req.KeyPackage == nil { - return status.Errorf(codes.InvalidArgument, "no key package") + return status.Error(codes.InvalidArgument, "no key package") } return nil } func requireReadyToSend(groupId string, message []byte) error { if len(groupId) == 0 { - return status.Errorf(codes.InvalidArgument, "group id is empty") + return status.Error(codes.InvalidArgument, "group id is empty") } if len(message) == 0 { - return status.Errorf(codes.InvalidArgument, "message is empty") + return status.Error(codes.InvalidArgument, "message is empty") } return nil } diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index 479bdd78..a02615ad 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -81,13 +81,13 @@ func TestRegisterInstallation(t *testing.T) { defer cleanup() installationId := test.RandomBytes(32) - inboxId := test.RandomInboxId() + keyPackage := []byte("test") - mockValidateInboxIdKeyPackages(mlsValidationService, installationId, inboxId) + mockValidateInboxIdKeyPackages(mlsValidationService, installationId, test.RandomInboxId()) res, err := svc.RegisterInstallation(ctx, &mlsv1.RegisterInstallationRequest{ KeyPackage: &mlsv1.KeyPackageUpload{ - KeyPackageTlsSerialized: []byte("test"), + KeyPackageTlsSerialized: keyPackage, }, IsInboxIdCredential: false, }) @@ -98,7 +98,8 @@ func TestRegisterInstallation(t *testing.T) { installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId) require.NoError(t, err) - require.Equal(t, inboxId, installation.InboxID) + require.Equal(t, installationId, installation.ID) + require.Equal(t, []byte("test"), installation.KeyPackage) } func TestRegisterInstallationError(t *testing.T) { diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 561c69ae..d54a7e45 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -83,33 +83,24 @@ WHERE (address, inbox_id, association_sequence_id) =( address, inbox_id); --- name: CreateInstallation :exec -INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration) - VALUES (@id, @created_at, @updated_at, decode(@inbox_id, 'hex'), @key_package, @expiration); +-- name: CreateOrUpdateInstallation :exec +INSERT INTO installations(id, created_at, updated_at, key_package) + VALUES (@id, @created_at, @updated_at, @key_package) +ON CONFLICT (id) + DO UPDATE SET + key_package = @key_package, updated_at = @updated_at; -- name: GetInstallation :one SELECT id, created_at, updated_at, - encode(inbox_id, 'hex') AS inbox_id, - key_package, - expiration + key_package FROM installations WHERE id = $1; --- name: UpdateKeyPackage :execrows -UPDATE - installations -SET - key_package = @key_package, - updated_at = @updated_at, - expiration = @expiration -WHERE - id = @id; - -- name: FetchKeyPackages :many SELECT id, diff --git a/pkg/mls/store/queries/models.go b/pkg/mls/store/queries/models.go index 462220c6..3943a55a 100644 --- a/pkg/mls/store/queries/models.go +++ b/pkg/mls/store/queries/models.go @@ -35,9 +35,7 @@ type Installation struct { ID []byte CreatedAt int64 UpdatedAt int64 - InboxID []byte KeyPackage []byte - Expiration int64 } type WelcomeMessage struct { diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index 33428a08..90b51b14 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -13,28 +13,27 @@ import ( "github.com/lib/pq" ) -const createInstallation = `-- name: CreateInstallation :exec -INSERT INTO installations(id, created_at, updated_at, inbox_id, key_package, expiration) - VALUES ($1, $2, $3, decode($4, 'hex'), $5, $6) +const createOrUpdateInstallation = `-- name: CreateOrUpdateInstallation :exec +INSERT INTO installations(id, created_at, updated_at, key_package) + VALUES ($1, $2, $3, $4) +ON CONFLICT (id) + DO UPDATE SET + key_package = $4, updated_at = $3 ` -type CreateInstallationParams struct { +type CreateOrUpdateInstallationParams struct { ID []byte CreatedAt int64 UpdatedAt int64 - InboxID string KeyPackage []byte - Expiration int64 } -func (q *Queries) CreateInstallation(ctx context.Context, arg CreateInstallationParams) error { - _, err := q.db.ExecContext(ctx, createInstallation, +func (q *Queries) CreateOrUpdateInstallation(ctx context.Context, arg CreateOrUpdateInstallationParams) error { + _, err := q.db.ExecContext(ctx, createOrUpdateInstallation, arg.ID, arg.CreatedAt, arg.UpdatedAt, - arg.InboxID, arg.KeyPackage, - arg.Expiration, ) return err } @@ -305,34 +304,21 @@ SELECT id, created_at, updated_at, - encode(inbox_id, 'hex') AS inbox_id, - key_package, - expiration + key_package FROM installations WHERE id = $1 ` -type GetInstallationRow struct { - ID []byte - CreatedAt int64 - UpdatedAt int64 - InboxID string - KeyPackage []byte - Expiration int64 -} - -func (q *Queries) GetInstallation(ctx context.Context, id []byte) (GetInstallationRow, error) { +func (q *Queries) GetInstallation(ctx context.Context, id []byte) (Installation, error) { row := q.db.QueryRowContext(ctx, getInstallation, id) - var i GetInstallationRow + var i Installation err := row.Scan( &i.ID, &i.CreatedAt, &i.UpdatedAt, - &i.InboxID, &i.KeyPackage, - &i.Expiration, ) return i, err } @@ -786,34 +772,3 @@ func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFro _, err := q.db.ExecContext(ctx, revokeAddressFromLog, arg.RevocationSequenceID, arg.Address, arg.InboxID) return err } - -const updateKeyPackage = `-- name: UpdateKeyPackage :execrows -UPDATE - installations -SET - key_package = $1, - updated_at = $2, - expiration = $3 -WHERE - id = $4 -` - -type UpdateKeyPackageParams struct { - KeyPackage []byte - UpdatedAt int64 - Expiration int64 - ID []byte -} - -func (q *Queries) UpdateKeyPackage(ctx context.Context, arg UpdateKeyPackageParams) (int64, error) { - result, err := q.db.ExecContext(ctx, updateKeyPackage, - arg.KeyPackage, - arg.UpdatedAt, - arg.Expiration, - arg.ID, - ) - if err != nil { - return 0, err - } - return result.RowsAffected() -} diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 0aa62af4..75cfec87 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -39,8 +39,7 @@ type IdentityStore interface { type MlsStore interface { IdentityStore - CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error - UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error + CreateOrUpdateInstallation(ctx context.Context, installationId []byte, keyPackage []byte) error FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*queries.WelcomeMessage, error) @@ -246,38 +245,17 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent } // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, inboxId string, keyPackage []byte, expiration uint64) error { - createdAt := nowNs() +func (s *Store) CreateOrUpdateInstallation(ctx context.Context, installationId []byte, keyPackage []byte) error { + now := nowNs() - return s.queries.CreateInstallation(ctx, queries.CreateInstallationParams{ + return s.queries.CreateOrUpdateInstallation(ctx, queries.CreateOrUpdateInstallationParams{ ID: installationId, - CreatedAt: createdAt, - InboxID: inboxId, + CreatedAt: now, + UpdatedAt: now, KeyPackage: keyPackage, - Expiration: int64(expiration), }) } -// Insert a new key package, ignoring any that may already exist -func (s *Store) UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error { - rowsUpdated, err := s.queries.UpdateKeyPackage(ctx, queries.UpdateKeyPackageParams{ - ID: installationId, - UpdatedAt: nowNs(), - KeyPackage: keyPackage, - Expiration: int64(expiration), - }) - - if err != nil { - return err - } - - if rowsUpdated == 0 { - return errors.New("installation id unknown") - } - - return nil -} - func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) { return s.queries.FetchKeyPackages(ctx, installationIds) } diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 1ff9f73d..8cc32ac2 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -169,9 +169,8 @@ func TestCreateInstallation(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) - inboxId := test.RandomInboxId() - err := store.CreateInstallation(ctx, installationId, inboxId, test.RandomBytes(32), 0) + err := store.CreateOrUpdateInstallation(ctx, installationId, test.RandomBytes(32)) require.NoError(t, err) installationFromDb, err := store.queries.GetInstallation(ctx, installationId) @@ -185,21 +184,23 @@ func TestUpdateKeyPackage(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) - inboxId := test.RandomInboxId() keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0) + err := store.CreateOrUpdateInstallation(ctx, installationId, keyPackage) + require.NoError(t, err) + afterCreate, err := store.queries.GetInstallation(ctx, installationId) require.NoError(t, err) keyPackage2 := test.RandomBytes(32) - err = store.UpdateKeyPackage(ctx, installationId, keyPackage2, 1) + err = store.CreateOrUpdateInstallation(ctx, installationId, keyPackage2) require.NoError(t, err) installationFromDb, err := store.queries.GetInstallation(ctx, installationId) require.NoError(t, err) require.Equal(t, keyPackage2, installationFromDb.KeyPackage) - require.Equal(t, int64(1), installationFromDb.Expiration) + require.Greater(t, installationFromDb.UpdatedAt, afterCreate.UpdatedAt) + require.Equal(t, installationFromDb.CreatedAt, afterCreate.CreatedAt) } func TestConsumeLastResortKeyPackage(t *testing.T) { @@ -209,9 +210,8 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { ctx := context.Background() installationId := test.RandomBytes(32) keyPackage := test.RandomBytes(32) - inboxId := test.RandomInboxId() - err := store.CreateInstallation(ctx, installationId, inboxId, keyPackage, 0) + err := store.CreateOrUpdateInstallation(ctx, installationId, keyPackage) require.NoError(t, err) fetchResult, err := store.FetchKeyPackages(ctx, [][]byte{installationId}) diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index da16820a..af03ba84 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -87,11 +87,6 @@ func (s *MLSValidationServiceImpl) GetAssociationState(ctx context.Context, oldU }, nil } -type KeyAndExpiration struct { - InstallationId []byte - Expiration uint64 -} - func (s *MLSValidationServiceImpl) ValidateInboxIdKeyPackages(ctx context.Context, keyPackages [][]byte) ([]InboxIdValidationResult, error) { req := makeValidateKeyPackageRequest(keyPackages, true) @@ -100,69 +95,20 @@ func (s *MLSValidationServiceImpl) ValidateInboxIdKeyPackages(ctx context.Contex return nil, err } - keyPackageCredential := make(map[string]KeyAndExpiration, len(response.Responses)) - identityUpdatesRequest := make([]*identity.GetIdentityUpdatesRequest_Request, len(response.Responses)) - for i, response := range response.Responses { - if !response.IsOk { - return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) - } - keyPackageCredential[response.Credential.InboxId] = KeyAndExpiration{ - InstallationId: response.InstallationPublicKey, - Expiration: response.Expiration, - } - identityUpdatesRequest[i] = &identity.GetIdentityUpdatesRequest_Request{ - InboxId: response.Credential.InboxId, - SequenceId: 0, - } - } - - // TODO: do we need to take sequence ID Into account? - request := &identity.GetIdentityUpdatesRequest{Requests: identityUpdatesRequest} - identity_updates, err := s.identityStore.GetInboxLogs(ctx, request) - if err != nil { - return nil, err - } - - validation_requests := make([]*svc.ValidateInboxIdsRequest_ValidationRequest, len(identity_updates.Responses)) - for i, response := range identity_updates.Responses { - validation_requests[i] = makeValidationRequest(response, keyPackageCredential) - } - - validation_request := svc.ValidateInboxIdsRequest{Requests: validation_requests} - validate_inbox_response, err := s.grpcClient.ValidateInboxIds(ctx, &validation_request) - if err != nil { - return nil, err - } - out := make([]InboxIdValidationResult, len(response.Responses)) - for i, response := range validate_inbox_response.Responses { + for i, response := range response.Responses { if !response.IsOk { return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) } out[i] = InboxIdValidationResult{ - InstallationKey: keyPackageCredential[response.InboxId].InstallationId, - Credential: &identity_proto.MlsCredential{InboxId: response.InboxId}, - Expiration: keyPackageCredential[response.InboxId].Expiration, + InstallationKey: response.InstallationPublicKey, + Credential: nil, + Expiration: response.Expiration, } } return out, nil } -func makeValidationRequest(update *identity.GetIdentityUpdatesResponse_Response, pub_keys map[string]KeyAndExpiration) *svc.ValidateInboxIdsRequest_ValidationRequest { - identity_updates := make([]*associations.IdentityUpdate, len(update.Updates)) - for i, identity_log := range update.Updates { - identity_updates[i] = identity_log.Update - } - - out := svc.ValidateInboxIdsRequest_ValidationRequest{ - Credential: &identity_proto.MlsCredential{InboxId: update.InboxId}, - InstallationPublicKey: pub_keys[update.InboxId].InstallationId, - IdentityUpdates: identity_updates, - } - - return &out -} - func (s *MLSValidationServiceImpl) ValidateV3KeyPackages(ctx context.Context, keyPackages [][]byte) ([]IdentityValidationResult, error) { req := makeValidateKeyPackageRequest(keyPackages, false) diff --git a/pkg/mlsvalidate/service_test.go b/pkg/mlsvalidate/service_test.go index b0bf06ca..15bcbcfa 100644 --- a/pkg/mlsvalidate/service_test.go +++ b/pkg/mlsvalidate/service_test.go @@ -76,6 +76,37 @@ func TestValidateKeyPackages(t *testing.T) { assert.Equal(t, []byte("456"), res[0].CredentialIdentity) } -func TestValidateKeyPackagesError(t *testing.T) { +func TestValidateInboxIdKeyPackages(t *testing.T) { + mockGrpc, service := getMockedService() + + ctx := context.Background() + installationKey := []byte("key") + firstResponse := svc.ValidateInboxIdKeyPackagesResponse_Response{ + IsOk: true, + Credential: nil, + InstallationPublicKey: installationKey, + ErrorMessage: "", + } + mockGrpc.On("ValidateInboxIdKeyPackages", ctx, mock.Anything).Return(&svc.ValidateInboxIdKeyPackagesResponse{Responses: []*svc.ValidateInboxIdKeyPackagesResponse_Response{&firstResponse}}, nil) + + res, err := service.ValidateInboxIdKeyPackages(ctx, [][]byte{[]byte("123")}) + assert.NoError(t, err) + assert.Equal(t, res[0].InstallationKey, installationKey) +} + +func TestValidateInboxIdKeyPackagesError(t *testing.T) { + mockGrpc, service := getMockedService() + + ctx := context.Background() + firstResponse := svc.ValidateInboxIdKeyPackagesResponse_Response{ + IsOk: false, + Credential: nil, + InstallationPublicKey: []byte("foo"), + ErrorMessage: "DERP", + } + mockGrpc.On("ValidateInboxIdKeyPackages", ctx, mock.Anything).Return(&svc.ValidateInboxIdKeyPackagesResponse{Responses: []*svc.ValidateInboxIdKeyPackagesResponse_Response{&firstResponse}}, nil) + res, err := service.ValidateInboxIdKeyPackages(ctx, [][]byte{[]byte("123")}) + assert.Error(t, err) + assert.Nil(t, res) }