From 20fd526c6e37a798e2ed01651f2241a24e00bce6 Mon Sep 17 00:00:00 2001 From: Brendan Dougherty Date: Fri, 6 Jan 2023 15:03:46 +0000 Subject: [PATCH] VTGate: Set immediate caller id from gRPC static auth username Signed-off-by: Brendan Dougherty --- .../grpc_server_auth_static/main_test.go | 212 ++++++++++++++++++ go/vt/servenv/grpc_server_auth_static.go | 23 +- go/vt/vtgate/grpcvtgateservice/server.go | 11 +- test/config.json | 9 + 4 files changed, 252 insertions(+), 3 deletions(-) create mode 100644 go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go diff --git a/go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go b/go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go new file mode 100644 index 00000000000..3590ba491ff --- /dev/null +++ b/go/test/endtoend/vtgate/grpc_server_auth_static/main_test.go @@ -0,0 +1,212 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcserverauthstatic + +import ( + "context" + "flag" + "fmt" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/vt/grpcclient" + "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" + "vitess.io/vitess/go/vt/vtgate/vtgateconn" +) + +var ( + clusterInstance *cluster.LocalProcessCluster + vtgateGrpcAddress string + hostname = "localhost" + keyspaceName = "ks" + cell = "zone1" + sqlSchema = ` + create table test_table ( + id bigint, + val varchar(128), + primary key(id) + ) Engine=InnoDB; +` + grpcServerAuthStaticJSON = ` + [ + { + "Username": "user_with_access", + "Password": "test_password" + }, + { + "Username": "user_no_access", + "Password": "test_password" + } + ] +` + tableACLJSON = ` + { + "table_groups": [ + { + "name": "default", + "table_names_or_prefixes": ["%"], + "readers": ["user_with_access"], + "writers": ["user_with_access"], + "admins": ["user_with_access"] + } + ] + } +` +) + +func TestMain(m *testing.M) { + defer cluster.PanicHandler(nil) + flag.Parse() + + exitcode := func() int { + clusterInstance = cluster.NewCluster(cell, hostname) + defer clusterInstance.Teardown() + + // Start topo server + if err := clusterInstance.StartTopo(); err != nil { + return 1 + } + + // Directory for authn / authz config files + authDirectory := path.Join(clusterInstance.TmpDirectory, "auth") + if err := os.Mkdir(authDirectory, 0700); err != nil { + return 1 + } + + // Create grpc_server_auth_static.json file + grpcServerAuthStaticPath := path.Join(authDirectory, "grpc_server_auth_static.json") + if err := createFile(grpcServerAuthStaticPath, grpcServerAuthStaticJSON); err != nil { + return 1 + } + + // Create table_acl.json file + tableACLPath := path.Join(authDirectory, "table_acl.json") + if err := createFile(tableACLPath, tableACLJSON); err != nil { + return 1 + } + + // Configure vtgate to use static auth + clusterInstance.VtGateExtraArgs = []string{ + "--grpc_auth_mode", "static", + "--grpc_auth_static_password_file", grpcServerAuthStaticPath, + } + + // Configure vttablet to use table ACL + clusterInstance.VtTabletExtraArgs = []string{ + "--enforce-tableacl-config", + "--queryserver-config-strict-table-acl", + "--table-acl-config", tableACLPath, + } + + // Start keyspace + keyspace := &cluster.Keyspace{ + Name: keyspaceName, + SchemaSQL: sqlSchema, + } + if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil { + return 1 + } + + // Start vtgate + if err := clusterInstance.StartVtgate(); err != nil { + clusterInstance.VtgateProcess = cluster.VtgateProcess{} + return 1 + } + vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort) + + return m.Run() + }() + os.Exit(exitcode) +} + +func TestAuthenticatedUserWithAccess(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vtgateConn, err := dialVTGate(ctx, t, "user_with_access", "test_password") + if err != nil { + t.Fatal(err) + } + defer vtgateConn.Close() + + session := vtgateConn.Session(keyspaceName+"@primary", nil) + query := "SELECT id FROM test_table" + _, err = session.Execute(ctx, query, nil) + assert.NoError(t, err) +} + +func TestAuthenticatedUserNoAccess(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vtgateConn, err := dialVTGate(ctx, t, "user_no_access", "test_password") + if err != nil { + t.Fatal(err) + } + defer vtgateConn.Close() + + session := vtgateConn.Session(keyspaceName+"@primary", nil) + query := "SELECT id FROM test_table" + _, err = session.Execute(ctx, query, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "Select command denied to user") + assert.Contains(t, err.Error(), "for table 'test_table' (ACL check error)") +} + +func TestUnauthenticatedUser(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vtgateConn, err := dialVTGate(ctx, t, "", "") + if err != nil { + t.Fatal(err) + } + defer vtgateConn.Close() + + session := vtgateConn.Session(keyspaceName+"@primary", nil) + query := "SELECT id FROM test_table" + _, err = session.Execute(ctx, query, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid credentials") +} + +func dialVTGate(ctx context.Context, t *testing.T, username string, password string) (*vtgateconn.VTGateConn, error) { + clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password} + creds := grpc.WithPerRPCCredentials(clientCreds) + dialerFunc := grpcvtgateconn.DialWithOpts(ctx, creds) + dialerName := t.Name() + vtgateconn.RegisterDialer(dialerName, dialerFunc) + return vtgateconn.DialProtocol(ctx, dialerName, vtgateGrpcAddress) +} + +func createFile(path string, contents string) error { + f, err := os.Create(path) + if err != nil { + return err + } + _, err = f.WriteString(contents) + if err != nil { + return err + } + return f.Close() +} diff --git a/go/vt/servenv/grpc_server_auth_static.go b/go/vt/servenv/grpc_server_auth_static.go index ffcd8a72c56..b7c7142508a 100644 --- a/go/vt/servenv/grpc_server_auth_static.go +++ b/go/vt/servenv/grpc_server_auth_static.go @@ -36,6 +36,14 @@ var ( _ Authenticator = (*StaticAuthPlugin)(nil) ) +// The datatype for static auth Context keys +type staticAuthKey int + +const ( + // Internal Context key for the authenticated username + staticAuthUsername staticAuthKey = 0 +) + func registerGRPCServerAuthStaticFlags(fs *pflag.FlagSet) { fs.StringVar(&credsFile, "grpc_auth_static_password_file", credsFile, "JSON File to read the users/passwords from.") } @@ -66,7 +74,7 @@ func (sa *StaticAuthPlugin) Authenticate(ctx context.Context, fullMethod string) password := md["password"][0] for _, authEntry := range sa.entries { if username == authEntry.Username && password == authEntry.Password { - return ctx, nil + return newStaticAuthContext(ctx, username), nil } } return nil, status.Errorf(codes.PermissionDenied, "auth failure: caller %q provided invalid credentials", username) @@ -74,6 +82,19 @@ func (sa *StaticAuthPlugin) Authenticate(ctx context.Context, fullMethod string) return nil, status.Errorf(codes.Unauthenticated, "username and password must be provided") } +// StaticAuthUsernameFromContext returns the username authenticated by the static auth plugin and stored in the Context, if any +func StaticAuthUsernameFromContext(ctx context.Context) string { + username, ok := ctx.Value(staticAuthUsername).(string) + if ok { + return username + } + return "" +} + +func newStaticAuthContext(ctx context.Context, username string) context.Context { + return context.WithValue(ctx, staticAuthUsername, username) +} + func staticAuthPluginInitializer() (Authenticator, error) { entries := make([]StaticAuthConfigEntry, 0) if credsFile == "" { diff --git a/go/vt/vtgate/grpcvtgateservice/server.go b/go/vt/vtgate/grpcvtgateservice/server.go index edf9659c283..d012786d6eb 100644 --- a/go/vt/vtgate/grpcvtgateservice/server.go +++ b/go/vt/vtgate/grpcvtgateservice/server.go @@ -66,13 +66,13 @@ type VTGate struct { server vtgateservice.VTGateService } -// immediateCallerID tries to extract the common name as well as the (domain) subject +// immediateCallerIDFromCert tries to extract the common name as well as the (domain) subject // alternative names of the certificate that was used to connect to vtgate. // If it fails for any reason, it will return "". // That immediate caller id is then inserted into a Context, // and will be used when talking to vttablet. // vttablet in turn can use table ACLs to validate access is authorized. -func immediateCallerID(ctx context.Context) (string, []string) { +func immediateCallerIDFromCert(ctx context.Context) (string, []string) { p, ok := peer.FromContext(ctx) if !ok { return "", nil @@ -94,6 +94,13 @@ func immediateCallerID(ctx context.Context) (string, []string) { return cert.Subject.CommonName, cert.DNSNames } +func immediateCallerID(ctx context.Context) (string, []string) { + if immediate := servenv.StaticAuthUsernameFromContext(ctx); immediate != "" { + return immediate, nil + } + return immediateCallerIDFromCert(ctx) +} + // withCallerIDContext creates a context that extracts what we need // from the incoming call and can be forwarded for use when talking to vttablet. func withCallerIDContext(ctx context.Context, effectiveCallerID *vtrpcpb.CallerID) context.Context { diff --git a/test/config.json b/test/config.json index 247bb04bb8d..c8e6b18bfaf 100644 --- a/test/config.json +++ b/test/config.json @@ -900,6 +900,15 @@ "RetryMax": 1, "Tags": [] }, + "vtgate_grpc_server_auth_static": { + "File": "unused.go", + "Args": ["vitess.io/vitess/go/test/endtoend/vtgate/grpc_server_auth_static"], + "Command": [], + "Manual": false, + "Shard": "vtgate_general_heavy", + "RetryMax": 1, + "Tags": [] + }, "topo_zk2": { "File": "unused.go", "Args": ["vitess.io/vitess/go/test/endtoend/topotest/zk2", "--topo-flavor=zk2"],