-
Notifications
You must be signed in to change notification settings - Fork 656
/
echo.go
187 lines (157 loc) · 5.98 KB
/
echo.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
package testing
import (
"context"
"fmt"
"sync"
"time"
idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/storage"
)
const (
echoTaskType = "echo"
)
type EchoPlugin struct {
enqueueOwner core.EnqueueOwner
taskStartTimes map[string]time.Time
sync.Mutex
}
func (e *EchoPlugin) GetID() string {
return echoTaskType
}
func (e *EchoPlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}
// Enqueue the task to be re-evaluated after SleepDuration.
// If the task is already enqueued, return the start time of the task.
func (e *EchoPlugin) addTask(ctx context.Context, tCtx core.TaskExecutionContext) time.Time {
e.Lock()
defer e.Unlock()
var startTime time.Time
var exists bool
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
if startTime, exists = e.taskStartTimes[taskExecutionID]; !exists {
startTime = time.Now()
e.taskStartTimes[taskExecutionID] = startTime
// start timer to enqueue owner once task sleep duration has elapsed
go func() {
echoConfig := ConfigSection.GetConfig().(*Config)
time.Sleep(echoConfig.SleepDuration.Duration)
if err := e.enqueueOwner(tCtx.TaskExecutionMetadata().GetOwnerID()); err != nil {
logger.Warnf(ctx, "failed to enqueue owner [%s]: %v", tCtx.TaskExecutionMetadata().GetOwnerID(), err)
}
}()
}
return startTime
}
// Remove the task from the taskStartTimes map.
func (e *EchoPlugin) removeTask(taskExecutionID string) {
e.Lock()
defer e.Unlock()
delete(e.taskStartTimes, taskExecutionID)
}
func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
echoConfig := ConfigSection.GetConfig().(*Config)
if echoConfig.SleepDuration.Duration == time.Duration(0) {
return copyInputsToOutputs(ctx, tCtx)
}
startTime := e.addTask(ctx, tCtx)
if time.Since(startTime) >= echoConfig.SleepDuration.Duration {
return copyInputsToOutputs(ctx, tCtx)
}
return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
}
func (e *EchoPlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
}
func (e *EchoPlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
e.removeTask(taskExecutionID)
return nil
}
// copyInputsToOutputs copies the input literals to the output location.
func copyInputsToOutputs(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}
if len(inputToOutputVariableMappings) > 0 {
inputLiterals, err := tCtx.InputReader().Get(ctx)
if err != nil {
return core.UnknownTransition, err
}
outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings))
for inputVariableName, outputVariableName := range inputToOutputVariableMappings {
outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName]
}
outputLiteralMap := &idlcore.LiteralMap{
Literals: outputLiterals,
}
outputFile := tCtx.OutputWriter().GetOutputPath()
if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil {
return core.UnknownTransition, err
}
or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0)
if err = tCtx.OutputWriter().Put(ctx, or); err != nil {
return core.UnknownTransition, err
}
}
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}
func compileInputToOutputVariableMappings(ctx context.Context, tCtx core.TaskExecutionContext) (map[string]string, error) {
// validate outputs are castable from inputs otherwise error as this plugin is not applicable
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, fmt.Errorf("failed to read TaskTemplate: [%w]", err)
}
var inputs, outputs map[string]*idlcore.Variable
if taskTemplate.Interface != nil {
if taskTemplate.Interface.Inputs != nil {
inputs = taskTemplate.Interface.Inputs.Variables
}
if taskTemplate.Interface.Outputs != nil {
outputs = taskTemplate.Interface.Outputs.Variables
}
}
if len(inputs) != len(outputs) {
return nil, fmt.Errorf("the number of input [%d] and output [%d] variables does not match", len(inputs), len(outputs))
} else if len(inputs) > 1 {
return nil, fmt.Errorf("this plugin does not currently support more than one input variable")
}
inputToOutputVariableMappings := make(map[string]string)
outputVariableNameUsed := make(map[string]struct{})
for inputVariableName := range inputs {
firstCastableOutputName := ""
for outputVariableName := range outputs {
// TODO - need to check if types are castable to support multiple values
if _, ok := outputVariableNameUsed[outputVariableName]; !ok {
firstCastableOutputName = outputVariableName
break
}
}
if len(firstCastableOutputName) == 0 {
return nil, fmt.Errorf("no castable output variable found for input variable [%s]", inputVariableName)
}
outputVariableNameUsed[firstCastableOutputName] = struct{}{}
inputToOutputVariableMappings[inputVariableName] = firstCastableOutputName
}
return inputToOutputVariableMappings, nil
}
func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: echoTaskType,
RegisteredTaskTypes: []core.TaskType{echoTaskType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return &EchoPlugin{
enqueueOwner: iCtx.EnqueueOwner(),
taskStartTimes: make(map[string]time.Time),
}, nil
},
IsDefault: false,
},
)
}