Skip to content

Commit

Permalink
Assume KeyPackages are always last resort. (#329)
Browse files Browse the repository at this point in the history
* Assume KeyPackages are always last resort.

* Missing KeyPackages don't cause entire request to fail.

* Update go.mod

* UpdateKeyPackage returns an error if installation is unknown.
  • Loading branch information
Bren2010 authored Jan 9, 2024
1 parent 0853311 commit 0210936
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 276 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ require (
github.com/uptrace/bun/driver/pgdriver v1.1.16
github.com/waku-org/go-waku v0.8.0
github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3
github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090
github.com/xmtp/proto/v3 v3.36.1
github.com/yoheimuta/protolint v0.39.0
go.uber.org/zap v1.24.0
golang.org/x/sync v0.3.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,8 @@ github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg=
github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 h1:wzUffJGCTBGXIDyNU+1UBu1fn2Nzo+OQzM1pLrheh58=
github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3/go.mod h1:bJREWk+NDnZYjgLQdAi8SUWuq/5pkMme4GqiffEhUF4=
github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090 h1:+0KTgQiUfu5UxgLjP18VL4BtG6hJMJYL0n1mVXtf3Ss=
github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xmtp/proto/v3 v3.36.1 h1:eBUsWlA/jbfhxfbDtbAofBpi8Q+TIXPMz84e64o1XTE=
github.com/xmtp/proto/v3 v3.36.1/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw=
github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8=
Expand Down
54 changes: 24 additions & 30 deletions pkg/api/message/v3/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI
return nil, err
}

results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.LastResortKeyPackage.KeyPackageTlsSerialized})
results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized})
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err)
}
Expand All @@ -57,7 +57,7 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI
accountAddress := results[0].AccountAddress
credentialIdentity := results[0].CredentialIdentity

if err = s.mlsStore.CreateInstallation(ctx, installationId, accountAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized, credentialIdentity); err != nil {
if err = s.mlsStore.CreateInstallation(ctx, installationId, accountAddress, credentialIdentity, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil {
return nil, err
}

Expand All @@ -66,31 +66,31 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI
}, nil
}

func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) {
func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPackagesRequest) (*proto.FetchKeyPackagesResponse, error) {
ids := req.InstallationIds
keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids)
installations, err := s.mlsStore.FetchKeyPackages(ctx, ids)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err)
return nil, status.Errorf(codes.Internal, "failed to fetch key packages: %s", err)
}
keyPackageMap := make(map[string]int)
for idx, id := range ids {
keyPackageMap[string(id)] = idx
}

resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages))
for _, keyPackage := range keyPackages {
resPackages := make([]*proto.FetchKeyPackagesResponse_KeyPackage, len(ids))
for _, installation := range installations {

idx, ok := keyPackageMap[string(keyPackage.InstallationId)]
idx, ok := keyPackageMap[string(installation.ID)]
if !ok {
return nil, status.Errorf(codes.Internal, "could not find key package for installation")
}

resPackages[idx] = &proto.ConsumeKeyPackagesResponse_KeyPackage{
KeyPackageTlsSerialized: keyPackage.Data,
resPackages[idx] = &proto.FetchKeyPackagesResponse_KeyPackage{
KeyPackageTlsSerialized: installation.KeyPackage,
}
}

return &proto.ConsumeKeyPackagesResponse{
return &proto.FetchKeyPackagesResponse{
KeyPackages: resPackages,
}, nil
}
Expand Down Expand Up @@ -166,28 +166,22 @@ func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcome
return &emptypb.Empty{}, nil
}

func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (res *emptypb.Empty, err error) {
if err = validateUploadKeyPackagesRequest(req); err != nil {
func (s *Service) UploadKeyPackage(ctx context.Context, req *proto.UploadKeyPackageRequest) (res *emptypb.Empty, err error) {
if err = validateUploadKeyPackageRequest(req); err != nil {
return nil, err
}
// Extract the key packages from the request
keyPackageBytes := make([][]byte, len(req.KeyPackages))
for i, keyPackage := range req.KeyPackages {
keyPackageBytes[i] = keyPackage.KeyPackageTlsSerialized
}
validationResults, err := s.validationService.ValidateKeyPackages(ctx, keyPackageBytes)
keyPackageBytes := req.KeyPackage.KeyPackageTlsSerialized

validationResults, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{keyPackageBytes})
if err != nil {
// TODO: Differentiate between validation errors and internal errors
return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err)
}
installationId := validationResults[0].InstallationId
expiration := validationResults[0].Expiration

keyPackageModels := make([]*mlsstore.KeyPackage, len(validationResults))
for i, validationResult := range validationResults {
kp := mlsstore.NewKeyPackage(validationResult.InstallationId, keyPackageBytes[i], false)
keyPackageModels[i] = kp
}

if err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels); err != nil {
if err = s.mlsStore.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil {
return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err)
}

Expand Down Expand Up @@ -275,15 +269,15 @@ func validatePublishWelcomesRequest(req *proto.PublishWelcomesRequest) error {
}

func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error {
if req == nil || req.LastResortKeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no last resort key package")
if req == nil || req.KeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no key package")
}
return nil
}

func validateUploadKeyPackagesRequest(req *proto.UploadKeyPackagesRequest) error {
if req == nil || len(req.KeyPackages) == 0 {
return status.Errorf(codes.InvalidArgument, "no key packages to upload")
func validateUploadKeyPackageRequest(req *proto.UploadKeyPackageRequest) error {
if req == nil || req.KeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no key package")
}
return nil
}
Expand Down
44 changes: 22 additions & 22 deletions pkg/api/message/v3/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []by
InstallationId: installationId,
AccountAddress: accountAddress,
CredentialIdentity: []byte("test"),
Expiration: 0,
},
}, nil)
}
Expand Down Expand Up @@ -102,7 +103,7 @@ func TestRegisterInstallation(t *testing.T) {
mlsValidationService.mockValidateKeyPackages(installationId, accountAddress)

res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
Expand All @@ -126,15 +127,15 @@ func TestRegisterInstallationError(t *testing.T) {
mlsValidationService.On("ValidateKeyPackages", ctx, mock.Anything).Return(nil, errors.New("error validating"))

res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
require.Error(t, err)
require.Nil(t, res)
}

func TestUploadKeyPackages(t *testing.T) {
func TestUploadKeyPackage(t *testing.T) {
ctx := context.Background()
svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()
Expand All @@ -145,28 +146,27 @@ func TestUploadKeyPackages(t *testing.T) {
mlsValidationService.mockValidateKeyPackages(installationId, accountAddress)

res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
require.NoError(t, err)
require.NotNil(t, res)

uploadRes, err := svc.UploadKeyPackages(ctx, &proto.UploadKeyPackagesRequest{
KeyPackages: []*proto.KeyPackageUpload{
{KeyPackageTlsSerialized: []byte("test2")},
uploadRes, err := svc.UploadKeyPackage(ctx, &proto.UploadKeyPackageRequest{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test2"),
},
})
require.NoError(t, err)
require.NotNil(t, uploadRes)

keyPackages := []mlsstore.KeyPackage{}
err = mlsDb.NewSelect().Model(&keyPackages).Where("installation_id = ?", installationId).Scan(ctx)
installation := &mlsstore.Installation{}
err = mlsDb.NewSelect().Model(installation).Where("id = ?", installationId).Scan(ctx)
require.NoError(t, err)
require.Len(t, keyPackages, 2)
}

func TestConsumeKeyPackages(t *testing.T) {
func TestFetchKeyPackages(t *testing.T) {
ctx := context.Background()
svc, _, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()
Expand All @@ -177,7 +177,7 @@ func TestConsumeKeyPackages(t *testing.T) {
mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, accountAddress1)

res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
Expand All @@ -192,14 +192,14 @@ func TestConsumeKeyPackages(t *testing.T) {
mlsValidationService.mockValidateKeyPackages(installationId2, accountAddress2)

res, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test2"),
},
})
require.NoError(t, err)
require.NotNil(t, res)

consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{
InstallationIds: [][]byte{installationId1, installationId2},
})
require.NoError(t, err)
Expand All @@ -209,7 +209,7 @@ func TestConsumeKeyPackages(t *testing.T) {
require.Equal(t, []byte("test2"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized)

// Now do it with the installationIds reversed
consumeRes, err = svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
consumeRes, err = svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{
InstallationIds: [][]byte{installationId2, installationId1},
})

Expand All @@ -220,17 +220,17 @@ func TestConsumeKeyPackages(t *testing.T) {
require.Equal(t, []byte("test"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized)
}

// Trying to consume key packages that don't exist should fail
func TestConsumeKeyPackagesFail(t *testing.T) {
// Trying to fetch key packages that don't exist should return nil
func TestFetchKeyPackagesFail(t *testing.T) {
ctx := context.Background()
svc, _, _, cleanup := newTestService(t, ctx)
defer cleanup()

consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{
InstallationIds: [][]byte{test.RandomBytes(32)},
})
require.Error(t, err)
require.Nil(t, consumeRes)
require.Nil(t, err)
require.Equal(t, []*proto.FetchKeyPackagesResponse_KeyPackage{nil}, consumeRes.KeyPackages)
}

func TestPublishToGroup(t *testing.T) {
Expand Down Expand Up @@ -273,7 +273,7 @@ func TestGetIdentityUpdates(t *testing.T) {
mockCall := mlsValidationService.mockValidateKeyPackages(installationId, accountAddress)

_, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
Expand All @@ -297,7 +297,7 @@ func TestGetIdentityUpdates(t *testing.T) {
mockCall.Unset()
mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), accountAddress)
_, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
},
})
Expand Down
6 changes: 5 additions & 1 deletion pkg/authn/authn.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pkg/migrations/mls/20231023050806_init-schema.down.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,3 @@ SET

--bun:split
DROP TABLE IF EXISTS installations;

--bun:split
DROP TABLE IF EXISTS key_packages;
29 changes: 4 additions & 25 deletions pkg/migrations/mls/20231023050806_init-schema.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,12 @@ CREATE TABLE installations (
id BYTEA PRIMARY KEY,
wallet_address TEXT NOT NULL,
created_at BIGINT NOT NULL,
updated_at BIGINT NOT NULL,
credential_identity BYTEA NOT NULL,
revoked_at BIGINT
);
revoked_at BIGINT,

--bun:split
CREATE TABLE key_packages (
id TEXT PRIMARY KEY,
installation_id BYTEA NOT NULL,
created_at BIGINT NOT NULL,
consumed_at BIGINT,
not_consumed BOOLEAN DEFAULT TRUE NOT NULL,
is_last_resort BOOLEAN NOT NULL,
data BYTEA NOT NULL,
-- Add a foreign key constraint to ensure key packages cannot be added for unregistered installations
CONSTRAINT fk_installation_id FOREIGN KEY (installation_id) REFERENCES installations (id)
key_package BYTEA NOT NULL,
expiration BIGINT NOT NULL
);

--bun:split
Expand All @@ -31,15 +22,3 @@ CREATE INDEX idx_installations_created_at ON installations(created_at);

--bun:split
CREATE INDEX idx_installations_revoked_at ON installations(revoked_at);

--bun:split
-- Adding indexes for the key_packages table
CREATE INDEX idx_key_packages_installation_id_not_consumed_is_last_resort_created_at ON key_packages(
installation_id,
not_consumed,
is_last_resort,
created_at
);

--bun:split
CREATE INDEX idx_key_packages_is_last_resort_id ON key_packages(is_last_resort, id);
14 changes: 3 additions & 11 deletions pkg/mlsstore/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,10 @@ type Installation struct {
ID []byte `bun:",pk,type:bytea"`
WalletAddress string `bun:"wallet_address,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
UpdatedAt int64 `bun:"updated_at,notnull"`
RevokedAt *int64 `bun:"revoked_at"`
CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"`
}

type KeyPackage struct {
bun.BaseModel `bun:"table:key_packages"`

ID string `bun:",pk"` // ID is the hash of the data field
InstallationId []byte `bun:"installation_id,notnull,type:bytea"`
CreatedAt int64 `bun:"created_at,notnull"`
ConsumedAt *int64 `bun:"consumed_at"`
NotConsumed bool `bun:"not_consumed,default:true"`
IsLastResort bool `bun:"is_last_resort,notnull"`
Data []byte `bun:"data,notnull,type:bytea"`
KeyPackage []byte `bun:"key_package,notnull,type:bytea"`
Expiration uint64 `bun:"expiration,notnull"`
}
Loading

0 comments on commit 0210936

Please sign in to comment.