diff --git a/docs/fast-snapshot-restores.md b/docs/fast-snapshot-restores.md new file mode 100644 index 0000000000..fe45a04f61 --- /dev/null +++ b/docs/fast-snapshot-restores.md @@ -0,0 +1,39 @@ +# Fast Snapshot Restores + +The EBS CSI Driver provides support for [Fast Snapshot Restores(FSR)](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ebs-fast-snapshot-restore.html) via `VolumeSnapshotClass.parameters.fastSnapshotRestoreAvailabilityZones`. + +Amazon EBS fast snapshot restore (FSR) enables you to create a volume from a snapshot that is fully initialized at creation. This eliminates the latency of I/O operations on a block when it is accessed for the first time. Volumes that are created using fast snapshot restore instantly deliver all of their provisioned performance. + +Availability zones are specified as a comma separated list. + +**Example** +``` +apiVersion: snapshot.storage.k8s.io/v1 +kind: VolumeSnapshotClass +metadata: + name: csi-aws-vsc +driver: ebs.csi.aws.com +deletionPolicy: Delete +parameters: + fastSnapshotRestoreAvailabilityZones: "us-east-1a, us-east-1b" +``` + +## Prerequisites + +- Install the [Kubernetes Volume Snapshot CRDs](https://github.com/kubernetes-csi/external-snapshotter/tree/master/client/config/crd) and external-snapshotter sidecar. For installation instructions, see [CSI Snapshotter Usage](https://github.com/kubernetes-csi/external-snapshotter#usage). + +- The EBS CSI Driver must be given permission to access the [`EnableFastSnapshotRestores` EC2 API](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_EnableFastSnapshotRestores.html). This example snippet can be used in an IAM policy to grant access to `EnableFastSnapshotRestores`: + +```json +{ + "Effect": "Allow", + "Action": [ + "ec2:EnableFastSnapshotRestores" + ], + "Resource": "*" +} +``` + +## Failure Mode + +The driver will attempt to check if the availability zones provided are supported for fast snapshot restore before attempting to create the snapshot. If the `EnableFastSnapshotRestores` API call fails, the driver will hard-fail the request and delete the snapshot. This is to ensure that the snapshot is not left in an inconsistent state. diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index a54a0671cb..9e61c93b5f 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -874,6 +874,24 @@ func (c *cloud) ec2SnapshotResponseToStruct(ec2Snapshot *ec2.Snapshot) *Snapshot return snapshot } +func (c *cloud) EnableFastSnapshotRestores(ctx context.Context, availabilityZones []string, snapshotID string) (*ec2.EnableFastSnapshotRestoresOutput, error) { + request := &ec2.EnableFastSnapshotRestoresInput{ + AvailabilityZones: aws.StringSlice(availabilityZones), + SourceSnapshotIds: []*string{ + aws.String(snapshotID), + }, + } + klog.V(4).InfoS("Creating Fast Snapshot Restores", "snapshotID", snapshotID, "availabilityZones", availabilityZones) + response, err := c.ec2.EnableFastSnapshotRestoresWithContext(ctx, request) + if err != nil { + return nil, err + } + if len(response.Unsuccessful) > 0 { + return response, fmt.Errorf("failed to create fast snapshot restores for snapshot %s: %v", snapshotID, response.Unsuccessful) + } + return response, nil +} + func (c *cloud) getVolume(ctx context.Context, request *ec2.DescribeVolumesInput) (*ec2.Volume, error) { var volumes []*ec2.Volume var nextToken *string @@ -1236,6 +1254,19 @@ func (c *cloud) randomAvailabilityZone(ctx context.Context) (string, error) { return zones[0], nil } +// AvailabilityZones returns availability zones from the given region +func (c *cloud) AvailabilityZones(ctx context.Context) (map[string]struct{}, error) { + response, err := c.ec2.DescribeAvailabilityZonesWithContext(ctx, &ec2.DescribeAvailabilityZonesInput{}) + if err != nil { + return nil, fmt.Errorf("error describing availability zones: %w", err) + } + zones := make(map[string]struct{}) + for _, zone := range response.AvailabilityZones { + zones[*zone.ZoneName] = struct{}{} + } + return zones, nil +} + func volumeModificationDone(state string) bool { if state == ec2.VolumeModificationStateCompleted || state == ec2.VolumeModificationStateOptimizing { return true diff --git a/pkg/cloud/cloud_interface.go b/pkg/cloud/cloud_interface.go index aedf6ae654..c36d4b6ae1 100644 --- a/pkg/cloud/cloud_interface.go +++ b/pkg/cloud/cloud_interface.go @@ -21,4 +21,6 @@ type Cloud interface { GetSnapshotByName(ctx context.Context, name string) (snapshot *Snapshot, err error) GetSnapshotByID(ctx context.Context, snapshotID string) (snapshot *Snapshot, err error) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *ListSnapshotsResponse, err error) + EnableFastSnapshotRestores(ctx context.Context, availabilityZones []string, snapshotID string) (*ec2.EnableFastSnapshotRestoresOutput, error) + AvailabilityZones(ctx context.Context) (map[string]struct{}, error) } diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 975ca36c72..2c2c9a8a03 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -1061,6 +1061,142 @@ func TestCreateSnapshot(t *testing.T) { } } +func TestEnableFastSnapshotRestores(t *testing.T) { + testCases := []struct { + name string + snapshotID string + availabilityZones []string + expOutput *ec2.EnableFastSnapshotRestoresOutput + expErr error + }{ + { + name: "success: normal", + snapshotID: "snap-test-id", + availabilityZones: []string{"us-west-2a", "us-west-2b"}, + expOutput: &ec2.EnableFastSnapshotRestoresOutput{ + Successful: []*ec2.EnableFastSnapshotRestoreSuccessItem{{ + AvailabilityZone: aws.String("us-west-2a,us-west-2b"), + SnapshotId: aws.String("snap-test-id")}}, + Unsuccessful: []*ec2.EnableFastSnapshotRestoreErrorItem{}, + }, + expErr: nil, + }, + { + name: "fail: unsuccessful response", + snapshotID: "snap-test-id", + availabilityZones: []string{"us-west-2a", "invalid-zone"}, + expOutput: &ec2.EnableFastSnapshotRestoresOutput{ + Unsuccessful: []*ec2.EnableFastSnapshotRestoreErrorItem{{ + SnapshotId: aws.String("snap-test-id"), + FastSnapshotRestoreStateErrors: []*ec2.EnableFastSnapshotRestoreStateErrorItem{ + {AvailabilityZone: aws.String("us-west-2a,invalid-zone"), + Error: &ec2.EnableFastSnapshotRestoreStateError{ + Message: aws.String("failed to create fast snapshot restore")}}, + }, + }}, + }, + expErr: fmt.Errorf("failed to create fast snapshot restores for snapshot"), + }, + { + name: "fail: error", + snapshotID: "", + availabilityZones: nil, + expOutput: nil, + expErr: fmt.Errorf("EnableFastSnapshotRestores error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEC2 := NewMockEC2(mockCtrl) + c := newCloud(mockEC2) + + ctx := context.Background() + mockEC2.EXPECT().EnableFastSnapshotRestoresWithContext(gomock.Eq(ctx), gomock.Any()).Return(tc.expOutput, tc.expErr).AnyTimes() + + response, err := c.EnableFastSnapshotRestores(ctx, tc.availabilityZones, tc.snapshotID) + + if err != nil { + if tc.expErr == nil { + t.Fatalf("EnableFastSnapshotRestores() failed: expected no error, got: %v", err) + } + if err.Error() != tc.expErr.Error() { + t.Fatalf("EnableFastSnapshotRestores() failed: expected error %v, got %v", tc.expErr, err) + } + } else { + if tc.expErr != nil { + t.Fatalf("EnableFastSnapshotRestores() failed: expected error %v, got nothing", tc.expErr) + } + if len(response.Successful) == 0 || len(response.Unsuccessful) > 0 { + t.Fatalf("EnableFastSnapshotRestores() failed: expected successful response, got %v", response) + } + if *response.Successful[0].SnapshotId != tc.snapshotID { + t.Fatalf("EnableFastSnapshotRestores() failed: expected successful response to have SnapshotId %s, got %s", tc.snapshotID, *response.Successful[0].SnapshotId) + } + az := strings.Split(*response.Successful[0].AvailabilityZone, ",") + if !reflect.DeepEqual(az, tc.availabilityZones) { + t.Fatalf("EnableFastSnapshotRestores() failed: expected successful response to have AvailabilityZone %v, got %v", az, tc.availabilityZones) + } + } + + mockCtrl.Finish() + }) + } +} + +func TestAvailabilityZones(t *testing.T) { + testCases := []struct { + name string + availabilityZone string + expOutput *ec2.DescribeAvailabilityZonesOutput + expErr error + }{ + { + name: "success: normal", + availabilityZone: expZone, + expOutput: &ec2.DescribeAvailabilityZonesOutput{ + AvailabilityZones: []*ec2.AvailabilityZone{ + {ZoneName: aws.String(expZone)}, + }}, + expErr: nil, + }, + { + name: "fail: error", + availabilityZone: "", + expOutput: nil, + expErr: fmt.Errorf("TestAvailabilityZones error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEC2 := NewMockEC2(mockCtrl) + c := newCloud(mockEC2) + + ctx := context.Background() + mockEC2.EXPECT().DescribeAvailabilityZonesWithContext(gomock.Eq(ctx), gomock.Any()).Return(tc.expOutput, tc.expErr).AnyTimes() + + az, err := c.AvailabilityZones(ctx) + if err != nil { + if tc.expErr == nil { + t.Fatalf("AvailabilityZones() failed: expected no error, got: %v", err) + } + } else { + if tc.expErr != nil { + t.Fatalf("AvailabilityZones() failed: expected error, got nothing") + } + if val, ok := az[tc.availabilityZone]; !ok { + t.Fatalf("AvailabilityZones() failed: expected to find %s, got %v", tc.availabilityZone, val) + } + } + + mockCtrl.Finish() + }) + } +} + func TestDeleteSnapshot(t *testing.T) { testCases := []struct { name string diff --git a/pkg/cloud/ec2_interface.go b/pkg/cloud/ec2_interface.go index 65a2fb3bbd..c3f808cba9 100644 --- a/pkg/cloud/ec2_interface.go +++ b/pkg/cloud/ec2_interface.go @@ -38,4 +38,5 @@ type EC2 interface { DescribeVolumesModificationsWithContext(ctx aws.Context, input *ec2.DescribeVolumesModificationsInput, opts ...request.Option) (*ec2.DescribeVolumesModificationsOutput, error) DescribeAvailabilityZonesWithContext(ctx aws.Context, input *ec2.DescribeAvailabilityZonesInput, opts ...request.Option) (*ec2.DescribeAvailabilityZonesOutput, error) CreateTagsWithContext(ctx aws.Context, input *ec2.CreateTagsInput, opts ...request.Option) (*ec2.CreateTagsOutput, error) + EnableFastSnapshotRestoresWithContext(ctx aws.Context, input *ec2.EnableFastSnapshotRestoresInput, opts ...request.Option) (*ec2.EnableFastSnapshotRestoresOutput, error) } diff --git a/pkg/cloud/mock_cloud.go b/pkg/cloud/mock_cloud.go index e0d6ac261d..edea607d0d 100644 --- a/pkg/cloud/mock_cloud.go +++ b/pkg/cloud/mock_cloud.go @@ -50,6 +50,21 @@ func (mr *MockCloudMockRecorder) AttachDisk(ctx, volumeID, nodeID interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachDisk", reflect.TypeOf((*MockCloud)(nil).AttachDisk), ctx, volumeID, nodeID) } +// AvailabilityZones mocks base method. +func (m *MockCloud) AvailabilityZones(ctx context.Context) (map[string]struct{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AvailabilityZones", ctx) + ret0, _ := ret[0].(map[string]struct{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AvailabilityZones indicates an expected call of AvailabilityZones. +func (mr *MockCloudMockRecorder) AvailabilityZones(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AvailabilityZones", reflect.TypeOf((*MockCloud)(nil).AvailabilityZones), ctx) +} + // CreateDisk mocks base method. func (m *MockCloud) CreateDisk(ctx context.Context, volumeName string, diskOptions *DiskOptions) (*Disk, error) { m.ctrl.T.Helper() @@ -124,6 +139,21 @@ func (mr *MockCloudMockRecorder) DetachDisk(ctx, volumeID, nodeID interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachDisk", reflect.TypeOf((*MockCloud)(nil).DetachDisk), ctx, volumeID, nodeID) } +// EnableFastSnapshotRestores mocks base method. +func (m *MockCloud) EnableFastSnapshotRestores(ctx context.Context, availabilityZones []string, snapshotID string) (*ec2.EnableFastSnapshotRestoresOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnableFastSnapshotRestores", ctx, availabilityZones, snapshotID) + ret0, _ := ret[0].(*ec2.EnableFastSnapshotRestoresOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EnableFastSnapshotRestores indicates an expected call of EnableFastSnapshotRestores. +func (mr *MockCloudMockRecorder) EnableFastSnapshotRestores(ctx, availabilityZones, snapshotID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableFastSnapshotRestores", reflect.TypeOf((*MockCloud)(nil).EnableFastSnapshotRestores), ctx, availabilityZones, snapshotID) +} + // GetDiskByID mocks base method. func (m *MockCloud) GetDiskByID(ctx context.Context, volumeID string) (*Disk, error) { m.ctrl.T.Helper() diff --git a/pkg/cloud/mock_ec2.go b/pkg/cloud/mock_ec2.go index 514437fb05..56752b0e1e 100644 --- a/pkg/cloud/mock_ec2.go +++ b/pkg/cloud/mock_ec2.go @@ -276,6 +276,26 @@ func (mr *MockEC2MockRecorder) DetachVolumeWithContext(ctx, input interface{}, o return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachVolumeWithContext", reflect.TypeOf((*MockEC2)(nil).DetachVolumeWithContext), varargs...) } +// EnableFastSnapshotRestoresWithContext mocks base method. +func (m *MockEC2) EnableFastSnapshotRestoresWithContext(ctx aws.Context, input *ec2.EnableFastSnapshotRestoresInput, opts ...request.Option) (*ec2.EnableFastSnapshotRestoresOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, input} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EnableFastSnapshotRestoresWithContext", varargs...) + ret0, _ := ret[0].(*ec2.EnableFastSnapshotRestoresOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EnableFastSnapshotRestoresWithContext indicates an expected call of EnableFastSnapshotRestoresWithContext. +func (mr *MockEC2MockRecorder) EnableFastSnapshotRestoresWithContext(ctx, input interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, input}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableFastSnapshotRestoresWithContext", reflect.TypeOf((*MockEC2)(nil).EnableFastSnapshotRestoresWithContext), varargs...) +} + // ModifyVolumeWithContext mocks base method. func (m *MockEC2) ModifyVolumeWithContext(ctx aws.Context, input *ec2.ModifyVolumeInput, opts ...request.Option) (*ec2.ModifyVolumeOutput, error) { m.ctrl.T.Helper() diff --git a/pkg/driver/constants.go b/pkg/driver/constants.go index 5b262259ed..ccbc07d630 100644 --- a/pkg/driver/constants.go +++ b/pkg/driver/constants.go @@ -80,6 +80,12 @@ const ( TagKeyPrefix = "tagSpecification" ) +// constants of keys in snapshot parameters +const ( + // FastSnapShotRestoreAvailabilityZones represents key for fast snapshot restore availability zones + FastSnapshotRestoreAvailabilityZones = "fastsnapshotrestoreavailabilityzones" +) + // constants for volume tags and their values const ( // ResourceLifecycleTagPrefix is prefix of tag for provisioned EBS volume that diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index aa6093eef9..ae5cdf80dd 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -605,11 +605,18 @@ func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS } var vscTags []string + var fsrAvailabilityZones []string for key, value := range req.GetParameters() { - if strings.HasPrefix(key, TagKeyPrefix) { - vscTags = append(vscTags, value) - } else { - return nil, status.Errorf(codes.InvalidArgument, "Invalid parameter key %s for CreateSnapshot", key) + switch strings.ToLower(key) { + case FastSnapshotRestoreAvailabilityZones: + f := strings.ReplaceAll(value, " ", "") + fsrAvailabilityZones = strings.Split(f, ",") + default: + if strings.HasPrefix(key, TagKeyPrefix) { + vscTags = append(vscTags, value) + } else { + return nil, status.Errorf(codes.InvalidArgument, "Invalid parameter key %s for CreateSnapshot", key) + } } } @@ -639,11 +646,35 @@ func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS Tags: snapshotTags, } - snapshot, err = d.cloud.CreateSnapshot(ctx, volumeID, opts) + // Check if the availability zone is supported for fast snapshot restore + if len(fsrAvailabilityZones) > 0 { + zones, error := d.cloud.AvailabilityZones(ctx) + if error != nil { + klog.ErrorS(error, "failed to get availability zones") + } else { + klog.V(4).InfoS("Availability Zones", "zone", zones) + for _, az := range fsrAvailabilityZones { + if _, ok := zones[az]; !ok { + return nil, status.Errorf(codes.InvalidArgument, "Availability zone %s is not supported for fast snapshot restore", az) + } + } + } + } + snapshot, err = d.cloud.CreateSnapshot(ctx, volumeID, opts) if err != nil { return nil, status.Errorf(codes.Internal, "Could not create snapshot %q: %v", snapshotName, err) } + + if len(fsrAvailabilityZones) > 0 { + _, err := d.cloud.EnableFastSnapshotRestores(ctx, fsrAvailabilityZones, snapshot.SnapshotID) + if err != nil { + if _, err = d.cloud.DeleteSnapshot(ctx, snapshot.SnapshotID); err != nil { + return nil, status.Errorf(codes.Internal, "Could not delete snapshot ID %q: %v", snapshotName, err) + } + return nil, status.Errorf(codes.Internal, "Failed to create Fast Snapshot Restores for snapshot ID %q: %v", snapshotName, err) + } + } return newCreateSnapshotResponse(snapshot) } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index de957e9258..8176c88433 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -27,7 +27,9 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/service/ec2" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" @@ -2461,6 +2463,289 @@ func TestCreateSnapshot(t *testing.T) { } }, }, + { + name: "success with EnableFastSnapshotRestore - normal", + testFunc: func(t *testing.T) { + const ( + snapshotName = "test-snapshot" + ) + + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + "fastSnapshotRestoreAvailabilityZones": "us-east-1a, us-east-1f", + }, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + snapshotOptions := &cloud.SnapshotOptions{ + Tags: map[string]string{ + cloud.SnapshotNameTagKey: snapshotName, + cloud.AwsEbsDriverTagKey: isManagedByDriver, + }, + } + + expOutput := &ec2.EnableFastSnapshotRestoresOutput{ + Successful: []*ec2.EnableFastSnapshotRestoreSuccessItem{{ + AvailabilityZone: aws.String("us-east-1a,us-east-1f"), + SnapshotId: aws.String("snap-test-id")}}, + Unsuccessful: []*ec2.EnableFastSnapshotRestoreErrorItem{}, + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound).AnyTimes() + mockCloud.EXPECT().AvailabilityZones(gomock.Eq(ctx)).Return(map[string]struct{}{ + "us-east-1a": {}, "us-east-1f": {}}, nil).AnyTimes() + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil).AnyTimes() + mockCloud.EXPECT().EnableFastSnapshotRestores(gomock.Eq(ctx), gomock.Eq([]string{"us-east-1a", "us-east-1f"}), gomock.Eq(mockSnapshot.SnapshotID)).Return(expOutput, nil).AnyTimes() + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } + + resp, err := awsDriver.CreateSnapshot(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if snap := resp.GetSnapshot(); snap == nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + }, + }, + { + name: "success with EnableFastSnapshotRestore - failed to get availability zones", + testFunc: func(t *testing.T) { + const ( + snapshotName = "test-snapshot" + ) + + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + "fastSnapshotRestoreAvailabilityZones": "us-east-1a, us-east-1f", + }, + SourceVolumeId: "vol-test", + } + expSnapshot := &csi.Snapshot{ + ReadyToUse: true, + } + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + snapshotOptions := &cloud.SnapshotOptions{ + Tags: map[string]string{ + cloud.SnapshotNameTagKey: snapshotName, + cloud.AwsEbsDriverTagKey: isManagedByDriver, + }, + } + + expOutput := &ec2.EnableFastSnapshotRestoresOutput{ + Successful: []*ec2.EnableFastSnapshotRestoreSuccessItem{{ + AvailabilityZone: aws.String("us-east-1a,us-east-1f"), + SnapshotId: aws.String("snap-test-id")}}, + Unsuccessful: []*ec2.EnableFastSnapshotRestoreErrorItem{}, + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound).AnyTimes() + mockCloud.EXPECT().AvailabilityZones(gomock.Eq(ctx)).Return(nil, fmt.Errorf("error describing availability zones")).AnyTimes() + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil).AnyTimes() + mockCloud.EXPECT().EnableFastSnapshotRestores(gomock.Eq(ctx), gomock.Eq([]string{"us-east-1a", "us-east-1f"}), gomock.Eq(mockSnapshot.SnapshotID)).Return(expOutput, nil).AnyTimes() + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } + + resp, err := awsDriver.CreateSnapshot(context.Background(), req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if snap := resp.GetSnapshot(); snap == nil { + t.Fatalf("Expected snapshot %v, got nil", expSnapshot) + } + }, + }, + { + name: "fail with EnableFastSnapshotRestore - call to enable FSR failed", + testFunc: func(t *testing.T) { + const ( + snapshotName = "test-snapshot" + ) + + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + "fastSnapshotRestoreAvailabilityZones": "us-west-1a, us-east-1f", + }, + SourceVolumeId: "vol-test", + } + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + snapshotOptions := &cloud.SnapshotOptions{ + Tags: map[string]string{ + cloud.SnapshotNameTagKey: snapshotName, + cloud.AwsEbsDriverTagKey: isManagedByDriver, + }, + } + expOutput := &ec2.EnableFastSnapshotRestoresOutput{ + Successful: []*ec2.EnableFastSnapshotRestoreSuccessItem{}, + Unsuccessful: []*ec2.EnableFastSnapshotRestoreErrorItem{{ + SnapshotId: aws.String("snap-test-id"), + FastSnapshotRestoreStateErrors: []*ec2.EnableFastSnapshotRestoreStateErrorItem{ + { + AvailabilityZone: aws.String("us-west-1a,us-east-1f"), + Error: &ec2.EnableFastSnapshotRestoreStateError{ + Message: aws.String("failed to create fast snapshot restore"), + }}, + }, + }}, + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound).AnyTimes() + mockCloud.EXPECT().AvailabilityZones(gomock.Eq(ctx)).Return(nil, fmt.Errorf("error describing availability zones")).AnyTimes() + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil).AnyTimes() + mockCloud.EXPECT().EnableFastSnapshotRestores(gomock.Eq(ctx), gomock.Eq([]string{"us-west-1a", "us-east-1f"}), gomock.Eq(mockSnapshot.SnapshotID)). + Return(expOutput, fmt.Errorf("Failed to create Fast Snapshot Restores")).AnyTimes() + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq(mockSnapshot.SnapshotID)).Return(true, nil).AnyTimes() + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } + + _, err := awsDriver.CreateSnapshot(context.Background(), req) + if err == nil { + t.Fatalf("Expected error, got nil") + } + }, + }, + { + name: "fail with EnableFastSnapshotRestore - invalid availability zones", + testFunc: func(t *testing.T) { + const ( + snapshotName = "test-snapshot" + ) + + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + "fastSnapshotRestoreAvailabilityZones": "invalid-az, us-east-1b", + }, + SourceVolumeId: "vol-test", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound).AnyTimes() + mockCloud.EXPECT().AvailabilityZones(gomock.Eq(ctx)).Return(map[string]struct{}{ + "us-east-1a": {}, "us-east-1b": {}}, nil).AnyTimes() + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } + + _, err := awsDriver.CreateSnapshot(context.Background(), req) + if err == nil { + t.Fatalf("Expected error, got nil") + } + }, + }, + { + name: "fail with EnableFastSnapshotRestore", + testFunc: func(t *testing.T) { + const ( + snapshotName = "test-snapshot" + ) + + req := &csi.CreateSnapshotRequest{ + Name: snapshotName, + Parameters: map[string]string{ + "fastSnapshotRestoreAvailabilityZones": "us-east-1a, us-east-1f", + }, + SourceVolumeId: "vol-test", + } + + ctx := context.Background() + mockSnapshot := &cloud.Snapshot{ + SnapshotID: fmt.Sprintf("snapshot-%d", rand.New(rand.NewSource(time.Now().UnixNano())).Uint64()), + SourceVolumeID: req.SourceVolumeId, + Size: 1, + CreationTime: time.Now(), + } + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + snapshotOptions := &cloud.SnapshotOptions{ + Tags: map[string]string{ + cloud.SnapshotNameTagKey: snapshotName, + cloud.AwsEbsDriverTagKey: isManagedByDriver, + }, + } + + mockCloud := cloud.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetSnapshotByName(gomock.Eq(ctx), gomock.Eq(req.GetName())).Return(nil, cloud.ErrNotFound).AnyTimes() + mockCloud.EXPECT().AvailabilityZones(gomock.Eq(ctx)).Return(map[string]struct{}{ + "us-east-1a": {}, "us-east-1f": {}}, nil).AnyTimes() + mockCloud.EXPECT().CreateSnapshot(gomock.Eq(ctx), gomock.Eq(req.SourceVolumeId), gomock.Eq(snapshotOptions)).Return(mockSnapshot, nil).AnyTimes() + mockCloud.EXPECT().EnableFastSnapshotRestores(gomock.Eq(ctx), gomock.Eq([]string{"us-east-1a", "us-east-1f"}), + gomock.Eq(mockSnapshot.SnapshotID)).Return(nil, fmt.Errorf("error")).AnyTimes() + mockCloud.EXPECT().DeleteSnapshot(gomock.Eq(ctx), gomock.Eq(mockSnapshot.SnapshotID)).Return(true, nil).AnyTimes() + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: internal.NewInFlight(), + driverOptions: &DriverOptions{}, + } + + _, err := awsDriver.CreateSnapshot(context.Background(), req) + if err == nil { + t.Fatalf("Expected error, got nil") + } + }, + }, } for _, tc := range testCases { diff --git a/pkg/driver/sanity_test.go b/pkg/driver/sanity_test.go index 22e75e199c..165fed7325 100644 --- a/pkg/driver/sanity_test.go +++ b/pkg/driver/sanity_test.go @@ -259,6 +259,10 @@ func (c *fakeCloudProvider) GetSnapshotByID(ctx context.Context, snapshotID stri return ret.Snapshot, nil } +func (c *fakeCloudProvider) AvailabilityZones(ctx context.Context) (map[string]struct{}, error) { + return nil, nil +} + func (c *fakeCloudProvider) ListSnapshots(ctx context.Context, volumeID string, maxResults int64, nextToken string) (listSnapshotsResponse *cloud.ListSnapshotsResponse, err error) { var snapshots []*cloud.Snapshot var retToken string @@ -284,6 +288,10 @@ func (c *fakeCloudProvider) ListSnapshots(ctx context.Context, volumeID string, } +func (c *fakeCloudProvider) EnableFastSnapshotRestores(ctx context.Context, availabilityZones []string, snapshotID string) (*ec2.EnableFastSnapshotRestoresOutput, error) { + return nil, nil +} + func (c *fakeCloudProvider) ResizeDisk(ctx context.Context, volumeID string, newSize int64) (int64, error) { for volName, f := range c.disks { if f.Disk.VolumeID == volumeID {