Skip to content

Commit

Permalink
🐛 (Deserialization) Fixed type of deserialized CreateChatCompletionRe…
Browse files Browse the repository at this point in the history
…quest all SystemMessage

Turns out we dont even need role in the child. SIMPLIFY REQUIREMENTS. also just use serde tag, it handles the serilization for us too.
thanks coca.codes

Closes 64bit#216
  • Loading branch information
djmango committed May 11, 2024
1 parent d8dc8c7 commit 6349b78
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions async-openai/src/types/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, pin::Pin};

use derive_builder::Builder;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};

use crate::error::OpenAIError;

Expand Down Expand Up @@ -104,9 +104,6 @@ pub struct CompletionUsage {
pub struct ChatCompletionRequestSystemMessage {
/// The contents of the system message.
pub content: String,
/// The role of the messages author, in this case `system`.
#[builder(default = "Role::System")]
pub role: Role,
/// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
Expand Down Expand Up @@ -185,9 +182,6 @@ pub enum ChatCompletionRequestUserMessageContent {
pub struct ChatCompletionRequestUserMessage {
/// The contents of the user message.
pub content: ChatCompletionRequestUserMessageContent,
/// The role of the messages author, in this case `user`.
#[builder(default = "Role::User")]
pub role: Role,
/// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
Expand All @@ -202,9 +196,6 @@ pub struct ChatCompletionRequestUserMessage {
pub struct ChatCompletionRequestAssistantMessage {
/// The contents of the assistant message.
pub content: Option<String>,
/// The role of the messages author, in this case `assistant`.
#[builder(default = "Role::Assistant")]
pub role: Role,
/// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
Expand All @@ -224,9 +215,6 @@ pub struct ChatCompletionRequestAssistantMessage {
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestToolMessage {
/// The role of the messages author, in this case `tool`.
#[builder(default = "Role::Tool")]
pub role: Role,
/// The contents of the tool message.
pub content: String,
pub tool_call_id: String,
Expand All @@ -239,17 +227,14 @@ pub struct ChatCompletionRequestToolMessage {
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestFunctionMessage {
/// The role of the messages author, in this case `function`.
#[builder(default = "Role::Function")]
pub role: Role,
/// The return value from the function call, to return to the model.
pub content: Option<String>,
/// The name of the function to call.
pub name: String,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatCompletionRequestMessage {
System(ChatCompletionRequestSystemMessage),
User(ChatCompletionRequestUserMessage),
Expand All @@ -258,6 +243,33 @@ pub enum ChatCompletionRequestMessage {
Function(ChatCompletionRequestFunctionMessage),
}

impl ChatCompletionRequestMessage {
pub fn role(&self) -> Role {
match self {
ChatCompletionRequestMessage::System(_) => Role::System,
ChatCompletionRequestMessage::User(_) => Role::User,
ChatCompletionRequestMessage::Assistant(_) => Role::Assistant,
ChatCompletionRequestMessage::Tool(_) => Role::Tool,
ChatCompletionRequestMessage::Function(_) => Role::Function,
}
}
}

// #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
// #[serde(tag = "role")]
// pub enum ChatCompletionRequestMessage {
// #[serde(rename = "System")]
// System(ChatCompletionRequestSystemMessage),
// #[serde(rename = "User")]
// User(ChatCompletionRequestUserMessage),
// #[serde(rename = "Assistant")]
// Assistant(ChatCompletionRequestAssistantMessage),
// #[serde(rename = "Tool")]
// Tool(ChatCompletionRequestToolMessage),
// #[serde(rename = "Function")]
// Function(ChatCompletionRequestFunctionMessage),
// }

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionMessageToolCall {
/// The ID of the tool call.
Expand Down

0 comments on commit 6349b78

Please sign in to comment.