Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): All tests pass (again) #701

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
from ..routers.docs.search_docs import search_agent_docs, search_user_docs


# FIXME: This is a total mess. Should be refactored.

@auto_blob_store
@beartype
async def execute_system(
context: StepContext,
system: SystemDef,
) -> Any:
arguments = system.arguments
arguments: dict[str, Any] = system.arguments or {}
arguments["developer_id"] = context.execution_input.developer_id

# Unbox all the arguments
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Optional, Type

from temporalio.activity import _CompleteAsyncError as CompleteAsyncError
from temporalio.exceptions import ApplicationError, FailureError, TemporalError
from temporalio.service import RPCError
from temporalio.worker import (
Expand Down Expand Up @@ -42,6 +43,7 @@ async def execute_activity(self, input: ExecuteActivityInput):
ReadOnlyContextError,
NondeterminismError,
RPCError,
CompleteAsyncError,
TemporalError,
FailureError,
):
Expand Down Expand Up @@ -73,6 +75,7 @@ async def execute_workflow(self, input: ExecuteWorkflowInput):
ReadOnlyContextError,
NondeterminismError,
RPCError,
CompleteAsyncError,
TemporalError,
FailureError,
):
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extract_keywords(text: str, top_n: int = 10, clean: bool = True) -> list[str
combined = entities + nouns

# Normalize and count frequency
normalized = [re.sub(r"\s+", " ", kw).strip().lower() for kw in combined]
normalized = [re.sub(r"\s+", " ", kw).strip() for kw in combined]
freq = Counter(normalized)

# Get top_n keywords
Expand Down
10 changes: 7 additions & 3 deletions agents-api/agents_api/common/protocol/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __save_item(self, item: Any) -> Any:

return store_in_blob_store_if_large(item)

def __getitem__(self, index: int | slice) -> Any:
def __getitem__(
self, index: int | slice
) -> Any: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Obtain the slice without triggering __getitem__ recursively
sliced_items = super().__getitem__(
Expand Down Expand Up @@ -162,7 +164,9 @@ def _extend_without_processing(self, items: list[Any]) -> None:
"""
super().extend(items)

def __setitem__(self, index: int | slice, value: Any) -> None:
def __setitem__(
self, index: int | slice, value: Any
) -> None: # pytype: disable=signature-mismatch
if isinstance(index, slice):
# Handle slice assignment without processing existing RemoteObjects
processed_values = [self.__save_item(v) for v in value]
Expand Down Expand Up @@ -231,7 +235,7 @@ def extend(self, iterable: list[Any]) -> None:
for item in iterable:
self.append(item)

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[Any]: # pytype: disable=signature-mismatch
for index in range(len(self)):
yield self.__getitem__(index)

Expand Down
12 changes: 9 additions & 3 deletions agents-api/agents_api/common/storage_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any:
if not use_blob_store_for_temporal:
return x

s3.setup()

serialized = serialize(x)
Expand All @@ -28,6 +31,9 @@ def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any:


def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any:
if not use_blob_store_for_temporal:
return x

s3.setup()

if isinstance(x, RemoteObject):
Expand All @@ -45,8 +51,8 @@ def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any:
def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable:
def auto_blob_store_decorator(f: Callable) -> Callable:
def load_args(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args: list | tuple, kwargs: dict[str, Any]
) -> tuple[list | tuple, dict[str, Any]]:
new_args = [load_from_blob_store_if_remote(arg) for arg in args]
new_kwargs = {
k: load_from_blob_store_if_remote(v) for k, v in kwargs.items()
Expand Down Expand Up @@ -143,4 +149,4 @@ async def wrapper(*args, **kwargs) -> Any:

return result

return wrapper
return wrapper if use_blob_store_for_temporal else f
6 changes: 6 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
use_blob_store_for_temporal=use_blob_store_for_temporal,
blob_store_bucket=blob_store_bucket,
blob_store_cutoff_kb=blob_store_cutoff_kb,
s3_endpoint=s3_endpoint,
s3_access_key=s3_access_key,
s3_secret_key=s3_secret_key,
testing=testing,
)

Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/models/docs/get_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
one=True,
transform=lambda d: {
"content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [s[2] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [
s[2]
for s in sorted(d["snippet_data"], key=lambda x: x[0])
if s[2] is not None
],
**d,
},
)
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/models/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
Doc,
transform=lambda d: {
"content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [s[2] for s in sorted(d["snippet_data"], key=lambda x: x[0])],
"embeddings": [
s[2]
for s in sorted(d["snippet_data"], key=lambda x: x[0])
if s[2] is not None
],
**d,
},
)
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import re
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -62,9 +63,10 @@ def search_docs_by_text(
[owner_type, str(owner_id)] for owner_type, owner_id in owners
]

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
fts_queries = paragraph_to_custom_queries(query)
fts_queries = paragraph_to_custom_queries(query) or [
re.sub(r"[^\w\s\-_]+", "", query)
]

# Construct the datalog query for searching document snippets
search_query = f"""
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
if isinstance(d["output"], dict) and OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
},
)
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/models/execution/list_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
if isinstance(d.get("output"), dict) and OUTPUT_UNNEST_KEY in d["output"]
else d.get("output"),
},
)
@cozo_query
Expand All @@ -58,6 +58,7 @@ def list_executions(
task_id,
status,
input,
output,
session_id,
metadata,
created_at,
Expand All @@ -68,6 +69,7 @@ def list_executions(
execution_id: id,
status,
input,
output,
session_id,
metadata,
created_at,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel

from ..common.utils.cozo import uuid_int_list_to_uuid4
from ..env import debug, do_verify_developer, do_verify_developer_owns_resource
from ..env import do_verify_developer, do_verify_developer_owns_resource

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -185,8 +185,8 @@ def make_cozo_json_query(fields):

def cozo_query(
func: Callable[P, tuple[str | list[str | None], dict]] | None = None,
debug: bool | None = debug,
only_on_error: bool = True,
debug: bool | None = None,
only_on_error: bool = False,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
"""
Expand Down
4 changes: 0 additions & 4 deletions agents-api/agents_api/routers/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ async def create_or_update_task(
# TODO: Do thorough validation of the task spec
# SCRUM-10

# FIXME: There is also some subtle bug here that prevents us from
# starting executions from tasks created via this endpoint
# SCRUM-9

# Validate the input schema
try:
if data.input_schema is not None:
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/tasks/create_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ async def create_task(
) -> ResourceCreatedResponse:
# TODO: Do thorough validation of the task spec
# SCRUM-10
# TODO: Validate the jinja templates

# Validate the input schema
try:
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/tasks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def body(self) -> bytes:
"application/yaml",
"text/yaml",
]:
body = yaml.load(body, yaml.CSafeLoader)
body = yaml.load(body)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved

self._body = body

Expand Down
Loading
Loading