diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 1f6b82e..398a8a9 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -22,6 +22,7 @@ rustls-webpki-roots = ["reqwest/rustls-tls-webpki-roots"] native-tls = ["reqwest/native-tls"] # Remove dependency on OpenSSL native-tls-vendored = ["reqwest/native-tls-vendored"] +realtime = ["dep:tokio-tungstenite"] [dependencies] backoff = { version = "0.4.0", features = ["tokio"] } @@ -46,6 +47,11 @@ async-convert = "1.0.0" secrecy = { version = "0.8.0", features = ["serde"] } bytes = "1.6.0" eventsource-stream = "0.2.3" +tokio-tungstenite = { version = "0.24.0", optional = true, default-features = false } [dev-dependencies] tokio-test = "0.4.4" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/async-openai/README.md b/async-openai/README.md index 58dd594..970dbd3 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -35,6 +35,7 @@ - [x] Models - [x] Moderations - [ ] Organizations | Administration + - [x] Realtime API types (Beta) - [ ] Uploads - SSE streaming on available APIs - Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits). @@ -58,6 +59,11 @@ $Env:OPENAI_API_KEY='sk-...' - Visit [examples](https://github.com/64bit/async-openai/tree/main/examples) directory on how to use `async-openai`. - Visit [docs.rs/async-openai](https://docs.rs/async-openai) for docs. +## Realtime API + +Only types for Realtime API are imlemented, and can be enabled with feature flag `realtime` +These types may change when OpenAI releases official specs for them. + ## Image Generation Example ```rust diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index 7218ac4..af7fc21 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -76,6 +76,7 @@ //! ## Examples //! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository. //! +#![cfg_attr(docsrs, feature(doc_cfg))] mod assistant_files; mod assistants; mod audio; diff --git a/async-openai/src/types/mod.rs b/async-openai/src/types/mod.rs index d3b8996..90fd5f3 100644 --- a/async-openai/src/types/mod.rs +++ b/async-openai/src/types/mod.rs @@ -17,6 +17,9 @@ mod message; mod message_file; mod model; mod moderation; +#[cfg_attr(docsrs, doc(cfg(feature = "realtime")))] +#[cfg(feature = "realtime")] +pub mod realtime; mod run; mod step; mod thread; diff --git a/async-openai/src/types/realtime/client_event.rs b/async-openai/src/types/realtime/client_event.rs new file mode 100644 index 0000000..d5ec24d --- /dev/null +++ b/async-openai/src/types/realtime/client_event.rs @@ -0,0 +1,220 @@ +use serde::{Deserialize, Serialize}; +use tokio_tungstenite::tungstenite::Message; + +use super::{item::Item, session_resource::SessionResource}; + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct SessionUpdateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + /// Session configuration to update. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferAppendEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + /// Base64-encoded audio bytes. + pub audio: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferCommitEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferClearEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemCreateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the preceding item after which the new item will be inserted. + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_item_id: Option, + + /// The item to add to the conversation. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ConversationItemTruncateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the assistant message item to truncate. + pub item_id: String, + + /// The index of the content part to truncate. + pub content_index: u32, + + /// Inclusive duration up to which audio is truncated, in milliseconds. + pub audio_end_ms: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ConversationItemDeleteEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the item to delete. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ResponseCreateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// Configuration for the response. + pub response: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ResponseCancelEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +/// These are events that the OpenAI Realtime WebSocket server will accept from the client. +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ClientEvent { + /// Send this event to update the session’s default configuration. + #[serde(rename = "session.update")] + SessionUpdate(SessionUpdateEvent), + + /// Send this event to append audio bytes to the input audio buffer. + #[serde(rename = "input_audio_buffer.append")] + InputAudioBufferAppend(InputAudioBufferAppendEvent), + + /// Send this event to commit audio bytes to a user message. + #[serde(rename = "input_audio_buffer.commit")] + InputAudioBufferCommit(InputAudioBufferCommitEvent), + + /// Send this event to clear the audio bytes in the buffer. + #[serde(rename = "input_audio_buffer.clear")] + InputAudioBufferClear(InputAudioBufferClearEvent), + + /// Send this event when adding an item to the conversation. + #[serde(rename = "conversation.item.create")] + ConversationItemCreate(ConversationItemCreateEvent), + + /// Send this event when you want to truncate a previous assistant message’s audio. + #[serde(rename = "conversation.item.truncate")] + ConversationItemTruncate(ConversationItemTruncateEvent), + + /// Send this event when you want to remove any item from the conversation history. + #[serde(rename = "conversation.item.delete")] + ConversationItemDelete(ConversationItemDeleteEvent), + + /// Send this event to trigger a response generation. + #[serde(rename = "response.create")] + ResponseCreate(ResponseCreateEvent), + + /// Send this event to cancel an in-progress response. + #[serde(rename = "response.cancel")] + ResponseCancel(ResponseCancelEvent), +} + +impl From<&ClientEvent> for String { + fn from(value: &ClientEvent) -> Self { + serde_json::to_string(value).unwrap() + } +} + +impl From for Message { + fn from(value: ClientEvent) -> Self { + Message::Text(String::from(&value)) + } +} + +macro_rules! message_from_event { + ($from_typ:ty, $evt_typ:ty) => { + impl From<$from_typ> for Message { + fn from(value: $from_typ) -> Self { + Self::from(<$evt_typ>::from(value)) + } + } + }; +} + +macro_rules! event_from { + ($from_typ:ty, $evt_typ:ty, $variant:ident) => { + impl From<$from_typ> for $evt_typ { + fn from(value: $from_typ) -> Self { + <$evt_typ>::$variant(value) + } + } + }; +} + +event_from!(SessionUpdateEvent, ClientEvent, SessionUpdate); +event_from!( + InputAudioBufferAppendEvent, + ClientEvent, + InputAudioBufferAppend +); +event_from!( + InputAudioBufferCommitEvent, + ClientEvent, + InputAudioBufferCommit +); +event_from!( + InputAudioBufferClearEvent, + ClientEvent, + InputAudioBufferClear +); +event_from!( + ConversationItemCreateEvent, + ClientEvent, + ConversationItemCreate +); +event_from!( + ConversationItemTruncateEvent, + ClientEvent, + ConversationItemTruncate +); +event_from!( + ConversationItemDeleteEvent, + ClientEvent, + ConversationItemDelete +); +event_from!(ResponseCreateEvent, ClientEvent, ResponseCreate); +event_from!(ResponseCancelEvent, ClientEvent, ResponseCancel); + +message_from_event!(SessionUpdateEvent, ClientEvent); +message_from_event!(InputAudioBufferAppendEvent, ClientEvent); +message_from_event!(InputAudioBufferCommitEvent, ClientEvent); +message_from_event!(InputAudioBufferClearEvent, ClientEvent); +message_from_event!(ConversationItemCreateEvent, ClientEvent); +message_from_event!(ConversationItemTruncateEvent, ClientEvent); +message_from_event!(ConversationItemDeleteEvent, ClientEvent); +message_from_event!(ResponseCreateEvent, ClientEvent); +message_from_event!(ResponseCancelEvent, ClientEvent); + +impl From for ConversationItemCreateEvent { + fn from(value: Item) -> Self { + Self { + event_id: None, + previous_item_id: None, + item: value, + } + } +} diff --git a/async-openai/src/types/realtime/content_part.rs b/async-openai/src/types/realtime/content_part.rs new file mode 100644 index 0000000..eec93ab --- /dev/null +++ b/async-openai/src/types/realtime/content_part.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { + /// The text content + text: String, + }, + #[serde(rename = "audio")] + Audio { + /// Base64-encoded audio data + audio: Option, + /// The transcript of the audio + transcript: String, + }, +} diff --git a/async-openai/src/types/realtime/conversation.rs b/async-openai/src/types/realtime/conversation.rs new file mode 100644 index 0000000..3ea43bd --- /dev/null +++ b/async-openai/src/types/realtime/conversation.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Conversation { + /// The unique ID of the conversation. + pub id: String, + + /// The object type, must be "realtime.conversation". + pub object: String, +} diff --git a/async-openai/src/types/realtime/error.rs b/async-openai/src/types/realtime/error.rs new file mode 100644 index 0000000..6ce907c --- /dev/null +++ b/async-openai/src/types/realtime/error.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RealtimeAPIError { + /// The type of error (e.g., "invalid_request_error", "server_error"). + pub r#type: String, + + /// Error code, if any. + pub code: Option, + + /// A human-readable error message. + pub message: String, + + /// Parameter related to the error, if any. + pub param: Option, + + /// The event_id of the client event that caused the error, if applicable. + pub event_id: Option, +} diff --git a/async-openai/src/types/realtime/item.rs b/async-openai/src/types/realtime/item.rs new file mode 100644 index 0000000..3af7d0d --- /dev/null +++ b/async-openai/src/types/realtime/item.rs @@ -0,0 +1,99 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemType { + Message, + FunctionCall, + FunctionCallOutput, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemStatus { + Completed, + InProgress, + Incomplete, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ItemRole { + User, + Assistant, + System, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemContentType { + InputText, + InputAudio, + Text, + Audio, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ItemContent { + /// The content type ("input_text", "input_audio", "text", "audio"). + pub r#type: ItemContentType, + + /// The text content. + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Base64-encoded audio bytes. + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + + /// The transcript of the audio. + #[serde(skip_serializing_if = "Option::is_none")] + pub transcript: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Item { + /// The unique ID of the item. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// The type of the item ("message", "function_call", "function_call_output"). + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, + + /// The status of the item ("completed", "in_progress", "incomplete"). + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + + /// The role of the message sender ("user", "assistant", "system"). + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + + /// The content of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option>, + + /// The ID of the function call (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub call_id: Option, + + /// The name of the function being called (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// The arguments of the function call (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + + /// The output of the function call (for "function_call_output" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, +} + +impl TryFrom for Item { + type Error = serde_json::Error; + + fn try_from(value: serde_json::Value) -> Result { + serde_json::from_value(value) + } +} diff --git a/async-openai/src/types/realtime/mod.rs b/async-openai/src/types/realtime/mod.rs new file mode 100644 index 0000000..b47605f --- /dev/null +++ b/async-openai/src/types/realtime/mod.rs @@ -0,0 +1,19 @@ +mod client_event; +mod content_part; +mod conversation; +mod error; +mod item; +mod rate_limit; +mod response_resource; +mod server_event; +mod session_resource; + +pub use client_event::*; +pub use content_part::*; +pub use conversation::*; +pub use error::*; +pub use item::*; +pub use rate_limit::*; +pub use response_resource::*; +pub use server_event::*; +pub use session_resource::*; diff --git a/async-openai/src/types/realtime/rate_limit.rs b/async-openai/src/types/realtime/rate_limit.rs new file mode 100644 index 0000000..f3fc4aa --- /dev/null +++ b/async-openai/src/types/realtime/rate_limit.rs @@ -0,0 +1,13 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RateLimit { + /// The name of the rate limit ("requests", "tokens", "input_tokens", "output_tokens"). + pub name: String, + /// The maximum allowed value for the rate limit. + pub limit: u32, + /// The remaining value before the limit is reached. + pub remaining: u32, + /// Seconds until the rate limit resets. + pub reset_seconds: f32, +} diff --git a/async-openai/src/types/realtime/response_resource.rs b/async-openai/src/types/realtime/response_resource.rs new file mode 100644 index 0000000..4a50089 --- /dev/null +++ b/async-openai/src/types/realtime/response_resource.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; + +use super::item::Item; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Usage { + pub total_tokens: u32, + pub input_tokens: u32, + pub output_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + InProgress, + Completed, + Cancelled, + Failed, + Incomplete, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FailedError { + pub code: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum IncompleteReason { + Interruption, + MaxOutputTokens, + ContentFilter, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ResponseStatusDetail { + #[serde(rename = "incomplete")] + Incomplete { reason: IncompleteReason }, + #[serde(rename = "failed")] + Failed { error: Option }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseResource { + /// The unique ID of the response. + pub id: String, + /// The object type, must be "realtime.response". + pub object: String, + /// The status of the response + pub status: ResponseStatus, + /// Additional details about the status. + pub status_details: Option, + /// The list of output items generated by the response. + pub output: Vec, + /// Usage statistics for the response. + pub usage: Option, +} diff --git a/async-openai/src/types/realtime/server_event.rs b/async-openai/src/types/realtime/server_event.rs new file mode 100644 index 0000000..3ba5f55 --- /dev/null +++ b/async-openai/src/types/realtime/server_event.rs @@ -0,0 +1,459 @@ +use serde::{Deserialize, Serialize}; + +use super::{ + content_part::ContentPart, conversation::Conversation, error::RealtimeAPIError, item::Item, + rate_limit::RateLimit, response_resource::ResponseResource, session_resource::SessionResource, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ErrorEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Details of the error. + pub error: RealtimeAPIError, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SessionCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The session resource. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SessionUpdatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The updated session resource. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The conversation resource. + pub conversation: Conversation, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferCommitedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the preceding item after which the new item will be inserted. + pub previous_item_id: String, + /// The ID of the user message item that will be created. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferClearedEvent { + /// The unique ID of the server event. + pub event_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferSpeechStartedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Milliseconds since the session started when speech was detected. + pub audio_start_ms: u32, + /// The ID of the user message item that will be created when speech stops. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferSpeechStoppedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Milliseconds since the session started when speech stopped. + pub audio_end_ms: u32, + /// The ID of the user message item that will be created. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the preceding item. + pub previous_item_id: Option, + /// The item that was created. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemInputAudioTranscriptionCompletedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the user message item. + pub item_id: String, + /// The index of the content part containing the audio. + pub content_index: u32, + /// The transcribed text. + pub transcript: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemInputAudioTranscriptionFailedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the user message item. + pub item_id: String, + /// The index of the content part containing the audio. + pub content_index: u32, + /// Details of the transcription error. + pub error: RealtimeAPIError, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemTruncatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the assistant message item that was truncated. + pub item_id: String, + /// The index of the content part that was truncated. + pub content_index: u32, + /// The duration up to which the audio was truncated, in milliseconds. + pub audio_end_ms: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemDeletedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the item that was deleted. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The response resource. + pub response: ResponseResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The response resource. + pub response: ResponseResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseOutputItemAddedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response to which the item belongs. + pub response_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The item that was added. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseOutputItemDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response to which the item belongs. + pub response_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The completed item. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseContentPartAddedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item to which the content part was added. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The content part that was added. + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseContentPartDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item to which the content part was added. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The content part that is done. + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseTextDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The text delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseTextDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The final text content. + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioTranscriptDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The text delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioTranscriptDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + ///The final transcript of the audio. + pub transcript: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// Base64-encoded audio data delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseFunctionCallArgumentsDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the function call item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The ID of the function call. + pub call_id: String, + /// The arguments delta as a JSON string. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseFunctionCallArgumentsDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the function call item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The ID of the function call. + pub call_id: String, + /// The final arguments as a JSON string. + pub arguments: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RateLimitsUpdatedEvent { + /// The unique ID of the server event. + pub event_id: String, + pub rate_limits: Vec, +} + +/// These are events emitted from the OpenAI Realtime WebSocket server to the client. +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ServerEvent { + /// Returned when an error occurs. + #[serde(rename = "error")] + Error(ErrorEvent), + + /// Returned when a session is created. Emitted automatically when a new connection is established. + #[serde(rename = "session.created")] + SessionCreated(SessionCreatedEvent), + + /// Returned when a session is updated. + #[serde(rename = "session.updated")] + SessionUpdated(SessionUpdatedEvent), + + /// Returned when a conversation is created. Emitted right after session creation. + #[serde(rename = "conversation.created")] + ConversationCreated(ConversationCreatedEvent), + + /// Returned when an input audio buffer is committed, either by the client or automatically in server VAD mode. + #[serde(rename = "input_audio_buffer.committed")] + InputAudioBufferCommited(InputAudioBufferCommitedEvent), + + /// Returned when the input audio buffer is cleared by the client. + #[serde(rename = "input_audio_buffer.cleared")] + InputAudioBufferCleared(InputAudioBufferClearedEvent), + + /// Returned in server turn detection mode when speech is detected. + #[serde(rename = "input_audio_buffer.speech_started")] + InputAudioBufferSpeechStarted(InputAudioBufferSpeechStartedEvent), + + /// Returned in server turn detection mode when speech stops. + #[serde(rename = "input_audio_buffer.speech_stopped")] + InputAudioBufferSpeechStopped(InputAudioBufferSpeechStoppedEvent), + + /// Returned when a conversation item is created. + #[serde(rename = "conversation.item.created")] + ConversationItemCreated(ConversationItemCreatedEvent), + + /// Returned when input audio transcription is enabled and a transcription succeeds. + #[serde(rename = "conversation.item.input_audio_transcription.completed")] + ConversationItemInputAudioTranscriptionCompleted( + ConversationItemInputAudioTranscriptionCompletedEvent, + ), + + /// Returned when input audio transcription is configured, and a transcription request for a user message failed. + #[serde(rename = "conversation.item.input_audio_transcription.failed")] + ConversationItemInputAudioTranscriptionFailed( + ConversationItemInputAudioTranscriptionFailedEvent, + ), + + /// Returned when an earlier assistant audio message item is truncated by the client. + #[serde(rename = "conversation.item.truncated")] + ConversationItemTruncated(ConversationItemTruncatedEvent), + + /// Returned when an item in the conversation is deleted. + #[serde(rename = "conversation.item.deleted")] + ConversationItemDeleted(ConversationItemDeletedEvent), + + /// Returned when a new Response is created. The first event of response creation, where the response is in an initial state of "in_progress". + #[serde(rename = "response.created")] + ResponseCreated(ResponseCreatedEvent), + + /// Returned when a Response is done streaming. Always emitted, no matter the final state. + #[serde(rename = "response.done")] + ResponseDone(ResponseDoneEvent), + + /// Returned when a new Item is created during response generation. + #[serde(rename = "response.output_item.added")] + ResponseOutputItemAdded(ResponseOutputItemAddedEvent), + + /// Returned when an Item is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.output_item.done")] + ResponseOutputItemDone(ResponseOutputItemDoneEvent), + + /// Returned when a new content part is added to an assistant message item during response generation. + #[serde(rename = "response.content_part.added")] + ResponseContentPartAdded(ResponseContentPartAddedEvent), + + /// Returned when a content part is done streaming in an assistant message item. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.content_part.done")] + ResponseContentPartDone(ResponseContentPartDoneEvent), + + /// Returned when the text value of a "text" content part is updated. + #[serde(rename = "response.text.delta")] + ResponseTextDelta(ResponseTextDeltaEvent), + + /// Returned when the text value of a "text" content part is done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.text.done")] + ResponseTextDone(ResponseTextDoneEvent), + + /// Returned when the model-generated transcription of audio output is updated. + #[serde(rename = "response.audio_transcript.delta")] + ResponseAudioTranscriptDelta(ResponseAudioTranscriptDeltaEvent), + + /// Returned when the model-generated transcription of audio output is done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.audio_transcript.done")] + ResponseAudioTranscriptDone(ResponseAudioTranscriptDoneEvent), + + /// Returned when the model-generated audio is updated. + #[serde(rename = "response.audio.delta")] + ResponseAudioDelta(ResponseAudioDeltaEvent), + + /// Returned when the model-generated audio is done. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.audio.done")] + ResponseAudioDone(ResponseAudioDoneEvent), + + /// Returned when the model-generated function call arguments are updated. + #[serde(rename = "response.function_call_arguments.delta")] + ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent), + + /// Returned when the model-generated function call arguments are done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.function_call_arguments.done")] + ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent), + + /// Emitted after every "response.done" event to indicate the updated rate limits. + #[serde(rename = "rate_limits.updated")] + RateLimitsUpdated(RateLimitsUpdatedEvent), +} diff --git a/async-openai/src/types/realtime/session_resource.rs b/async-openai/src/types/realtime/session_resource.rs new file mode 100644 index 0000000..d49e7a7 --- /dev/null +++ b/async-openai/src/types/realtime/session_resource.rs @@ -0,0 +1,136 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum AudioFormat { + #[serde(rename = "pcm16")] + PCM16, + #[serde(rename = "g711-ulaw")] + G711ULAW, + #[serde(rename = "g711-alaw")] + G711ALAW, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct AudioTranscription { + /// Whether to enable input audio transcription. + pub enabled: bool, + /// The model to use for transcription (e.g., "whisper-1"). + pub model: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum TurnDetection { + /// Type of turn detection, only "server_vad" is currently supported. + #[serde(rename = "server_vad")] + ServerVAD { + /// Activation threshold for VAD (0.0 to 1.0). + threshold: f32, + /// Amount of audio to include before speech starts (in milliseconds). + prefix_padding_ms: u32, + /// Duration of silence to detect speech stop (in milliseconds). + silence_duration_ms: u32, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum MaxOutputTokens { + Num(u16), + #[serde(rename = "inf")] + Inf, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ToolDefinition { + #[serde(rename = "function")] + Function { + /// The name of the function. + name: String, + /// The description of the function. + description: String, + /// Parameters of the function in JSON Schema. + parameters: serde_json::Value, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum FunctionType { + Function, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + None, + Required, + #[serde(untagged)] + Function { + r#type: FunctionType, + name: String, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum RealtimeVoice { + Alloy, + Shimmer, + Echo, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct SessionResource { + /// The default model used for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// The set of modalities the model can respond with. To disable audio, set this to ["text"]. + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + //// The default system instructions prepended to model calls. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// The voice the model uses to respond. Cannot be changed once the model has responded with audio at least once. + #[serde(skip_serializing_if = "Option::is_none")] + pub voice: Option, + + /// The format of input audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + #[serde(skip_serializing_if = "Option::is_none")] + pub input_audio_format: Option, + + /// The format of output audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + #[serde(skip_serializing_if = "Option::is_none")] + pub output_audio_format: Option, + + /// Configuration for input audio transcription. Can be set to null to turn off. + #[serde(skip_serializing_if = "Option::is_none")] + pub input_audio_transcription: Option, + + /// Configuration for turn detection. Can be set to null to turn off. + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_detection: Option, + + /// Tools (functions) available to the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + /// How the model chooses tools. + pub tool_choice: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + /// Sampling temperature for the model. + pub temperature: Option, + + /// Maximum number of output tokens for a single assistant response, inclusive of tool calls. + /// Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. + /// Defaults to "inf". + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, +} diff --git a/async-openai/src/types/vector_store.rs b/async-openai/src/types/vector_store.rs index 2433887..3d32493 100644 --- a/async-openai/src/types/vector_store.rs +++ b/async-openai/src/types/vector_store.rs @@ -40,7 +40,6 @@ pub enum VectorStoreChunkingStrategy { /// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. #[default] Auto, - /// Static(StaticChunkingStrategy), } diff --git a/examples/realtime/Cargo.toml b/examples/realtime/Cargo.toml new file mode 100644 index 0000000..4f9d978 --- /dev/null +++ b/examples/realtime/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "realtime" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-openai = { path = "../../async-openai", features = ["realtime"] } +futures-channel = "0.3.31" +futures-util = { version = "0.3.31", features = ["sink", "std"] } +serde = { version = "1.0.210", features = ["derive"] } +serde_json = "1.0.128" +tokio = { version = "1.40.0", features = [ + "io-std", + "io-util", + "macros", + "rt-multi-thread", +] } +tokio-tungstenite = { version = "0.24.0", features = ["connect", "native-tls"] } diff --git a/examples/realtime/README.md b/examples/realtime/README.md new file mode 100644 index 0000000..8967038 --- /dev/null +++ b/examples/realtime/README.md @@ -0,0 +1,138 @@ +## Overview + +The example takes input from stdin, for every input two client events are sent: +1. "conversation.item.create" with content of type "input_text" +2. "response.create" + +All the output happens on stderr, so conversation can be continued on stdin. To stop type "quit" and press enter. + +Code is based on https://github.com/snapview/tokio-tungstenite/blob/master/examples/client.rs + +## Sample Output + +``` +WebSocket handshake complete +session.created | +age of sun? +conversation.item.created | +response.created | +response.output_item.added | +conversation.item.created | +response.content_part.added | +response.audio_transcript.delta | The +response.audio_transcript.delta | Sun +response.audio_transcript.delta | is +response.audio_transcript.delta | about +response.audio_transcript.delta | +response.audio_transcript.delta | 4 +response.audio_transcript.delta | . +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | 6 +response.audio.delta | +response.audio_transcript.delta | billion +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | years +response.audio_transcript.delta | old +response.audio_transcript.delta | . +response.audio_transcript.delta | It +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | formed +response.audio.delta | +response.audio_transcript.delta | from +response.audio_transcript.delta | the +response.audio_transcript.delta | gravitational +response.audio_transcript.delta | collapse +response.audio_transcript.delta | of +response.audio_transcript.delta | a +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | region +response.audio.delta | +response.audio_transcript.delta | within +response.audio_transcript.delta | a +response.audio_transcript.delta | large +response.audio_transcript.delta | molecular +response.audio_transcript.delta | cloud +response.audio_transcript.delta | . +response.audio_transcript.delta | It +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | 's +response.audio.delta | +response.audio_transcript.delta | currently +response.audio_transcript.delta | in +response.audio_transcript.delta | the +response.audio_transcript.delta | middle +response.audio_transcript.delta | of +response.audio_transcript.delta | its +response.audio_transcript.delta | life +response.audio_transcript.delta | cycle +response.audio_transcript.delta | , +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | expected +response.audio.delta | +response.audio_transcript.delta | to +response.audio_transcript.delta | last +response.audio_transcript.delta | for +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio_transcript.delta | another +response.audio.delta | +response.audio_transcript.delta | +response.audio_transcript.delta | 5 +response.audio_transcript.delta | billion +response.audio_transcript.delta | years +response.audio_transcript.delta | or +response.audio_transcript.delta | so +response.audio_transcript.delta | . +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.delta | +response.audio.done | +response.audio_transcript.done | +response.content_part.done | +response.output_item.done | [Some(Assistant)]: The Sun is about 4.6 billion years old. It formed from the gravitational collapse of a region within a large molecular cloud. It's currently in the middle of its life cycle, expected to last for another 5 billion years or so. + +response.done | +rate_limits.updated | +quit +``` diff --git a/examples/realtime/src/main.rs b/examples/realtime/src/main.rs new file mode 100644 index 0000000..ae94e4a --- /dev/null +++ b/examples/realtime/src/main.rs @@ -0,0 +1,148 @@ +use std::process::exit; + +use async_openai::types::realtime::{ + ConversationItemCreateEvent, Item, ResponseCreateEvent, ServerEvent, +}; +use futures_util::{future, pin_mut, StreamExt}; + +use tokio::io::AsyncReadExt; +use tokio_tungstenite::{ + connect_async, + tungstenite::{client::IntoClientRequest, protocol::Message}, +}; + +#[tokio::main] +async fn main() { + let url = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"; + let api_key = std::env::var("OPENAI_API_KEY").expect("Please provide OPENAPI_API_KEY env var"); + + let (stdin_tx, stdin_rx) = futures_channel::mpsc::unbounded(); + tokio::spawn(read_stdin(stdin_tx)); + + // create request from url and add required headers + let mut request = url.into_client_request().unwrap(); + request.headers_mut().insert( + "Authorization", + format!("Bearer {api_key}").parse().unwrap(), + ); + request + .headers_mut() + .insert("OpenAI-Beta", "realtime=v1".parse().unwrap()); + + // connect to WebSocket endpoint + let (ws_stream, _) = connect_async(request).await.expect("Failed to connect"); + + // output everything to stderr, for rest of the program stdin is used to send items of type "input_text" + eprintln!("WebSocket handshake complete"); + + let (write, read) = ws_stream.split(); + + let stdin_to_ws = stdin_rx.map(Ok).forward(write); + + let ws_to_stdout = { + read.for_each(|message| async { + let message = message.unwrap(); + + match message { + Message::Text(_) => { + let data = message.clone().into_data(); + let server_event: Result = + serde_json::from_slice(&data); + match server_event { + Ok(server_event) => { + let value = serde_json::to_value(&server_event).unwrap(); + let event_type = value["type"].clone(); + + eprint!("{:32} | ", event_type.as_str().unwrap()); + + match server_event { + ServerEvent::ResponseOutputItemDone(event) => { + event.item.content.unwrap_or(vec![]).iter().for_each( + |content| { + if let Some(ref transcript) = content.transcript { + eprintln!( + "[{:?}]: {}", + event.item.role, + transcript.trim(), + ); + } + }, + ); + } + ServerEvent::ResponseAudioTranscriptDelta(event) => { + eprint!("{}", event.delta.trim()); + } + ServerEvent::Error(e) => { + eprint!("{e:?}"); + } + _ => {} + } + } + Err(error) => { + eprintln!("failed to deserialize: {error:?}"); + eprintln!("{message:?}"); + } + } + } + Message::Binary(_) => eprintln!("Binary"), + Message::Frame(_) => eprintln!("Frame"), + Message::Ping(_) => eprintln!("Ping"), + Message::Pong(_) => eprintln!("Pong"), + Message::Close(_) => { + eprintln!("Close"); + exit(0); + } + } + + // after every message add newline + eprint!("\n"); + }) + }; + + pin_mut!(stdin_to_ws, ws_to_stdout); + future::select(stdin_to_ws, ws_to_stdout).await; +} + +// Read from stdin and send "conversation.item.create" and "response.create" client events. +// type "quit" to stop +async fn read_stdin(tx: futures_channel::mpsc::UnboundedSender) { + let mut stdin = tokio::io::stdin(); + loop { + let mut buf = vec![0; 1024]; + let n = match stdin.read(&mut buf).await { + Err(_) | Ok(0) => break, + Ok(n) => n, + }; + buf.truncate(n); + + let text = String::from_utf8_lossy(&buf).into_owned(); + + if text.trim() == "quit" { + tx.close_channel(); + return; + } + + // Create item from json representation + let item = Item::try_from(serde_json::json!({ + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": String::from_utf8_lossy(&buf).into_owned() + } + ] + })) + .unwrap(); + + // Create event of type "conversation.item.create" + let event: ConversationItemCreateEvent = item.into(); + // Create WebSocket message from client event + let message: Message = event.into(); + // send WebSocket message containing event of type "conversation.item.create" to server + tx.unbounded_send(message).unwrap(); + // send WebSocket message containing event of type "response.create" to server + tx.unbounded_send(ResponseCreateEvent::default().into()) + .unwrap(); + } +}