Skip to content

Commit

Permalink
Propagate custom_info Dict through agent Resource
Browse files Browse the repository at this point in the history
 - The agent defines a Resource return type with values:

   * outputs
   * message
   * log_links
   * phase

   These are all a part of the underlying protobuf contract defined in
   flyteidl.

   However, the message field custom_info from the protobuf is not here

   google.protobuf.Struct custom_info

   https://github.com/flyteorg/flyte/blob/519080b6e4e53fc0e216b5715ad9b5b5270f35c0/flyteidl/protos/flyteidl/admin/agent.proto#L140

   This field was added in flyteorg/flyte#4874
   but never made it into the corresponding flytekit PR
   flyteorg#2146

 - It's useful for agents to return additional metadata about the job,
   and it looks like custom_info is the intended location

 - Make a minor refactor to how the agent responds to requests that
   return Resource by implementing to_flyte_idl / from_flyte_idl
   directly

Signed-off-by: ddl-ebrown <[email protected]>
  • Loading branch information
ddl-ebrown committed May 19, 2024
1 parent 7213723 commit 2fd654a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 36 deletions.
27 changes: 3 additions & 24 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import (
AgentMetadataServiceServicer,
Expand All @@ -25,8 +24,7 @@
)
from prometheus_client import Counter, Summary

from flytekit import FlyteContext, logger
from flytekit.core.type_engine import TypeEngine
from flytekit import logger
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -136,16 +134,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)
logger.info(f"{agent.name} start checking the status of the job")
res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta))

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)
return GetTaskResponse(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
return GetTaskResponse(resource=res.to_flyte_idl())

@record_agent_metrics
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
Expand Down Expand Up @@ -175,17 +164,7 @@ async def ExecuteTaskSync(
literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map)

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)

header = ExecuteTaskSyncResponseHeader(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
header = ExecuteTaskSyncResponseHeader(resource=res.to_flyte_idl())
yield ExecuteTaskSyncResponse(header=header)
request_success_count.labels(task_type=task_type, operation=do_operation).inc()
except Exception as e:
Expand Down
36 changes: 35 additions & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from typing import Any, Dict, List, Optional, Union

from flyteidl.admin.agent_pb2 import Agent
from flyteidl.admin.agent_pb2 import Resource as _Resource
from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory
from flyteidl.core import literals_pb2
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct
from rich.logging import RichHandler
from rich.progress import Progress

Expand All @@ -27,6 +30,7 @@
from flytekit.exceptions.user import FlyteUserException
from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template
from flytekit.loggers import set_flytekit_log_properties
from flytekit.models import common
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate

Expand Down Expand Up @@ -75,7 +79,7 @@ def decode(cls, data: bytes) -> "ResourceMeta":


@dataclass
class Resource:
class Resource(common.FlyteIdlEntity):
"""
This is the output resource of the job.
Expand All @@ -90,6 +94,36 @@ class Resource:
message: Optional[str] = None
log_links: Optional[List[TaskLog]] = None
outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None
custom_info: Optional[typing.Dict[str, Any]] = None

def to_flyte_idl(self) -> _Resource:
if self.outputs is None:
outputs = None
elif isinstance(self.outputs, LiteralMap):
outputs = self.outputs.to_flyte_idl()
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, self.outputs)

return _Resource(
phase=self.phase,
message=self.message,
log_links=self.log_links,
outputs=outputs,
custom_info=(json_format.Parse(json.dumps(self.custom_info), Struct()) if self.custom_info else None),
)

@classmethod
def from_flyte_idl(cls, pb2_object: _Resource):
return cls(
phase=pb2_object.phase,
message=pb2_object.message,
log_links=pb2_object.log_links,
outputs=(LiteralMap.from_flyte_idl(pb2_object.outputs) if pb2_object.outputs else None),
custom_info=(
json_format.MessageToDict(pb2_object.custom_info) if pb2_object.HasField("custom_info") else None
),
)


class AgentBase(ABC):
Expand Down
101 changes: 90 additions & 11 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@

from flytekit import PythonFunctionTask, task
from flytekit.clis.sdk_in_container.serve import print_agents_metadata
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.configuration import (
FastSerializationSettings,
Image,
ImageConfig,
SerializationSettings,
)
from flytekit.core.base_task import PythonTask, kwtypes
from flytekit.core.interface import Interface
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService
from flytekit.extend.backend.agent_service import (
AgentMetadataService,
AsyncAgentService,
SyncAgentService,
)
from flytekit.extend.backend.base_agent import (
AgentRegistry,
AsyncAgentBase,
Expand Down Expand Up @@ -70,7 +79,11 @@ def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap
return DummyMetadata(job_id=dummy_id)

def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
custom_info={"custom": "info", "num": 1},
)

def delete(self, resource_meta: DummyMetadata, **kwargs):
...
Expand All @@ -95,7 +108,11 @@ async def create(
return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name)

async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
custom_info={"custom": "info", "num": 1},
)

async def delete(self, resource_meta: DummyMetadata, **kwargs):
...
Expand All @@ -107,7 +124,12 @@ class MockOpenAIAgent(SyncAgentBase):
def __init__(self):
super().__init__(task_type_name="openai")

def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource:
def do(
self,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
**kwargs,
) -> Resource:
assert inputs.literals["a"].scalar.primitive.integer == 1
return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1})

Expand Down Expand Up @@ -172,6 +194,8 @@ def test_dummy_agent():
assert resource.phase == TaskExecution.SUCCEEDED
assert resource.log_links[0].name == "console"
assert resource.log_links[0].uri == "localhost:3000"
assert resource.custom_info["custom"] == "info"
assert resource.custom_info["num"] == 1
assert agent.delete(metadata) is None

class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask):
Expand All @@ -187,7 +211,9 @@ def __init__(self, **kwargs):


@pytest.mark.parametrize(
"agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"]
"agent,consume_metadata",
[(DummyAgent(), False), (AsyncDummyAgent(), True)],
ids=["sync", "async"],
)
@pytest.mark.asyncio
async def test_async_agent_service(agent, consume_metadata):
Expand Down Expand Up @@ -220,7 +246,10 @@ async def test_async_agent_service(agent, consume_metadata):
assert res.resource_meta == metadata_bytes
res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx)
assert res.resource.phase == TaskExecution.SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx)
res = await service.DeleteTask(
DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes),
ctx,
)
assert res is None

agent_metadata = AgentRegistry.get_agent_metadata(agent.name)
Expand Down Expand Up @@ -267,7 +296,9 @@ def test_openai_agent():
class OpenAITask(SyncAgentExecutorMixin, PythonTask):
def __init__(self, **kwargs):
super().__init__(
task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs
task_type="openai",
interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)),
**kwargs,
)

t = OpenAITask(task_config={}, name="openai task")
Expand Down Expand Up @@ -391,10 +422,14 @@ def test_render_task_template():
@pytest.fixture
def sample_agents():
async_agent = Agent(
name="Sensor", is_sync=False, supported_task_categories=[TaskCategory(name="sensor", version=0)]
name="Sensor",
is_sync=False,
supported_task_categories=[TaskCategory(name="sensor", version=0)],
)
sync_agent = Agent(
name="ChatGPT Agent", is_sync=True, supported_task_categories=[TaskCategory(name="chatgpt", version=0)]
name="ChatGPT Agent",
is_sync=True,
supported_task_categories=[TaskCategory(name="chatgpt", version=0)],
)
return [async_agent, sync_agent]

Expand All @@ -406,7 +441,51 @@ def test_print_agents_metadata_output(list_agents_mock, mock_secho, sample_agent
print_agents_metadata()
expected_calls = [
(("Starting Sensor that supports task categories ['sensor']",), {"fg": "blue"}),
(("Starting ChatGPT Agent that supports task categories ['chatgpt']",), {"fg": "blue"}),
(
("Starting ChatGPT Agent that supports task categories ['chatgpt']",),
{"fg": "blue"},
),
]
mock_secho.assert_has_calls(expected_calls, any_order=True)
assert mock_secho.call_count == len(expected_calls)


def test_resource_type():
o = Resource(
phase=TaskExecution.SUCCEEDED,
)
v = o.to_flyte_idl()
assert v
assert v.phase == TaskExecution.SUCCEEDED
assert len(v.log_links) == 0
assert v.message == ""
assert len(v.outputs.literals) == 0
assert len(v.custom_info) == 0

o2 = Resource.from_flyte_idl(v)
assert o2

o = Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="console", uri="localhost:3000")],
message="foo",
outputs={"o0": 1},
custom_info={"custom": "info", "num": 1},
)
v = o.to_flyte_idl()
assert v
assert v.phase == TaskExecution.SUCCEEDED
assert v.log_links[0].name == "console"
assert v.log_links[0].uri == "localhost:3000"
assert v.message == "foo"
assert v.outputs.literals["o0"].scalar.primitive.integer == 1
assert v.custom_info["custom"] == "info"
assert v.custom_info["num"] == 1

o2 = Resource.from_flyte_idl(v)
assert o2.phase == o.phase
assert o2.log_links == o.log_links
assert o2.message == o.message
# round-tripping creates a literal map out of outputs
assert o2.outputs.literals["o0"].scalar.primitive.integer == 1
assert o2.custom_info == o.custom_info

0 comments on commit 2fd654a

Please sign in to comment.