Skip to content

Commit

Permalink
feat(agents-api): Add jinja templates support
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Apr 26, 2024
1 parent 5922144 commit 3e6ccac
Show file tree
Hide file tree
Showing 16 changed files with 263 additions and 167 deletions.
4 changes: 4 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ class SessionData(BaseModel):
created_at: float
model: str
default_settings: SessionSettings
render_templates: bool = False
metadata: dict = {}
user_metadata: dict = {}
agent_metadata: dict = {}
12 changes: 2 additions & 10 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
# flake8: noqa
# This is so ruff doesn't remove * imports

import inspect

import arrow
from jinja2 import Environment, ImmutableSandboxedEnvironment
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate

from .args import get_fn_signature
from .lang import inflect

__all__ = [
"render_template",
]

# jinja environment
jinja_env: Environment = ImmutableSandboxedEnvironment(
jinja_env = ImmutableSandboxedEnvironment(
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
Expand Down
7 changes: 6 additions & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def create_session_query(
user_id: UUID | None,
situation: str | None,
metadata: dict = {},
render_templates: bool = False,
) -> tuple[str, dict]:
"""
Constructs and executes a datalog query to create a new session in the database.
Expand All @@ -28,6 +29,7 @@ def create_session_query(
- user_id (UUID | None): The unique identifier for the user, if applicable.
- situation (str | None): The situation/context of the session.
- metadata (dict): Additional metadata for the session.
- render_templates (bool): Specifies whether to render templates.
Returns:
- pd.DataFrame: The result of the query execution.
Expand All @@ -52,18 +54,20 @@ def create_session_query(
}
} {
# Insert the new session data into the 'session' table with the specified columns.
?[session_id, developer_id, situation, metadata] <- [[
?[session_id, developer_id, situation, metadata, render_templates] <- [[
$session_id,
$developer_id,
$situation,
$metadata,
$render_templates,
]]
:insert sessions {
developer_id,
session_id,
situation,
metadata,
render_templates,
}
# Specify the data to return after the query execution, typically the newly created session's ID.
:returning
Expand All @@ -79,5 +83,6 @@ def create_session_query(
"developer_id": str(developer_id),
"situation": situation,
"metadata": metadata,
"render_templates": render_templates,
},
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_session_query(
updated_at,
created_at,
metadata,
render_templates,
] := input[developer_id, id],
*sessions{
developer_id,
Expand All @@ -49,6 +50,7 @@ def get_session_query(
created_at,
updated_at: validity,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
]


# TODO: Add support for updating `render_templates` field


@cozo_query
def patch_session_query(
session_id: UUID,
Expand All @@ -41,6 +44,7 @@ def patch_session_query(
},
session_id = to_uuid($session_id),
developer_id = to_uuid($developer_id),
# Assertion to ensure the session exists before updating.
:assert some
"""
Expand Down
14 changes: 8 additions & 6 deletions agents-api/agents_api/models/session/session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def session_data_query(
agent_about,
model,
default_settings,
session_metadata,
users_metadata,
agents_metadata,
metadata,
render_templates,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
*sessions{
developer_id,
Expand All @@ -53,7 +54,8 @@ def session_data_query(
summary,
created_at,
updated_at: validity,
metadata: session_metadata,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand All @@ -65,14 +67,14 @@ def session_data_query(
user_id,
name: user_name,
about: user_about,
metadata: users_metadata,
metadata: user_metadata,
},
*agents{
agent_id,
name: agent_name,
about: agent_about,
model,
metadata: agents_metadata,
metadata: agent_metadata,
},
*agent_default_settings {
agent_id,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"developer_id",
]

# TODO: Add support for updating `render_templates` field


@cozo_query
def update_session_query(
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def create_session(
user_id=request.user_id,
situation=request.situation,
metadata=request.metadata or {},
render_templates=request.render_templates or False,
)

return ResourceCreatedResponse(
Expand Down
67 changes: 52 additions & 15 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import UUID4

from agents_api.clients.embed import embed
from agents_api.env import summarization_tokens_threshold
from agents_api.clients.temporal import run_summarization_task
from agents_api.models.entry.add_entries import add_entries_query
from agents_api.common.protocol.entries import Entry
from agents_api.common.exceptions.sessions import SessionNotFoundError
from agents_api.clients.worker.types import ChatML
from agents_api.models.session.session_data import get_session_data
from agents_api.models.entry.proc_mem_context import proc_mem_context_query
from agents_api.autogen.openapi_model import InputChatMLMessage, Tool
from agents_api.model_registry import (
from ...autogen.openapi_model import InputChatMLMessage, Tool
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
from ...common.exceptions.sessions import SessionNotFoundError
from ...common.protocol.entries import Entry
from ...common.protocol.sessions import SessionData
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...model_registry import (
get_extra_settings,
get_model_client,
load_context,
)
from ...common.protocol.sessions import SessionData
from .protocol import Settings
from ...models.entry.add_entries import add_entries_query
from ...models.entry.proc_mem_context import proc_mem_context_query
from ...models.session.session_data import get_session_data

from .exceptions import InputTooBigError
from .protocol import Settings


THOUGHTS_STRIP_LEN = 2
Expand Down Expand Up @@ -118,18 +120,22 @@ async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None]:
# TODO: implement locking at some point

# Get session data
session_data = get_session_data(self.developer_id, self.session_id)
if session_data is None:
raise SessionNotFoundError(self.developer_id, self.session_id)

# Assemble context
init_context, final_settings = await self.forward(
session_data, new_input, settings
)

# Generate response
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
)

# Save response to session
# if final_settings.get("remember"):
# await self.add_to_session(new_input, response)
Expand Down Expand Up @@ -195,10 +201,11 @@ async def forward(
)

entries: list[Entry] = []
instructions = "IMPORTANT INSTRUCTIONS:\n\n"
instructions = "Instructions:\n\n"
first_instruction_idx = -1
first_instruction_created_at = 0
tools = []

for idx, row in proc_mem_context_query(
session_id=self.session_id,
tool_query_embedding=tool_query_embedding,
Expand All @@ -224,7 +231,7 @@ async def forward(
first_instruction_idx = idx
first_instruction_created_at = row["created_at"]

instructions += f"- {row['content']}\n"
instructions += f"{row['content']}\n\n"

continue

Expand Down Expand Up @@ -266,6 +273,36 @@ async def forward(
if e.content
]

# If render_templates=True, render the templates
if session_data is not None and session_data.render_templates:

template_data = {
"session": {
"id": session_data.session_id,
"situation": session_data.situation,
"metadata": session_data.metadata,
},
"user": {
"id": session_data.user_id,
"name": session_data.user_name,
"about": session_data.user_about,
"metadata": session_data.user_metadata,
},
"agent": {
"id": session_data.agent_id,
"name": session_data.agent_name,
"about": session_data.agent_about,
"metadata": session_data.agent_metadata,
},
}

for i, msg in enumerate(messages):
# Only render templates for system/assistant messages
if msg.role not in ["system", "assistant"]:
continue

messages[i].content = await render_template(msg.content, template_data)

# FIXME: This sometimes returns "The model `` does not exist."
if session_data is not None:
settings.model = session_data.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,7 @@ def run(client, *queries):
query = joiner.join(queries)
query = f"{{\n{query}\n}}"

try:
client.run(query)
except Exception as error:
print(error)
import pdb

pdb.set_trace()
client.run(query)


def up(client):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
summary,
created_at,
developer_id
}, metadata = {},
},
metadata = {},
render_templates = false
:replace sessions {
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/julep/managers/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ class SessionCreateArgs(TypedDict):
agent_id: Union[str, UUID]
situation: Optional[str] = None
metadata: Dict[str, Any] = {}
render_templates: bool = False


class SessionUpdateArgs(TypedDict):
session_id: Union[str, UUID]
situation: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
overwrite: bool = False
render_templates: bool = False


class BaseSessionsManager(BaseManager):
Expand Down Expand Up @@ -179,6 +181,7 @@ def _create(
user_id: Optional[Union[str, UUID]] = None,
situation: Optional[str] = None,
metadata: Dict[str, Any] = {},
render_templates: bool = False,
) -> Union[ResourceCreatedResponse, Awaitable[ResourceCreatedResponse]]:
# Cast instructions to a list of Instruction objects
"""
Expand All @@ -191,6 +194,7 @@ def _create(
user_id (Optional[Union[str, UUID]]): The user's identifier which could be a string or a UUID object.
situation (Optional[str], optional): An optional description of the situation.
metadata (Dict[str, Any])
render_templates (bool, optional): Whether to render templates in the metadata. Defaults to False.
Returns:
Union[ResourceCreatedResponse, Awaitable[ResourceCreatedResponse]]: The response from the API client upon successful session creation, which can be a synchronous `ResourceCreatedResponse` or an asynchronous `Awaitable` of it.
Expand All @@ -208,6 +212,7 @@ def _create(
agent_id=agent_id,
situation=situation,
metadata=metadata,
render_templates=render_templates,
)

def _list_items(
Expand Down
21 changes: 21 additions & 0 deletions sdks/python/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
"metadata": {"type": "test"},
}

mock_session_with_template = {
"situation": "Say 'hello {{ session.metadata.arg }}'",
"metadata": {"type": "test", "arg": "banana"},
"render_templates": True,
}

mock_session_update = {
"situation": "updated situation",
"metadata": {"type": "test"},
Expand Down Expand Up @@ -189,6 +195,21 @@ def test_session(client=client, user=test_user, agent=test_agent) -> Session:
client.sessions.delete(session.id)


@fixture
def test_session_with_template(
client=client, user=test_user, agent=test_agent
) -> Session:
session = client.sessions.create(
user_id=user.id,
agent_id=agent.id,
**mock_session_with_template,
)

yield session

client.sessions.delete(session.id)


@fixture
def test_session_agent_user(client=client, user=test_user, agent=test_agent) -> Session:
session = client.sessions.create(
Expand Down
Loading

0 comments on commit 3e6ccac

Please sign in to comment.