Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jcieslak committed Mar 27, 2024
1 parent 214f86e commit 4495b27
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 39 deletions.
36 changes: 24 additions & 12 deletions pkg/resources/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package resources

import (
"context"
"errors"
"fmt"
"log"
"slices"
Expand Down Expand Up @@ -285,7 +284,7 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error {
}

// CreateTask implements schema.CreateFunc.
func CreateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) {
func CreateTask(d *schema.ResourceData, meta interface{}) error {
client := meta.(*provider.Context).Client
ctx := context.Background()

Expand Down Expand Up @@ -349,11 +348,16 @@ func CreateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) {
precedingTasks := make([]sdk.SchemaObjectIdentifier, 0)
for _, dep := range after {
precedingTaskId := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, dep)
resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, precedingTaskId, taskId)
defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }()
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, precedingTaskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}

precedingTasks = append(precedingTasks, precedingTaskId)
}
createRequest.WithAfter(precedingTasks)
Expand Down Expand Up @@ -397,14 +401,18 @@ func waitForTaskStart(ctx context.Context, client *sdk.Client, id sdk.SchemaObje
}

// UpdateTask implements schema.UpdateFunc.
func UpdateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) {
func UpdateTask(d *schema.ResourceData, meta interface{}) error {
client := meta.(*provider.Context).Client
ctx := context.Background()

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, taskId, taskId)
defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }()
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
Expand Down Expand Up @@ -495,8 +503,12 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) (returnedErr error) {
}
if len(toAdd) > 0 {
for _, depId := range toAdd {
resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, depId, taskId)
defer func() { returnedErr = errors.Join(returnedErr, resumeSuspended()) }()
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, depId, taskId)
defer func() {
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
return err
}
Expand Down Expand Up @@ -664,10 +676,10 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error {

taskId := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)

resumeSuspended, err := client.Tasks.TemporarilySuspendRootTasks(ctx, taskId, taskId)
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, taskId, taskId)
defer func() {
if err := resumeSuspended(); err != nil {
log.Printf("[WARN] failed to resume suspended task: %s", err)
if err := client.Tasks.ResumeTasks(ctx, tasksToResume); err != nil {
log.Printf("[WARN] failed to resume tasks: %s", err)
}
}()
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/sdk/tasks_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ type Tasks interface {
ShowByID(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Describe(ctx context.Context, id SchemaObjectIdentifier) (*Task, error)
Execute(ctx context.Context, request *ExecuteTaskRequest) error
TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) (func() error, error)
SuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error)
ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error
}

// CreateTaskOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-task.
Expand Down
37 changes: 17 additions & 20 deletions pkg/sdk/tasks_impl_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,11 @@ func (v *tasks) Execute(ctx context.Context, request *ExecuteTaskRequest) error
return validateAndExec(v.client, ctx, opts)
}

// TemporarilySuspendRootTasks takes in the depId for which root tasks will be searched. Then, for all root tasks,
// check if the task is started. If it is, then suspend it and add to the list of tasks to resume only if the root task name
// is not that same as taskId name.
//
// Returns:
// A callback function to resume all the suspended root tasks.
// An error joined from all the suspend calls, nil if no error was returned during by task suspending calls.
func (v *tasks) TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, taskId SchemaObjectIdentifier) (func() error, error) {
// TODO(SNOW-1277135): See if depId is necessary or could be removed
func (v *tasks) SuspendRootTasks(ctx context.Context, depId SchemaObjectIdentifier, id SchemaObjectIdentifier) ([]SchemaObjectIdentifier, error) {
rootTasks, err := GetRootTasks(v.client.Tasks, ctx, depId)
if err != nil {
return func() error { return nil }, err
return nil, err
}

tasksToResume := make([]SchemaObjectIdentifier, 0)
Expand All @@ -89,27 +83,30 @@ func (v *tasks) TemporarilySuspendRootTasks(ctx context.Context, depId SchemaObj
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(rootTask.ID()).WithSuspend(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to suspend task %s", rootTask.ID().FullyQualifiedName())
suspendErrs = append(suspendErrs, err)
}

// Resume the task after modifications are complete as long as it is not a standalone task
if rootTask.Name != taskId.Name() {
// TODO(SNOW-1277135): Document the purpose of this check and why we need different value for GetRootTasks (depId).
if rootTask.Name != id.Name() {
tasksToResume = append(tasksToResume, rootTask.ID())
}
suspendErrs = append(suspendErrs, err)
}
}

return func() error {
resumeErrs := make([]error, 0)
for _, taskId := range tasksToResume {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(taskId).WithResume(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", taskId.FullyQualifiedName())
}
return tasksToResume, errors.Join(suspendErrs...)
}

func (v *tasks) ResumeTasks(ctx context.Context, ids []SchemaObjectIdentifier) error {
resumeErrs := make([]error, 0)
for _, id := range ids {
err := v.client.Tasks.Alter(ctx, NewAlterTaskRequest(id).WithResume(Bool(true)))
if err != nil {
log.Printf("[WARN] failed to resume task %s", id.FullyQualifiedName())
resumeErrs = append(resumeErrs, err)
}
return errors.Join(resumeErrs...)
}, errors.Join(suspendErrs...)
}
return errors.Join(resumeErrs...)
}

// GetRootTasks is a way to get all root tasks for the given tasks.
Expand Down
117 changes: 111 additions & 6 deletions pkg/sdk/testint/tasks_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,17 +605,122 @@ func TestInt_Tasks(t *testing.T) {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

resumeRoots, err := client.Tasks.TemporarilySuspendRootTasks(ctx, task.ID(), task.ID())
tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rt, err := client.Tasks.ShowByID(ctx, rootTask.ID())
rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rt.State)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

require.NoError(t, resumeRoots())
require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rt, err = client.Tasks.ShowByID(ctx, rootTask.ID())
rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rt.State)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)
})

t.Run("resume root tasks within a graph containing more than one root task", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

secondRootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
secondRootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(secondRootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
_ = createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID(), secondRootTask.ID()}))

require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
require.ErrorContains(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(secondRootTask.ID()).WithResume(sdk.Bool(true))), "The graph has more than one root task (one without predecessors)")
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - last in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

id := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
task := createTaskWithRequest(t, sdk.NewCreateTaskRequest(id, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(middleTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, task.ID(), task.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)
require.Contains(t, tasksToResume, rootTask.ID())

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

middleTaskStatus, err := client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

middleTaskStatus, err = client.Tasks.ShowByID(ctx, middleTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, middleTaskStatus.State)
})

t.Run("suspend root tasks temporarily with three sequentially connected tasks - middle in DAG", func(t *testing.T) {
rootTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
rootTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(rootTaskId, sql).WithSchedule(sdk.String("60 minutes")))

middleTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
middleTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(middleTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{rootTask.ID()}))

childTaskId := sdk.NewSchemaObjectIdentifier(testDb(t).Name, testSchema(t).Name, random.String())
childTask := createTaskWithRequest(t, sdk.NewCreateTaskRequest(childTaskId, sql).WithAfter([]sdk.SchemaObjectIdentifier{middleTask.ID()}))

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(childTask.ID()).WithSuspend(sdk.Bool(true))))
})

require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithResume(sdk.Bool(true))))
t.Cleanup(func() {
require.NoError(t, client.Tasks.Alter(ctx, sdk.NewAlterTaskRequest(rootTask.ID()).WithSuspend(sdk.Bool(true))))
})

tasksToResume, err := client.Tasks.SuspendRootTasks(ctx, middleTask.ID(), middleTask.ID())
require.NoError(t, err)
require.NotEmpty(t, tasksToResume)

rootTaskStatus, err := client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateSuspended, rootTaskStatus.State)

childTaskStatus, err := client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)

require.NoError(t, client.Tasks.ResumeTasks(ctx, tasksToResume))

rootTaskStatus, err = client.Tasks.ShowByID(ctx, rootTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, rootTaskStatus.State)

childTaskStatus, err = client.Tasks.ShowByID(ctx, childTask.ID())
require.NoError(t, err)
require.Equal(t, sdk.TaskStateStarted, childTaskStatus.State)
})

// TODO(SNOW-1277135): Create more tests with different sets of roots/children and see if the current implementation
// acts correctly in certain situations/edge cases.
}

0 comments on commit 4495b27

Please sign in to comment.