Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI DSL : Generic serialization support for Sealed classes and other primitives #602

Merged
merged 7 commits into from
Dec 26, 2023

Conversation

raulraja
Copy link
Contributor

@raulraja raulraja commented Dec 24, 2023

This Pull Request introduces support for serializing sealed classes in Kotlin Serialization. It does so by mapping function names to the serial names of descriptors and integrating them into the JSON object processed by Kotlin Serialization. This enhancement enables the sealed class serializer to correctly identify and handle the appropriate class.

Here's an example demonstrating the support for sealed classes:

package com.xebia.functional.xef.dsl

import com.xebia.functional.xef.AI
import kotlinx.serialization.Serializable

@Serializable
sealed class Response {

  @Serializable
  data class City(val name: String) : Response()

  @Serializable
  data class Country(val name: String) : Response()

  @Serializable
  data class Continent(val name: String) : Response()
}

suspend fun main() {
  val response = AI<Response>("Capital of France?")
  println(response) // This will print: City(name=Paris)
}

In this example, the Response sealed class has three subclasses: City, Country, and Continent. Depending on the question asked, the appropriate subclass is instantiated and serialized.

Additionally, the PR extends support to primitives and other types that are not wrapped in a JsonObject. Here's an example:

suspend fun main() {
  val two: Int = AI("What is 1 + 1?")
  val truth: Boolean = AI("Is the earth flat?")
  val name: String = AI("Hi AI, What is your name?")
  println(
    """
    |two: $two
    |truth: $truth
    |name: $name
    """.trimMargin()
  )
}

In this second example, simple types like Int, Boolean, and String are used. The AI function returns the appropriate type based on the question asked, demonstrating the flexibility and utility of the new serialization support.

?: error("No descriptor found for ${call.functionName}")
val newJson =
JsonObject(
jsonWithDiscriminator.jsonObject + ("type" to JsonPrimitive(descriptor.serialName))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addition supports polymorphic types like sealed classes whose serializer is chosen by the type property, the discriminator.

@@ -92,43 +95,57 @@ suspend fun <A> ChatApi.prompt(
requestedMemories,
prompt.configuration.messagePolicy.addMessagesToConversation
)
.mapNotNull { it.message.toolCalls?.firstOrNull()?.function?.arguments }
.mapNotNull {
val functionName = it.message.toolCalls?.firstOrNull()?.function?.name
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need the name to find the correct descriptor from the cases of the sealed class.

conversation.metric.parameter("functions", function?.let { "yes" } ?: "no")
conversation.metric.parameter(
"functions",
if (functions.isEmpty()) "no" else functions.joinToString(",") { it.name }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of tracking yes we track the actual function names

*/
@JvmInline
@Serializable
value class ChatCompletionToolChoiceOption(val element: JsonElement) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trick to support the spec supporting either String or a complex json object that specifies a function name.

// adds a `type` field with the call.functionName serial name equivalent to the call arguments
val jsonWithDiscriminator = Json.decodeFromString(JsonElement.serializer(), call.arguments)
val descriptor =
descriptors.firstOrNull { it.serialName.endsWith(call.functionName) }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be an issue because we use simple function names without FQN. Sadly, sending the full package name to Open AI is not a good idea because selecting the proper function makes it more challenging. We bail if the function is not found in the cases, forcing the user to reencode it or provide @SerialName.

Montagon
Montagon previously approved these changes Dec 26, 2023
Copy link
Contributor

@Montagon Montagon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great @raulraja!! 🔝

@raulraja raulraja merged commit bdc12b6 into main Dec 26, 2023
6 checks passed
@raulraja raulraja deleted the choices branch December 26, 2023 17:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants