Skip to content

Commit

Permalink
fix(agents-api): Fixed tests
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 9, 2024
1 parent 440b5f7 commit d305468
Show file tree
Hide file tree
Showing 25 changed files with 711 additions and 551 deletions.
32 changes: 2 additions & 30 deletions agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,7 @@ class FunctionDef(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str, Field("overriden", pattern="^[^\\W0-9]\\w*$")]
"""
DO NOT USE: This will be overriden by the tool name. Here only for compatibility reasons.
"""
description: Annotated[
str | None,
Field(
None,
pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$",
),
]
"""
Description of the function
"""
parameters: dict[str, Any]
"""
The parameters the function accepts
"""


class FunctionDefUpdate(BaseModel):
"""
Function definition
"""

model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str, Field("overriden", pattern="^[^\\W0-9]\\w*$")]
name: Any | None = None
"""
DO NOT USE: This will be overriden by the tool name. Here only for compatibility reasons.
"""
Expand Down Expand Up @@ -124,7 +96,7 @@ class PatchToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
function: FunctionDefUpdate | None = None
function: FunctionDef | None = None
integration: Any | None = None
system: Any | None = None
api_call: Any | None = None
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/agent/create_or_update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_or_update_agent(
developer_id: UUID,
agent_id: UUID,
data: CreateOrUpdateAgentRequest,
) -> tuple[list[str], dict]:
) -> tuple[list[str | None], dict]:
"""
Constructs and executes a datalog query to create a new agent in the database.
Expand Down Expand Up @@ -123,7 +123,7 @@ def create_or_update_agent(

queries = [
verify_developer_id_query(developer_id),
default_settings and default_settings_query,
default_settings_query if default_settings else None,
agent_query,
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(Transition, transform=lambda d: {"id": d["transition_id"], **d})
@wrap_in_class(
Transition, transform=lambda d: {"id": d["transition_id"], **d}, one=True
)
@cozo_query
@beartype
def create_execution_transition(
Expand Down
33 changes: 22 additions & 11 deletions agents-api/agents_api/models/tools/patch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@cozo_query
@beartype
def patch_tool(
*, developer_id: UUID, agent_id: UUID, tool_id: UUID, patch_tool: PatchToolRequest
*, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest
) -> tuple[list[str], dict]:
"""
# Execute the datalog query and return the results as a DataFrame
Expand All @@ -41,14 +41,17 @@ def patch_tool(
Parameters:
- agent_id (UUID): The unique identifier of the agent.
- tool_id (UUID): The unique identifier of the tool to be updated.
- patch_tool (PatchToolRequest): The request payload containing the updated tool information.
- data (PatchToolRequest): The request payload containing the updated tool information.
Returns:
- ResourceUpdatedResponse: The updated tool data.
"""

agent_id = str(agent_id)
tool_id = str(tool_id)

# Extract the tool data from the payload
patch_data = patch_tool.model_dump(exclude_none=True)
patch_data = data.model_dump(exclude_none=True)

# Assert that only one of the tool type fields is present
tool_specs = [
Expand All @@ -64,28 +67,33 @@ def patch_tool(
patch_data["type"] = patch_data.get("type", tool_type)
assert patch_data["type"] == tool_type, "Invalid tool update"

if tool_spec is not None:
# Rename the tool definition to 'spec'
patch_data["spec"] = tool_spec
tool_spec = tool_spec or {}
if tool_spec:
del patch_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**patch_data,
"agent_id": str(agent_id),
"tool_id": str(tool_id),
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input
?[{tool_cols}, updated_at] :=
?[{tool_cols}, spec, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
spec: old_spec,
}},
input[{tool_cols}],
spec = concat(old_spec, $spec),
updated_at = now()
:update tools {{ {tool_cols}, updated_at }}
:update tools {{ {tool_cols}, spec, updated_at }}
:returning
"""

Expand All @@ -95,4 +103,7 @@ def patch_tool(
patch_query,
]

return (queries, dict(input=tool_vals))
return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
)
93 changes: 83 additions & 10 deletions agents-api/agents_api/models/tools/update_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import (
PatchToolRequest,
ResourceUpdatedResponse,
UpdateToolRequest,
)
from .patch_tool import patch_tool
from ...common.utils.cozo import cozo_process_mutate_data
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
)
@cozo_query
@beartype
def update_tool(
*,
Expand All @@ -18,12 +41,62 @@ def update_tool(
tool_id: UUID,
data: UpdateToolRequest,
**kwargs,
) -> ResourceUpdatedResponse:
# Same as patch_tool_query, but with a different request payload
return patch_tool(
developer_id=developer_id,
agent_id=agent_id,
tool_id=tool_id,
patch_tool=PatchToolRequest(**data.model_dump()),
**kwargs,
) -> tuple[list[str], dict]:
agent_id = str(agent_id)
tool_id = str(tool_id)

# Extract the tool data from the payload
update_data = data.model_dump(exclude_none=True)

# Assert that only one of the tool type fields is present
tool_specs = [
(tool_type, update_data.get(tool_type))
for tool_type in ["function", "integration", "system", "api_call"]
if update_data.get(tool_type) is not None
]

assert len(tool_specs) <= 1, "Invalid tool update"
tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None)

if tool_type is not None:
update_data["type"] = update_data.get("type", tool_type)
assert update_data["type"] == tool_type, "Invalid tool update"

update_data["spec"] = tool_spec
del update_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**update_data,
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input
?[{tool_cols}, created_at, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
created_at
}},
input[{tool_cols}],
updated_at = now()
:put tools {{ {tool_cols}, created_at, updated_at }}
:returning
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id),
patch_query,
]

return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
)
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def cozo_query(
func: Callable[P, tuple[str | list[str], dict]] | None = None,
debug: bool | None = None,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[str], dict]]):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
"""
Decorator that wraps a function that takes arbitrary arguments, and
returns a (query string, variables) tuple.
Expand All @@ -135,7 +135,7 @@ def wrapper(
if isinstance(queries, str):
query = queries
else:
queries = [query for query in queries if query]
queries = [str(query) for query in queries if query]
query = "}\n\n{\n".join(queries)
query = f"{{ {query} }}"

Expand Down
Loading

0 comments on commit d305468

Please sign in to comment.