Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Lazy load gRPC plugin (#353)
Browse files Browse the repository at this point in the history
* Lazy load grpc plugin

Signed-off-by: Kevin Su <[email protected]>

* rename

Signed-off-by: Kevin Su <[email protected]>

* rename

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* rename

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jun 5, 2023
1 parent c499a48 commit 06866ee
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 65 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v1.5.2
github.com/flyteorg/flyteidl v1.5.10
github.com/flyteorg/flytestdlib v1.0.15
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8=
github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I=
github.com/flyteorg/flyteidl v1.5.10 h1:SHeiaWRt8EAVuFsat+BJswtc07HTZ4DqhfTEYSm621k=
github.com/flyteorg/flyteidl v1.5.10/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"time"
Expand Down Expand Up @@ -39,22 +39,22 @@ var (
Value: 50,
},
},
DefaultGrpcEndpoint: "dns:///external-plugin-service.flyte.svc.cluster.local:80",
DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
}

configSection = pluginsConfig.MustRegisterSubSection("external-plugin-service", &defaultConfig)
configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
)

// Config is config for 'databricks' plugin
// Config is config for 'agent' plugin
type Config struct {
// WebAPI defines config for the base WebAPI plugin
WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."`

// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of external plugin service."`
DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`

// Maps endpoint to their plugin handler. {TaskType: Endpoint}
EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"testing"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"context"
Expand All @@ -14,6 +14,7 @@ import (
ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"

"github.com/flyteorg/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
Expand All @@ -38,27 +39,27 @@ type MockPlugin struct {
type MockClient struct {
}

func (m *MockClient) CreateTask(_ context.Context, _ *service.TaskCreateRequest, _ ...grpc.CallOption) (*service.TaskCreateResponse, error) {
return &service.TaskCreateResponse{JobId: "job-id"}, nil
func (m *MockClient) CreateTask(_ context.Context, _ *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

func (m *MockClient) GetTask(_ context.Context, _ *service.TaskGetRequest, _ ...grpc.CallOption) (*service.TaskGetResponse, error) {
return &service.TaskGetResponse{State: service.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
func (m *MockClient) GetTask(_ context.Context, _ *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
},
}}, nil
}}}, nil
}

func (m *MockClient) DeleteTask(_ context.Context, _ *service.TaskDeleteRequest, _ ...grpc.CallOption) (*service.TaskDeleteResponse, error) {
return &service.TaskDeleteResponse{}, nil
func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) {
return &admin.DeleteTaskResponse{}, nil
}

func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return &MockClient{}, nil
}

func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return nil, fmt.Errorf("error")
}

Expand Down Expand Up @@ -98,7 +99,7 @@ func TestEndToEnd(t *testing.T) {
basePrefix := storage.DataReference("fake://bucket/prefix/")

t.Run("run a job", func(t *testing.T) {
pluginEntry := pluginmachinery.CreateRemotePlugin(newMockGrpcPlugin())
pluginEntry := pluginmachinery.CreateRemotePlugin(newMockAgentPlugin())
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test1"))
assert.NoError(t, err)

Expand All @@ -107,8 +108,8 @@ func TestEndToEnd(t *testing.T) {
})

t.Run("failed to create a job", func(t *testing.T) {
grpcPlugin := newMockGrpcPlugin()
grpcPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
agentPlugin := newMockAgentPlugin()
agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &MockPlugin{
Plugin{
metricScope: iCtx.MetricsScope(),
Expand All @@ -117,7 +118,7 @@ func TestEndToEnd(t *testing.T) {
},
}, nil
}
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test2"))
assert.NoError(t, err)

Expand All @@ -144,8 +145,8 @@ func TestEndToEnd(t *testing.T) {
tr.OnRead(context.Background()).Return(nil, fmt.Errorf("read fail"))
tCtx.OnTaskReader().Return(tr)

grpcPlugin := newMockGrpcPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
agentPlugin := newAgentPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test3"))
assert.NoError(t, err)

Expand All @@ -165,8 +166,8 @@ func TestEndToEnd(t *testing.T) {
inputReader.OnGetMatch(mock.Anything).Return(nil, fmt.Errorf("read fail"))
tCtx.OnInputReader().Return(inputReader)

grpcPlugin := newMockGrpcPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
agentPlugin := newMockAgentPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test4"))
assert.NoError(t, err)

Expand Down Expand Up @@ -239,9 +240,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
return tCtx
}

func newMockGrpcPlugin() webapi.PluginEntry {
func newMockAgentPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "external-plugin-service",
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &MockPlugin{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package grpc
package agent

import (
"context"
"encoding/gob"
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"

"google.golang.org/grpc/grpclog"

flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -19,7 +21,7 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error)
type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
Expand All @@ -29,15 +31,15 @@ type Plugin struct {
}

type ResourceWrapper struct {
State service.State
State admin.State
Outputs *flyteIdl.LiteralMap
}

type ResourceMetaWrapper struct {
OutputPrefix string
Token string
JobID string
TaskType string
OutputPrefix string
Token string
AgentResourceMeta []byte
TaskType string
}

func (p Plugin) GetConfig() webapi.PluginConfig {
Expand Down Expand Up @@ -67,20 +69,20 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect external plugin service with error: %v", err)
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.CreateTask(ctx, &service.TaskCreateRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix})
res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix})
if err != nil {
return nil, nil, err
}

return &ResourceMetaWrapper{
OutputPrefix: outputPrefix,
JobID: res.GetJobId(),
Token: "",
TaskType: taskTemplate.Type,
}, &ResourceWrapper{State: service.State_RUNNING}, nil
OutputPrefix: outputPrefix,
AgentResourceMeta: res.GetResourceMeta(),
Token: "",
TaskType: taskTemplate.Type,
}, &ResourceWrapper{State: admin.State_RUNNING}, nil
}

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
Expand All @@ -89,17 +91,17 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba
endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect external plugin service with error: %v", err)
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.GetTask(ctx, &service.TaskGetRequest{TaskType: metadata.TaskType, JobId: metadata.JobID})
res, err := client.GetTask(ctx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
if err != nil {
return nil, err
}

return &ResourceWrapper{
State: res.State,
Outputs: res.Outputs,
State: res.Resource.State,
Outputs: res.Resource.Outputs,
}, nil
}

Expand All @@ -112,10 +114,10 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect external plugin service with error: %v", err)
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

_, err = client.DeleteTask(ctx, &service.TaskDeleteRequest{TaskType: metadata.TaskType, JobId: metadata.JobID})
_, err = client.DeleteTask(ctx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
return err
}

Expand All @@ -124,13 +126,13 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
taskInfo := &core.TaskInfo{}

switch resource.State {
case service.State_RUNNING:
case admin.State_RUNNING:
return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil
case service.State_PERMANENT_FAILURE:
case admin.State_PERMANENT_FAILURE:
return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case service.State_RETRYABLE_FAILURE:
case admin.State_RETRYABLE_FAILURE:
return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case service.State_SUCCEEDED:
case admin.State_SUCCEEDED:
if resource.Outputs != nil {
err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil))
if err != nil {
Expand All @@ -150,10 +152,10 @@ func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map
return defaultEndpoint
}

func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
if ok {
return service.NewExternalPluginServiceClient(conn), nil
return service.NewAsyncAgentServiceClient(conn), nil
}
var opts []grpc.DialOption
var err error
Expand All @@ -178,14 +180,14 @@ func getClientFunc(ctx context.Context, endpoint string, connectionCache map[str
}
}()
}()
return service.NewExternalPluginServiceClient(conn), nil
return service.NewAsyncAgentServiceClient(conn), nil
}

func newGrpcPlugin() webapi.PluginEntry {
func newAgentPlugin() webapi.PluginEntry {
supportedTaskTypes := GetConfig().SupportedTaskTypes

return webapi.PluginEntry{
ID: "external-plugin-service",
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &Plugin{
Expand All @@ -198,9 +200,9 @@ func newGrpcPlugin() webapi.PluginEntry {
}
}

func init() {
func RegisterAgentPlugin() {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newGrpcPlugin())
pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"context"
Expand All @@ -25,7 +25,7 @@ func TestPlugin(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = "test-service.flyte.svc.cluster.local:80"
cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80"
cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"}
err := SetConfig(&cfg)
assert.NoError(t, err)
Expand All @@ -38,10 +38,10 @@ func TestPlugin(t *testing.T) {
assert.Equal(t, plugin.cfg.ResourceConstraints, constraints)
})

t.Run("tet newGrpcPlugin", func(t *testing.T) {
p := newGrpcPlugin()
t.Run("tet newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
assert.NotNil(t, p)
assert.Equal(t, p.ID, "external-plugin-service")
assert.Equal(t, p.ID, "agent-service")
assert.NotNil(t, p.PluginLoader)
})

Expand Down

0 comments on commit 06866ee

Please sign in to comment.