Skip to content

Commit

Permalink
Merge pull request #202 from AI21Labs/chat-model-with-tools-docs-and-…
Browse files Browse the repository at this point in the history
…response-format

feat: Chat model with tools docs and response format
  • Loading branch information
amirai21 committed Aug 20, 2024
2 parents e20978a + d95fc69 commit 9f4ed9b
Show file tree
Hide file tree
Showing 30 changed files with 673 additions and 51 deletions.
45 changes: 38 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ messages = [

chat_completions = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)
```

Expand Down Expand Up @@ -207,7 +207,7 @@ client = AsyncAI21Client(
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)

print(response)
Expand All @@ -227,8 +227,9 @@ A more detailed example can be found [here](examples/studio/chat/chat_completion
### Supported Models:

- j2-light
- j2-mid
- j2-ultra
- [j2-ultra](#Chat)
- [j2-mid](#Completion)
- [jamba-instruct](#Chat-Completion)

you can read more about the models [here](https://docs.ai21.com/reference/j2-complete-api-ref#jurassic-2-models).

Expand Down Expand Up @@ -270,6 +271,36 @@ completion_response = client.completion.create(
)
```

### Chat Completion

```python
from ai21 import AI21Client
from ai21.models.chat import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
ChatMessage(content=system, role="system"),
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"),
ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"),
]

client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct",
max_tokens=100,
temperature=0.7,
top_p=1.0,
stop=["\n"],
)

print(response)
```

Note that jamba-instruct supports async streaming as well.

</details>

For a more detailed example, see the completion [examples](examples/studio/completion.py).
Expand All @@ -290,7 +321,7 @@ client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-instruct",
stream=True,
)
for chunk in response:
Expand All @@ -314,7 +345,7 @@ client = AsyncAI21Client()
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
stream=True,
)
async for chunk in response:
Expand Down Expand Up @@ -700,7 +731,7 @@ messages = [
]

response = client.chat.completions.create(
model="jamba-instruct",
model="jamba-1.5-mini",
messages=messages,
)
```
Expand Down
17 changes: 16 additions & 1 deletion ai21/clients/studio/resources/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.async_stream import AsyncStream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -23,7 +26,10 @@ async def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
stream: Optional[False] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -39,6 +45,9 @@ async def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> AsyncStream[ChatCompletionChunk]:
pass
Expand All @@ -53,6 +62,9 @@ async def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | AsyncStream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +82,9 @@ async def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
13 changes: 13 additions & 0 deletions ai21/clients/studio/resources/chat/base_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List, Optional, Union, Any, Dict, Literal

from ai21.models.chat import ChatMessage
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.types import NotGiven
from ai21.utils.typing import remove_not_given
from ai21.models._pydantic_compatibility import _to_dict
Expand Down Expand Up @@ -40,6 +43,9 @@ def _create_body(
stop: Optional[Union[str, List[str]]] | NotGiven,
n: Optional[int] | NotGiven,
stream: Literal[False] | Literal[True] | NotGiven,
tools: List[ToolDefinition] | NotGiven,
response_format: ResponseFormat | NotGiven,
documents: List[DocumentSchema] | NotGiven,
**kwargs: Any,
) -> Dict[str, Any]:
return remove_not_given(
Expand All @@ -52,6 +58,13 @@ def _create_body(
"stop": stop,
"n": n,
"stream": stream,
"tools": tools,
"response_format": (
_to_dict(response_format) if not isinstance(response_format, NotGiven) else response_format
),
"documents": (
[_to_dict(document) for document in documents] if not isinstance(documents, NotGiven) else documents
),
**kwargs,
}
)
15 changes: 15 additions & 0 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.stream import Stream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -24,6 +27,9 @@ def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN,
response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN,
documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -39,6 +45,9 @@ def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN,
response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN,
documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN,
**kwargs: Any,
) -> Stream[ChatCompletionChunk]:
pass
Expand All @@ -53,6 +62,9 @@ def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | Stream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +82,9 @@ def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion ai21/http_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def execute_http_request(

if response.status_code != httpx.codes.OK:
_logger.error(
f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}"
f"Calling {method} {self._base_url} failed with a non-200 "
f"response code: {response.status_code} headers: {response.headers}"
)
handle_non_success_response(response.status_code, response.text)

Expand Down
21 changes: 20 additions & 1 deletion ai21/models/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from .chat_completion_response import ChatCompletionResponse
from .chat_completion_response import ChatCompletionResponseChoice
from .chat_message import ChatMessage
from .chat_message import ChatMessage, AssistantMessage, ToolMessage, UserMessage, SystemMessage
from .document_schema import DocumentSchema
from .function_tool_definition import FunctionToolDefinition
from .response_format import ResponseFormat
from .role_type import RoleType as RoleType
from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta
from .tool_call import ToolCall
from .tool_defintions import ToolDefinition
from .tool_function import ToolFunction

__all__ = [
"ChatCompletionResponse",
Expand All @@ -14,4 +20,17 @@
"ChatCompletionChunk",
"ChoicesChunk",
"ChoiceDelta",
"AssistantMessage",
"ToolMessage",
"UserMessage",
"SystemMessage",
"DocumentSchema",
"FunctionToolDefinition",
"ResponseFormat",
"ToolCall",
"ToolDefinition",
"ToolFunction",
"ToolParameters",
]

from .tool_parameters import ToolParameters
4 changes: 2 additions & 2 deletions ai21/models/chat/chat_completion_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.logprobs import Logprobs
from ai21.models.usage_info import UsageInfo
from .chat_message import ChatMessage
from .chat_message import AssistantMessage


class ChatCompletionResponseChoice(AI21BaseModel):
index: int
message: ChatMessage
message: AssistantMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[str] = None

Expand Down
22 changes: 21 additions & 1 deletion ai21/models/chat/chat_message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from __future__ import annotations
from typing import Literal, List, Optional

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.chat.tool_call import ToolCall


class ChatMessage(AI21BaseModel):
role: str
content: str


class AssistantMessage(ChatMessage):
role: Literal["assistant"] = "assistant"
tool_calls: Optional[List[ToolCall]] = None
content: Optional[str] = None


class ToolMessage(ChatMessage):
role: Literal["tool"] = "tool"
tool_call_id: str


class UserMessage(ChatMessage):
role: Literal["user"] = "user"


class SystemMessage(ChatMessage):
role: Literal["system"] = "system"
9 changes: 9 additions & 0 deletions ai21/models/chat/document_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Optional, Dict

from ai21.models.ai21_base_model import AI21BaseModel


class DocumentSchema(AI21BaseModel):
content: str
id: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
9 changes: 9 additions & 0 deletions ai21/models/chat/function_tool_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing_extensions import TypedDict, Required

from ai21.models.chat.tool_parameters import ToolParameters


class FunctionToolDefinition(TypedDict, total=False):
name: Required[str]
description: str
parameters: ToolParameters
7 changes: 7 additions & 0 deletions ai21/models/chat/response_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Literal

from ai21.models.ai21_base_model import AI21BaseModel


class ResponseFormat(AI21BaseModel):
type: Literal["text", "json_object"]
2 changes: 2 additions & 0 deletions ai21/models/chat/role_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
class RoleType(str, Enum):
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = "system"
10 changes: 10 additions & 0 deletions ai21/models/chat/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Literal

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.chat.tool_function import ToolFunction


class ToolCall(AI21BaseModel):
id: str
function: ToolFunction
type: Literal["function"] = "function"
8 changes: 8 additions & 0 deletions ai21/models/chat/tool_defintions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing_extensions import Literal, TypedDict, Required

from ai21.models.chat import FunctionToolDefinition


class ToolDefinition(TypedDict, total=False):
type: Required[Literal["function"]]
function: Required[FunctionToolDefinition]
6 changes: 6 additions & 0 deletions ai21/models/chat/tool_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ai21.models.ai21_base_model import AI21BaseModel


class ToolFunction(AI21BaseModel):
name: str
arguments: str
7 changes: 7 additions & 0 deletions ai21/models/chat/tool_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing_extensions import Literal, Any, Dict, List, TypedDict, Required


class ToolParameters(TypedDict, total=False):
type: Literal["object"]
properties: Required[Dict[str, Any]]
required: List[str]
2 changes: 1 addition & 1 deletion examples/studio/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
max_tokens=100,
temperature=0.7,
top_p=1.0,
Expand Down
Loading

0 comments on commit 9f4ed9b

Please sign in to comment.