diff --git a/.github/workflows/monodocs_build.yml b/.github/workflows/monodocs_build.yml index 7b30ef957d..d729ff107d 100644 --- a/.github/workflows/monodocs_build.yml +++ b/.github/workflows/monodocs_build.yml @@ -38,7 +38,7 @@ jobs: working-directory: ${{ github.workspace }}/flyte run: | conda activate monodocs-env - pip install -e ./flyteidl + pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl conda info conda list conda config --show-sources diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index afa22ddde7..9ecd6a1a60 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -210,6 +210,7 @@ jobs: # onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4. # The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693 # flytekit-onnx-tensorflow + - flytekit-chatgpt - flytekit-pandera - flytekit-papermill - flytekit-polars @@ -275,6 +276,7 @@ jobs: pip freeze - name: Test with coverage run: | + pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl cd plugins/${{ matrix.plugin-names }} # onnx plugins does not support protobuf>4 yet (in fact it is tensorflow that # does not support that yet). More details in https://github.com/onnx/onnx/issues/4239. diff --git a/Dockerfile.agent b/Dockerfile.agent index fe4ce56290..8f3d5b52e9 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -14,6 +14,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ flytekitplugins-mmcloud==$VERSION \ flytekitplugins-spark==$VERSION \ flytekitplugins-snowflake==$VERSION \ + flytekitplugins-chatgpt==$VERSION \ && apt-get clean autoclean \ && apt-get autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ diff --git a/Dockerfile.dev b/Dockerfile.dev index 2b85a5f7d2..1a2a40ff6f 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -29,8 +29,7 @@ COPY . /flytekit # 4. Create a non-root user 'flytekit' and set appropriate permissions for directories. RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \ && pip install --no-cache-dir -U --pre \ - flyteidl \ - -e /flytekit \ + git+https://github.com/flyteorg/flyte.git@a46196d42e80f055efc9a013cb95c20342c72622#subdirectory=flyteidl \ -e /flytekit/plugins/flytekit-k8s-pod \ -e /flytekit/plugins/flytekit-deck-standard \ -e /flytekit/plugins/flytekit-flyteinteractive \ diff --git a/Makefile b/Makefile index 321bd34241..cefb7a68a9 100644 --- a/Makefile +++ b/Makefile @@ -24,9 +24,9 @@ update_boilerplate: .PHONY: setup setup: install-piptools ## Install requirements + pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl pip install --pre -r dev-requirements.in - .PHONY: fmt fmt: pre-commit run ruff --all-files || true diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 87f008b084..2783afc727 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -4,6 +4,7 @@ from flyteidl.service.agent_pb2_grpc import ( add_AgentMetadataServiceServicer_to_server, add_AsyncAgentServiceServicer_to_server, + add_SyncAgentServiceServicer_to_server, ) from grpc import aio @@ -52,7 +53,7 @@ def agent(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService + from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService try: from prometheus_client import start_http_server @@ -64,6 +65,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) + add_SyncAgentServiceServicer_to_server(SyncAgentService(), server) add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server) server.add_insecure_port(f"[::]:{port}") diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 76f750233b..0be04dac32 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -16,6 +16,7 @@ from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json +from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -1210,6 +1211,16 @@ def dict_to_literal_map( raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) return LiteralMap(literal_map) + @classmethod + def dict_to_literal_map_pb( + cls, + ctx: FlyteContext, + d: typing.Dict[str, typing.Any], + type_hints: Optional[typing.Dict[str, type]] = None, + ) -> Optional[literals_pb2.LiteralMap]: + literal_map = cls.dict_to_literal_map(ctx, d, type_hints) + return literal_map.to_flyte_idl() + @classmethod def get_available_transformers(cls) -> typing.KeysView[Type]: """ diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 2d4246c6c1..ce77eb5ba1 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -1,4 +1,5 @@ import typing +from http import HTTPStatus import grpc from flyteidl.admin.agent_pb2 import ( @@ -6,19 +7,28 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + ExecuteTaskSyncRequest, + ExecuteTaskSyncResponse, + ExecuteTaskSyncResponseHeader, GetAgentRequest, GetAgentResponse, GetTaskRequest, GetTaskResponse, ListAgentsRequest, ListAgentsResponse, + Resource, +) +from flyteidl.service.agent_pb2_grpc import ( + AgentMetadataServiceServicer, + AsyncAgentServiceServicer, + SyncAgentServiceServicer, ) -from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer from prometheus_client import Counter, Summary -from flytekit import logger +from flytekit import FlyteContext, logger +from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.system import FlyteAgentNotFound -from flytekit.extend.backend.base_agent import AgentRegistry, mirror_async_methods +from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -26,6 +36,7 @@ create_operation = "create" get_operation = "get" delete_operation = "delete" +do_operation = "do" # Follow the naming convention. https://prometheus.io/docs/practices/naming/ request_success_count = Counter( @@ -46,7 +57,24 @@ input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"]) -def agent_exception_handler(func: typing.Callable): +def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str): + if isinstance(e, FlyteAgentNotFound): + error_message = f"Cannot find agent for task type: {task_type}." + logger.error(error_message) + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(error_message) + request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc() + else: + error_message = f"failed to {operation} {task_type} task with error: {e}." + logger.error(error_message) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(error_message) + request_failure_count.labels( + task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR + ).inc() + + +def record_agent_metrics(func: typing.Callable): async def wrapper( self, request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], @@ -60,10 +88,10 @@ async def wrapper( if request.inputs: input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) elif isinstance(request, GetTaskRequest): - task_type = request.task_type + task_type = request.task_type or request.task_category.name operation = get_operation elif isinstance(request, DeleteTaskRequest): - task_type = request.task_type + task_type = request.task_category or request.task_category.name operation = delete_operation else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -75,51 +103,90 @@ async def wrapper( res = await func(self, request, context, *args, **kwargs) request_success_count.labels(task_type=task_type, operation=operation).inc() return res - except FlyteAgentNotFound: - error_message = f"Cannot find agent for task type: {task_type}." - logger.error(error_message) - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc() except Exception as e: - error_message = f"failed to {operation} {task_type} task with error {e}." - logger.error(error_message) - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(error_message) - request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc() + _handle_exception(e, context, task_type, operation) return wrapper class AsyncAgentService(AsyncAgentServiceServicer): - @agent_exception_handler + @record_agent_metrics async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: - tmp = TaskTemplate.from_flyte_idl(request.template) + template = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(tmp.type) + agent = AgentRegistry.get_agent(template.type, template.task_type_version) - logger.info(f"{tmp.type} agent start creating the job") - return await mirror_async_methods( - agent.create, output_prefix=request.output_prefix, task_template=tmp, inputs=inputs - ) + logger.info(f"{agent.name} start creating the job") + resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs) + return CreateTaskResponse(resource_meta=resource_mata.encode()) - @agent_exception_handler + @record_agent_metrics async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start checking the status of the job") - return await mirror_async_methods(agent.get, resource_meta=request.resource_meta) + if request.task_category: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start checking the status of the job") + res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + return GetTaskResponse( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) - @agent_exception_handler + @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: - agent = AgentRegistry.get_agent(request.task_type) - logger.info(f"{agent.task_type} agent start deleting the job") - return await mirror_async_methods(agent.delete, resource_meta=request.resource_meta) + if request.task_category: + agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version) + else: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.name} start deleting the job") + return await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta)) + + +class SyncAgentService(SyncAgentServiceServicer): + async def ExecuteTaskSync( + self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext + ) -> typing.AsyncIterator[ExecuteTaskSyncResponse]: + request = await request_iterator.__anext__() + template = TaskTemplate.from_flyte_idl(request.header.template) + task_type = template.type + try: + with request_latency.labels(task_type=task_type, operation=do_operation).time(): + agent = AgentRegistry.get_agent(task_type, template.task_type_version) + if not isinstance(agent, SyncAgentBase): + raise ValueError(f"[{agent.name}] does not support sync execution") + + request = await request_iterator.__anext__() + literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) + + if res.outputs is None: + outputs = None + elif isinstance(res.outputs, LiteralMap): + outputs = res.outputs.to_flyte_idl() + else: + ctx = FlyteContext.current_context() + outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + + header = ExecuteTaskSyncResponseHeader( + resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) + ) + yield ExecuteTaskSyncResponse(header=header) + request_success_count.labels(task_type=task_type, operation=do_operation).inc() + except Exception as e: + _handle_exception(e, context, template.type, do_operation) class AgentMetadataService(AgentMetadataServiceServicer): async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse: - return GetAgentResponse(agent=AgentRegistry._METADATA[request.name]) + return GetAgentResponse(agent=AgentRegistry.METADATA[request.name]) async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse: - agents = [agent for agent in AgentRegistry._METADATA.values()] - return ListAgentsResponse(agents=agents) + return ListAgentsResponse(agents=AgentRegistry.list_agents()) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 5a6e5cd3bf..375b56beb6 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,72 +1,161 @@ import asyncio -import inspect +import json import signal import sys import time import typing -from abc import ABC +from abc import ABC, abstractmethod from collections import OrderedDict +from dataclasses import asdict, dataclass from functools import partial from types import FrameType, coroutine +from typing import Any, Dict, List, Optional, Union -from flyteidl.admin.agent_pb2 import ( - Agent, - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, -) +from flyteidl.admin.agent_pb2 import Agent +from flyteidl.admin.agent_pb2 import TaskCategory as _TaskCategory from flyteidl.core import literals_pb2 -from flyteidl.core.execution_pb2 import TaskExecution -from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from rich.progress import Progress -import flytekit from flytekit import FlyteContext, PythonFunctionTask, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils from flytekit.core.base_task import PythonTask -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, dataclass_from_dict from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.exceptions.user import FlyteUserException +from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class TaskCategory: + def __init__(self, name: str, version: int = 0): + self._name = name + self._version = version + + def __hash__(self): + return hash((self._name, self._version)) + + def __eq__(self, other: "TaskCategory"): + return self._name == other._name and self._version == other._version + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> int: + return self._version + + def __str__(self): + return f"{self._name}_v{self._version}" + + +@dataclass +class ResourceMeta: + """ + This is the metadata for the job. For example, the id of the job. + """ + + def encode(self) -> bytes: + """ + Encode the resource meta to bytes. + """ + return json.dumps(asdict(self)).encode("utf-8") + + @classmethod + def decode(cls, data: bytes) -> "ResourceMeta": + """ + Decode the resource meta from bytes. + """ + return dataclass_from_dict(cls, json.loads(data.decode("utf-8"))) + + +@dataclass +class Resource: + """ + This is the output resource of the job. + + Args: + phase: The phase of the job. + message: The return message from the job. + log_links: The log links of the job. For example, the link to the BigQuery Console. + outputs: The outputs of the job. If return python native types, the agent will convert them to flyte literals. + """ + + phase: TaskExecution.Phase + message: Optional[str] = None + log_links: Optional[List[TaskLog]] = None + outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None + + +T = typing.TypeVar("T", bound=ResourceMeta) class AgentBase(ABC): + name = "Base Agent" + + def __init__(self, task_type_name: str, task_type_version: int = 0, **kwargs): + self._task_category = TaskCategory(name=task_type_name, version=task_type_version) + + @property + def task_category(self) -> TaskCategory: + """ + task category that the agent supports + """ + return self._task_category + + +class SyncAgentBase(AgentBase): + """ + This is the base class for all sync agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. + Propeller sends a request to agent service, and gets a response in the same call. + + All the agents should be registered in the AgentRegistry. Agent Service + will look up the agent based on the task type. Every task type can only have one agent. """ - This is the base class for all agents. It defines the interface that all agents must implement. - The agent service will be run either locally or in a pod, and will be responsible for - invoking agents. The propeller will communicate with the agent service + + name = "Base Sync Agent" + + @abstractmethod + def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> Resource: + """ + This is the method that the agent will run. + """ + raise NotImplementedError + + +class AsyncAgentBase(AgentBase, typing.Generic[T]): + """ + This is the base class for all async agents. It defines the interface that all agents must implement. + The agent service is responsible for invoking agents. The propeller will communicate with the agent service to create tasks, get the status of tasks, and delete tasks. All the agents should be registered in the AgentRegistry. Agent Service will look up the agent based on the task type. Every task type can only have one agent. """ - name = "Base Agent" + name = "Base Async Agent" - def __init__(self, task_type: str, **kwargs): - self._task_type = task_type + def __init__(self, metadata_type: typing.Type[T], **kwargs): + super().__init__(**kwargs) + self._metadata_type = metadata_type @property - def task_type(self) -> str: - """ - task_type is the name of the task type that this agent supports. - """ - return self._task_type - - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + def metadata_type(self) -> ResourceMeta: + return self._metadata_type + + @abstractmethod + def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> T: """ - Return a Unique ID for the task that was created. It should return error code if the task creation failed. + Return a resource meta that can be used to get the status of the task. """ raise NotImplementedError - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + @abstractmethod + def get(self, resource_meta: T, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -74,9 +163,10 @@ def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: """ raise NotImplementedError - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + @abstractmethod + def delete(self, resource_meta: T, **kwargs): """ - Delete the task. This call should be idempotent. + Delete the task. This call should be idempotent. It should raise an error if fails to delete the task. """ raise NotImplementedError @@ -88,127 +178,131 @@ class AgentRegistry(object): The agent metadata service will look up the agent metadata based on the agent name. """ - _REGISTRY: typing.Dict[str, AgentBase] = {} - _METADATA: typing.Dict[str, Agent] = {} + _REGISTRY: Dict[TaskCategory, Union[AsyncAgentBase, SyncAgentBase]] = {} + METADATA: Dict[str, Agent] = {} @staticmethod - def register(agent: AgentBase): - if agent.task_type in AgentRegistry._REGISTRY: - raise ValueError(f"Duplicate agent for task type {agent.task_type}") - AgentRegistry._REGISTRY[agent.task_type] = agent - - if agent.name in AgentRegistry._METADATA: - agent_metadata = AgentRegistry._METADATA[agent.name] - agent_metadata.supported_task_types.append(agent.task_type) + def register(agent: Union[AsyncAgentBase, SyncAgentBase], override: bool = False): + if agent.task_category in AgentRegistry._REGISTRY and override is False: + raise ValueError(f"Duplicate agent for task type: {agent.task_category}") + AgentRegistry._REGISTRY[agent.task_category] = agent + + task_category = _TaskCategory(name=agent.task_category.name, version=agent.task_category.version) + + if agent.name in AgentRegistry.METADATA: + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + agent_metadata.supported_task_categories.append(task_category) + agent_metadata.supported_task_types.append(task_category.name) else: - agent_metadata = Agent(name=agent.name, supported_task_types=[agent.task_type]) - AgentRegistry._METADATA[agent.name] = agent_metadata + agent_metadata = Agent( + name=agent.name, + supported_task_types=[task_category.name], + supported_task_categories=[task_category], + is_sync=isinstance(agent, SyncAgentBase), + ) + AgentRegistry.METADATA[agent.name] = agent_metadata + + logger.info(f"Registering {agent.name} for task type: {agent.task_category}") - logger.info(f"Registering an agent for task type: {agent.task_type}, name: {agent.name}") + @staticmethod + def get_agent(task_type_name: str, task_type_version: int = 0) -> Union[SyncAgentBase, AsyncAgentBase]: + task_category = TaskCategory(name=task_type_name, version=task_type_version) + if task_category not in AgentRegistry._REGISTRY: + raise FlyteAgentNotFound(f"Cannot find agent for task category: {task_category}.") + return AgentRegistry._REGISTRY[task_category] @staticmethod - def get_agent(task_type: str) -> AgentBase: - if task_type not in AgentRegistry._REGISTRY: - raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.") - return AgentRegistry._REGISTRY[task_type] + def list_agents() -> List[Agent]: + return list(AgentRegistry.METADATA.values()) @staticmethod def get_agent_metadata(name: str) -> Agent: - if name not in AgentRegistry._METADATA: + if name not in AgentRegistry.METADATA: raise FlyteAgentNotFound(f"Cannot find agent for name: {name}.") - return AgentRegistry._METADATA[name] - - -def mirror_async_methods(func: typing.Callable, **kwargs) -> typing.Coroutine: - if inspect.iscoroutinefunction(func): - return func(**kwargs) - args = [v for _, v in kwargs.items()] - return asyncio.get_running_loop().run_in_executor(None, func, *args) + return AgentRegistry.METADATA[name] -def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: - """ - Convert the state from the agent to the phase in flyte. - """ - state = state.lower() - # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate - if state in ["failed", "timeout", "timedout", "canceled"]: - return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: - return TaskExecution.SUCCEEDED - elif state in ["running"]: - return TaskExecution.RUNNING - raise ValueError(f"Unrecognized state: {state}") - - -def is_terminal_phase(phase: TaskExecution.Phase) -> bool: +class SyncAgentExecutorMixin: """ - Return true if the phase is terminal. + This mixin class is used to run the sync task locally, and it's only used for local execution. + Task should inherit from this class if the task can be run in the agent. + + Synchronous tasks run quickly and can return their results instantly. + Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ - return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask) + + def execute(self: T, **kwargs) -> LiteralMap: + from flytekit.tools.translator import get_serializable + + ctx = FlyteContext.current_context() + ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) + task_template = get_serializable(OrderedDict(), ss, self).template + + agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) -def get_agent_secret(secret_key: str) -> str: - return flytekit.current_context().secrets.get(secret_key) + resource = asyncio.run(self._do(agent, task_template, kwargs)) + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {agent.name} with error: {resource.message}") + + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + return resource.outputs + + async def _do(self: T, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None) -> Resource: + ctx = FlyteContext.current_context() + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + return await mirror_async_methods(agent.do, task_template=template, inputs=literal_map) class AsyncAgentExecutorMixin: """ - This mixin class is used to run the agent task locally, and it's only used for local execution. + This mixin class is used to run the async task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. - It can handle asynchronous tasks and synchronous tasks. + Asynchronous tasks are tasks that take a long time to complete, such as running a query. - Synchronous tasks run quickly and can return their results instantly. Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ + T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask) + _clean_up_task: coroutine = None - _agent: AgentBase = None - _entity: PythonTask = None + _agent: AsyncAgentBase = None - def execute(self, **kwargs) -> typing.Any: + def execute(self: T, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() from flytekit.tools.translator import get_serializable - self._entity = typing.cast(PythonTask, self) - task_template = get_serializable(OrderedDict(), ss, self._entity).template - self._agent = AgentRegistry.get_agent(task_template.type) - - res = asyncio.run(self._create(task_template, output_prefix, kwargs)) - - # If the task is synchronous, the agent will return the output from the resource literals. - if res.HasField("resource"): - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") - return LiteralMap.from_flyte_idl(res.resource.outputs) + task_template = get_serializable(OrderedDict(), ss, self).template + self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) + resource_mata = asyncio.run(self._create(task_template, output_prefix, kwargs)) + resource = asyncio.run(self._get(resource_meta=resource_mata)) - if res.resource.phase != TaskExecution.SUCCEEDED: - raise FlyteUserException(f"Failed to run the task {self._entity.name}") + if resource.phase != TaskExecution.SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") - # Read the literals from a remote file, if agent doesn't return the output literals. - if task_template.interface.outputs and len(res.resource.outputs.literals) == 0: + # Read the literals from a remote file if the agent doesn't return the output literals. + if task_template.interface.outputs and resource.outputs and len(resource.outputs.literals) == 0: local_outputs_file = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(f"{output_prefix}/output/outputs.pb", local_outputs_file) + ctx.file_access.get_data(f"{output_prefix}/outputs.pb", local_outputs_file) output_proto = utils.load_proto_from_file(literals_pb2.LiteralMap, local_outputs_file) return LiteralMap.from_flyte_idl(output_proto) - return LiteralMap.from_flyte_idl(res.resource.outputs) + if resource.outputs and not isinstance(resource.outputs, LiteralMap): + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + + return resource.outputs async def _create( - self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Dict[str, typing.Any] = None - ) -> CreateTaskResponse: + self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None + ) -> ResourceMeta: ctx = FlyteContext.current_context() - # Convert python inputs to literals - literals = inputs or {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) - literal_map = LiteralMap(literals) - + literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) if isinstance(self, PythonFunctionTask): # Write the inputs to a remote file, so that the remote task can read the inputs from this file. path = ctx.file_access.get_random_local_path() @@ -216,58 +310,47 @@ async def _create( ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) - res = await mirror_async_methods( + resource_meta = await mirror_async_methods( self._agent.create, - output_prefix=output_prefix, task_template=task_template, inputs=literal_map, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore - return res + signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore + return resource_meta - async def _get(self, resource_meta: bytes) -> GetTaskResponse: + async def _get(self: T, resource_meta: ResourceMeta) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self.name}...", total=None) task_phase = progress.add_task("[cyan]Task phase: RUNNING, Phase message: ", total=None, visible=False) task_log_links = progress.add_task("[cyan]Log Links: ", total=None, visible=False) with progress: while not is_terminal_phase(phase): progress.start_task(task) time.sleep(1) - res = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) + resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) if self._clean_up_task: await self._clean_up_task sys.exit(1) - phase = res.resource.phase + phase = resource.phase progress.update( task_phase, - description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {res.resource.message}", + description=f"[cyan]Task phase: {TaskExecution.Phase.Name(phase)}, Phase message: {resource.message}", visible=True, ) - log_links = "" - for link in res.log_links: - log_links += f"{link.name}: {link.uri}\n" - if log_links: - progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) + if resource.log_links: + log_links = "" + for link in resource.log_links: + log_links += f"{link.name}: {link.uri}\n" + if log_links: + progress.update(task_log_links, description=f"[cyan]{log_links}", visible=True) - return res + return resource - def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameType) -> Any: if self._clean_up_task is None: co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) self._clean_up_task = asyncio.create_task(co) - - -def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: - args = tt.container.args - for i in range(len(args)): - tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") - tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}/output") - tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") - tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") - tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") - return tt diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py new file mode 100644 index 0000000000..b20c9fdf66 --- /dev/null +++ b/flytekit/extend/backend/utils.py @@ -0,0 +1,52 @@ +import asyncio +import inspect +from typing import Callable, Coroutine + +from flyteidl.core.execution_pb2 import TaskExecution + +import flytekit +from flytekit.models.task import TaskTemplate + + +def mirror_async_methods(func: Callable, **kwargs) -> Coroutine: + if inspect.iscoroutinefunction(func): + return func(**kwargs) + args = [v for _, v in kwargs.items()] + return asyncio.get_running_loop().run_in_executor(None, func, *args) + + +def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: + """ + Convert the state from the agent to the phase in flyte. + """ + state = state.lower() + # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if state in ["failed", "timeout", "timedout", "canceled"]: + return TaskExecution.FAILED + elif state in ["done", "succeeded", "success"]: + return TaskExecution.SUCCEEDED + elif state in ["running"]: + return TaskExecution.RUNNING + raise ValueError(f"Unrecognized state: {state}") + + +def is_terminal_phase(phase: TaskExecution.Phase) -> bool: + """ + Return true if the phase is terminal. + """ + return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED] + + +def get_agent_secret(secret_key: str) -> str: + return flytekit.current_context().secrets.get(secret_key) + + +def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: + args = tt.container.args + for i in range(len(args)): + tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb") + tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}") + tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output") + tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output") + tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint") + return tt diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 0e40055ea5..3392f77009 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -1,26 +1,48 @@ import collections import inspect +import typing from abc import abstractmethod +from dataclasses import asdict, dataclass from typing import Any, Dict, Optional, TypeVar -import jsonpickle -from typing_extensions import get_type_hints +from typing_extensions import Protocol, get_type_hints, runtime_checkable from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin, ResourceMeta -T = TypeVar("T") -SENSOR_MODULE = "sensor_module" -SENSOR_NAME = "sensor_name" -SENSOR_CONFIG_PKL = "sensor_config_pkl" -INPUTS = "inputs" + +@runtime_checkable +class SensorConfig(Protocol): + def to_dict(self) -> typing.Dict[str, Any]: + """ + Serialize the sensor config to a dictionary. + """ + raise NotImplementedError + + @classmethod + def from_dict(cls, d: typing.Dict[str, Any]) -> "SensorConfig": + """ + Deserialize the sensor config from a dictionary. + """ + raise NotImplementedError + + +@dataclass +class SensorMetadata(ResourceMeta): + sensor_module: str + sensor_name: str + sensor_config: Optional[dict] = None + inputs: Optional[dict] = None + + +T = TypeVar("T", bound=SensorConfig) class BaseSensor(AsyncAgentExecutorMixin, PythonTask): """ - Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some + Base class for all sensors. Sensors are tasks that are designed to run forever and periodically check for some condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the sensor agent, and not by the Flyte engine. """ @@ -57,10 +79,9 @@ async def poke(self, **kwargs) -> bool: raise NotImplementedError def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - cfg = { - SENSOR_MODULE: type(self).__module__, - SENSOR_NAME: type(self).__name__, - } - if self._sensor_config is not None: - cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config) - return cfg + sensor_config = self._sensor_config.to_dict() if self._sensor_config else None + return asdict( + SensorMetadata( + sensor_module=type(self).__module__, sensor_name=type(self).__name__, sensor_config=sensor_config + ) + ) diff --git a/flytekit/sensor/file_sensor.py b/flytekit/sensor/file_sensor.py index 2fb3d64ec1..f894546927 100644 --- a/flytekit/sensor/file_sensor.py +++ b/flytekit/sensor/file_sensor.py @@ -1,14 +1,10 @@ -from typing import Optional, TypeVar - from flytekit import FlyteContextManager from flytekit.sensor.base_sensor import BaseSensor -T = TypeVar("T") - class FileSensor(BaseSensor): - def __init__(self, name: str, config: Optional[T] = None, **kwargs): - super().__init__(name=name, sensor_config=config, **kwargs) + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) async def poke(self, path: str) -> bool: file_access = FlyteContextManager.current_context().file_access diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 816360715a..ac718abe35 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -1,62 +1,49 @@ import importlib -import typing from typing import Optional -import cloudpickle -import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata -T = typing.TypeVar("T") - -class SensorEngine(AgentBase): +class SensorEngine(AsyncAgentBase): name = "Sensor" def __init__(self): - super().__init__(task_type="sensor") - - async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - python_interface_inputs = { - name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() - } - ctx = FlyteContextManager.current_context() + super().__init__(task_type_name="sensor", metadata_type=SensorMetadata) + + async def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwarg) -> SensorMetadata: + sensor_metadata = SensorMetadata(**task_template.custom) + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - task_template.custom[INPUTS] = native_inputs - return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) + sensor_metadata.inputs = native_inputs - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) + return sensor_metadata - sensor_module = importlib.import_module(name=meta[SENSOR_MODULE]) - sensor_def = getattr(sensor_module, meta[SENSOR_NAME]) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None + async def get(self, resource_meta: SensorMetadata, **kwargs) -> Resource: + sensor_module = importlib.import_module(name=resource_meta.sensor_module) + sensor_def = getattr(sensor_module, resource_meta.sensor_name) - inputs = meta.get(INPUTS, {}) + inputs = resource_meta.inputs cur_phase = ( TaskExecution.SUCCEEDED - if await sensor_def("sensor", config=sensor_config).poke(**inputs) + if await sensor_def("sensor", config=resource_meta.sensor_config).poke(**inputs) else TaskExecution.RUNNING ) - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=None)) + return Resource(phase=cur_phase, outputs=None) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: SensorMetadata, **kwargs): + return AgentRegistry.register(SensorEngine()) diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py index e52453d7bb..2ff0d0e9a5 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/agent.py @@ -5,12 +5,6 @@ import cloudpickle import jsonpickle -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow.task import AirflowObj, _get_airflow_instance @@ -21,13 +15,13 @@ from airflow.utils.context import Context from flytekit import logger from flytekit.exceptions.user import FlyteUserException -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class ResourceMetadata: +class AirflowMetadata(ResourceMeta): """ This class is used to store the Airflow task configuration. It is serialized and returned to FlytePropeller. """ @@ -37,8 +31,15 @@ class ResourceMetadata: airflow_trigger_callback: str = field(default=None) job_id: typing.Optional[str] = field(default=None) + def encode(self) -> bytes: + return cloudpickle.dumps(self) -class AirflowAgent(AgentBase): + @classmethod + def decode(cls, data: bytes) -> "AirflowMetadata": + return cloudpickle.loads(data) + + +class AirflowAgent(AsyncAgentBase): """ It is used to run Airflow tasks. It is registered as an agent in the AgentRegistry. There are three kinds of Airflow tasks: AirflowOperator, AirflowSensor, and AirflowHook. @@ -62,22 +63,18 @@ class AirflowAgent(AgentBase): name = "Airflow Agent" def __init__(self): - super().__init__(task_type="airflow") + super().__init__(task_type_name="airflow", metadata_type=AirflowMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> AirflowMetadata: airflow_obj = jsonpickle.decode(task_template.custom["task_config_pkl"]) airflow_instance = _get_airflow_instance(airflow_obj) - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) if isinstance(airflow_instance, BaseOperator) and not isinstance(airflow_instance, BaseSensorOperator): try: - resource_meta = ResourceMetadata(airflow_operator=airflow_obj) + resource_meta = AirflowMetadata(airflow_operator=airflow_obj) airflow_instance.execute(context=Context()) except TaskDeferred as td: parameters = td.trigger.__dict__.copy() @@ -90,12 +87,13 @@ async def create( ) resource_meta.airflow_trigger_callback = td.method_name - return CreateTaskResponse(resource_meta=cloudpickle.dumps(resource_meta)) + return resource_meta - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - meta = cloudpickle.loads(resource_meta) - airflow_operator_instance = _get_airflow_instance(meta.airflow_operator) - airflow_trigger_instance = _get_airflow_instance(meta.airflow_trigger) if meta.airflow_trigger else None + async def get(self, resource_meta: AirflowMetadata, **kwargs) -> Resource: + airflow_operator_instance = _get_airflow_instance(resource_meta.airflow_operator) + airflow_trigger_instance = ( + _get_airflow_instance(resource_meta.airflow_trigger) if resource_meta.airflow_trigger else None + ) airflow_ctx = Context() message = None cur_phase = TaskExecution.RUNNING @@ -107,7 +105,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: if airflow_trigger_instance: try: # Airflow trigger returns immediately when - # 1. Failed to get the task status + # 1. Failed to get task status # 2. Task succeeded or failed # succeeded or failed: returns a TriggerEvent with payload # running: runs forever, so set a default timeout (2 seconds) here. @@ -115,7 +113,7 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: event = await asyncio.wait_for(airflow_trigger_instance.run().__anext__(), 2) try: # Trigger callback will check the status of the task in the payload, and raise AirflowException if failed. - trigger_callback = getattr(airflow_operator_instance, meta.airflow_trigger_callback) + trigger_callback = getattr(airflow_operator_instance, resource_meta.airflow_trigger_callback) trigger_callback(context=airflow_ctx, event=typing.cast(TriggerEvent, event).payload) cur_phase = TaskExecution.SUCCEEDED except AirflowException as e: @@ -136,10 +134,10 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: else: raise FlyteUserException("Only sensor and operator are supported.") - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message)) + return Resource(phase=cur_phase, message=message) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + async def delete(self, resource_meta: AirflowMetadata, **kwargs): + return AgentRegistry.register(AirflowAgent()) diff --git a/plugins/flytekit-airflow/tests/test_agent.py b/plugins/flytekit-airflow/tests/test_agent.py index dc4d167b10..57999d5c59 100644 --- a/plugins/flytekit-airflow/tests/test_agent.py +++ b/plugins/flytekit-airflow/tests/test_agent.py @@ -5,10 +5,9 @@ from airflow.operators.python import PythonOperator from airflow.sensors.bash import BashSensor from airflow.sensors.time_sensor import TimeSensor -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.airflow import AirflowObj -from flytekitplugins.airflow.agent import AirflowAgent, ResourceMetadata +from flytekitplugins.airflow.agent import AirflowAgent, AirflowMetadata from flytekit import workflow from flytekit.interfaces.cli_identifiers import Identifier @@ -44,7 +43,7 @@ def test_resource_metadata(): parameters={"task_id": "id", "bash_command": "echo 'hello world'"}, ) trigger_cfg = AirflowObj(module="airflow.trigger.file", name="FileTrigger", parameters={"filepath": "file.txt"}) - meta = ResourceMetadata( + meta = AirflowMetadata( airflow_operator=task_cfg, airflow_trigger=trigger_cfg, airflow_trigger_callback="execute_complete", @@ -89,10 +88,9 @@ async def test_airflow_agent(): ) agent = AirflowAgent() - res = await agent.create("/tmp", dummy_template, None) - metadata = res.resource_meta - res = await agent.get(metadata) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.message == "" + metadata = await agent.create(dummy_template, None) + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.message is None res = await agent.delete(metadata) - assert res == DeleteTaskResponse() + assert res is None diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 4c34285793..0275162f72 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -1,25 +1,16 @@ import datetime -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Dict, Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase -from flytekit.models import literals -from flytekit.models.core.execution import TaskLog +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.models.types import LiteralType, StructuredDatasetType pythonTypeToBigQueryType: Dict[type, str] = { # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes @@ -34,25 +25,24 @@ @dataclass -class Metadata: +class BigQueryMetadata(ResourceMeta): job_id: str project: str location: str -class BigQueryAgent(AgentBase): +class BigQueryAgent(AsyncAgentBase[BigQueryMetadata]): name = "Bigquery Agent" def __init__(self): - super().__init__(task_type="bigquery_query_job_task") + super().__init__(task_type_name="bigquery_query_job_task", metadata_type=BigQueryMetadata) def create( self, - output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, - ) -> CreateTaskResponse: + ) -> BigQueryMetadata: job_config = None if inputs: ctx = FlyteContextManager.current_context() @@ -73,54 +63,36 @@ def create( location = custom["Location"] client = bigquery.Client(project=project, location=location) query_job = client.query(task_template.sql.statement, job_config=job_config) - metadata = Metadata(job_id=str(query_job.job_id), location=location, project=project) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - log_links = [ - TaskLog( - uri=f"https://console.cloud.google.com/bigquery?project={metadata.project}&j=bq:{metadata.location}:{metadata.job_id}&page=queryresults", - name="BigQuery Console", - ).to_flyte_idl() - ] - - job = client.get_job(metadata.job_id, metadata.project, metadata.location) + log_link = TaskLog( + uri=f"https://console.cloud.google.com/bigquery?project={resource_meta.project}&j=bq:{resource_meta.location}:{resource_meta.job_id}&page=queryresults", + name="BigQuery Console", + ) + + job = client.get_job(resource_meta.job_id, resource_meta.project, resource_meta.location) if job.errors: logger.error("failed to run BigQuery job with error:", job.errors.__str__()) - return GetTaskResponse( - resource=Resource(state=TaskExecution.FAILED, message=job.errors.__str__()), log_links=log_links - ) + return Resource(phase=TaskExecution.FAILED, message=job.errors.__str__(), log_links=[log_link]) cur_phase = convert_to_flyte_phase(str(job.state)) res = None if cur_phase == TaskExecution.SUCCEEDED: - ctx = FlyteContextManager.current_context() - if job.destination: - output_location = ( - f"bq://{job.destination.project}:{job.destination.dataset_id}.{job.destination.table_id}" - ) - res = literals.LiteralMap( - { - "results": TypeEngine.to_literal( - ctx, - StructuredDataset(uri=output_location), - StructuredDataset, - LiteralType(structured_dataset_type=StructuredDatasetType(format="")), - ) - } - ).to_flyte_idl() - - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res), log_links=log_links) - - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + dst = job.destination + if dst: + ctx = FlyteContextManager.current_context() + output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" + res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)}) + + return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res) + + def delete(self, resource_meta: BigQueryMetadata, **kwargs): client = bigquery.Client() - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - client.cancel_job(metadata.job_id, metadata.project, metadata.location) - return DeleteTaskResponse() + client.cancel_job(resource_meta.job_id, resource_meta.project, resource_meta.location) AgentRegistry.register(BigQueryAgent()) diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index dc2af4ab80..5897b4b468 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -1,10 +1,8 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.bigquery.agent import Metadata +from flytekitplugins.bigquery.agent import BigQueryMetadata import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry @@ -86,20 +84,18 @@ def __init__(self): sql=Sql("SELECT 1"), ) - metadata_bytes = json.dumps( - asdict(Metadata(job_id="dummy_id", project="dummy_project", location="us-central1")) - ).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + metadata = BigQueryMetadata(job_id="dummy_id", project="dummy_project", location="us-central1") + assert agent.create(dummy_template, task_inputs) == metadata + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" ) - assert res.log_links[0].name == "BigQuery Console" + assert resource.log_links[0].name == "BigQuery Console" assert ( - res.log_links[0].uri + resource.log_links[0].uri == "https://console.cloud.google.com/bigquery?project=dummy_project&j=bq:us-central1:dummy_id&page=queryresults" ) - agent.delete(metadata_bytes) + agent.delete(metadata) mock_instance.cancel_job.assert_called() diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py index 285be4e88b..e0dbceada2 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py @@ -1,30 +1,29 @@ import json import shlex import subprocess -from dataclasses import asdict, dataclass +from dataclasses import dataclass from tempfile import NamedTemporaryFile from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_phase from flytekit import current_context -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta from flytekit.loggers import logger from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @dataclass -class Metadata: +class MMCloudMetadata(ResourceMeta): job_id: str -class MMCloudAgent(AgentBase): +class MMCloudAgent(AsyncAgentBase): name = "MMCloud Agent" def __init__(self): - super().__init__(task_type="mmcloud_task", asynchronous=True) + super().__init__(task_type_name="mmcloud_task", metadata_type=MMCloudMetadata) self._response_format = ["--format", "json"] async def async_login(self): @@ -57,10 +56,10 @@ async def async_login(self): logger.info("Logged in to OpCenter") async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> MMCloudMetadata: """ - Submit Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. + Submit a Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. """ submit_command = [ "float", @@ -128,16 +127,13 @@ async def create( logger.exception("Cannot open job script for writing") raise - metadata = Metadata(job_id=job_id) + return MMCloudMetadata(job_id=job_id) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: + async def get(self, resource_meta: MMCloudMetadata, **kwargs) -> Resource: """ Return the status of the task, and return the outputs on success. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id show_command = [ "float", @@ -173,14 +169,13 @@ async def async_get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}") logger.debug(f"OpCenter response: {show_response}") - return GetTaskResponse(resource=Resource(phase=task_phase)) + return Resource(phase=task_phase) - async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: + async def delete(self, resource_meta: MMCloudMetadata, **kwargs): """ Delete the task. This call should be idempotent. """ - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - job_id = metadata.job_id + job_id = resource_meta.job_id cancel_command = [ "float", @@ -203,7 +198,5 @@ async def async_delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskRespon logger.info(f"Submitted cancel request for MMCloud job: {job_id}") - return DeleteTaskResponse() - AgentRegistry.register(MMCloudAgent()) diff --git a/plugins/flytekit-openai/README.md b/plugins/flytekit-openai/README.md new file mode 100644 index 0000000000..21c5553ce7 --- /dev/null +++ b/plugins/flytekit-openai/README.md @@ -0,0 +1,44 @@ +# Flytekit ChatGPT Plugin +ChatGPT plugin allows you to run ChatGPT tasks in the Flyte workflow without changing any code. + +## Example +```python +from flytekit import task, workflow +from flytekitplugins.chatgpt import ChatGPTTask, ChatGPTConfig + +chatgpt_small_job = ChatGPTTask( + name="chatgpt gpt-3.5-turbo", + openai_organization="org-NayNG68kGnVXMJ8Ak4PMgQv7", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, +) + +chatgpt_big_job = ChatGPTTask( + name="chatgpt gpt-4", + openai_organization="org-NayNG68kGnVXMJ8Ak4PMgQv7", + chatgpt_config={ + "model": "gpt-4", + "temperature": 0.7, + }, +) + + +@workflow +def wf(message: str) -> str: + message = chatgpt_small_job(message=message) + message = chatgpt_big_job(message=message) + return message + + +if __name__ == "__main__": + print(wf(message="hi")) +``` + + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-chatgpt +``` diff --git a/plugins/flytekit-openai/dev-requirements.in b/plugins/flytekit-openai/dev-requirements.in new file mode 100644 index 0000000000..2d73dba5b4 --- /dev/null +++ b/plugins/flytekit-openai/dev-requirements.in @@ -0,0 +1 @@ +pytest-asyncio diff --git a/plugins/flytekit-openai/dev-requirements.txt b/plugins/flytekit-openai/dev-requirements.txt new file mode 100644 index 0000000000..6ee1144cb8 --- /dev/null +++ b/plugins/flytekit-openai/dev-requirements.txt @@ -0,0 +1,20 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +exceptiongroup==1.2.0 + # via pytest +iniconfig==2.0.0 + # via pytest +packaging==23.2 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.4 + # via pytest-asyncio +pytest-asyncio==0.23.3 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..1de5b544d1 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,15 @@ +""" +.. currentmodule:: flytekitplugins.chatgpt + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ChatGPTAgent + ChatGPTTask +""" + +from .agent import ChatGPTAgent +from .task import ChatGPTTask diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py new file mode 100644 index 0000000000..cb3e8c46c3 --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/agent.py @@ -0,0 +1,57 @@ +import asyncio +import logging +from typing import Optional + +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekit import FlyteContextManager, lazy_module +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase +from flytekit.extend.backend.utils import get_agent_secret +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +openai = lazy_module("openai") + +TIMEOUT_SECONDS = 10 +OPENAI_API_KEY = "FLYTE_OPENAI_API_KEY" + + +class ChatGPTAgent(SyncAgentBase): + name = "ChatGPT Agent" + + def __init__(self): + super().__init__(task_type_name="chatgpt") + + async def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> Resource: + ctx = FlyteContextManager.current_context() + input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) + message = input_python_value["message"] + + custom = task_template.custom + custom["chatgpt_config"]["messages"] = [{"role": "user", "content": message}] + client = openai.AsyncOpenAI( + organization=custom["openai_organization"], + api_key=get_agent_secret(secret_key=OPENAI_API_KEY), + ) + + logger = logging.getLogger("httpx") + logger.setLevel(logging.WARNING) + + try: + completion = await asyncio.wait_for( + client.chat.completions.create(**custom["chatgpt_config"]), TIMEOUT_SECONDS + ) + message = completion.choices[0].message.content + outputs = {"o0": message} + + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) + except Exception as error_message: + return Resource(phase=TaskExecution.FAILED, message=str(error_message)) + + +AgentRegistry.register(ChatGPTAgent()) diff --git a/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py new file mode 100644 index 0000000000..c37a40650d --- /dev/null +++ b/plugins/flytekit-openai/flytekitplugins/chatgpt/task.py @@ -0,0 +1,44 @@ +from typing import Any, Dict + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin + + +class ChatGPTTask(SyncAgentExecutorMixin, PythonTask): + """ + This is the simplest form of a ChatGPT Task, you can define the model and the input you want. + """ + + _TASK_TYPE = "chatgpt" + + def __init__(self, name: str, openai_organization: str, chatgpt_config: Dict[str, Any], **kwargs): + """ + Args: + name: Name of this task, should be unique in the project + openai_organization: OpenAI Organization. String can be found here. https://platform.openai.com/docs/api-reference/organization-optional + chatgpt_config: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create + """ + + if "model" not in chatgpt_config: + raise ValueError("The 'model' configuration variable is required in chatgpt_config") + + task_config = {"openai_organization": openai_organization, "chatgpt_config": chatgpt_config} + + inputs = {"message": str} + outputs = {"o0": str} + + super().__init__( + task_type=self._TASK_TYPE, + name=name, + task_config=task_config, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "openai_organization": self.task_config["openai_organization"], + "chatgpt_config": self.task_config["chatgpt_config"], + } diff --git a/plugins/flytekit-openai/setup.py b/plugins/flytekit-openai/setup.py new file mode 100644 index 0000000000..53b3e1e212 --- /dev/null +++ b/plugins/flytekit-openai/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.10.7", "openai>=1.12.0", "flyteidl>=1.10.7"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the ChatGPT plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-openai/tests/test_agent.py b/plugins/flytekit-openai/tests/test_agent.py new file mode 100644 index 0000000000..dd340bd1a7 --- /dev/null +++ b/plugins/flytekit-openai/tests/test_agent.py @@ -0,0 +1,69 @@ +from datetime import timedelta +from unittest import mock + +import pytest +from flyteidl.core.execution_pb2 import TaskExecution + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals +from flytekit.models.core.identifier import ResourceType +from flytekit.models.literals import LiteralMap +from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate + + +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + +@pytest.mark.asyncio +async def test_chatgpt_agent(): + agent = AgentRegistry.get_agent("chatgpt") + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_config = { + "openai_organization": "test-openai-orgnization-id", + "chatgpt_config": {"model": "gpt-3.5-turbo", "temperature": 0.7}, + } + task_metadata = TaskMetadata( + True, + RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + tmp = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=None, + type="chatgpt", + ) + + task_inputs = LiteralMap( + { + "message": literals.Literal( + scalar=literals.Scalar(primitive=literals.Primitive(string_value="Test ChatGPT Plugin")) + ), + }, + ) + message = "mocked_message" + mocked_token = "mocked_openai_api_key" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + with mock.patch("openai.resources.chat.completions.AsyncCompletions.create", new=mock_acreate): + # Directly await the coroutine without using asyncio.run + response = await agent.do(tmp, task_inputs) + + assert response.phase == TaskExecution.SUCCEEDED + assert response.outputs == {"o0": message} diff --git a/plugins/flytekit-openai/tests/test_chatgpt.py b/plugins/flytekit-openai/tests/test_chatgpt.py new file mode 100644 index 0000000000..f85f94cc7b --- /dev/null +++ b/plugins/flytekit-openai/tests/test_chatgpt.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +from flytekitplugins.chatgpt import ChatGPTTask + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable +from flytekit.models.types import SimpleType + + +def test_chatgpt_task(): + chatgpt_task = ChatGPTTask( + name="chatgpt", + openai_organization="TEST ORGANIZATION ID", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, + ) + + assert len(chatgpt_task.interface.inputs) == 1 + assert len(chatgpt_task.interface.outputs) == 1 + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + chatgpt_task_spec = get_serializable(OrderedDict(), serialization_settings, chatgpt_task) + custom = chatgpt_task_spec.template.custom + assert custom["openai_organization"] == "TEST ORGANIZATION ID" + assert custom["chatgpt_config"]["model"] == "gpt-3.5-turbo" + assert custom["chatgpt_config"]["temperature"] == 0.7 + + assert len(chatgpt_task_spec.template.interface.inputs) == 1 + assert len(chatgpt_task_spec.template.interface.outputs) == 1 + + assert chatgpt_task_spec.template.interface.inputs["message"].type.simple == SimpleType.STRING + assert chatgpt_task_spec.template.interface.outputs["o0"].type.simple == SimpleType.STRING diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index d06bc68085..f350661a03 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -1,18 +1,12 @@ -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -25,7 +19,7 @@ @dataclass -class Metadata: +class SnowflakeJobMetadata(ResourceMeta): user: str account: str database: str @@ -53,7 +47,7 @@ def get_private_key(): return pkb -def get_connection(metadata: Metadata) -> snowflake_connector: +def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector: return snowflake_connector.connect( user=metadata.user, account=metadata.account, @@ -64,13 +58,13 @@ def get_connection(metadata: Metadata) -> snowflake_connector: ) -class SnowflakeAgent(AgentBase): +class SnowflakeAgent(AsyncAgentBase): def __init__(self): - super().__init__(task_type=TASK_TYPE) + super().__init__(task_type_name=TASK_TYPE, metadata_type=SnowflakeJobMetadata) async def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> SnowflakeJobMetadata: params = None if inputs: ctx = FlyteContextManager.current_context() @@ -95,7 +89,7 @@ async def create( cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) - metadata = Metadata( + return SnowflakeJobMetadata( user=config["user"], account=config["account"], database=config["database"], @@ -105,22 +99,19 @@ async def create( query_id=str(cs.sfqid), ) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: + conn = get_connection(resource_meta) try: - query_status = conn.get_query_status_throw_if_error(metadata.query_id) + query_status = conn.get_query_status_throw_if_error(resource_meta.query_id) except snowflake_connector.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) - return GetTaskResponse(resource=Resource(state=TaskExecution.FAILED)) + return Resource(phase=TaskExecution.FAILED) cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None if cur_phase == TaskExecution.SUCCEEDED: ctx = FlyteContextManager.current_context() - output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" + output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( @@ -132,19 +123,17 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(phase=cur_phase, outputs=res)) + return Resource(phase=cur_phase, outputs=res) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - conn = get_connection(metadata) + async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): + conn = get_connection(resource_meta) cs = conn.cursor() try: - cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{resource_meta.query_id}')") cs.fetchall() finally: cs.close() conn.close() - return DeleteTaskResponse() AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 017297704e..f3dcb0686d 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -1,13 +1,10 @@ -import json -from dataclasses import asdict from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.snowflake.agent import Metadata +from flytekitplugins.snowflake.agent import SnowflakeJobMetadata import flytekit.models.interface as interface_models from flytekit import lazy_module @@ -30,8 +27,11 @@ async def test_snowflake_agent(mock_get_private_key): mock_conn_instance = snowflake_connector.connect.return_value mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock - agent = AgentRegistry.get_agent("snowflake") + mock_cursor = MagicMock() + mock_cursor.sfqid = "dummy_id" + mock_conn_instance.cursor.return_value = mock_cursor + agent = AgentRegistry.get_agent("snowflake") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" ) @@ -82,32 +82,28 @@ async def test_snowflake_agent(mock_get_private_key): sql=Sql("SELECT 1"), ) - metadata = Metadata( + snowflake_metadata = SnowflakeJobMetadata( user="dummy_user", account="dummy_account", table="dummy_table", database="dummy_database", schema="dummy_schema", warehouse="dummy_warehouse", - query_id="dummy_query_id", + query_id="dummy_id", ) - res = await agent.create("/tmp", dummy_template, task_inputs) - metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id - metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") - assert res.resource_meta == metadata_bytes + metadata = await agent.create(dummy_template, task_inputs) + assert metadata == snowflake_metadata - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED assert ( - res.resource.outputs.literals["results"].scalar.structured_dataset.uri + resource.outputs.literals["results"].scalar.structured_dataset.uri == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" ) - delete_response = await agent.delete(metadata_bytes) - - # Assert the response - assert isinstance(delete_response, DeleteTaskResponse) + delete_response = await agent.delete(snowflake_metadata) + assert delete_response is None # Verify that the expected methods were called on the mock cursor mock_cursor = mock_conn_instance.cursor.return_value diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 2fe442182a..8200263ac3 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,15 +1,14 @@ import http import json -import pickle import typing from dataclasses import dataclass from typing import Optional -from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource from flyteidl.core.execution_pb2 import TaskExecution from flytekit import lazy_module -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_phase, get_agent_secret +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -20,24 +19,20 @@ @dataclass -class Metadata: +class DatabricksJobMetadata(ResourceMeta): databricks_instance: str run_id: str -class DatabricksAgent(AgentBase): +class DatabricksAgent(AsyncAgentBase): name = "Databricks Agent" def __init__(self): - super().__init__(task_type="spark", asynchronous=True) + super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> DatabricksJobMetadata: custom = task_template.custom container = task_template.container databricks_job = custom["databricksConf"] @@ -72,21 +67,18 @@ async def create( if resp.status != http.HTTPStatus.OK: raise Exception(f"Failed to create databricks job with error: {response}") - metadata = Metadata( - databricks_instance=databricks_instance, - run_id=str(response["run_id"]), - ) - return CreateTaskResponse(resource_meta=pickle.dumps(metadata)) + return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"])) - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - metadata = pickle.loads(resource_meta) - databricks_instance = metadata.databricks_instance - databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={metadata.run_id}" + async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: + databricks_instance = resource_meta.databricks_instance + databricks_url = ( + f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" + ) async with aiohttp.ClientSession() as session: async with session.get(databricks_url, headers=get_header()) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() cur_phase = TaskExecution.RUNNING @@ -99,25 +91,21 @@ async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: message = state["state_message"] job_id = response.get("job_id") - databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{metadata.run_id}" + databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}" log_links = [TaskLog(uri=databricks_console_url, name="Databricks Console").to_flyte_idl()] - return GetTaskResponse(resource=Resource(phase=cur_phase, message=message), log_links=log_links) + return Resource(phase=cur_phase, message=message, log_links=log_links) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - metadata = pickle.loads(resource_meta) - - databricks_url = f"https://{metadata.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" - data = json.dumps({"run_id": metadata.run_id}) + async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): + databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" + data = json.dumps({"run_id": resource_meta.run_id}) async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != http.HTTPStatus.OK: - raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") + raise Exception(f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}") await resp.json() - return DeleteTaskResponse() - def get_header() -> typing.Dict[str, str]: token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 5d19b3402f..80f91c5c76 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -1,12 +1,11 @@ import http -import pickle from datetime import timedelta from unittest import mock import pytest from aioresponses import aioresponses from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, Metadata, get_header +from flytekitplugins.spark.agent import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier @@ -103,11 +102,9 @@ async def test_databricks_agent(): mocked_context = mock.patch("flytekit.current_context", autospec=True).start() mocked_context.return_value.secrets.get.return_value = mocked_token - metadata_bytes = pickle.dumps( - Metadata( - databricks_instance="test-account.cloud.databricks.com", - run_id="123", - ) + databricks_metadata = DatabricksJobMetadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", ) mock_create_response = {"run_id": "123"} @@ -118,19 +115,19 @@ async def test_databricks_agent(): delete_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/cancel" with aioresponses() as mocked: mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) - res = await agent.create("/tmp", dummy_template, None) - assert res.resource_meta == metadata_bytes + res = await agent.create(dummy_template, None) + assert res == databricks_metadata mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() - assert res.resource.message == "OK" - assert res.log_links[0].name == "Databricks Console" - assert res.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" + resource = await agent.get(databricks_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.outputs is None + assert resource.message == "OK" + assert resource.log_links[0].name == "Databricks Console" + assert resource.log_links[0].uri == "https://test-account.cloud.databricks.com/#job/1/run/123" mocked.post(delete_url, status=http.HTTPStatus.OK, payload=mock_delete_response) - await agent.delete(metadata_bytes) + await agent.delete(databricks_metadata) assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f92489e9c4..a1c137dc48 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -26,9 +26,10 @@ @pytest.fixture(scope="session") def register(): - subprocess.run( + out = subprocess.run( [ "pyflyte", + "--verbose", "-c", CONFIG, "register", @@ -43,6 +44,7 @@ def register(): MODULE_PATH, ] ) + assert out.returncode == 0 def test_fetch_execute_launch_plan(register): @@ -52,7 +54,7 @@ def test_fetch_execute_launch_plan(register): assert execution.outputs["o0"] == "hello world" -def fetch_execute_launch_plan_with_args(register): +def test_fetch_execute_launch_plan_with_args(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.basic_workflow.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={"a": 10, "b": "foobar"}, wait=True) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 68dee74b3f..bfabead499 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -1,111 +1,107 @@ -import asyncio -import json import typing from collections import OrderedDict -from dataclasses import asdict, dataclass +from dataclasses import dataclass from unittest.mock import MagicMock, patch import grpc import pytest from flyteidl.admin.agent_pb2 import ( + CreateRequestHeader, CreateTaskRequest, - CreateTaskResponse, DeleteTaskRequest, - DeleteTaskResponse, + ExecuteTaskSyncRequest, + GetAgentRequest, GetTaskRequest, - GetTaskResponse, ListAgentsRequest, ListAgentsResponse, - Resource, + TaskCategory, ) -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog from flytekit import PythonFunctionTask, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings -from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService +from flytekit.core.base_task import PythonTask, kwtypes +from flytekit.core.interface import Interface +from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( - AgentBase, AgentRegistry, + AsyncAgentBase, AsyncAgentExecutorMixin, - convert_to_flyte_phase, - get_agent_secret, + Resource, + ResourceMeta, + SyncAgentBase, + SyncAgentExecutorMixin, is_terminal_phase, render_task_template, ) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models import literals -from flytekit.models.core.execution import TaskLog from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.tools.translator import get_serializable dummy_id = "dummy_id" -loop = asyncio.get_event_loop() @dataclass -class Metadata: +class DummyMetadata(ResourceMeta): job_id: str -class DummyAgent(AgentBase): +class DummyAgent(AsyncAgentBase): name = "Dummy Agent" def __init__(self): - super().__init__(task_type="dummy", asynchronous=False) + super().__init__(task_type_name="dummy", metadata_type=DummyMetadata) - def create( - self, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + def create(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap], **kwargs) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) - def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED), - log_links=[TaskLog(name="console", uri="localhost:3000").to_flyte_idl()], - ) + def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) - def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() + def delete(self, resource_meta: DummyMetadata, **kwargs): + ... -class AsyncDummyAgent(AgentBase): +class AsyncDummyAgent(AsyncAgentBase): name = "Async Dummy Agent" def __init__(self): - super().__init__(task_type="async_dummy") + super().__init__(task_type_name="async_dummy", metadata_type=DummyMetadata) async def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs + ) -> DummyMetadata: + return DummyMetadata(job_id=dummy_id) + + async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource: + return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")]) + + async def delete(self, resource_meta: DummyMetadata, **kwargs): + ... - async def get(self, resource_meta: bytes, **kwargs) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(phase=TaskExecution.SUCCEEDED)) - async def delete(self, resource_meta: bytes, **kwargs) -> DeleteTaskResponse: - return DeleteTaskResponse() +class MockOpenAIAgent(SyncAgentBase): + name = "mock openAI Agent" + def __init__(self): + super().__init__(task_type_name="openai") + + def do(self, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) -class SyncDummyAgent(AgentBase): - name = "Sync Dummy Agent" + +class MockAsyncOpenAIAgent(SyncAgentBase): + name = "mock async openAI Agent" def __init__(self): - super().__init__(task_type="sync_dummy") + super().__init__(task_type_name="async_openai") - def create( - self, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - **kwargs, - ) -> CreateTaskResponse: - return CreateTaskResponse( - resource=Resource(phase=TaskExecution.SUCCEEDED, outputs=LiteralMap({}).to_flyte_idl()) - ) + async def do(self, task_template: TaskTemplate, inputs: LiteralMap = None, **kwargs) -> Resource: + assert inputs.literals["a"].scalar.primitive.integer == 1 + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"o0": 1}) def get_task_template(task_type: str) -> TaskTemplate: @@ -134,115 +130,140 @@ def simple_task(i: int): ) -dummy_template = get_task_template("dummy") -async_dummy_template = get_task_template("async_dummy") -sync_dummy_template = get_task_template("sync_dummy") - - def test_dummy_agent(): - AgentRegistry.register(DummyAgent()) + AgentRegistry.register(DummyAgent(), override=True) agent = AgentRegistry.get_agent("dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert agent.create("/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - res = agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.log_links[0].name == "console" - assert res.log_links[0].uri == "localhost:3000" - assert agent.delete(metadata_bytes) == DeleteTaskResponse() + template = get_task_template("dummy") + metadata = DummyMetadata(job_id=dummy_id) + assert agent.create(template, task_inputs) == DummyMetadata(job_id=dummy_id) + resource = agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + assert resource.log_links[0].name == "console" + assert resource.log_links[0].uri == "localhost:3000" + assert agent.delete(metadata) is None class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): - super().__init__( - task_type="dummy", - **kwargs, - ) + super().__init__(task_type="dummy", **kwargs) t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."): + with pytest.raises(Exception, match="Cannot find agent for task category: non-exist-type."): t.execute() - agent_metadata = AgentRegistry.get_agent_metadata("Dummy Agent") - assert agent_metadata.name == "Dummy Agent" - assert agent_metadata.supported_task_types == ["dummy"] - +@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"]) @pytest.mark.asyncio -async def test_async_dummy_agent(): - AgentRegistry.register(AsyncDummyAgent()) - agent = AgentRegistry.get_agent("async_dummy") - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await agent.create("/tmp", async_dummy_template, task_inputs) +async def test_async_agent_service(agent): + AgentRegistry.register(agent, override=True) + service = AsyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) + + inputs_proto = task_inputs.to_flyte_idl() + output_prefix = "/tmp" + metadata_bytes = DummyMetadata(job_id=dummy_id).encode() + + tmp = get_task_template(agent.task_category.name).to_flyte_idl() + task_category = TaskCategory(name=agent.task_category.name, version=0) + req = CreateTaskRequest(inputs=inputs_proto, output_prefix=output_prefix, template=tmp) + + res = await service.CreateTask(req, ctx) assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) + res = await service.GetTask(GetTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + res = await service.DeleteTask(DeleteTaskRequest(task_category=task_category, resource_meta=metadata_bytes), ctx) + assert res is None - agent_metadata = AgentRegistry.get_agent_metadata("Async Dummy Agent") - assert agent_metadata.name == "Async Dummy Agent" - assert agent_metadata.supported_task_types == ["async_dummy"] + agent_metadata = AgentRegistry.get_agent_metadata(agent.name) + assert agent_metadata.supported_task_types[0] == agent.task_category.name + assert agent_metadata.supported_task_categories[0].name == agent.task_category.name -@pytest.mark.asyncio -async def test_sync_dummy_agent(): - AgentRegistry.register(SyncDummyAgent()) - agent = AgentRegistry.get_agent("sync_dummy") - res = agent.create("/tmp", sync_dummy_template, task_inputs) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() +def test_register_agent(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + assert AgentRegistry.get_agent("dummy").name == agent.name - agent_metadata = AgentRegistry.get_agent_metadata("Sync Dummy Agent") - assert agent_metadata.name == "Sync Dummy Agent" - assert agent_metadata.supported_task_types == ["sync_dummy"] + with pytest.raises(ValueError, match="Duplicate agent for task type: dummy_v0"): + AgentRegistry.register(agent) @pytest.mark.asyncio -async def run_agent_server(): - service = AsyncAgentService() +async def test_agent_metadata_service(): + agent = DummyAgent() + AgentRegistry.register(agent, override=True) + ctx = MagicMock(spec=grpc.ServicerContext) - request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() - ) - async_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() - ) - sync_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() - ) - fake_agent = "fake" - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + metadata_service = AgentMetadataService() + res = await metadata_service.ListAgents(ListAgentsRequest(), ctx) + assert isinstance(res, ListAgentsResponse) + res = await metadata_service.GetAgent(GetAgentRequest(name="Dummy Agent"), ctx) + assert res.agent.name == agent.name + assert res.agent.supported_task_types[0] == agent.task_category.name + assert res.agent.supported_task_categories[0].name == agent.task_category.name - res = await service.CreateTask(request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) - res = await service.CreateTask(async_request, ctx) - assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) - assert isinstance(res, DeleteTaskResponse) +def test_openai_agent(): + AgentRegistry.register(MockOpenAIAgent(), override=True) - res = await service.CreateTask(sync_request, ctx) - assert res.resource.phase == TaskExecution.SUCCEEDED - assert res.resource.outputs == LiteralMap({}).to_flyte_idl() + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="openai", interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), **kwargs + ) - res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) - assert res is None + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +def test_async_openai_agent(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + + class OpenAITask(SyncAgentExecutorMixin, PythonTask): + def __init__(self, **kwargs): + super().__init__( + task_type="async_openai", + interface=Interface(inputs=kwtypes(a=int), outputs=kwtypes(o0=int)), + **kwargs, + ) + + t = OpenAITask(task_config={}, name="openai task") + res = t(a=1) + assert res == 1 + + +async def get_request_iterator(task_type: str): + inputs_proto = task_inputs.to_flyte_idl() + template = get_task_template(task_type).to_flyte_idl() + header = CreateRequestHeader(template=template, output_prefix="/tmp") + yield ExecuteTaskSyncRequest(header=header) + yield ExecuteTaskSyncRequest(inputs=inputs_proto) - metadata_service = AgentMetadataService() - res = await metadata_service.ListAgent(ListAgentsRequest(), ctx) - assert isinstance(res, ListAgentsResponse) +@pytest.mark.asyncio +async def test_sync_agent_service(): + AgentRegistry.register(MockOpenAIAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) + + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 + + +@pytest.mark.asyncio +async def test_sync_agent_service_with_asyncio(): + AgentRegistry.register(MockAsyncOpenAIAgent(), override=True) + AgentRegistry.register(DummyAgent(), override=True) + ctx = MagicMock(spec=grpc.ServicerContext) -def test_agent_server(): - loop.run_in_executor(None, run_agent_server) + service = SyncAgentService() + res = await service.ExecuteTaskSync(get_request_iterator("async_openai"), ctx).__anext__() + assert res.header.resource.phase == TaskExecution.SUCCEEDED + assert res.header.resource.outputs.literals["o0"].scalar.primitive.integer == 1 def test_is_terminal_phase(): @@ -276,7 +297,8 @@ def test_get_agent_secret(mocked_context): def test_render_task_template(): - tt = render_task_template(dummy_template, "s3://becket") + template = get_task_template("dummy") + tt = render_task_template(template, "s3://becket") assert tt.container.args == [ "pyflyte-fast-execute", "--additional-distribution", @@ -288,7 +310,7 @@ def test_render_task_template(): "--inputs", "s3://becket/inputs.pb", "--output-prefix", - "s3://becket/output", + "s3://becket", "--raw-output-data-prefix", "s3://becket/raw_output", "--checkpoint-path", diff --git a/tests/flytekit/unit/sensor/test_file_sensor.py b/tests/flytekit/unit/sensor/test_file_sensor.py index f6a50836be..bb0553dc27 100644 --- a/tests/flytekit/unit/sensor/test_file_sensor.py +++ b/tests/flytekit/unit/sensor/test_file_sensor.py @@ -16,7 +16,12 @@ def test_sensor_task(): env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"} + assert sensor.get_custom(settings) == { + "sensor_module": "flytekit.sensor.file_sensor", + "sensor_name": "FileSensor", + "sensor_config": None, + "inputs": None, + } tmp_file = tempfile.NamedTemporaryFile() @task() diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index b5353b61b4..4a12aed877 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -1,20 +1,20 @@ import tempfile +from dataclasses import asdict -import cloudpickle import pytest -from flyteidl.admin.agent_pb2 import DeleteTaskResponse from flyteidl.core.execution_pb2 import TaskExecution import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals, types from flytekit.sensor import FileSensor -from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SensorMetadata from tests.flytekit.unit.extend.test_agent import get_task_template @pytest.mark.asyncio async def test_sensor_engine(): + file = tempfile.NamedTemporaryFile() interfaces = interface_models.TypedInterface( { "path": interface_models.Variable(types.LiteralType(types.SimpleType.STRING), "description1"), @@ -22,12 +22,10 @@ async def test_sensor_engine(): {}, ) tmp = get_task_template("sensor") - tmp._custom = { - SENSOR_MODULE: FileSensor.__module__, - SENSOR_NAME: FileSensor.__name__, - } - file = tempfile.NamedTemporaryFile() - + sensor_metadata = SensorMetadata( + sensor_module=FileSensor.__module__, sensor_name=FileSensor.__name__, inputs={"path": file.name} + ) + tmp._custom = asdict(sensor_metadata) tmp._interface = interfaces task_inputs = literals.LiteralMap( @@ -37,11 +35,10 @@ async def test_sensor_engine(): ) agent = AgentRegistry.get_agent("sensor") - res = await agent.create("/tmp", tmp, task_inputs) + res = await agent.create(tmp, task_inputs) - metadata_bytes = cloudpickle.dumps(tmp.custom) - assert res.resource_meta == metadata_bytes - res = await agent.get(metadata_bytes) - assert res.resource.phase == TaskExecution.SUCCEEDED - res = await agent.delete(metadata_bytes) - assert res == DeleteTaskResponse() + assert res == sensor_metadata + resource = await agent.get(sensor_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + res = await agent.delete(sensor_metadata) + assert res is None