From 19c0f1372b39f637fc101ce453c07126cb6886f9 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Thu, 29 Aug 2024 14:04:36 -0400 Subject: [PATCH] fix(agents-api): Fix prompt render, codec and task execution Signed-off-by: Diwank Tomer --- .../agents_api/common/utils/template.py | 112 +++++++----------- .../agents_api/dependencies/developer_id.py | 5 +- .../agents_api/models/agent/delete_agent.py | 4 +- .../agents_api/models/agent/get_agent.py | 8 +- agents-api/agents_api/models/user/get_user.py | 4 +- agents-api/agents_api/worker/codec.py | 17 +-- .../agents_api/workflows/task_execution.py | 5 +- agents-api/pyproject.toml | 2 +- agents-api/tests/fixtures.py | 12 +- agents-api/tests/test_execution_workflow.py | 1 + 10 files changed, 80 insertions(+), 90 deletions(-) diff --git a/agents-api/agents_api/common/utils/template.py b/agents-api/agents_api/common/utils/template.py index 35ae2c350..a846a5d28 100644 --- a/agents-api/agents_api/common/utils/template.py +++ b/agents-api/agents_api/common/utils/template.py @@ -1,6 +1,10 @@ -from typing import List +import json +from typing import List, TypeVar import arrow +import re2 +import yaml +from beartype import beartype from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2schema import infer, to_json_schema from jsonschema import validate @@ -20,10 +24,18 @@ ) # Add arrow to jinja + +jinja_env.globals["dump_yaml"] = yaml.dump +jinja_env.globals["match_regex"] = lambda pattern, string: bool( + re2.fullmatch(pattern, string) +) +jinja_env.globals["search_regex"] = lambda pattern, string: re2.search(pattern, string) +jinja_env.globals["dump_json"] = json.dumps jinja_env.globals["arrow"] = arrow # Funcs +@beartype async def render_template_string( template_string: str, variables: dict, @@ -42,64 +54,42 @@ async def render_template_string( return rendered -async def render_template_chatml( - messages: list[dict], variables: dict, check: bool = False -) -> list[dict]: - # Parse template - # FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow? - templates = [jinja_env.from_string(msg["content"]) for msg in messages] - - # If check is required, get required vars from template and validate variables - if check: - for template in templates: - schema = to_json_schema(infer(template)) - validate(instance=variables, schema=schema) - - # Render - rendered = [ - ({**msg, "content": await template.render_async(**variables)}) - for template, msg in zip(templates, messages) - ] - - return rendered - - -async def render_template_parts( - template_strings: list[dict], variables: dict, check: bool = False -) -> list[dict]: - # Parse template - # FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow? - templates = [ - ( - jinja_env.from_string(msg["content"]["text"]) - if msg["content"]["type"] == "text" - else None - ) - for msg in template_strings - ] - - # If check is required, get required vars from template and validate variables - if check: - for template in templates: - if template is None: - continue - - schema = to_json_schema(infer(template)) - validate(instance=variables, schema=schema) +# A render function that can render arbitrarily nested lists of dicts +# only render keys: content, text, image_url +# and only render values that are strings +T = TypeVar("T", str, dict, list[dict | list[dict]]) - # Render - rendered = [ - ( - {"type": "text", "text": await template.render_async(**variables)} - if template is not None - else msg - ) - for template, msg in zip(templates, template_strings) - ] - return rendered +@beartype +async def render_template_nested( + input: T, + variables: dict, + check: bool = False, + whitelist: list[str] = ["content", "text", "image_url"], +) -> T: + match input: + case str(): + return await render_template_string(input, variables, check) + + case dict(): + return { + k: ( + await render_template_nested(v, variables, check, whitelist) + if k in whitelist + else v + ) + for k, v in input.items() + } + case list(): + return [ + await render_template_nested(v, variables, check, whitelist) + for v in input + ] + case _: + raise ValueError(f"Invalid input type: {type(input)}") +@beartype async def render_template( input: str | list[dict], variables: dict, @@ -112,14 +102,4 @@ async def render_template( if not (skip_vars is not None and isinstance(name, str) and name in skip_vars) } - match input: - case str(): - future = render_template_string(input, variables, check) - - case [{"content": str()}, *_]: - future = render_template_chatml(input, variables, check) - - case _: - future = render_template_parts(input, variables, check) - - return await future + return await render_template_nested(input, variables, check) diff --git a/agents-api/agents_api/dependencies/developer_id.py b/agents-api/agents_api/dependencies/developer_id.py index 735349fbd..976ee471e 100644 --- a/agents-api/agents_api/dependencies/developer_id.py +++ b/agents-api/agents_api/dependencies/developer_id.py @@ -4,7 +4,7 @@ from fastapi import Header from ..common.protocol.developers import Developer -from ..env import multi_tenant_mode +from ..env import multi_tenant_mode, testing from ..models.developer.get_developer import get_developer, verify_developer from .exceptions import InvalidHeaderFormat @@ -13,9 +13,6 @@ async def get_developer_id( x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None, ) -> UUID: if not multi_tenant_mode: - assert ( - not x_developer_id - ), "X-Developer-Id header not allowed in multi-tenant mode" return UUID("00000000-0000-0000-0000-000000000000") if not x_developer_id: diff --git a/agents-api/agents_api/models/agent/delete_agent.py b/agents-api/agents_api/models/agent/delete_agent.py index 409c755d3..926007bdf 100644 --- a/agents-api/agents_api/models/agent/delete_agent.py +++ b/agents-api/agents_api/models/agent/delete_agent.py @@ -28,8 +28,8 @@ @rewrap_exceptions( { lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( + and "Developer does not own resource" + in e.resp["display"]: lambda *_: HTTPException( detail="developer not found or doesnt own resource", status_code=404 ), QueryException: partialclass(HTTPException, status_code=400), diff --git a/agents-api/agents_api/models/agent/get_agent.py b/agents-api/agents_api/models/agent/get_agent.py index c977fa614..956fa46a5 100644 --- a/agents-api/agents_api/models/agent/get_agent.py +++ b/agents-api/agents_api/models/agent/get_agent.py @@ -24,12 +24,12 @@ { lambda e: isinstance(e, QueryException) and "Developer not found" in str(e): lambda *_: HTTPException( - detail="developer does not exist", status_code=403 + detail="Developer does not exist", status_code=403 ), lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( - detail="developer doesnt own resource", status_code=404 + and "Developer does not own resource" + in e.resp["display"]: lambda *_: HTTPException( + detail="Developer does not own resource", status_code=404 ), QueryException: partialclass(HTTPException, status_code=400), ValidationError: partialclass(HTTPException, status_code=400), diff --git a/agents-api/agents_api/models/user/get_user.py b/agents-api/agents_api/models/user/get_user.py index 181bf05f0..cc7c6f970 100644 --- a/agents-api/agents_api/models/user/get_user.py +++ b/agents-api/agents_api/models/user/get_user.py @@ -27,8 +27,8 @@ detail="developer does not exist", status_code=403 ), lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( + and "Developer does not own resource" + in e.resp["display"]: lambda *_: HTTPException( detail="developer doesnt own resource", status_code=404 ), QueryException: partialclass(HTTPException, status_code=400), diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index d56b81de1..d99652687 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -4,7 +4,6 @@ ### The codec is used to serialize/deserialize the data ### But this code is quite brittle. Be careful when changing it - import dataclasses import logging import pickle @@ -82,12 +81,16 @@ class PydanticEncodingPayloadConverter(EncodingPayloadConverter): b_encoding = encoding.encode() def to_payload(self, value: Any) -> Optional[Payload]: - return Payload( - metadata={ - "encoding": self.b_encoding, - }, - data=serialize(value), - ) + try: + return Payload( + metadata={ + "encoding": self.b_encoding, + }, + data=serialize(value), + ) + except Exception as e: + logging.warning(f"WARNING: Could not encode {value}: {e}") + return None def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: assert payload.metadata["encoding"] == self.b_encoding diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index ff0d6f692..327329932 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -167,7 +167,9 @@ async def run( ) except Exception as e: - await transition(state, context, type="error", output=dict(error=e)) + await transition( + state, context, type="error", output=dict(error=str(e)) + ) raise ApplicationError(f"Activity {activity} threw error: {e}") from e # --- @@ -376,6 +378,7 @@ async def run( ): await transition( state, + context, output=output, type=yield_transition_type, next=yield_next_target, diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index eee68f124..5e2ae68da 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -99,5 +99,5 @@ datamodel-codegen \ --disable-timestamp""" [tool.poe.tasks.test] -env = { AGENTS_API_TESTING = "true" } +env = { AGENTS_API_TESTING = "true", PYTHONPATH = "{PYTHONPATH}:." } cmd = "ward test" diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index c6a98ee99..e5e51dbad 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,4 +1,4 @@ -from uuid import uuid4 +from uuid import UUID, uuid4 from cozo_migrate.api import apply, init from fastapi.testclient import TestClient @@ -16,7 +16,7 @@ CreateTransitionRequest, CreateUserRequest, ) -from agents_api.env import api_key, api_key_header_name +from agents_api.env import api_key, api_key_header_name, multi_tenant_mode from agents_api.models.agent.create_agent import create_agent from agents_api.models.agent.delete_agent import delete_agent from agents_api.models.developer.get_developer import get_developer @@ -56,6 +56,10 @@ def cozo_client(migrations_dir: str = "./migrations"): @fixture(scope="global") def test_developer_id(cozo_client=cozo_client): + if not multi_tenant_mode: + yield UUID(int=0) + return + developer_id = uuid4() cozo_client.run( @@ -347,10 +351,12 @@ def _make_request(method, url, **kwargs): headers = kwargs.pop("headers", {}) headers = { **headers, - "X-Developer-Id": str(developer_id), api_key_header_name: api_key, } + if multi_tenant_mode: + headers["X-Developer-Id"] = str(developer_id) + return client.request(method, url, headers=headers, **kwargs) return _make_request diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 65ca7d143..a0ddc22e9 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -707,6 +707,7 @@ async def _( mock_run_task_execution_workflow.assert_called_once() result = await handle.result() + result = result["choices"][0]["message"] assert result["content"] == "Hello, world!" assert result["role"] == "assistant"