Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChatGPT Agent V2 #2086

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
b525b5d
Add SyncAgentBase
pingsutw Jan 30, 2024
eaf4a97
Merged master
pingsutw Feb 10, 2024
3f99c0e
Cleanup
pingsutw Feb 10, 2024
d0862e9
Use task type in registry
pingsutw Feb 10, 2024
b55a005
Update idl
pingsutw Feb 10, 2024
46d9ee6
fix tests
pingsutw Feb 10, 2024
e521f3a
ExecuteTaskSync
pingsutw Feb 10, 2024
bb8ce05
update idl
pingsutw Feb 10, 2024
8ea23d8
nit
pingsutw Feb 10, 2024
a988135
Merge branch 'master' of github.com:flyteorg/flytekit into sync-agent…
pingsutw Feb 10, 2024
b58201d
Add unit tests
pingsutw Feb 12, 2024
5963b2c
fix tests
pingsutw Feb 12, 2024
00258d3
more unit tests
pingsutw Feb 13, 2024
dfc0f6d
fixed tests
pingsutw Feb 13, 2024
04a063a
wip
pingsutw Feb 13, 2024
4c9095d
fix integration tests
pingsutw Feb 13, 2024
3399d5e
fix tests
pingsutw Feb 13, 2024
6b50c26
fix integration tests
pingsutw Feb 13, 2024
e39e6dc
fix integration tests
pingsutw Feb 13, 2024
5ca81be
fix tests
pingsutw Feb 13, 2024
d688ee2
fix tests
pingsutw Feb 13, 2024
c496f5b
fix tests
pingsutw Feb 13, 2024
d22bc3c
fix tests
pingsutw Feb 13, 2024
c5ac926
wip
pingsutw Feb 17, 2024
575def9
merged master
pingsutw Feb 18, 2024
baf827d
update idl
pingsutw Feb 18, 2024
b0e4024
fix tests
pingsutw Feb 18, 2024
a10745a
fix tests
pingsutw Feb 18, 2024
f34e56c
init
Jan 3, 2024
33c9f08
ugly version
Jan 3, 2024
32fbcd3
succeed v1
Jan 3, 2024
43a0608
succeed v2
Jan 3, 2024
5d98a82
remove unused code
Jan 3, 2024
7bd870c
add test_agent
Jan 4, 2024
b70aa12
add dev-requirements
Jan 4, 2024
39d1f36
openai lazy module import
Jan 4, 2024
2cae899
use dataclass
Jan 4, 2024
20664a6
rename config to task_config
Jan 4, 2024
288c36d
change dir name from openai-chatgpt to openai and add chatgpt_task test
Jan 4, 2024
d967aba
add more test
Jan 4, 2024
e310156
add README
Jan 4, 2024
b73497a
better generic type
Jan 6, 2024
8cf9756
change idl version
Jan 9, 2024
9c94018
install latest flyteidl
Jan 9, 2024
ebc7f8f
remove setup.py and add chatgpt to agent's dockerfile
Jan 11, 2024
44f3933
fix tests
pingsutw Feb 19, 2024
c19221a
support openai latest api, test is not complete yet
Future-Outlier Feb 19, 2024
a68f9f3
nit
pingsutw Feb 19, 2024
87845e8
nit
pingsutw Feb 19, 2024
101a0c6
finish test
Future-Outlier Feb 20, 2024
ffb5ae1
make setup, revert
Future-Outlier Feb 20, 2024
1bf859d
Add metric
pingsutw Feb 20, 2024
139fbc0
fix sensor
pingsutw Feb 20, 2024
6983f87
nit
pingsutw Feb 20, 2024
1f0908b
nit
pingsutw Feb 20, 2024
507e3a2
bump idl
pingsutw Feb 20, 2024
2b0c956
solve conflict
Future-Outlier Feb 20, 2024
e1f404c
merge sync_agent_base, implement method
Future-Outlier Feb 20, 2024
cba504b
nit
pingsutw Feb 20, 2024
267accb
Fix airflow tests
pingsutw Feb 20, 2024
1fdd372
Fix airflow tests
pingsutw Feb 21, 2024
65d7cd5
solve conflict
Future-Outlier Feb 21, 2024
b058a98
new interface
Future-Outlier Feb 21, 2024
e610799
add flytekit-chatgpt in pythonbuild.yml
Future-Outlier Feb 21, 2024
8c580eb
fix tests
pingsutw Feb 21, 2024
6e7bd7a
fix tests
pingsutw Feb 21, 2024
a96e08d
fix tests
pingsutw Feb 21, 2024
8a2343e
Merge branch 'sync-agent-base' into chagpt-agent
Future-Outlier Feb 21, 2024
440eaf2
nit
pingsutw Feb 21, 2024
326db36
lint
pingsutw Feb 21, 2024
3698771
nit
pingsutw Feb 21, 2024
519d972
nit
pingsutw Feb 21, 2024
a31befc
resource_meta
pingsutw Feb 21, 2024
f6fae67
nit
pingsutw Feb 21, 2024
8933b82
Merge branch 'sync-agent-base' into chagpt-agent
Future-Outlier Feb 22, 2024
045675f
add SyncAgentBase Error Handling
Future-Outlier Feb 23, 2024
6c36d35
change OPENAI SECRET NAME, fix chatgpt agent test and add error handling
Future-Outlier Feb 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/monodocs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
working-directory: ${{ github.workspace }}/flyte
run: |
conda activate monodocs-env
pip install -e ./flyteidl
pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl
conda info
conda list
conda config --show-sources
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ jobs:
# onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693
# flytekit-onnx-tensorflow
- flytekit-chatgpt
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
Expand Down Expand Up @@ -275,6 +276,7 @@ jobs:
pip freeze
- name: Test with coverage
run: |
pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl
cd plugins/${{ matrix.plugin-names }}
# onnx plugins does not support protobuf>4 yet (in fact it is tensorflow that
# does not support that yet). More details in https://github.com/onnx/onnx/issues/4239.
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.agent
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \
flytekitplugins-mmcloud==$VERSION \
flytekitplugins-spark==$VERSION \
flytekitplugins-snowflake==$VERSION \
flytekitplugins-chatgpt==$VERSION \
&& apt-get clean autoclean \
&& apt-get autoremove --yes \
&& rm -rf /var/lib/{apt,dpkg,cache,log}/ \
Expand Down
3 changes: 1 addition & 2 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ COPY . /flytekit
# 4. Create a non-root user 'flytekit' and set appropriate permissions for directories.
RUN apt-get update && apt-get install build-essential vim libmagic1 git -y \
&& pip install --no-cache-dir -U --pre \
flyteidl \
-e /flytekit \
git+https://github.com/flyteorg/flyte.git@a46196d42e80f055efc9a013cb95c20342c72622#subdirectory=flyteidl \
-e /flytekit/plugins/flytekit-k8s-pod \
-e /flytekit/plugins/flytekit-deck-standard \
-e /flytekit/plugins/flytekit-flyteinteractive \
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ update_boilerplate:

.PHONY: setup
setup: install-piptools ## Install requirements
pip install git+https://github.com/flyteorg/flyte.git@b5c1d99c7f85764775c9924d5b98948e4d8478d0#subdirectory=flyteidl
pip install --pre -r dev-requirements.in


.PHONY: fmt
fmt:
pre-commit run ruff --all-files || true
Expand Down
4 changes: 3 additions & 1 deletion flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flyteidl.service.agent_pb2_grpc import (
add_AgentMetadataServiceServicer_to_server,
add_AsyncAgentServiceServicer_to_server,
add_SyncAgentServiceServicer_to_server,
)
from grpc import aio

Expand Down Expand Up @@ -52,7 +53,7 @@ def agent(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService

try:
from prometheus_client import start_http_server
Expand All @@ -64,6 +65,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)
add_SyncAgentServiceServicer_to_server(SyncAgentService(), server)
add_AgentMetadataServiceServicer_to_server(AgentMetadataService(), server)

server.add_insecure_port(f"[::]:{port}")
Expand Down
11 changes: 11 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from google.protobuf.json_format import MessageToDict as _MessageToDict
Expand Down Expand Up @@ -1210,6 +1211,16 @@ def dict_to_literal_map(
raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v)
return LiteralMap(literal_map)

@classmethod
def dict_to_literal_map_pb(
cls,
ctx: FlyteContext,
d: typing.Dict[str, typing.Any],
type_hints: Optional[typing.Dict[str, type]] = None,
) -> Optional[literals_pb2.LiteralMap]:
literal_map = cls.dict_to_literal_map(ctx, d, type_hints)
return literal_map.to_flyte_idl()

@classmethod
def get_available_transformers(cls) -> typing.KeysView[Type]:
"""
Expand Down
137 changes: 102 additions & 35 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
import typing
from http import HTTPStatus

import grpc
from flyteidl.admin.agent_pb2 import (
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
ExecuteTaskSyncRequest,
ExecuteTaskSyncResponse,
ExecuteTaskSyncResponseHeader,
GetAgentRequest,
GetAgentResponse,
GetTaskRequest,
GetTaskResponse,
ListAgentsRequest,
ListAgentsResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import (
AgentMetadataServiceServicer,
AsyncAgentServiceServicer,
SyncAgentServiceServicer,
)
from flyteidl.service.agent_pb2_grpc import AgentMetadataServiceServicer, AsyncAgentServiceServicer
from prometheus_client import Counter, Summary

from flytekit import logger
from flytekit import FlyteContext, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, mirror_async_methods
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

metric_prefix = "flyte_agent_"
create_operation = "create"
get_operation = "get"
delete_operation = "delete"
do_operation = "do"

# Follow the naming convention. https://prometheus.io/docs/practices/naming/
request_success_count = Counter(
Expand All @@ -46,7 +57,24 @@
input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])


def agent_exception_handler(func: typing.Callable):
def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str):
if isinstance(e, FlyteAgentNotFound):
error_message = f"Cannot find agent for task type: {task_type}."
logger.error(error_message)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()

Check warning on line 66 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L62-L66

Added lines #L62 - L66 were not covered by tests
else:
error_message = f"failed to {operation} {task_type} task with error: {e}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(error_message)
request_failure_count.labels(

Check warning on line 72 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L68-L72

Added lines #L68 - L72 were not covered by tests
task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR
).inc()


def record_agent_metrics(func: typing.Callable):
async def wrapper(
self,
request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
Expand All @@ -60,10 +88,10 @@
if request.inputs:
input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
elif isinstance(request, GetTaskRequest):
task_type = request.task_type
task_type = request.task_type or request.task_category.name
operation = get_operation
elif isinstance(request, DeleteTaskRequest):
task_type = request.task_type
task_type = request.task_category or request.task_category.name
operation = delete_operation
else:
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand All @@ -75,51 +103,90 @@
res = await func(self, request, context, *args, **kwargs)
request_success_count.labels(task_type=task_type, operation=operation).inc()
return res
except FlyteAgentNotFound:
error_message = f"Cannot find agent for task type: {task_type}."
logger.error(error_message)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc()
except Exception as e:
error_message = f"failed to {operation} {task_type} task with error {e}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc()
_handle_exception(e, context, task_type, operation)

Check warning on line 107 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L107

Added line #L107 was not covered by tests

return wrapper


class AsyncAgentService(AsyncAgentServiceServicer):
@agent_exception_handler
@record_agent_metrics
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
tmp = TaskTemplate.from_flyte_idl(request.template)
template = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(tmp.type)
agent = AgentRegistry.get_agent(template.type, template.task_type_version)

logger.info(f"{tmp.type} agent start creating the job")
return await mirror_async_methods(
agent.create, output_prefix=request.output_prefix, task_template=tmp, inputs=inputs
)
logger.info(f"{agent.name} start creating the job")
resource_mata = await mirror_async_methods(agent.create, task_template=template, inputs=inputs)
return CreateTaskResponse(resource_meta=resource_mata.encode())

@agent_exception_handler
@record_agent_metrics
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
return await mirror_async_methods(agent.get, resource_meta=request.resource_meta)
if request.task_category:
agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version)
else:
agent = AgentRegistry.get_agent(request.task_type)

Check warning on line 128 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L128

Added line #L128 was not covered by tests
logger.info(f"{agent.name} start checking the status of the job")
res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta))

if res.outputs is None:
outputs = None
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()

Check warning on line 135 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L135

Added line #L135 was not covered by tests
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)

Check warning on line 138 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L137-L138

Added lines #L137 - L138 were not covered by tests
return GetTaskResponse(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)

@agent_exception_handler
@record_agent_metrics
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
return await mirror_async_methods(agent.delete, resource_meta=request.resource_meta)
if request.task_category:
agent = AgentRegistry.get_agent(request.task_category.name, request.task_category.version)
else:
agent = AgentRegistry.get_agent(request.task_type)

Check warning on line 148 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L148

Added line #L148 was not covered by tests
logger.info(f"{agent.name} start deleting the job")
return await mirror_async_methods(agent.delete, resource_meta=agent.metadata_type.decode(request.resource_meta))


class SyncAgentService(SyncAgentServiceServicer):
async def ExecuteTaskSync(
self, request_iterator: typing.AsyncIterator[ExecuteTaskSyncRequest], context: grpc.ServicerContext
) -> typing.AsyncIterator[ExecuteTaskSyncResponse]:
request = await request_iterator.__anext__()
template = TaskTemplate.from_flyte_idl(request.header.template)
task_type = template.type
try:
with request_latency.labels(task_type=task_type, operation=do_operation).time():
agent = AgentRegistry.get_agent(task_type, template.task_type_version)
if not isinstance(agent, SyncAgentBase):
raise ValueError(f"[{agent.name}] does not support sync execution")

Check warning on line 164 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L164

Added line #L164 was not covered by tests

request = await request_iterator.__anext__()
literal_map = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
res = await mirror_async_methods(agent.do, task_template=template, inputs=literal_map)

if res.outputs is None:
outputs = None

Check warning on line 171 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L171

Added line #L171 was not covered by tests
elif isinstance(res.outputs, LiteralMap):
outputs = res.outputs.to_flyte_idl()

Check warning on line 173 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L173

Added line #L173 was not covered by tests
else:
ctx = FlyteContext.current_context()
outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs)

header = ExecuteTaskSyncResponseHeader(
resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs)
)
yield ExecuteTaskSyncResponse(header=header)
request_success_count.labels(task_type=task_type, operation=do_operation).inc()
except Exception as e:
_handle_exception(e, context, template.type, do_operation)

Check warning on line 184 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L182-L184

Added lines #L182 - L184 were not covered by tests


class AgentMetadataService(AgentMetadataServiceServicer):
async def GetAgent(self, request: GetAgentRequest, context: grpc.ServicerContext) -> GetAgentResponse:
return GetAgentResponse(agent=AgentRegistry._METADATA[request.name])
return GetAgentResponse(agent=AgentRegistry.METADATA[request.name])

async def ListAgents(self, request: ListAgentsRequest, context: grpc.ServicerContext) -> ListAgentsResponse:
agents = [agent for agent in AgentRegistry._METADATA.values()]
return ListAgentsResponse(agents=agents)
return ListAgentsResponse(agents=AgentRegistry.list_agents())
Loading
Loading