From 6bb95efb7997692a52c321e787e633a5045b21f8 Mon Sep 17 00:00:00 2001 From: Varun Naik Date: Tue, 13 Dec 2022 20:21:21 -0800 Subject: [PATCH] feat(spanner): add database roles (#5701) * feat(spanner): add database roles * tests * make tests pass * respond to PR comments * add test for ListDatabaseRoles * skip emulator tests * Add check for nil Co-authored-by: rahul2393 Co-authored-by: rahul2393 --- spanner/client.go | 6 +- spanner/client_test.go | 29 + spanner/integration_test.go | 569 ++++++++++++++++++ .../internal/testutil/inmem_spanner_server.go | 12 +- spanner/read_test.go | 1 + spanner/sessionclient.go | 8 +- spanner/sessionclient_test.go | 71 +++ 7 files changed, 690 insertions(+), 6 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index 155fcf7410dd..4acf0c8b3467 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -143,6 +143,10 @@ type ClientConfig struct { // Recommended format: ``application-or-tool-ID/major.minor.version``. UserAgent string + // DatabaseRole specifies the role to be assumed for all operations on the + // database by this client. + DatabaseRole string + // Logger is the logger to use for this client. If it is nil, all logging // will be directed to the standard logger. Logger *log.Logger @@ -220,7 +224,7 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf config.incStep = DefaultSessionPoolConfig.incStep } // Create a session client. - sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, metadata.Pairs(resourcePrefixHeader, database), config.Logger, config.CallOptions) + sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, metadata.Pairs(resourcePrefixHeader, database), config.Logger, config.CallOptions) // Create a session pool. config.SessionPoolConfig.sessionLabels = sessionLabels sp, err := newSessionPool(sc, config.SessionPoolConfig) diff --git a/spanner/client_test.go b/spanner/client_test.go index 1dc4ebe5d367..051012a37d15 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -1146,6 +1146,35 @@ func TestClient_ReadWriteTransaction_DoNotLeakSessionOnPanic(t *testing.T) { } } +func TestClient_SessionContainsDatabaseRole(t *testing.T) { + // Make sure that there is always only one session in the pool. + sc := SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + } + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc, DatabaseRole: "test"}) + defer teardown() + + // Wait until all sessions have been created, so we know that those requests will not interfere with the test. + sp := client.idleSessions + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if uint64(sp.idleList.Len()) != 1 { + return fmt.Errorf("num open sessions mismatch.\nGot: %d\nWant: %d", sp.numOpened, sp.MinOpened) + } + return nil + }) + + resp, err := server.TestSpanner.GetSession(context.Background(), &sppb.GetSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id}) + if err != nil { + t.Fatalf("Failed to get session unexpectedly: %v", err) + } + if g, w := resp.CreatorRole, "test"; g != w { + t.Fatalf("database role mismatch.\nGot: %v\nWant: %v", g, w) + } +} + func TestClient_SessionNotFound(t *testing.T) { // Ensure we always have at least one session in the pool. sc := SessionPoolConfig{ diff --git a/spanner/integration_test.go b/spanner/integration_test.go index a3f046e5ea74..5817263443fb 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -2485,6 +2485,560 @@ func TestIntegration_TransactionRunner(t *testing.T) { } } +func TestIntegration_QueryWithRoles(t *testing.T) { + t.Parallel() + // Database roles are not currently available in emulator and PG dialect + skipEmulatorTest(t) + skipUnsupportedPGTest(t) + + // Set up testing environment. + var ( + row *Row + client, clientWithRole *Client + iter *RowIterator + err error + id int64 + firstName, lastName string + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + `CREATE ROLE singers_reader`, + `CREATE ROLE singers_unauthorized`, + `CREATE ROLE singers_reader_revoked`, + `CREATE ROLE dropped`, + `GRANT SELECT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_reader`, + `GRANT SELECT(SingerId, FirstName) ON TABLE Singers TO ROLE singers_unauthorized`, + `GRANT SELECT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_reader_revoked`, + `REVOKE SELECT(LastName) ON TABLE Singers FROM ROLE singers_reader_revoked`, + `DROP ROLE dropped`, + } + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerColumns := []string{"SingerId", "FirstName", "LastName"} + var ms = []*Mutation{ + InsertOrUpdate("Singers", singerColumns, []interface{}{1, "Marc", "Richards"}), + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatalf("Could not insert rows to table. Got error %v", err) + } + queryStmt := Statement{SQL: `SELECT SingerId, FirstName, LastName FROM Singers`} + + // A request with sufficient privileges should return all rows + for _, dbRole := range []string{ + "", + "singers_reader", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + iter = clientWithRole.Single().Query(ctx, queryStmt) + defer iter.Stop() + + row, err = iter.Next() + if err != nil { + t.Fatalf("Could not read row. Got error %v", err) + } + if err = row.Columns(&id, &firstName, &lastName); err != nil { + t.Fatalf("failed to parse row %v", err) + } + if id != 1 || firstName != "Marc" || lastName != "Richards" { + t.Fatalf("execution didn't return expected values") + } + + _, err = iter.Next() + if err != iterator.Done { + t.Fatalf("got results from iterator, want none: %#v, err = %v\n", row, err) + } + } + + // A request with insufficient privileges should return permission denied + for _, test := range []struct { + dbRole string + errMsg string + }{ + { + "singers_unauthorized", + "Role singers_unauthorized does not have required privileges on table Singers.", + }, + { + "singers_reader_revoked", + "Role singers_reader_revoked does not have required privileges on table Singers.", + }, + { + "nonexistent", + "Role not found: nonexistent.", + }, + { + "dropped", + "Role not found: dropped.", + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + iter = clientWithRole.Single().Query(ctx, queryStmt) + defer iter.Stop() + + _, err = iter.Next() + if err == nil { + t.Fatal("expected err, got nil") + } + if msg, ok := matchError(err, codes.PermissionDenied, test.errMsg); !ok { + t.Fatal(msg) + } + } +} + +func TestIntegration_ReadWithRoles(t *testing.T) { + t.Parallel() + // Database roles are not currently available in emulator and PG dialect + skipEmulatorTest(t) + skipUnsupportedPGTest(t) + + // Set up testing environment. + var ( + row *Row + client, clientWithRole *Client + iter *RowIterator + err error + id int64 + firstName, lastName string + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + `CREATE ROLE singers_reader`, + `CREATE ROLE singers_unauthorized`, + `CREATE ROLE singers_reader_revoked`, + `CREATE ROLE dropped`, + `GRANT SELECT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_reader`, + `GRANT SELECT(SingerId, FirstName) ON TABLE Singers TO ROLE singers_unauthorized`, + `GRANT SELECT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_reader_revoked`, + `REVOKE SELECT(LastName) ON TABLE Singers FROM ROLE singers_reader_revoked`, + `DROP ROLE dropped`, + } + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerColumns := []string{"SingerId", "FirstName", "LastName"} + var ms = []*Mutation{ + InsertOrUpdate("Singers", singerColumns, []interface{}{1, "Marc", "Richards"}), + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatalf("Could not insert rows to table. Got error %v", err) + } + + // A request with sufficient privileges should return all rows + for _, dbRole := range []string{ + "", + "singers_reader", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + iter = clientWithRole.Single().Read(ctx, "Singers", AllKeys(), singerColumns) + defer iter.Stop() + + row, err = iter.Next() + if err != nil { + t.Fatalf("Could not read row. Got error %v", err) + } + if err = row.Columns(&id, &firstName, &lastName); err != nil { + t.Fatalf("failed to parse row %v", err) + } + if id != 1 || firstName != "Marc" || lastName != "Richards" { + t.Fatalf("execution didn't return expected values") + } + + _, err = iter.Next() + if err != iterator.Done { + t.Fatalf("got results from iterator, want none: %#v, err = %v\n", row, err) + } + } + + // A request with insufficient privileges should return permission denied + for _, test := range []struct { + dbRole string + errMsg string + }{ + { + "singers_unauthorized", + "Role singers_unauthorized does not have required privileges on table Singers.", + }, + { + "singers_reader_revoked", + "Role singers_reader_revoked does not have required privileges on table Singers.", + }, + { + "nonexistent", + "Role not found: nonexistent.", + }, + { + "dropped", + "Role not found: dropped.", + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + iter = clientWithRole.Single().Read(ctx, "Singers", AllKeys(), singerColumns) + defer iter.Stop() + + _, err = iter.Next() + if err == nil { + t.Fatal("expected err, got nil") + } + if msg, ok := matchError(err, codes.PermissionDenied, test.errMsg); !ok { + t.Fatal(msg) + } + } +} + +func TestIntegration_DMLWithRoles(t *testing.T) { + t.Parallel() + // Database roles are not currently available in emulator and PG dialect + skipEmulatorTest(t) + skipUnsupportedPGTest(t) + + // Set up testing environment. + var ( + client, clientWithRole *Client + err error + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + `CREATE ROLE singers_updater`, + `CREATE ROLE singers_unauthorized`, + `CREATE ROLE singers_creator`, + `CREATE ROLE singers_deleter`, + `GRANT SELECT(SingerId), UPDATE(FirstName, LastName) ON TABLE Singers TO ROLE singers_updater`, + `GRANT SELECT(SingerId), UPDATE(FirstName) ON TABLE Singers TO ROLE singers_unauthorized`, + `GRANT INSERT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_creator`, + `GRANT SELECT(SingerId), DELETE ON TABLE Singers TO ROLE singers_deleter`, + } + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerColumns := []string{"SingerId", "FirstName", "LastName"} + var ms = []*Mutation{ + InsertOrUpdate("Singers", singerColumns, []interface{}{1, "Marc", "Richards"}), + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatalf("Could not insert rows to table. Got error %v", err) + } + updateStmt := Statement{SQL: `UPDATE Singers SET FirstName = "Mark", LastName = "Richards" WHERE SingerId = 1`} + + // A request with sufficient privileges should update the row + for _, dbRole := range []string{ + "", + "singers_updater", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.ReadWriteTransaction(ctx, func(ctx context.Context, txn *ReadWriteTransaction) error { + _, err := txn.Update(ctx, updateStmt) + return err + }) + if err != nil { + t.Fatalf("Could not update row. Got error %v", err) + } + } + + // A request with insufficient privileges should return permission denied + for _, test := range []struct { + dbRole string + errMsg string + }{ + { + "singers_unauthorized", + "Role singers_unauthorized does not have required privileges on table Singers.", + }, + { + "nonexistent", + "Role not found: nonexistent.", + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.ReadWriteTransaction(ctx, func(ctx context.Context, txn *ReadWriteTransaction) error { + _, err := txn.Update(ctx, updateStmt) + return err + }) + + if err == nil { + t.Fatal("expected err, got nil") + } + if msg, ok := matchError(err, codes.PermissionDenied, test.errMsg); !ok { + t.Fatal(msg) + } + } + + // A request with sufficient privileges should insert the row + getInsertStmt := func(vals []interface{}) Statement { + sql := fmt.Sprintf(`INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (%d, "%s", "%s")`, vals...) + return Statement{SQL: sql} + } + for _, test := range []struct { + dbRole string + vals []interface{} + }{ + { + "", + []interface{}{2, "Catalina", "Smith"}, + }, + { + "singers_creator", + []interface{}{3, "Alice", "Trentor"}, + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.ReadWriteTransaction(ctx, func(ctx context.Context, txn *ReadWriteTransaction) error { + _, err := txn.Update(ctx, getInsertStmt(test.vals)) + return err + }) + if err != nil { + t.Fatalf("Could not insert row. Got error %v", err) + } + } + + // A request with sufficient privileges should delete the row + deleteStmt := Statement{SQL: `DELETE FROM Singers WHERE TRUE`} + for _, dbRole := range []string{ + "", + "singers_deleter", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.ReadWriteTransaction(ctx, func(ctx context.Context, txn *ReadWriteTransaction) error { + _, err := txn.Update(ctx, deleteStmt) + return err + }) + if err != nil { + t.Fatalf("Could not delete row. Got error %v", err) + } + } +} + +func TestIntegration_MutationWithRoles(t *testing.T) { + t.Parallel() + // Database roles are not currently available in emulator and PG dialect + skipEmulatorTest(t) + skipUnsupportedPGTest(t) + + // Set up testing environment. + var ( + client, clientWithRole *Client + err error + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + `CREATE ROLE singers_updater`, + `CREATE ROLE singers_unauthorized`, + `CREATE ROLE singers_creator`, + `CREATE ROLE singers_deleter`, + `GRANT SELECT(SingerId), UPDATE(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_updater`, + `GRANT SELECT(SingerId), UPDATE(SingerId, FirstName) ON TABLE Singers TO ROLE singers_unauthorized`, + `GRANT INSERT(SingerId, FirstName, LastName) ON TABLE Singers TO ROLE singers_creator`, + `GRANT SELECT(SingerId), DELETE ON TABLE Singers TO ROLE singers_deleter`, + } + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerColumns := []string{"SingerId", "FirstName", "LastName"} + var ms = []*Mutation{ + InsertOrUpdate("Singers", singerColumns, []interface{}{1, "Marc", "Richards"}), + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatalf("Could not insert rows to table. Got error %v", err) + } + + // A request with sufficient privileges should update the row + for _, dbRole := range []string{ + "", + "singers_updater", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.Apply(ctx, []*Mutation{ + Update("Singers", singerColumns, []interface{}{1, "Mark", "Richards"}), + }) + if err != nil { + t.Fatalf("Could not update row. Got error %v", err) + } + } + + // A request with insufficient privileges should return permission denied + for _, test := range []struct { + dbRole string + errMsg string + }{ + { + "singers_unauthorized", + "Role singers_unauthorized does not have required privileges on table Singers.", + }, + { + "nonexistent", + "Role not found: nonexistent.", + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.Apply(ctx, []*Mutation{ + Update("Singers", singerColumns, []interface{}{1, "Mark", "Richards"}), + }) + + if err == nil { + t.Fatal("expected err, got nil") + } + if msg, ok := matchError(err, codes.PermissionDenied, test.errMsg); !ok { + t.Fatal(msg) + } + } + + // A request with sufficient privileges should insert the row + for _, test := range []struct { + dbRole string + vals []interface{} + }{ + { + "", + []interface{}{2, "Catalina", "Smith"}, + }, + { + "singers_creator", + []interface{}{3, "Alice", "Trentor"}, + }, + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, test.dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.Apply(ctx, []*Mutation{ + Insert("Singers", singerColumns, test.vals), + }) + if err != nil { + t.Fatalf("Could not insert row. Got error %v", err) + } + } + + // A request with sufficient privileges should delete the row + for _, dbRole := range []string{ + "", + "singers_deleter", + } { + if clientWithRole, err = createClientWithRole(ctx, dbPath, SessionPoolConfig{}, dbRole); err != nil { + t.Fatal(err) + } + defer clientWithRole.Close() + _, err = clientWithRole.Apply(ctx, []*Mutation{ + Delete("Singers", Key{1}), + }) + if err != nil { + t.Fatalf("Could not delete row. Got error %v", err) + } + } +} + +func TestIntegration_ListDatabaseRoles(t *testing.T) { + t.Parallel() + // Database roles are not currently available in emulator and PG dialect + skipEmulatorTest(t) + skipUnsupportedPGTest(t) + + // Set up testing environment. + var ( + err error + iter *database.DatabaseRoleIterator + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{ + `CREATE ROLE a`, + `CREATE ROLE z`, + } + _, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + iter = databaseAdmin.ListDatabaseRoles(ctx, &adminpb.ListDatabaseRolesRequest{ + Parent: dbPath, + }) + roles, err := readDatabaseRoles(iter) + if err != nil { + t.Fatalf("cannot list database roles in %v: %v", dbPath, err) + } + var got []string + rolePrefix := dbPath + "/databaseRoles/" + for _, role := range roles { + if !strings.HasPrefix(role.Name, rolePrefix) { + t.Fatalf("Role %v does not have prefix %v", role.Name, rolePrefix) + } + got = append(got, strings.TrimPrefix(role.Name, rolePrefix)) + } + want := []string{"a", "public", "spanner_info_reader", "spanner_sys_reader", "z"} + if !testEqual(got, want) { + t.Fatalf("Database role mismatch\nGot: %v, Want: %v", got, want) + } +} + +func readDatabaseRoles(iter *database.DatabaseRoleIterator) ([]*adminpb.DatabaseRole, error) { + var vals []*adminpb.DatabaseRole + for { + v, err := iter.Next() + if err == iterator.Done { + return vals, nil + } + if err != nil { + return nil, err + } + vals = append(vals, v) + } +} + // Test PartitionQuery of BatchReadOnlyTransaction, create partitions then // serialize and deserialize both transaction and partition to be used in // execution on another client, and compare results. @@ -4005,6 +4559,21 @@ func createClient(ctx context.Context, dbPath string, spc SessionPoolConfig) (cl return client, nil } +func createClientWithRole(ctx context.Context, dbPath string, spc SessionPoolConfig, role string) (client *Client, err error) { + opts := grpcHeaderChecker.CallOptions() + if spannerHost != "" { + opts = append(opts, option.WithEndpoint(spannerHost)) + } + if dpConfig.attemptDirectPath { + opts = append(opts, option.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.Peer(peerInfo)))) + } + client, err = NewClientWithConfig(ctx, dbPath, ClientConfig{SessionPoolConfig: spc, DatabaseRole: role}, opts...) + if err != nil { + return nil, fmt.Errorf("cannot create data client on DB %v: %v", dbPath, err) + } + return client, nil +} + // populate prepares the database with some data. func populate(ctx context.Context, client *Client) error { // Populate data diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 9152b01d8737..2939434a77b3 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -653,7 +653,11 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C } sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() - session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} + var creatorRole string + if req.Session != nil { + creatorRole = req.Session.CreatorRole + } + session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts, CreatorRole: creatorRole} s.totalSessionsCreated++ s.sessions[sessionName] = session return session, nil @@ -685,7 +689,11 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann for i := int32(0); i < sessionsToCreate; i++ { sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() - sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} + var creatorRole string + if req.SessionTemplate != nil { + creatorRole = req.SessionTemplate.CreatorRole + } + sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts, CreatorRole: creatorRole} s.totalSessionsCreated++ s.sessions[sessionName] = sessions[i] } diff --git a/spanner/read_test.go b/spanner/read_test.go index 1f37abd2e1d8..03c7c65ed435 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -1778,6 +1778,7 @@ func createSession(client *vkit.Client) (*sppb.Session, error) { var formattedDatabase string = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") var request = &sppb.CreateSessionRequest{ Database: formattedDatabase, + Session: &sppb.Session{}, } return client.CreateSession(context.Background(), request) } diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index 28e4e7a5a9d0..aef2ab9e100f 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -93,6 +93,7 @@ type sessionClient struct { id string userAgent string sessionLabels map[string]string + databaseRole string md metadata.MD batchTimeout time.Duration logger *log.Logger @@ -100,13 +101,14 @@ type sessionClient struct { } // newSessionClient creates a session client to use for a database. -func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, sessionLabels map[string]string, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient { +func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, sessionLabels map[string]string, databaseRole string, md metadata.MD, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient { return &sessionClient{ connPool: connPool, database: database, userAgent: userAgent, id: cidGen.nextID(database), sessionLabels: sessionLabels, + databaseRole: databaseRole, md: md, batchTimeout: time.Minute, logger: logger, @@ -138,7 +140,7 @@ func (sc *sessionClient) createSession(ctx context.Context) (*session, error) { var md metadata.MD sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{ Database: sc.database, - Session: &sppb.Session{Labels: sc.sessionLabels}, + Session: &sppb.Session{Labels: sc.sessionLabels, CreatorRole: sc.databaseRole}, }, gax.WithGRPCOptions(grpc.Header(&md))) if getGFELatencyMetricsFlag() && md != nil { @@ -260,7 +262,7 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{ SessionCount: remainingCreateCount, Database: sc.database, - SessionTemplate: &sppb.Session{Labels: labels}, + SessionTemplate: &sppb.Session{Labels: labels, CreatorRole: sc.databaseRole}, }, gax.WithGRPCOptions(grpc.Header(&mdForGFELatency))) if getGFELatencyMetricsFlag() && mdForGFELatency != nil { diff --git a/spanner/sessionclient_test.go b/spanner/sessionclient_test.go index f06b175c36e1..309ae2d70c9e 100644 --- a/spanner/sessionclient_test.go +++ b/spanner/sessionclient_test.go @@ -26,6 +26,7 @@ import ( vkit "cloud.google.com/go/spanner/apiv1" . "cloud.google.com/go/spanner/internal/testutil" gax "github.com/googleapis/gax-go/v2" + sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -148,6 +149,41 @@ func TestCreateAndCloseSession(t *testing.T) { } } +func TestCreateSessionWithDatabaseRole(t *testing.T) { + // Make sure that there is always only one session in the pool. + sc := SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 1, + } + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc, DatabaseRole: "test"}) + defer teardown() + ctx := context.Background() + + s, err := client.sc.createSession(ctx) + if err != nil { + t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err) + } + if s == nil { + t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) + } + if g, w := server.TestSpanner.TotalSessionsCreated(), uint(1); g != w { + t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", g, w) + } + + resp, err := server.TestSpanner.GetSession(ctx, &sppb.GetSessionRequest{Name: s.id}) + if err != nil { + t.Fatalf("Failed to get session unexpectedly: %v", err) + } + if g, w := resp.CreatorRole, "test"; g != w { + t.Fatalf("database role mismatch.\nGot: %v\nWant: %v", g, w) + } + + s.delete(ctx) + if g, w := server.TestSpanner.TotalSessionsDeleted(), uint(1); g != w { + t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", g, w) + } +} + func TestBatchCreateAndCloseSession(t *testing.T) { t.Parallel() @@ -201,6 +237,41 @@ func TestBatchCreateAndCloseSession(t *testing.T) { } } +func TestBatchCreateSessionsWithDatabaseRole(t *testing.T) { + // Make sure that there is always only one session in the pool. + sc := SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 1, + } + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc, DatabaseRole: "test"}) + defer teardown() + ctx := context.Background() + + consumer := newTestConsumer(1) + client.sc.batchCreateSessions(1, true, consumer) + <-consumer.receivedAll + if g, w := len(consumer.sessions), 1; g != w { + t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", g, w) + } + if g, w := server.TestSpanner.TotalSessionsCreated(), uint(1); g != w { + t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", g, w) + } + s := consumer.sessions[0] + + resp, err := server.TestSpanner.GetSession(ctx, &sppb.GetSessionRequest{Name: s.id}) + if err != nil { + t.Fatalf("Failed to get session unexpectedly: %v", err) + } + if g, w := resp.CreatorRole, "test"; g != w { + t.Fatalf("database role mismatch.\nGot: %v\nWant: %v", g, w) + } + + s.delete(ctx) + if g, w := server.TestSpanner.TotalSessionsDeleted(), uint(1); g != w { + t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", g, w) + } +} + func TestBatchCreateSessionsWithExceptions(t *testing.T) { t.Parallel()