From b2a3ff35c0d3a9abcb10175aa0e4e187f094efbc Mon Sep 17 00:00:00 2001 From: Yevgeniy Firsov Date: Wed, 17 May 2023 10:16:11 -0700 Subject: [PATCH] feat: Support user specified collection names This allows to support multiple collections using the same model type. --- tigris/database.go | 44 +++++++++++++++++++++++++++++++++++------ tigris/database_test.go | 36 +++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/tigris/database.go b/tigris/database.go index caa68d0..3a45b85 100644 --- a/tigris/database.go +++ b/tigris/database.go @@ -32,14 +32,14 @@ import ( // top level GetTxCollection(ctx, tx) function should be used // instead of method of Tx interface. type Database struct { - name string driver driver.Driver + name string } -func newDatabase(name string, driver driver.Driver) *Database { +func newDatabase(name string, drv driver.Driver) *Database { return &Database{ name: name, - driver: driver, + driver: drv, } } @@ -55,10 +55,34 @@ func (db *Database) CreateCollections(ctx context.Context, model schema.Model, m return db.createCollectionsFromSchemas(ctx, schemas) } +// CreateCollection creates collection in the Database using provided collection model and optional name. +// This method is only needed if collection need to be created dynamically, +// all static collections are created by OpenDatabase. +func (db *Database) CreateCollection(ctx context.Context, model schema.Model, name ...string) error { + schemas, err := schema.FromCollectionModels(SchemaVersion, schema.Documents, model) + if err != nil { + return fmt.Errorf("error parsing model schema: %w", err) + } + + if len(name) > 1 { + return fmt.Errorf("only one name parameter allowed") + } + + if len(name) == 1 { + // there only one schema + for _, v := range schemas { + v.Name = name[0] + } + } + + return db.createCollectionsFromSchemas(ctx, schemas) +} + func (db *Database) createCollectionsFromSchemasLow(ctx context.Context, tx driver.Tx, inSchemas map[string]*schema.Schema) error { schemas := make([]driver.Schema, len(inSchemas)) var i int + for _, v := range inSchemas { sch, err := schema.Build(v) if err != nil { @@ -70,11 +94,13 @@ func (db *Database) createCollectionsFromSchemasLow(ctx context.Context, tx driv } var err error + if tx != nil { _, err = tx.CreateOrUpdateCollections(ctx, schemas) } else { _, err = db.driver.UseDatabase(db.name).CreateOrUpdateCollections(ctx, schemas) } + if err != nil { return err } @@ -171,10 +197,16 @@ func MustOpenDatabase(ctx context.Context, cfg *Config, models ...schema.Model, } // GetCollection returns collection object corresponding to collection model T. -func GetCollection[T schema.Model](db *Database) *Collection[T] { +func GetCollection[T schema.Model](db *Database, name ...string) *Collection[T] { var m T - name := schema.ModelName(&m) - return getNamedCollection[T](db, name) + + nm := schema.ModelName(&m) + + if len(name) > 0 { + nm = name[0] + } + + return getNamedCollection[T](db, nm) } func getNamedCollection[T schema.Model](db *Database, name string) *Collection[T] { diff --git a/tigris/database_test.go b/tigris/database_test.go index 5a7948a..8c92b4d 100644 --- a/tigris/database_test.go +++ b/tigris/database_test.go @@ -88,4 +88,40 @@ func TestDatabase(t *testing.T) { require.Nil(t, resp) require.Equal(t, &driver.Error{TigrisError: &api.TigrisError{Code: api.Code_NOT_FOUND, Message: "branch does not exist"}}, err1) }) + + t.Run("create collection", func(t *testing.T) { + mc.EXPECT().CreateOrUpdateCollections(gomock.Any(), + pm(&api.CreateOrUpdateCollectionsRequest{ + Project: "db1", + Branch: "staging", + Schemas: [][]byte{[]byte(`{"title":"coll_1","properties":{"Key1":{"type":"string"}},"primary_key":["Key1"],"collection_type":"documents"}`)}, + Options: &api.CollectionOptions{}, + })).Do(func(ctx context.Context, r *api.CreateOrUpdateCollectionsRequest) { + }).Return(&api.CreateOrUpdateCollectionsResponse{}, nil) + + type Coll1 struct { + Key1 string `tigris:"primary_key:1"` + } + + err1 := client.GetDatabase().CreateCollection(ctx, &Coll1{}) + require.NoError(t, err1) + + mc.EXPECT().CreateOrUpdateCollections(gomock.Any(), + pm(&api.CreateOrUpdateCollectionsRequest{ + Project: "db1", + Branch: "staging", + Schemas: [][]byte{[]byte(`{"title":"other_coll_1","properties":{"Key1":{"type":"string"}},"primary_key":["Key1"],"collection_type":"documents"}`)}, + Options: &api.CollectionOptions{}, + })).Do(func(ctx context.Context, r *api.CreateOrUpdateCollectionsRequest) { + }).Return(&api.CreateOrUpdateCollectionsResponse{}, nil) + + err1 = client.GetDatabase().CreateCollection(ctx, &Coll1{}, "other_coll_1") + require.NoError(t, err1) + + err1 = client.GetDatabase().CreateCollection(ctx, &Coll1{}, "other_coll_1", "multi_parameter_not_allowed") + require.Error(t, err1) + + coll := GetCollection[Coll1](client.GetDatabase(), "other_coll_1") + require.Equal(t, "other_coll_1", coll.name) + }) }