Skip to content

Commit

Permalink
Add supportTaskTypes for agentservice without write it in config twic…
Browse files Browse the repository at this point in the history
…e. (flyteorg#398)

Signed-off-by: Future Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
  • Loading branch information
Future-Outlier and Future Outlier authored Sep 22, 2023
1 parent 08dd79b commit 2598c96
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 0 additions & 4 deletions go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ var (
Insecure: true,
DefaultTimeout: config.Duration{Duration: 10 * time.Second},
},
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -66,9 +65,6 @@ type Config struct {

// Maps task types to their agents. {TaskType: AgentId}
AgentForTaskTypes map[string]string `json:"agentForTaskTypes" pflag:"-,"`

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`
}

type Agent struct {
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func TestEndToEnd(t *testing.T) {
tr.OnRead(context.Background()).Return(nil, fmt.Errorf("read fail"))
tCtx.OnTaskReader().Return(tr)

agentPlugin := newAgentPlugin()
agentPlugin := newAgentPlugin(SupportedTaskTypes{})
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test3"))
assert.NoError(t, err)
Expand Down
12 changes: 8 additions & 4 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (

type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type TaskType = string
type SupportedTaskTypes []TaskType
type Plugin struct {
metricScope promutils.Scope
cfg *Config
Expand Down Expand Up @@ -296,8 +298,10 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte
return context.WithTimeout(ctx, timeout)
}

func newAgentPlugin() webapi.PluginEntry {
supportedTaskTypes := GetConfig().SupportedTaskTypes
func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
if len(supportedTaskTypes) == 0 {
supportedTaskTypes = SupportedTaskTypes{"default_supported_task_type"}
}

return webapi.PluginEntry{
ID: "agent-service",
Expand All @@ -313,9 +317,9 @@ func newAgentPlugin() webapi.PluginEntry {
}
}

func RegisterAgentPlugin() {
func RegisterAgentPlugin(supportedTaskTypes SupportedTaskTypes) {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin(supportedTaskTypes))
}
2 changes: 1 addition & 1 deletion go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestPlugin(t *testing.T) {
})

t.Run("test newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
p := newAgentPlugin(SupportedTaskTypes{})
assert.NotNil(t, p)
assert.Equal(t, "agent-service", p.ID)
assert.NotNil(t, p.PluginLoader)
Expand Down

0 comments on commit 2598c96

Please sign in to comment.