Skip to content

Commit

Permalink
chatgpt succeed
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
Future Outlier committed Oct 2, 2023
1 parent b08924f commit adc1c2c
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 41 deletions.
69 changes: 48 additions & 21 deletions flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@ import (
"fmt"
"time"

"github.com/flyteorg/flyte/flytestdlib/cache"
stdErrs "github.com/flyteorg/flyte/flytestdlib/errors"
"k8s.io/utils/clock"

stdErrs "github.com/flyteorg/flytestdlib/errors"

"github.com/flyteorg/flytestdlib/cache"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flytestdlib/logger"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flyte/flytestdlib/logger"
)

const (
Expand Down Expand Up @@ -69,26 +66,55 @@ func (c CorePlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

// syncHandle
// TODO: ADD Sync Handle
func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
func (c CorePlugin) syncHandle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader())
if err != nil {
return core.UnknownTransition, err
}

var state *State
state = &incomingState
// TODO: This is incoming State
newCacheItem := CacheItem{
State: *state,
}
cacheItemID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
item, err := c.cache.GetOrCreate(cacheItemID, newCacheItem)
cacheItem, ok := item.(CacheItem)
if !ok {
logger.Errorf(ctx, "Error casting cache object into ExecutionState")
return core.UnknownTransition, err
}

res, err := c.sp.Do(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}

cacheItem.Resource = res

phase := PhaseSucceeded
cacheItem.Phase = phase
err = c.cache.DeleteDelayed(cacheItemID)

if err := tCtx.PluginStateWriter().Put(pluginStateVersion, cacheItem.State); err != nil {
return core.UnknownTransition, err
}

taskInfo := &core.TaskInfo{}
return core.DoTransition(core.PhaseInfoSuccess(taskInfo)), nil
}

func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
taskTemplate, err := tCtx.TaskReader().Read(ctx)

if taskTemplate.Type == "dispatcher" {
res, err := c.sp.Do(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}
logger.Infof(ctx, "@@@ SyncPlugin [%v] returned result: %v", c.GetID(), res)
// if err := tCtx.PluginStateWriter().Put(pluginStateVersion, nextState); err != nil {
// return core.UnknownTransition, err
// }
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
return c.syncHandle(ctx, tCtx)
}

incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader())
if err != nil {
return core.UnknownTransition, err
}

var nextState *State
Expand Down Expand Up @@ -163,6 +189,7 @@ func validateRangeFloat64(fieldName string, min, max, provided float64) error {

return nil
}

func validateConfig(cfg webapi.PluginConfig) error {
errs := stdErrs.ErrorCollection{}
errs.Append(validateRangeInt("cache size", minCacheSize, maxCacheSize, cfg.Caching.Size))
Expand Down
3 changes: 2 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/webapi/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

// A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed
// that the plugin loader will be called before any Handle/Abort/Finalize functions are invoked
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, error)
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, SyncPlugin, error)

// PluginEntry is a structure that is used to indicate to the system a K8s plugin
type PluginEntry struct {
Expand Down Expand Up @@ -151,4 +151,5 @@ type SyncPlugin interface {

// Do performs the action associated with this plugin.
Do(ctx context.Context, tCtx TaskExecutionContext) (phase pluginsCore.PhaseInfo, err error)
Do(ctx context.Context, tCtx TaskExecutionContext) (latest Resource, err error)
}
90 changes: 83 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,77 @@ func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionCo
return "default", p.cfg.ResourceConstraints, nil
}

func (p Plugin) Do(ctx context.Context, taskCtx webapi.TaskExecutionContext) (latest webapi.Resource, err error) {
// write the resource here
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, err
}

inputs, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, err
}
taskTemplate.GetContainer().Args = modifiedArgs
}

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}

client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "DoTask", agent)

defer cancel()

res, err := client.DoTask(finalCtx, &admin.DoTaskRequest{Inputs: inputs, Template: taskTemplate})
if err != nil {
return nil, err
}

resource := &ResourceWrapper{
State: res.Resource.State,
Outputs: res.Resource.Outputs,
}

// Write the output
if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
logger.Debugf(ctx, "The task declares no outputs. Skipping writing the outputs.")
return resource, nil
}

var opReader io.OutputReader
if resource.Outputs != nil {
logger.Debugf(ctx, "Agent returned an output.")
opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)
} else {
logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.")
opReader = ioutils.NewRemoteFileOutputReader(ctx, taskCtx.DataStore(), taskCtx.OutputWriter(), taskCtx.MaxDatasetSizeBytes())
}

return &ResourceWrapper{
State: res.Resource.State,
Outputs: res.Resource.Outputs,
}, taskCtx.OutputWriter().Put(ctx, opReader)
}

func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta,
webapi.Resource, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -203,7 +274,7 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource *Re

var opReader io.OutputReader
if resource.Outputs != nil {
logger.Debugf(ctx, "Agent returned an output")
logger.Debugf(ctx, "Agent returned an output.")
opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)
} else {
logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.")
Expand Down Expand Up @@ -306,13 +377,18 @@ func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/athena/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,25 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo {
}
}

func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, error) {
func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, webapi.SyncPlugin, error) {
sdkCfg, err := awsConfig.GetSdkConfig()
if err != nil {
return Plugin{}, err
return Plugin{}, nil, err
}

return Plugin{
metricScope: metricScope,
client: athena.NewFromConfig(sdkCfg),
cfg: cfg,
awsConfig: sdkCfg,
}, nil
}, nil, nil
}

func init() {
pluginmachinery.PluginRegistry().RegisterRemotePlugin(webapi.PluginEntry{
ID: "athena",
SupportedTaskTypes: []core.TaskType{"hive", "presto"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return NewPlugin(ctx, GetConfig(), aws.GetConfig(), iCtx.MetricsScope())
},
})
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,25 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity)
return bigquery.NewService(ctx, options...)
}

func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, error) {
func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, webapi.SyncPlugin, error) {
googleTokenSource, err := google.NewTokenSourceFactory(cfg.GoogleTokenSource)

if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
return nil, nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
}

return &Plugin{
metricScope: metricScope,
cfg: cfg,
googleTokenSource: googleTokenSource,
}, nil
}, nil, nil
}

func newBigQueryJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "bigquery",
SupportedTaskTypes: []core.TaskType{bigqueryQueryJobTask},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
cfg := GetConfig()

return NewPlugin(cfg, iCtx.MetricsScope())
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,12 @@ func newDatabricksJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "databricks",
SupportedTaskTypes: []core.TaskType{"spark"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/snowflake/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ func newSnowflakeJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "snowflake",
SupportedTaskTypes: []core.TaskType{"snowflake"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down

0 comments on commit adc1c2c

Please sign in to comment.