Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cache generated responses #336

Merged
merged 5 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion agents-api/agents_api/clients/worker/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, Literal, Optional, Protocol
from uuid import UUID

from pydantic import BaseModel


Expand Down
11 changes: 11 additions & 0 deletions agents-api/agents_api/models/session/get_cached_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ..utils import cozo_query


@cozo_query
def get_cached_response(key: str) -> tuple[str, dict]:
query = """
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
input[key] <- [[$key]]
?[key, value] := input[key], *session_cache{key, value}
"""

return (query, {"key": key})
16 changes: 16 additions & 0 deletions agents-api/agents_api/models/session/set_cached_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..utils import cozo_query


@cozo_query
def set_cached_response(key: str, value: dict) -> tuple[str, dict]:
query = """
?[key, value] <- [[$key, $value]]

:insert session_cache {
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
key => value
}

:returning
"""

return (query, {"key": key, "value": value})
30 changes: 30 additions & 0 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import xxhash
from functools import reduce
from json import JSONDecodeError
from typing import Callable
Expand All @@ -19,6 +20,7 @@
from ...common.protocol.entries import Entry
from ...common.protocol.sessions import SessionData
from ...common.utils.template import render_template
from ...common.utils.json import CustomJSONEncoder
from ...env import (
summarization_tokens_threshold,
docs_embedding_service_url,
Expand All @@ -32,6 +34,8 @@
from ...models.entry.add_entries import add_entries_query
from ...models.entry.proc_mem_context import proc_mem_context_query
from ...models.session.session_data import get_session_data
from ...models.session.get_cached_response import get_cached_response
from ...models.session.set_cached_response import set_cached_response

from .exceptions import InputTooBigError
from .protocol import Settings
Expand All @@ -52,6 +56,30 @@
)


def cache(f):
async def wrapper(
self, init_context: list[ChatML], settings: Settings
) -> ChatCompletion:
key = xxhash.xxh64(
json.dumps(
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
{
"init_context": [c.model_dump() for c in init_context],
"settings": settings.model_dump(),
},
cls=CustomJSONEncoder, default_empty_value="",
)
).hexdigest()
result = get_cached_response(key=key)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
if not result.size:
resp = await f(self, init_context, settings)
set_cached_response(key=key, value=resp.model_dump())
return resp
choices = result.iloc[0].to_dict()["value"]
return ChatCompletion(**choices)

return wrapper


@dataclass
class BaseSession:
session_id: UUID4
Expand Down Expand Up @@ -333,6 +361,7 @@ async def forward(

return messages, settings, doc_ids

@cache
async def generate(
self, init_context: list[ChatML], settings: Settings
) -> ChatCompletion:
Expand Down Expand Up @@ -370,6 +399,7 @@ async def generate(
api_key=api_key,
**extra_body,
)

alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
return res

async def backward(
Expand Down
33 changes: 33 additions & 0 deletions agents-api/migrations/migrate_1716013793_session_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# /usr/bin/env python3

MIGRATION_ID = "session_cache"
CREATED_AT = 1716013793.746602


session_cache = dict(
up="""
:create session_cache {
key: String,
=>
value: Json,
}
""",
down="""
::remove session_cache
""",
)


queries_to_run = [
session_cache,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading
Loading