Skip to content

Commit

Permalink
Introduce function calling to OpenAI Assistants (#710)
Browse files Browse the repository at this point in the history
Co-authored-by: Théo Monnom <[email protected]>
  • Loading branch information
keepingitneil and theomonnom authored Sep 9, 2024
1 parent 2124888 commit 9b55f2b
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changeset/little-timers-kiss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-openai": patch
---

Introduce function calling to OpenAI Assistants
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from __future__ import annotations

import asyncio
import json
import uuid
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Dict, MutableSet, Union

import httpx
Expand All @@ -25,10 +26,15 @@
from openai import AsyncAssistantEventHandler, AsyncClient
from openai.types.beta.threads import Text, TextDelta
from openai.types.beta.threads.run_create_params import AdditionalMessage
from openai.types.beta.threads.runs import ToolCall, ToolCallDelta
from openai.types.beta.threads.runs import (
CodeInterpreterToolCall,
FileSearchToolCall,
FunctionToolCall,
ToolCall,
)

from ..log import logger
from ..models import AssistantTools, ChatModels
from ..models import ChatModels
from ..utils import build_oai_message

DEFAULT_MODEL = "gpt-4o"
Expand Down Expand Up @@ -56,13 +62,15 @@ class AssistantCreateOptions:
instructions: str
model: ChatModels
temperature: float | None = None
tools: list[AssistantTools] = field(default_factory=list)
# TODO: when we implement code_interpreter and file_search tools
# tool_resources: ToolResources | None = None
# tools: list[AssistantTools] = field(default_factory=list)


@dataclass
class AssistantLoadOptions:
assistant_id: str
thread_id: str
thread_id: str | None


class AssistantLLM(llm.LLM):
Expand Down Expand Up @@ -95,7 +103,13 @@ def __init__(
self._assistant_opts = assistant_opts
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()

self._sync_openai_task = asyncio.create_task(self._sync_openai())
self._sync_openai_task: asyncio.Task[AssistantLoadOptions] | None = None
try:
self._sync_openai_task = asyncio.create_task(self._sync_openai())
except Exception:
logger.error(
"failed to create sync openai task. This can happen when instantiating without a running asyncio event loop (such has when running tests)"
)
self._done_futures = list[asyncio.Future[None]]()

async def _sync_openai(self) -> AssistantLoadOptions:
Expand All @@ -104,17 +118,26 @@ async def _sync_openai(self) -> AssistantLoadOptions:
"model": self._assistant_opts.create_options.model,
"name": self._assistant_opts.create_options.name,
"instructions": self._assistant_opts.create_options.instructions,
"tools": [
{"type": t} for t in self._assistant_opts.create_options.tools
],
# "tools": [
# {"type": t} for t in self._assistant_opts.create_options.tools
# ],
# "tool_resources": self._assistant_opts.create_options.tool_resources,
}
# TODO when we implement code_interpreter and file_search tools
# if self._assistant_opts.create_options.tool_resources:
# kwargs["tool_resources"] = (
# self._assistant_opts.create_options.tool_resources
# )
if self._assistant_opts.create_options.temperature:
kwargs["temperature"] = self._assistant_opts.create_options.temperature
assistant = await self._client.beta.assistants.create(**kwargs)

thread = await self._client.beta.threads.create()
return AssistantLoadOptions(assistant_id=assistant.id, thread_id=thread.id)
elif self._assistant_opts.load_options:
if not self._assistant_opts.load_options.thread_id:
thread = await self._client.beta.threads.create()
self._assistant_opts.load_options.thread_id = thread.id
return self._assistant_opts.load_options

raise Exception("One of create_options or load_options must be set")
Expand All @@ -136,6 +159,9 @@ def chat(
"OpenAI Assistants does not support the 'parallel_tool_calls' parameter"
)

if not self._sync_openai_task:
self._sync_openai_task = asyncio.create_task(self._sync_openai())

return AssistantLLMStream(
temperature=temperature,
assistant_llm=self,
Expand All @@ -150,12 +176,16 @@ class AssistantLLMStream(llm.LLMStream):
class EventHandler(AsyncAssistantEventHandler):
def __init__(
self,
llm_stream: AssistantLLMStream,
output_queue: asyncio.Queue[llm.ChatChunk | Exception | None],
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None = None,
):
super().__init__()
self._llm_stream = llm_stream
self._chat_ctx = chat_ctx
self._output_queue = output_queue
self._fnc_ctx = fnc_ctx

async def on_text_delta(self, delta: TextDelta, snapshot: Text):
self._output_queue.put_nowait(
Expand All @@ -171,9 +201,36 @@ async def on_text_delta(self, delta: TextDelta, snapshot: Text):
async def on_tool_call_created(self, tool_call: ToolCall):
print(f"\nassistant > {tool_call.type}\n", flush=True)

async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall):
if delta.type == "code_interpreter":
pass
async def on_tool_call_done(
self,
tool_call: CodeInterpreterToolCall | FileSearchToolCall | FunctionToolCall,
) -> None:
if tool_call.type == "code_interpreter":
logger.warning("code interpreter tool call not yet implemented")
elif tool_call.type == "file_search":
logger.warning("file_search tool call not yet implemented")
elif tool_call.type == "function":
if not self._fnc_ctx:
logger.error("function tool called without function context")
return

fnc = llm.FunctionCallInfo(
function_info=self._fnc_ctx.ai_functions[tool_call.function.name],
arguments=json.loads(tool_call.function.arguments),
tool_call_id=tool_call.id,
raw_arguments=tool_call.function.arguments,
)

self._llm_stream._function_calls_info.append(fnc)
chunk = llm.ChatChunk(
choices=[
llm.Choice(
delta=llm.ChoiceDelta(role="assistant", tool_calls=[fnc]),
index=0,
)
]
)
self._output_queue.put_nowait(chunk)

def __init__(
self,
Expand Down Expand Up @@ -250,6 +307,7 @@ async def _create_stream(self) -> None:
msg_id = msg._metadata.get(OPENAI_MESSAGE_ID_KEY, {}).get(
load_options.thread_id
)
assert load_options.thread_id
if msg_id and msg_id not in openai_addded_messages_set:
await self._client.beta.threads.messages.delete(
thread_id=load_options.thread_id,
Expand Down Expand Up @@ -289,15 +347,26 @@ async def _create_stream(self) -> None:
lk_msg_id_dict[load_options.thread_id] = msg_id

eh = AssistantLLMStream.EventHandler(
self._output_queue, chat_ctx=self._chat_ctx
output_queue=self._output_queue,
chat_ctx=self._chat_ctx,
fnc_ctx=self._fnc_ctx,
llm_stream=self,
)
async with self._client.beta.threads.runs.stream(
additional_messages=additional_messages,
thread_id=load_options.thread_id,
assistant_id=load_options.assistant_id,
event_handler=eh,
temperature=self._temperature,
) as stream:
assert load_options.thread_id
kwargs: dict[str, Any] = {
"additional_messages": additional_messages,
"thread_id": load_options.thread_id,
"assistant_id": load_options.assistant_id,
"event_handler": eh,
"temperature": self._temperature,
}
if self._fnc_ctx:
kwargs["tools"] = [
llm._oai_api.build_oai_function_description(f)
for f in self._fnc_ctx.ai_functions.values()
]

async with self._client.beta.threads.runs.stream(**kwargs) as stream:
await stream.until_done()

await self._output_queue.put(None)
Expand Down Expand Up @@ -335,9 +404,6 @@ async def _create_stream(self) -> None:
finally:
self._done_future.set_result(None)

async def aclose(self) -> None:
pass

async def __anext__(self):
item = await self._output_queue.get()
if item is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"
]

AssistantTools = Literal["code_interpreter"]
AssistantTools = Literal["code_interpreter", "file_search", "function"]

# adapters for OpenAI-compatible LLMs

Expand Down
74 changes: 49 additions & 25 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import asyncio
import uuid
from enum import Enum
from typing import Annotated, Optional
from typing import Annotated, Callable, Optional

import pytest
from livekit.agents import llm
Expand Down Expand Up @@ -87,26 +88,37 @@ def test_hashable_typeinfo():
hash(typeinfo)


LLMS = [
LLMS: list[llm.LLM | Callable[[], llm.LLM]] = [
openai.LLM(),
lambda: openai.beta.AssistantLLM(
assistant_opts=openai.beta.AssistantOptions(
create_options=openai.beta.AssistantCreateOptions(
name=f"test-{uuid.uuid4()}",
instructions="You are a basic assistant",
model="gpt-4o",
)
)
),
# anthropic.LLM(),
]


@pytest.mark.parametrize("llm", LLMS)
async def test_chat(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_chat(input_llm: llm.LLM | Callable[[], llm.LLM]):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
chat_ctx = ChatContext().append(
text='You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.'
)

# Anthropics LLM requires at least one message (system messages don't count)
if isinstance(llm, anthropic.LLM):
if isinstance(input_llm, anthropic.LLM):
chat_ctx.append(
text="Hello",
role="user",
)

stream = llm.chat(chat_ctx=chat_ctx)
stream = input_llm.chat(chat_ctx=chat_ctx)
text = ""
async for chunk in stream:
content = chunk.choices[0].delta.content
Expand All @@ -116,12 +128,14 @@ async def test_chat(llm: llm.LLM):
assert len(text) > 0


@pytest.mark.parametrize("llm", LLMS)
async def test_basic_fnc_calls(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_basic_fnc_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
llm,
input_llm,
"What's the weather in San Francisco and what's the weather Paris?",
fnc_ctx,
)
Expand All @@ -131,8 +145,10 @@ async def test_basic_fnc_calls(llm: llm.LLM):
assert len(calls) == 2, "get_weather should be called twice"


@pytest.mark.parametrize("llm", LLMS)
async def test_runtime_addition(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_runtime_addition(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()
called_msg = ""

Expand All @@ -144,7 +160,7 @@ async def show_message(
called_msg = message

stream = await _request_fnc_call(
llm, "Can you show 'Hello LiveKit!' on the screen?", fnc_ctx
input_llm, "Can you show 'Hello LiveKit!' on the screen?", fnc_ctx
)
fns = stream.execute_functions()
await asyncio.gather(*[f.task for f in fns])
Expand All @@ -153,15 +169,17 @@ async def show_message(
assert called_msg == "Hello LiveKit!", "send_message should be called"


@pytest.mark.parametrize("llm", LLMS)
async def test_cancelled_calls(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_cancelled_calls(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
llm, "Turn off the lights in the Theo's bedroom", fnc_ctx
input_llm, "Turn off the lights in the Theo's bedroom", fnc_ctx
)
calls = stream.execute_functions()
await asyncio.sleep(0) # wait for the loop executor to start the task
await asyncio.sleep(0.2) # wait for the loop executor to start the task

# don't wait for gather_function_results and directly close (this should cancel the ongoing calls)
await stream.aclose()
Expand All @@ -172,12 +190,14 @@ async def test_cancelled_calls(llm: llm.LLM):
), "toggle_light should have been cancelled"


@pytest.mark.parametrize("llm", LLMS)
async def test_calls_arrays(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_calls_arrays(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
llm,
input_llm,
"Can you select all currencies in Europe at once?",
fnc_ctx,
temperature=0.2,
Expand All @@ -196,11 +216,13 @@ async def test_calls_arrays(llm: llm.LLM):
), "select_currencies should have eur, gbp, sek"


@pytest.mark.parametrize("llm", LLMS)
async def test_calls_choices(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_calls_choices(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(llm, "Set the volume to 30", fnc_ctx)
stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()
Expand All @@ -212,12 +234,14 @@ async def test_calls_choices(llm: llm.LLM):
assert volume == 30, "change_volume should have been called with volume 30"


@pytest.mark.parametrize("llm", LLMS)
async def test_optional_args(llm: llm.LLM):
@pytest.mark.parametrize("input_llm", LLMS)
async def test_optional_args(input_llm: Callable[[], llm.LLM] | llm.LLM):
if not isinstance(input_llm, llm.LLM):
input_llm = input_llm()
fnc_ctx = FncCtx()

stream = await _request_fnc_call(
llm, "Can you update my information? My name is Theo", fnc_ctx
input_llm, "Can you update my information? My name is Theo", fnc_ctx
)

calls = stream.execute_functions()
Expand Down

0 comments on commit 9b55f2b

Please sign in to comment.