Skip to content

Commit

Permalink
add exec timeout func
Browse files Browse the repository at this point in the history
  • Loading branch information
umagnus committed Mar 15, 2024
1 parent ac7bf03 commit 23c3fab
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 32 deletions.
56 changes: 25 additions & 31 deletions pkg/blob/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ const (
MSI = "MSI"
SPN = "SPN"
authorizationPermissionMismatch = "AuthorizationPermissionMismatch"

waitForAzCopyInterval = 2 * time.Second
)

// CreateVolume provisions a volume
Expand All @@ -85,7 +83,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
// logging the job status if it's volume cloning
if req.GetVolumeContentSource() != nil {
jobState, percent, err := d.azcopy.GetAzcopyJob(volName, []string{})
klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err)
return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsWithAzcopyFmt, volName, jobState, percent, err)
}
return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volName)
}
Expand Down Expand Up @@ -759,43 +757,39 @@ func (d *Driver) copyBlobContainer(req *csi.CreateVolumeRequest, accountSasToken
return fmt.Errorf("srcContainerName(%s) or dstContainerName(%s) is empty", srcContainerName, dstContainerName)
}

timeAfter := time.After(time.Duration(d.waitForAzCopyTimeoutMinutes) * time.Minute)
timeTick := time.Tick(waitForAzCopyInterval)
srcPath := fmt.Sprintf("https://%s.blob.%s/%s%s", accountName, storageEndpointSuffix, srcContainerName, accountSasToken)
dstPath := fmt.Sprintf("https://%s.blob.%s/%s%s", accountName, storageEndpointSuffix, dstContainerName, accountSasToken)

jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName, authAzcopyEnv)
klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err)
if jobState == util.AzcopyJobError || jobState == util.AzcopyJobCompleted {
switch jobState {
case util.AzcopyJobError, util.AzcopyJobCompleted:
return err
}
klog.V(2).Infof("begin to copy blob container %s to %s", srcContainerName, dstContainerName)
for {
select {
case <-timeTick:
jobState, percent, err := d.azcopy.GetAzcopyJob(dstContainerName, authAzcopyEnv)
klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err)
switch jobState {
case util.AzcopyJobError, util.AzcopyJobCompleted:
return err
case util.AzcopyJobNotFound:
klog.V(2).Infof("copy blob container %s to %s", srcContainerName, dstContainerName)
cmd := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false")
if len(authAzcopyEnv) > 0 {
cmd.Env = append(os.Environ(), authAzcopyEnv...)
}
out, copyErr := cmd.CombinedOutput()
if copyErr != nil {
klog.Warningf("CopyBlobContainer(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstPath, copyErr, string(out))
} else {
klog.V(2).Infof("copied blob container %s to %s successfully", srcContainerName, dstContainerName)
}
return copyErr
case util.AzcopyJobRunning:
return fmt.Errorf("an existed azcopy job is running, copy percent: %s%%, please wait for it to complete", percent)
case util.AzcopyJobNotFound:
klog.V(2).Infof("copy blob container %s to %s", srcContainerName, dstContainerName)
copyErr := util.WaitForExecCompletion(time.Duration(d.waitForAzCopyTimeoutMinutes)*time.Minute, func() error {
cmd := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false")
if len(authAzcopyEnv) > 0 {
cmd.Env = append(os.Environ(), authAzcopyEnv...)
}
case <-timeAfter:
return fmt.Errorf("timeout waiting for copy blob container %s to %s succeed", srcContainerName, dstContainerName)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("exec error: %v, output: %v", err, string(out))
}
return nil
}, func() error {
_, percent, _ := d.azcopy.GetAzcopyJob(dstContainerName, authAzcopyEnv)
return fmt.Errorf("timeout waiting for copy blob container %s to %s succeed, copy percent: %s%%", srcContainerName, dstContainerName, percent)
})
if copyErr != nil {
klog.Warningf("CopyBlobContainer(%s, %s, %s) failed with error: %v", resourceGroupName, accountName, dstPath, copyErr)
} else {
klog.V(2).Infof("copied blob container %s to %s successfully", srcContainerName, dstContainerName)
}
return copyErr
}
return err
}

// copyVolume copies a volume form volume or snapshot, snapshot is not supported now
Expand Down
3 changes: 2 additions & 1 deletion pkg/blob/volume_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import (
)

const (
volumeOperationAlreadyExistsFmt = "An operation with the given Volume ID %s already exists"
volumeOperationAlreadyExistsFmt = "An operation with the given Volume ID %s already exists"
volumeOperationAlreadyExistsWithAzcopyFmt = "An operation using azcopy with the given Volume ID %s already exists. Azcopy job status: %s, copy percent: %s%%, error: %v"
)

// VolumeLocks implements a map with atomic operations. It stores a set of all volume IDs
Expand Down
28 changes: 28 additions & 0 deletions pkg/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strconv"
"strings"
"sync"
"time"

"github.com/go-ini/ini"
"github.com/pkg/errors"
Expand Down Expand Up @@ -387,3 +388,30 @@ func SetVolumeOwnership(path, gid, policy string) error {
}
return volume.SetVolumeOwnership(&VolumeMounter{path: path}, path, &gidInt64, &fsGroupChangePolicy, nil)
}

// ExecFunc returns a exec function's output and error
type ExecFunc func() (err error)

// TimeoutFunc returns output and error if an ExecFunc timeout
type TimeoutFunc func() (err error)

// WaitForExecCompletion waits for the exec function to complete or times out
func WaitForExecCompletion(timeout time.Duration, execFunc ExecFunc, timeoutFunc TimeoutFunc) error {
// Create a channel to receive the result of the azcopy exec function
done := make(chan bool)
var err error

// Start the azcopy exec function in a goroutine
go func() {
err = execFunc()
done <- true
}()

// Wait for the function to complete or time out
select {
case <-done:
return err
case <-time.After(timeout):
return timeoutFunc()
}
}
52 changes: 52 additions & 0 deletions pkg/util/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,55 @@ func TestSetVolumeOwnership(t *testing.T) {
}
}
}

func TestWaitForExecCompletion(t *testing.T) {
tests := []struct {
desc string
timeout time.Duration
execFunc ExecFunc
timeoutFunc TimeoutFunc
expectedErr error
}{
{
desc: "execFunc returns error",
timeout: 1 * time.Second,
execFunc: func() error {
return fmt.Errorf("execFunc error")
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: fmt.Errorf("execFunc error"),
},
{
desc: "execFunc timeout",
timeout: 1 * time.Second,
execFunc: func() error {
time.Sleep(2 * time.Second)
return nil
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: fmt.Errorf("timeout error"),
},
{
desc: "execFunc completed successfully",
timeout: 1 * time.Second,
execFunc: func() error {
return nil
},
timeoutFunc: func() error {
return fmt.Errorf("timeout error")
},
expectedErr: nil,
},
}

for _, test := range tests {
err := WaitForExecCompletion(test.timeout, test.execFunc, test.timeoutFunc)
if err != nil && (err.Error() != test.expectedErr.Error()) {
t.Errorf("unexpected error: %v, expected error: %v", err, test.expectedErr)
}
}
}

0 comments on commit 23c3fab

Please sign in to comment.