Skip to content

Commit

Permalink
feat(agents-api): Add jinja templates support (#300)
Browse files Browse the repository at this point in the history
* feat(agents-api): Add 'render_templates' field to sessions relation

Signed-off-by: Diwank Singh Tomer <[email protected]>

* feat(agents-api): Add jinja env

Signed-off-by: Diwank Singh Tomer <[email protected]>

* feat(agents-api): Add jinja templates support

Signed-off-by: Diwank Singh Tomer <[email protected]>

---------

Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr authored Apr 26, 2024
1 parent de291e1 commit 2850cf4
Show file tree
Hide file tree
Showing 16 changed files with 364 additions and 156 deletions.
4 changes: 4 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ class SessionData(BaseModel):
created_at: float
model: str
default_settings: SessionSettings
render_templates: bool = False
metadata: dict = {}
user_metadata: dict = {}
agent_metadata: dict = {}
38 changes: 38 additions & 0 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import arrow
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate

__all__ = [
"render_template",
]

# jinja environment
jinja_env = ImmutableSandboxedEnvironment(
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
auto_reload=False,
enable_async=True,
loader=None,
)

# Add arrow to jinja
jinja_env.globals["arrow"] = arrow


# Funcs
async def render_template(
template_string: str, variables: dict, check: bool = False
) -> str:
# Parse template
template = jinja_env.from_string(template_string)

# If check is required, get required vars from template and validate variables
if check:
schema = to_json_schema(infer(template_string))
validate(instance=variables, schema=schema)

# Render
rendered = await template.render_async(**variables)
return rendered
7 changes: 6 additions & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def create_session_query(
user_id: UUID | None,
situation: str | None,
metadata: dict = {},
render_templates: bool = False,
) -> tuple[str, dict]:
"""
Constructs and executes a datalog query to create a new session in the database.
Expand All @@ -28,6 +29,7 @@ def create_session_query(
- user_id (UUID | None): The unique identifier for the user, if applicable.
- situation (str | None): The situation/context of the session.
- metadata (dict): Additional metadata for the session.
- render_templates (bool): Specifies whether to render templates.
Returns:
- pd.DataFrame: The result of the query execution.
Expand All @@ -52,18 +54,20 @@ def create_session_query(
}
} {
# Insert the new session data into the 'session' table with the specified columns.
?[session_id, developer_id, situation, metadata] <- [[
?[session_id, developer_id, situation, metadata, render_templates] <- [[
$session_id,
$developer_id,
$situation,
$metadata,
$render_templates,
]]
:insert sessions {
developer_id,
session_id,
situation,
metadata,
render_templates,
}
# Specify the data to return after the query execution, typically the newly created session's ID.
:returning
Expand All @@ -79,5 +83,6 @@ def create_session_query(
"developer_id": str(developer_id),
"situation": situation,
"metadata": metadata,
"render_templates": render_templates,
},
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_session_query(
updated_at,
created_at,
metadata,
render_templates,
] := input[developer_id, id],
*sessions{
developer_id,
Expand All @@ -49,6 +50,7 @@ def get_session_query(
created_at,
updated_at: validity,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
]


# TODO: Add support for updating `render_templates` field


@cozo_query
def patch_session_query(
session_id: UUID,
Expand All @@ -41,6 +44,7 @@ def patch_session_query(
},
session_id = to_uuid($session_id),
developer_id = to_uuid($developer_id),
# Assertion to ensure the session exists before updating.
:assert some
"""
Expand Down
14 changes: 8 additions & 6 deletions agents-api/agents_api/models/session/session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def session_data_query(
agent_about,
model,
default_settings,
session_metadata,
users_metadata,
agents_metadata,
metadata,
render_templates,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
*sessions{
developer_id,
Expand All @@ -53,7 +54,8 @@ def session_data_query(
summary,
created_at,
updated_at: validity,
metadata: session_metadata,
metadata,
render_templates,
@ "NOW"
},
*session_lookup{
Expand All @@ -65,14 +67,14 @@ def session_data_query(
user_id,
name: user_name,
about: user_about,
metadata: users_metadata,
metadata: user_metadata,
},
*agents{
agent_id,
name: agent_name,
about: agent_about,
model,
metadata: agents_metadata,
metadata: agent_metadata,
},
*agent_default_settings {
agent_id,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"developer_id",
]

# TODO: Add support for updating `render_templates` field


@cozo_query
def update_session_query(
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def create_session(
user_id=request.user_id,
situation=request.situation,
metadata=request.metadata or {},
render_templates=request.render_templates or False,
)

return ResourceCreatedResponse(
Expand Down
67 changes: 52 additions & 15 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import UUID4

from agents_api.clients.embed import embed
from agents_api.env import summarization_tokens_threshold
from agents_api.clients.temporal import run_summarization_task
from agents_api.models.entry.add_entries import add_entries_query
from agents_api.common.protocol.entries import Entry
from agents_api.common.exceptions.sessions import SessionNotFoundError
from agents_api.clients.worker.types import ChatML
from agents_api.models.session.session_data import get_session_data
from agents_api.models.entry.proc_mem_context import proc_mem_context_query
from agents_api.autogen.openapi_model import InputChatMLMessage, Tool
from agents_api.model_registry import (
from ...autogen.openapi_model import InputChatMLMessage, Tool
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
from ...common.exceptions.sessions import SessionNotFoundError
from ...common.protocol.entries import Entry
from ...common.protocol.sessions import SessionData
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...model_registry import (
get_extra_settings,
get_model_client,
load_context,
)
from ...common.protocol.sessions import SessionData
from .protocol import Settings
from ...models.entry.add_entries import add_entries_query
from ...models.entry.proc_mem_context import proc_mem_context_query
from ...models.session.session_data import get_session_data

from .exceptions import InputTooBigError
from .protocol import Settings


THOUGHTS_STRIP_LEN = 2
Expand Down Expand Up @@ -118,18 +120,22 @@ async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None]:
# TODO: implement locking at some point

# Get session data
session_data = get_session_data(self.developer_id, self.session_id)
if session_data is None:
raise SessionNotFoundError(self.developer_id, self.session_id)

# Assemble context
init_context, final_settings = await self.forward(
session_data, new_input, settings
)

# Generate response
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
)

# Save response to session
# if final_settings.get("remember"):
# await self.add_to_session(new_input, response)
Expand Down Expand Up @@ -195,10 +201,11 @@ async def forward(
)

entries: list[Entry] = []
instructions = "IMPORTANT INSTRUCTIONS:\n\n"
instructions = "Instructions:\n\n"
first_instruction_idx = -1
first_instruction_created_at = 0
tools = []

for idx, row in proc_mem_context_query(
session_id=self.session_id,
tool_query_embedding=tool_query_embedding,
Expand All @@ -224,7 +231,7 @@ async def forward(
first_instruction_idx = idx
first_instruction_created_at = row["created_at"]

instructions += f"- {row['content']}\n"
instructions += f"{row['content']}\n\n"

continue

Expand Down Expand Up @@ -266,6 +273,36 @@ async def forward(
if e.content
]

# If render_templates=True, render the templates
if session_data is not None and session_data.render_templates:

template_data = {
"session": {
"id": session_data.session_id,
"situation": session_data.situation,
"metadata": session_data.metadata,
},
"user": {
"id": session_data.user_id,
"name": session_data.user_name,
"about": session_data.user_about,
"metadata": session_data.user_metadata,
},
"agent": {
"id": session_data.agent_id,
"name": session_data.agent_name,
"about": session_data.agent_about,
"metadata": session_data.agent_metadata,
},
}

for i, msg in enumerate(messages):
# Only render templates for system/assistant messages
if msg.role not in ["system", "assistant"]:
continue

messages[i].content = await render_template(msg.content, template_data)

# FIXME: This sometimes returns "The model `` does not exist."
if session_data is not None:
settings.model = session_data.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,7 @@ def run(client, *queries):
query = joiner.join(queries)
query = f"{{\n{query}\n}}"

try:
client.run(query)
except Exception as error:
print(error)
import pdb

pdb.set_trace()
client.run(query)


def up(client):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# /usr/bin/env python3

MIGRATION_ID = "session_render_templates"
CREATED_AT = 1714119679.493182

extend_sessions = {
"up": """
?[render_templates, developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{
session_id,
updated_at,
situation,
summary,
created_at,
developer_id
},
metadata = {},
render_templates = false
:replace sessions {
developer_id: Uuid,
session_id: Uuid,
updated_at: Validity default [floor(now()), true],
=>
situation: String,
summary: String? default null,
created_at: Float default now(),
metadata: Json default {},
render_templates: Bool default false,
}
""",
"down": """
?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{
session_id,
updated_at,
situation,
summary,
created_at,
developer_id
}, metadata = {}
:replace sessions {
developer_id: Uuid,
session_id: Uuid,
updated_at: Validity default [floor(now()), true],
=>
situation: String,
summary: String? default null,
created_at: Float default now(),
metadata: Json default {},
}
""",
}


queries_to_run = [
extend_sessions,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading

0 comments on commit 2850cf4

Please sign in to comment.