Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The type of messages in deserialized CreateChatCompletionRequest are all SystemMessage #216

Closed
sontallive opened this issue Apr 21, 2024 · 5 comments · Fixed by #228
Closed
Labels
bug Something isn't working

Comments

@sontallive
Copy link

I want to deserialize request json to CreateChatCompletionRequest but i found the messages are all System.

code

use async_openai::types::{
    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
    CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
        .messages([
            ChatCompletionRequestSystemMessageArgs::default()
                .content("your are a calculator")
                .build()?
                .into(),
            ChatCompletionRequestUserMessageArgs::default()
                .content("what is the result of 1+1")
                .build()?
                .into(),
        ])
        .build()?;
    // serialize the request
    let serialized = serde_json::to_string(&request)?;
    println!("{}", serialized);
    // deserialize the request
    let deserialized: CreateChatCompletionRequest = serde_json::from_str(&serialized)?;
    println!("{:?}", deserialized);
    Ok(())
}

result

{"messages":[{"content":"your are a calculator","role":"system"},{"content":"what is the result of 1+1","role":"user"}],"model":""}

CreateChatCompletionRequest { messages: [System(ChatCompletionRequestSystemMessage { content: "your are a calculator", role: System, name: None }), System(ChatCompletionRequestSystemMessage { content: "what is the result of 1+1", role: User, name: None })], model: "", frequency_penalty: None, logit_bias: None, logprobs: None, top_logprobs: None, max_tokens: None, n: None, presence_penalty: None, response_format: None, seed: None, stop: None, stream: None, temperature: None, top_p: None, tools: None, tool_choice: None, user: None, function_call: None, functions: None }
@sontallive sontallive changed the title the type of messages in CreateChatCompletionRequest deserialized are all SystemMessage the type of messages in deserialized CreateChatCompletionRequest are all SystemMessage Apr 21, 2024
@sontallive sontallive changed the title the type of messages in deserialized CreateChatCompletionRequest are all SystemMessage The type of messages in deserialized CreateChatCompletionRequest are all SystemMessage Apr 21, 2024
@64bit 64bit added the bug Something isn't working label Apr 28, 2024
@djmango
Copy link

djmango commented May 11, 2024

I also have this issue. Using actix_web

djmango added a commit to djmango/async-openai that referenced this issue May 11, 2024
…quest all SystemMessage

Turns out we dont even need role in the child. SIMPLIFY REQUIREMENTS. also just use serde tag, it handles the serilization for us too.
thanks coca.codes

Closes 64bit#216
@djmango
Copy link

djmango commented May 11, 2024

Was banging my head on this for a bit, but just pushed a fix on my branch.

thanks coco.codes from the NAMTAO discord!

to solve the parent issue, of them always being System, we implement the macro #[serde(tag = "role", rename_all = "lowercase")] in ChatCompletionRequestMessage

This maps the role key to the appropriate enum under ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the child ChatCompletionRequestUserMessage spits out an error during deserialization because it no longer can see the role key.

I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the role is not actually used anywhere in the lib nor my codebase)

I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like

@sontallive
Copy link
Author

Was banging my head on this for a bit, but just pushed a fix on my branch.

thanks coco.codes from the NAMTAO discord!

to solve the parent issue, of them always being System, we implement the macro #[serde(tag = "role", rename_all = "lowercase")] in ChatCompletionRequestMessage

This maps the role key to the appropriate enum under ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the child ChatCompletionRequestUserMessage spits out an error during deserialization because it no longer can see the role key.

I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the role is not actually used anywhere in the lib nor my codebase)

I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like

Thank you, I will have a try.

@digitalscyther
Copy link

digitalscyther commented May 29, 2024

i wrote custom wrapper for ser and deser

use async_openai::types::{ChatCompletionRequestMessage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::ser::SerializeStruct;
use serde_json::Value;

#[derive(Debug)]
pub struct Message(ChatCompletionRequestMessage);

impl Message {
    pub fn from_original(enum_val: ChatCompletionRequestMessage) -> Self {
        Message(enum_val)
    }

    pub fn into_original(self) -> ChatCompletionRequestMessage {
        self.0
    }
}

impl Serialize for Message {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut state = serializer.serialize_struct("Message", 2)?;
        match &self.0 {
            ChatCompletionRequestMessage::System(msg) => {
                state.serialize_field("type", "system")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::User(msg) => {
                state.serialize_field("type", "user")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Assistant(msg) => {
                state.serialize_field("type", "assistant")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Tool(msg) => {
                state.serialize_field("type", "tool")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Function(msg) => {
                state.serialize_field("type", "function")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
        }

        state.end()
    }
}

impl<'de> Deserialize<'de> for Message {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let value: Value = Deserialize::deserialize(deserializer)?;

        let msg_type = value.get("type").and_then(Value::as_str).ok_or_else(|| {
            serde::de::Error::custom("Missing or invalid `type` field")
        })?;

        match msg_type {
            "system" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestSystemMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::System(msg)))
            }
            "user" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestUserMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::User(msg)))
            }
            "assistant" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestAssistantMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Assistant(msg)))
            }
            "tool" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestToolMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Tool(msg)))
            }
            "function" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestFunctionMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Function(msg)))
            }
            _ => Err(serde::de::Error::unknown_variant(msg_type, &["system", "user", "assistant", "tool", "function"])),
        }
    }
}

@64bit
Copy link
Owner

64bit commented Jun 5, 2024

Instead of complex ser-de implementations, types have be udpated for proper serialization and deserialization in v0.23.0

Thank you @sontallive for contributing the test too - its included as part of tests in https://github.com/64bit/async-openai/blob/main/async-openai/tests/ser_de.rs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants