Skip to content

Commit

Permalink
feat: sync chat_ctx for openai RealtimeModel from and to the remote r…
Browse files Browse the repository at this point in the history
…ealtime session (#1015)
  • Loading branch information
longcw authored Nov 12, 2024
1 parent ba74250 commit 74f00c3
Show file tree
Hide file tree
Showing 13 changed files with 884 additions and 35 deletions.
6 changes: 6 additions & 0 deletions .changeset/smooth-moons-joke.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-openai": minor
"livekit-agents": minor
---

sync the Realtime API converstation items and add set_chat_ctx
25 changes: 25 additions & 0 deletions examples/multimodal_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
from typing import Annotated

Expand Down Expand Up @@ -67,6 +68,19 @@ async def get_weather(
# fnc_ctx=fnc_ctx,
# )

# create a chat context with chat history
chat_ctx = llm.ChatContext()
chat_ctx.append(text="I'm planning a trip to Paris next month.", role="user")
chat_ctx.append(
text="How exciting! Paris is a beautiful city. I'd be happy to suggest some must-visit places and help you plan your trip.",
role="assistant",
)
chat_ctx.append(text="What are the must-visit places in Paris?", role="user")
chat_ctx.append(
text="The must-visit places in Paris are the Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, and Montmartre.",
role="assistant",
)

agent = multimodal.MultimodalAgent(
model=openai.realtime.RealtimeModel(
voice="alloy",
Expand All @@ -77,9 +91,20 @@ async def get_weather(
),
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
)
agent.start(ctx.room, participant)

@agent.on("agent_speech_committed")
@agent.on("agent_speech_interrupted")
def _on_agent_speech_created(msg: llm.ChatMessage):
# example of truncating the chat context
max_ctx_len = 10
chat_ctx = agent.chat_ctx_copy()
if len(chat_ctx.messages) > max_ctx_len:
chat_ctx.messages = chat_ctx.messages[-max_ctx_len:]
asyncio.create_task(agent.set_chat_ctx(chat_ctx))


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM))
10 changes: 9 additions & 1 deletion livekit-agents/livekit/agents/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from . import _oai_api
from .chat_context import ChatAudio, ChatContext, ChatImage, ChatMessage, ChatRole
from .chat_context import (
ChatAudio,
ChatContent,
ChatContext,
ChatImage,
ChatMessage,
ChatRole,
)
from .function_context import (
USE_DOCSTRING,
CalledFunction,
Expand Down Expand Up @@ -27,6 +34,7 @@
"ChatMessage",
"ChatAudio",
"ChatImage",
"ChatContent",
"ChatContext",
"ChoiceDelta",
"Choice",
Expand Down
17 changes: 13 additions & 4 deletions livekit-agents/livekit/agents/llm/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Literal, Union

from livekit import rtc
from livekit.agents import utils

from . import function_context

Expand Down Expand Up @@ -45,7 +46,9 @@ class ChatAudio:
@dataclass
class ChatMessage:
role: ChatRole
id: str | None = None # used by the OAI realtime API
id: str = field(
default_factory=lambda: utils.shortuuid("item_")
) # used by the OAI realtime API
name: str | None = None
content: ChatContent | list[ChatContent] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
Expand Down Expand Up @@ -86,10 +89,15 @@ def create_tool_calls(

@staticmethod
def create(
*, text: str = "", images: list[ChatImage] = [], role: ChatRole = "system"
*,
text: str = "",
images: list[ChatImage] = [],
role: ChatRole = "system",
id: str | None = None,
) -> "ChatMessage":
id = id or utils.shortuuid("item_")
if len(images) == 0:
return ChatMessage(role=role, content=text)
return ChatMessage(role=role, content=text, id=id)
else:
content: list[ChatContent] = []
if text:
Expand All @@ -98,7 +106,7 @@ def create(
if len(images) > 0:
content.extend(images)

return ChatMessage(role=role, content=content)
return ChatMessage(role=role, content=content, id=id)

def copy(self):
content = self.content
Expand All @@ -111,6 +119,7 @@ def copy(self):

copied_msg = ChatMessage(
role=self.role,
id=self.id,
name=self.name,
content=content,
tool_calls=tool_calls,
Expand Down
54 changes: 51 additions & 3 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"user_speech_committed",
"agent_speech_committed",
"agent_speech_interrupted",
"function_calls_collected",
"function_calls_finished",
]


Expand Down Expand Up @@ -108,6 +110,12 @@ def fnc_ctx(self) -> llm.FunctionContext | None:
def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
self._session.fnc_ctx = value

def chat_ctx_copy(self) -> llm.ChatContext:
return self._session.chat_ctx_copy()

async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
await self._session.set_chat_ctx(ctx)

def start(
self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None
) -> None:
Expand All @@ -134,12 +142,30 @@ def start(
self._session = self._model.session(
chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx
)
self._main_atask = asyncio.create_task(self._main_task())

# Create a task to wait for initialization and start the main task
async def _init_and_start():
try:
await self._session._init_sync_task
logger.info("Session initialized with chat context")
self._main_atask = asyncio.create_task(self._main_task())
except Exception as e:
logger.exception("Failed to initialize session")
raise e

# Schedule the initialization and start task
asyncio.create_task(_init_and_start())

from livekit.plugins.openai import realtime

@self._session.on("response_content_added")
def _on_content_added(message: realtime.RealtimeContent):
if message.content_type == "text":
logger.warning(
"The realtime API returned a text content part, which is not supported"
)
return

tr_fwd = transcription.TTSSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
Expand Down Expand Up @@ -176,7 +202,13 @@ def _input_speech_transcription_completed(
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
user_msg = ChatMessage.create(text=ev.transcript, role="user")
user_msg = ChatMessage.create(
text=ev.transcript, role="user", id=ev.item_id
)
self._session._update_converstation_item_content(
ev.item_id, user_msg.content
)

self.emit("user_speech_committed", user_msg)
logger.debug(
"committed user speech",
Expand All @@ -200,6 +232,14 @@ def _input_speech_started():
def _input_speech_stopped():
self.emit("user_stopped_speaking")

@self._session.on("function_calls_collected")
def _function_calls_collected(fnc_call_infos: list[llm.FunctionCallInfo]):
self.emit("function_calls_collected", fnc_call_infos)

@self._session.on("function_calls_finished")
def _function_calls_finished(called_fncs: list[llm.CalledFunction]):
self.emit("function_calls_finished", called_fncs)

def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""

Expand Down Expand Up @@ -238,7 +278,15 @@ def _on_playout_stopped(interrupted: bool) -> None:
if interrupted:
collected_text += "..."

msg = ChatMessage.create(text=collected_text, role="assistant")
msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
self._session._update_converstation_item_content(
self._playing_handle.item_id, msg.content
)

if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
Expand Down
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from livekit import rtc

from . import aio, audio, codecs, http_context, hw, images
from ._message_change import compute_changes as _compute_changes # keep internal
from .audio import AudioBuffer, combine_frames, merge_frames
from .exp_filter import ExpFilter
from .log import log_exceptions
Expand All @@ -25,4 +26,5 @@
"audio",
"aio",
"hw",
"_compute_changes",
]
Loading

0 comments on commit 74f00c3

Please sign in to comment.