Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix prompt render, codec and task execution #478

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 46 additions & 66 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List
import json
from typing import List, TypeVar

import arrow
import re2
import yaml
from beartype import beartype
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate
Expand All @@ -20,10 +24,18 @@
)

# Add arrow to jinja

jinja_env.globals["dump_yaml"] = yaml.dump
jinja_env.globals["match_regex"] = lambda pattern, string: bool(
re2.fullmatch(pattern, string)
)
jinja_env.globals["search_regex"] = lambda pattern, string: re2.search(pattern, string)
jinja_env.globals["dump_json"] = json.dumps
jinja_env.globals["arrow"] = arrow


# Funcs
@beartype
async def render_template_string(
template_string: str,
variables: dict,
Expand All @@ -42,64 +54,42 @@ async def render_template_string(
return rendered


async def render_template_chatml(
messages: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [jinja_env.from_string(msg["content"]) for msg in messages]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
({**msg, "content": await template.render_async(**variables)})
for template, msg in zip(templates, messages)
]

return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [
(
jinja_env.from_string(msg["content"]["text"])
if msg["content"]["type"] == "text"
else None
)
for msg in template_strings
]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
if template is None:
continue

schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)
# A render function that can render arbitrarily nested lists of dicts
# only render keys: content, text, image_url
# and only render values that are strings
T = TypeVar("T", str, dict, list[dict | list[dict]])

# Render
rendered = [
(
{"type": "text", "text": await template.render_async(**variables)}
if template is not None
else msg
)
for template, msg in zip(templates, template_strings)
]

return rendered
@beartype
async def render_template_nested(
input: T,
variables: dict,
check: bool = False,
whitelist: list[str] = ["content", "text", "image_url"],
) -> T:
match input:
case str():
return await render_template_string(input, variables, check)

case dict():
return {
k: (
await render_template_nested(v, variables, check, whitelist)
if k in whitelist
else v
)
for k, v in input.items()
}
case list():
return [
await render_template_nested(v, variables, check, whitelist)
for v in input
]
case _:
raise ValueError(f"Invalid input type: {type(input)}")


@beartype
async def render_template(
input: str | list[dict],
variables: dict,
Expand All @@ -112,14 +102,4 @@ async def render_template(
if not (skip_vars is not None and isinstance(name, str) and name in skip_vars)
}

match input:
case str():
future = render_template_string(input, variables, check)

case [{"content": str()}, *_]:
future = render_template_chatml(input, variables, check)

case _:
future = render_template_parts(input, variables, check)

return await future
return await render_template_nested(input, variables, check)
5 changes: 1 addition & 4 deletions agents-api/agents_api/dependencies/developer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import Header

from ..common.protocol.developers import Developer
from ..env import multi_tenant_mode
from ..env import multi_tenant_mode, testing
from ..models.developer.get_developer import get_developer, verify_developer
from .exceptions import InvalidHeaderFormat

Expand All @@ -13,9 +13,6 @@ async def get_developer_id(
x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None,
) -> UUID:
if not multi_tenant_mode:
assert (
not x_developer_id
), "X-Developer-Id header not allowed in multi-tenant mode"
return UUID("00000000-0000-0000-0000-000000000000")

if not x_developer_id:
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/agent/delete_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
@rewrap_exceptions(
{
lambda e: isinstance(e, QueryException)
and "asserted to return some results, but returned none"
in str(e): lambda *_: HTTPException(
and "Developer does not own resource"
in e.resp["display"]: lambda *_: HTTPException(
detail="developer not found or doesnt own resource", status_code=404
),
QueryException: partialclass(HTTPException, status_code=400),
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/models/agent/get_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
{
lambda e: isinstance(e, QueryException)
and "Developer not found" in str(e): lambda *_: HTTPException(
detail="developer does not exist", status_code=403
detail="Developer does not exist", status_code=403
),
lambda e: isinstance(e, QueryException)
and "asserted to return some results, but returned none"
in str(e): lambda *_: HTTPException(
detail="developer doesnt own resource", status_code=404
and "Developer does not own resource"
in e.resp["display"]: lambda *_: HTTPException(
detail="Developer does not own resource", status_code=404
),
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/user/get_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
detail="developer does not exist", status_code=403
),
lambda e: isinstance(e, QueryException)
and "asserted to return some results, but returned none"
in str(e): lambda *_: HTTPException(
and "Developer does not own resource"
in e.resp["display"]: lambda *_: HTTPException(
detail="developer doesnt own resource", status_code=404
),
QueryException: partialclass(HTTPException, status_code=400),
Expand Down
17 changes: 10 additions & 7 deletions agents-api/agents_api/worker/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
### The codec is used to serialize/deserialize the data
### But this code is quite brittle. Be careful when changing it


import dataclasses
import logging
import pickle
Expand Down Expand Up @@ -82,12 +81,16 @@ class PydanticEncodingPayloadConverter(EncodingPayloadConverter):
b_encoding = encoding.encode()

def to_payload(self, value: Any) -> Optional[Payload]:
return Payload(
metadata={
"encoding": self.b_encoding,
},
data=serialize(value),
)
try:
return Payload(
metadata={
"encoding": self.b_encoding,
},
data=serialize(value),
)
except Exception as e:
logging.warning(f"WARNING: Could not encode {value}: {e}")
return None

def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
assert payload.metadata["encoding"] == self.b_encoding
Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ async def run(
)

except Exception as e:
await transition(state, context, type="error", output=dict(error=e))
await transition(
state, context, type="error", output=dict(error=str(e))
)
raise ApplicationError(f"Activity {activity} threw error: {e}") from e

# ---
Expand Down Expand Up @@ -376,6 +378,7 @@ async def run(
):
await transition(
state,
context,
output=output,
type=yield_transition_type,
next=yield_next_target,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ datamodel-codegen \
--disable-timestamp"""

[tool.poe.tasks.test]
env = { AGENTS_API_TESTING = "true" }
env = { AGENTS_API_TESTING = "true", PYTHONPATH = "{PYTHONPATH}:." }
cmd = "ward test"
12 changes: 9 additions & 3 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from uuid import uuid4
from uuid import UUID, uuid4

from cozo_migrate.api import apply, init
from fastapi.testclient import TestClient
Expand All @@ -16,7 +16,7 @@
CreateTransitionRequest,
CreateUserRequest,
)
from agents_api.env import api_key, api_key_header_name
from agents_api.env import api_key, api_key_header_name, multi_tenant_mode
from agents_api.models.agent.create_agent import create_agent
from agents_api.models.agent.delete_agent import delete_agent
from agents_api.models.developer.get_developer import get_developer
Expand Down Expand Up @@ -56,6 +56,10 @@ def cozo_client(migrations_dir: str = "./migrations"):

@fixture(scope="global")
def test_developer_id(cozo_client=cozo_client):
if not multi_tenant_mode:
yield UUID(int=0)
return

developer_id = uuid4()

cozo_client.run(
Expand Down Expand Up @@ -347,10 +351,12 @@ def _make_request(method, url, **kwargs):
headers = kwargs.pop("headers", {})
headers = {
**headers,
"X-Developer-Id": str(developer_id),
api_key_header_name: api_key,
}

if multi_tenant_mode:
headers["X-Developer-Id"] = str(developer_id)

return client.request(method, url, headers=headers, **kwargs)

return _make_request
1 change: 1 addition & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ async def _(
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
result = result["choices"][0]["message"]
assert result["content"] == "Hello, world!"
assert result["role"] == "assistant"

Expand Down
Loading