diff --git a/cmd/csi-sanity/sanity_test.go b/cmd/csi-sanity/sanity_test.go index e07b324a..6c8f5e2e 100644 --- a/cmd/csi-sanity/sanity_test.go +++ b/cmd/csi-sanity/sanity_test.go @@ -36,6 +36,7 @@ var ( func init() { flag.StringVar(&config.Address, prefix+"endpoint", "", "CSI endpoint") + flag.StringVar(&config.ControllerAddress, prefix+"controllerendpoint", "", "CSI controller endpoint") flag.BoolVar(&version, prefix+"version", false, "Version of this program") flag.StringVar(&config.TargetPath, prefix+"mountdir", os.TempDir()+"/csi", "Mount point for NodePublish") flag.StringVar(&config.StagingPath, prefix+"stagingdir", os.TempDir()+"/csi", "Mount point for NodeStage if staging is supported") diff --git a/cmd/mock-driver/main.go b/cmd/mock-driver/main.go index 1b1be639..8ea891db 100644 --- a/cmd/mock-driver/main.go +++ b/cmd/mock-driver/main.go @@ -48,51 +48,136 @@ func main() { os.Exit(1) } + controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT") + if len(controllerEndpoint) == 0 { + // If empty, set to the common endpoint. + controllerEndpoint = endpoint + } + if strings.Contains(controllerEndpoint, ":") { + fmt.Println("CSI_CONTROLLER_ENDPOINT must be a unix path") + os.Exit(1) + } + // Create mock driver s := service.New(config) - servers := &driver.CSIDriverServers{ - Controller: s, - Identity: s, - Node: s, - } - d := driver.NewCSIDriver(servers) - // If creds is enabled, set the default creds. - setCreds := os.Getenv("CSI_ENABLE_CREDS") - if len(setCreds) > 0 && setCreds == "true" { - d.SetDefaultCreds() - } + if endpoint == controllerEndpoint { + servers := &driver.CSIDriverServers{ + Controller: s, + Identity: s, + Node: s, + } + d := driver.NewCSIDriver(servers) - // Listen - os.Remove(endpoint) - l, err := net.Listen("unix", endpoint) - if err != nil { - fmt.Printf("Error: Unable to listen on %s socket: %v\n", - endpoint, - err) - os.Exit(1) - } - defer os.Remove(endpoint) + // If creds is enabled, set the default creds. + setCreds := os.Getenv("CSI_ENABLE_CREDS") + if len(setCreds) > 0 && setCreds == "true" { + d.SetDefaultCreds() + } - // Start server - if err := d.Start(l); err != nil { - fmt.Printf("Error: Unable to start mock CSI server: %v\n", - err) - os.Exit(1) - } - fmt.Println("mock driver started") - - // Wait for signal - sigc := make(chan os.Signal, 1) - sigs := []os.Signal{ - syscall.SIGTERM, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGQUIT, - } - signal.Notify(sigc, sigs...) + // Listen + os.Remove(endpoint) + os.Remove(controllerEndpoint) + l, err := net.Listen("unix", endpoint) + if err != nil { + fmt.Printf("Error: Unable to listen on %s socket: %v\n", + endpoint, + err) + os.Exit(1) + } + defer os.Remove(endpoint) - <-sigc - d.Stop() - fmt.Println("mock driver stopped") + // Start server + if err := d.Start(l); err != nil { + fmt.Printf("Error: Unable to start mock CSI server: %v\n", + err) + os.Exit(1) + } + fmt.Println("mock driver started") + + // Wait for signal + sigc := make(chan os.Signal, 1) + sigs := []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGQUIT, + } + signal.Notify(sigc, sigs...) + + <-sigc + d.Stop() + fmt.Println("mock driver stopped") + } else { + controllerServer := &driver.CSIDriverControllerServer{ + Controller: s, + Identity: s, + } + dc := driver.NewCSIDriverController(controllerServer) + + nodeServer := &driver.CSIDriverNodeServer{ + Node: s, + Identity: s, + } + dn := driver.NewCSIDriverNode(nodeServer) + + setCreds := os.Getenv("CSI_ENABLE_CREDS") + if len(setCreds) > 0 && setCreds == "true" { + dc.SetDefaultCreds() + dn.SetDefaultCreds() + } + + // Listen controller. + os.Remove(controllerEndpoint) + l, err := net.Listen("unix", controllerEndpoint) + if err != nil { + fmt.Printf("Error: Unable to listen on %s socket: %v\n", + controllerEndpoint, + err) + os.Exit(1) + } + defer os.Remove(controllerEndpoint) + + // Start controller server. + if err = dc.Start(l); err != nil { + fmt.Printf("Error: Unable to start mock CSI controller server: %v\n", + err) + os.Exit(1) + } + fmt.Println("mock controller driver started") + + // Listen node. + os.Remove(endpoint) + l, err = net.Listen("unix", endpoint) + if err != nil { + fmt.Printf("Error: Unable to listen on %s socket: %v\n", + endpoint, + err) + os.Exit(1) + } + defer os.Remove(endpoint) + + // Start node server. + if err = dn.Start(l); err != nil { + fmt.Printf("Error: Unable to start mock CSI node server: %v\n", + err) + os.Exit(1) + } + fmt.Println("mock node driver started") + + // Wait for signal + sigc := make(chan os.Signal, 1) + sigs := []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGQUIT, + } + signal.Notify(sigc, sigs...) + + <-sigc + dc.Stop() + dn.Stop() + fmt.Println("mock drivers stopped") + } } diff --git a/driver/driver-controller.go b/driver/driver-controller.go new file mode 100644 index 00000000..1d8d2bd7 --- /dev/null +++ b/driver/driver-controller.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 Kubernetes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "context" + "net" + "sync" + + "google.golang.org/grpc/reflection" + + csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" +) + +// CSIDriverControllerServer is the Controller service component of the driver. +type CSIDriverControllerServer struct { + Controller csi.ControllerServer + Identity csi.IdentityServer +} + +// CSIDriverController is the CSI Driver Controller backend. +type CSIDriverController struct { + listener net.Listener + server *grpc.Server + controllerServer *CSIDriverControllerServer + wg sync.WaitGroup + running bool + lock sync.Mutex + creds *CSICreds +} + +func NewCSIDriverController(controllerServer *CSIDriverControllerServer) *CSIDriverController { + return &CSIDriverController{ + controllerServer: controllerServer, + } +} + +func (c *CSIDriverController) goServe(started chan<- bool) { + goServe(c.server, &c.wg, c.listener, started) +} + +func (c *CSIDriverController) Address() string { + return c.listener.Addr().String() +} + +func (c *CSIDriverController) Start(l net.Listener) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Set listener. + c.listener = l + + // Create a new grpc server. + c.server = grpc.NewServer( + grpc.UnaryInterceptor(c.callInterceptor), + ) + + if c.controllerServer.Controller != nil { + csi.RegisterControllerServer(c.server, c.controllerServer.Controller) + } + if c.controllerServer.Identity != nil { + csi.RegisterIdentityServer(c.server, c.controllerServer.Identity) + } + + reflection.Register(c.server) + + waitForServer := make(chan bool) + c.goServe(waitForServer) + <-waitForServer + c.running = true + return nil +} + +func (c *CSIDriverController) Stop() { + stop(&c.lock, &c.wg, c.server, c.running) +} + +func (c *CSIDriverController) Close() { + c.server.Stop() +} + +func (c *CSIDriverController) IsRunning() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.running +} + +func (c *CSIDriverController) SetDefaultCreds() { + setDefaultCreds(c.creds) +} + +func (c *CSIDriverController) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} diff --git a/driver/driver-node.go b/driver/driver-node.go new file mode 100644 index 00000000..7720bfc4 --- /dev/null +++ b/driver/driver-node.go @@ -0,0 +1,109 @@ +/* +Copyright 2019 Kubernetes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + context "context" + "net" + "sync" + + csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +// CSIDriverNodeServer is the Node service component of the driver. +type CSIDriverNodeServer struct { + Node csi.NodeServer + Identity csi.IdentityServer +} + +// CSIDriverNode is the CSI Driver Node backend. +type CSIDriverNode struct { + listener net.Listener + server *grpc.Server + nodeServer *CSIDriverNodeServer + wg sync.WaitGroup + running bool + lock sync.Mutex + creds *CSICreds +} + +func NewCSIDriverNode(nodeServer *CSIDriverNodeServer) *CSIDriverNode { + return &CSIDriverNode{ + nodeServer: nodeServer, + } +} + +func (c *CSIDriverNode) goServe(started chan<- bool) { + goServe(c.server, &c.wg, c.listener, started) +} + +func (c *CSIDriverNode) Address() string { + return c.listener.Addr().String() +} + +func (c *CSIDriverNode) Start(l net.Listener) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Set listener. + c.listener = l + + // Create a new grpc server. + c.server = grpc.NewServer( + grpc.UnaryInterceptor(c.callInterceptor), + ) + + if c.nodeServer.Node != nil { + csi.RegisterNodeServer(c.server, c.nodeServer.Node) + } + if c.nodeServer.Identity != nil { + csi.RegisterIdentityServer(c.server, c.nodeServer.Identity) + } + + reflection.Register(c.server) + + waitForServer := make(chan bool) + c.goServe(waitForServer) + <-waitForServer + c.running = true + return nil +} + +func (c *CSIDriverNode) Stop() { + stop(&c.lock, &c.wg, c.server, c.running) +} + +func (c *CSIDriverNode) Close() { + c.server.Stop() +} + +func (c *CSIDriverNode) IsRunning() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.running +} + +func (c *CSIDriverNode) SetDefaultCreds() { + setDefaultCreds(c.creds) +} + +func (c *CSIDriverNode) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} diff --git a/driver/driver.go b/driver/driver.go index 01224a3a..102bbb40 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -41,6 +41,8 @@ var ( ErrAuthFailed = errors.New("authentication failed") ) +// CSIDriverServers is a unified driver component with both Controller and Node +// services. type CSIDriverServers struct { Controller csi.ControllerServer Identity csi.IdentityServer @@ -81,15 +83,7 @@ func NewCSIDriver(servers *CSIDriverServers) *CSIDriver { } func (c *CSIDriver) goServe(started chan<- bool) { - c.wg.Add(1) - go func() { - defer c.wg.Done() - started <- true - err := c.server.Serve(c.listener) - if err != nil { - panic(err.Error()) - } - }() + goServe(c.server, &c.wg, c.listener, started) } func (c *CSIDriver) Address() string { @@ -128,15 +122,7 @@ func (c *CSIDriver) Start(l net.Listener) error { } func (c *CSIDriver) Stop() { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.running { - return - } - - c.server.Stop() - c.wg.Wait() + stop(&c.lock, &c.wg, c.server, c.running) } func (c *CSIDriver) Close() { @@ -152,7 +138,42 @@ func (c *CSIDriver) IsRunning() bool { // SetDefaultCreds sets the default secrets for CSI creds. func (c *CSIDriver) SetDefaultCreds() { - c.creds = &CSICreds{ + setDefaultCreds(c.creds) +} + +func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} + +// goServe starts a grpc server. +func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) { + wg.Add(1) + go func() { + defer wg.Done() + started <- true + err := server.Serve(listener) + if err != nil { + panic(err.Error()) + } + }() +} + +// stop stops a grpc server. +func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) { + lock.Lock() + defer lock.Unlock() + + if !running { + return + } + + server.Stop() + wg.Wait() +} + +// setDefaultCreds sets the default credentials, given a CSICreds instance. +func setDefaultCreds(creds *CSICreds) { + creds = &CSICreds{ CreateVolumeSecret: "secretval1", DeleteVolumeSecret: "secretval2", ControllerPublishVolumeSecret: "secretval3", @@ -164,8 +185,8 @@ func (c *CSIDriver) SetDefaultCreds() { } } -func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - err := c.authInterceptor(req) +func callInterceptor(ctx context.Context, creds *CSICreds, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := authInterceptor(creds, req) if err != nil { logGRPC(info.FullMethod, req, nil, err) return nil, err @@ -175,9 +196,9 @@ func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info * return rsp, err } -func (c *CSIDriver) authInterceptor(req interface{}) error { - if c.creds != nil { - authenticated, authErr := isAuthenticated(req, c.creds) +func authInterceptor(creds *CSICreds, req interface{}) error { + if creds != nil { + authenticated, authErr := isAuthenticated(req, creds) if !authenticated { if authErr == ErrNoCredentials { return status.Error(codes.InvalidArgument, authErr.Error()) diff --git a/hack/e2e.sh b/hack/e2e.sh index 5ab7f7bb..419405c4 100755 --- a/hack/e2e.sh +++ b/hack/e2e.sh @@ -2,6 +2,8 @@ TESTARGS=$@ UDS="/tmp/e2e-csi-sanity.sock" +UDS_NODE="/tmp/e2e-csi-sanity-node.sock" +UDS_CONTROLLER="/tmp/e2e-csi-sanity-ctrl.sock" CSI_ENDPOINTS="$CSI_ENDPOINTS ${UDS}" CSI_MOCK_VERSION="master" @@ -22,6 +24,19 @@ runTest() fi } +runTestWithDifferentAddresses() +{ + CSI_ENDPOINT=$1 CSI_CONTROLLER_ENDPOINT=$2 ./bin/mock-driver & + local pid=$! + + ./cmd/csi-sanity/csi-sanity $TESTARGS --csi.endpoint=$1 --csi.controllerendpoint=$2; ret=$? + kill -9 $pid + + if [ $ret -ne 0 ] ; then + exit $ret + fi +} + runTestWithCreds() { CSI_ENDPOINT=$1 CSI_ENABLE_CREDS=true ./bin/mock-driver & @@ -69,4 +84,8 @@ rm -f $UDS runTestAPI "${UDS}" rm -f $UDS +runTestWithDifferentAddresses "${UDS_NODE}" "${UDS_CONTROLLER}" +rm -f $UDS_NODE +rm -f $UDS_CONTROLLER + exit 0 diff --git a/pkg/sanity/controller.go b/pkg/sanity/controller.go index 99b738bd..6cbabc67 100644 --- a/pkg/sanity/controller.go +++ b/pkg/sanity/controller.go @@ -89,7 +89,7 @@ var _ = DescribeSanity("Controller Service", func(sc *SanityContext) { ) BeforeEach(func() { - c = csi.NewControllerClient(sc.Conn) + c = csi.NewControllerClient(sc.ControllerConn) n = csi.NewNodeClient(sc.Conn) cl = &Cleanup{ @@ -1259,7 +1259,7 @@ var _ = DescribeSanity("ListSnapshots [Controller Server]", func(sc *SanityConte ) BeforeEach(func() { - c = csi.NewControllerClient(sc.Conn) + c = csi.NewControllerClient(sc.ControllerConn) if !isControllerCapabilitySupported(c, csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS) { Skip("ListSnapshots not supported") @@ -1512,7 +1512,7 @@ var _ = DescribeSanity("DeleteSnapshot [Controller Server]", func(sc *SanityCont ) BeforeEach(func() { - c = csi.NewControllerClient(sc.Conn) + c = csi.NewControllerClient(sc.ControllerConn) if !isControllerCapabilitySupported(c, csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT) { Skip("DeleteSnapshot not supported") @@ -1575,7 +1575,7 @@ var _ = DescribeSanity("CreateSnapshot [Controller Server]", func(sc *SanityCont ) BeforeEach(func() { - c = csi.NewControllerClient(sc.Conn) + c = csi.NewControllerClient(sc.ControllerConn) if !isControllerCapabilitySupported(c, csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT) { Skip("CreateSnapshot not supported") diff --git a/pkg/sanity/node.go b/pkg/sanity/node.go index 7cb570e3..bd706c51 100644 --- a/pkg/sanity/node.go +++ b/pkg/sanity/node.go @@ -80,7 +80,7 @@ var _ = DescribeSanity("Node Service", func(sc *SanityContext) { BeforeEach(func() { c = csi.NewNodeClient(sc.Conn) - s = csi.NewControllerClient(sc.Conn) + s = csi.NewControllerClient(sc.ControllerConn) controllerPublishSupported = isControllerCapabilitySupported( s, diff --git a/pkg/sanity/sanity.go b/pkg/sanity/sanity.go index 0a621011..f6ff707d 100644 --- a/pkg/sanity/sanity.go +++ b/pkg/sanity/sanity.go @@ -48,10 +48,11 @@ type CSISecrets struct { // Config provides the configuration for the sanity tests. It // needs to be initialized by the user of the sanity package. type Config struct { - TargetPath string - StagingPath string - Address string - SecretsFile string + TargetPath string + StagingPath string + Address string + ControllerAddress string + SecretsFile string TestVolumeSize int64 TestVolumeParametersFile string @@ -64,11 +65,13 @@ type Config struct { // SanityContext holds the variables that each test can depend on. It // gets initialized before each test block runs. type SanityContext struct { - Config *Config - Conn *grpc.ClientConn - Secrets *CSISecrets + Config *Config + Conn *grpc.ClientConn + ControllerConn *grpc.ClientConn + Secrets *CSISecrets - connAddress string + connAddress string + controllerConnAddress string } // Test will test the CSI driver at the specified address by @@ -135,6 +138,20 @@ func (sc *SanityContext) setup() { By(fmt.Sprintf("reusing connection to CSI driver at %s", sc.connAddress)) } + if sc.ControllerConn == nil || sc.controllerConnAddress != sc.Config.ControllerAddress { + // If controller address is empty, use the common connection. + if sc.Config.ControllerAddress == "" { + sc.ControllerConn = sc.Conn + sc.controllerConnAddress = sc.Config.Address + } else { + sc.ControllerConn, err = utils.Connect(sc.Config.ControllerAddress) + Expect(err).NotTo(HaveOccurred()) + sc.controllerConnAddress = sc.Config.ControllerAddress + } + } else { + By(fmt.Sprintf("reusing connection to CSI driver controller at %s", sc.controllerConnAddress)) + } + By("creating mount and staging directories") err = createMountTargetLocation(sc.Config.TargetPath) Expect(err).NotTo(HaveOccurred())