Skip to content

Commit

Permalink
fix: make dynamic enums work as outputs in Ruby (#972)
Browse files Browse the repository at this point in the history
Sorbet's type enforcement was breaking it, so as a workaround, all BAML
enums become `enum | string` in the Sorbet type system
  • Loading branch information
sxlijin committed Sep 20, 2024
1 parent 02b495d commit 7530402
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 35 deletions.
2 changes: 1 addition & 1 deletion engine/language_client_codegen/src/ruby/field_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ impl ToRuby for FieldType {
fn to_ruby(&self) -> String {
match self {
FieldType::Class(name) => format!("Baml::Types::{}", name.clone()),
FieldType::Enum(name) => format!("Baml::Types::{}", name.clone()),
FieldType::Enum(name) => format!("T.any(Baml::Types::{}, String)", name.clone()),
// https://sorbet.org/docs/stdlib-generics
FieldType::List(inner) => format!("T::Array[{}]", inner.to_ruby()),
FieldType::Map(key, value) => {
Expand Down
10 changes: 10 additions & 0 deletions integ-tests/baml_src/test-files/dynamic/dynamic.baml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ function MyFunc(input: string) -> DynamicOutput {
"#
}

function ClassifyDynEnumTwo(input: string) -> DynEnumTwo {
client GPT35
prompt #"
Given a string, extract info using the schema:

{{ input}}

{{ ctx.output_format }}
"#
}
26 changes: 26 additions & 0 deletions integ-tests/openapi/baml_client/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ paths:
title: AudioInputResponse
type: string
operationId: AudioInput
/call/ClassifyDynEnumTwo:
post:
requestBody:
$ref: '#/components/requestBodies/ClassifyDynEnumTwo'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
title: ClassifyDynEnumTwoResponse
$ref: '#/components/schemas/DynEnumTwo'
operationId: ClassifyDynEnumTwo
/call/ClassifyMessage:
post:
requestBody:
Expand Down Expand Up @@ -1073,6 +1086,19 @@ components:
required:
- aud
additionalProperties: false
ClassifyDynEnumTwo:
required: true
content:
application/json:
schema:
title: ClassifyDynEnumTwoRequest
type: object
properties:
input:
type: string
required:
- input
additionalProperties: false
ClassifyMessage:
required: true
content:
Expand Down
57 changes: 57 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,30 @@ async def AudioInput(
mdl = create_model("AudioInputReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

async def ClassifyDynEnumTwo(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> Union[types.DynEnumTwo, str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"ClassifyDynEnumTwo",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
return coerce(mdl, raw.parsed())

async def ClassifyMessage(
self,
input: str,
Expand Down Expand Up @@ -1984,6 +2008,39 @@ def AudioInput(
self.__ctx_manager.get(),
)

def ClassifyDynEnumTwo(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"ClassifyDynEnumTwo",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
partial_mdl = create_model("ClassifyDynEnumTwoPartialReturnType", inner=(Optional[Union[types.DynEnumTwo, str]], ...))

return baml_py.BamlStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]](
raw,
lambda x: coerce(partial_mdl, x),
lambda x: coerce(mdl, x),
self.__ctx_manager.get(),
)

def ClassifyMessage(
self,
input: str,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.",
"test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}",
"test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient<llm> BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n",
"test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\n",
"test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}",
"test-files/functions/input/named-args/single/named-audio.baml": "function AudioInput(aud: audio) -> string{\n client Gemini\n prompt #\"\n {{ _.role(\"user\") }}\n\n Does this sound like a roar? Yes or no? One word no other characters.\n \n {{ aud }}\n \"#\n}\n\n\ntest TestURLAudioInput{\n functions [AudioInput]\n args {\n aud{ \n url https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\n }\n } \n}\n\n\n",
"test-files/functions/input/named-args/single/named-boolean.baml": "\n\nfunction TestFnNamedArgsSingleBool(myBool: bool) -> string{\n client GPT35\n prompt #\"\n Return this value back to me: {{myBool}}\n \"#\n}\n\ntest TestFnNamedArgsSingleBool {\n functions [TestFnNamedArgsSingleBool]\n args {\n myBool true\n }\n}",
"test-files/functions/input/named-args/single/named-class-list.baml": "\n\n\nfunction TestFnNamedArgsSingleStringList(myArg: string[]) -> string{\n client GPT35\n prompt #\"\n Return this value back to me: {{myArg}}\n \"#\n}\n\ntest TestFnNamedArgsSingleStringList {\n functions [TestFnNamedArgsSingleStringList]\n args {\n myArg [\"hello\", \"world\"]\n }\n}",
Expand Down
57 changes: 57 additions & 0 deletions integ-tests/python/baml_client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ def AudioInput(
mdl = create_model("AudioInputReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

def ClassifyDynEnumTwo(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> Union[types.DynEnumTwo, str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.call_function_sync(
"ClassifyDynEnumTwo",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
return coerce(mdl, raw.parsed())

def ClassifyMessage(
self,
input: str,
Expand Down Expand Up @@ -1983,6 +2007,39 @@ def AudioInput(
self.__ctx_manager.get(),
)

def ClassifyDynEnumTwo(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlSyncStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function_sync(
"ClassifyDynEnumTwo",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
partial_mdl = create_model("ClassifyDynEnumTwoPartialReturnType", inner=(Optional[Union[types.DynEnumTwo, str]], ...))

return baml_py.BamlSyncStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]](
raw,
lambda x: coerce(partial_mdl, x),
lambda x: coerce(mdl, x),
self.__ctx_manager.get(),
)

def ClassifyMessage(
self,
input: str,
Expand Down
4 changes: 2 additions & 2 deletions integ-tests/ruby/Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
Rake::TestTask.new do |t|
t.libs << "../../engine/language_client_ruby/lib"
t.libs << "baml_client"
# t.test_files = FileList["test_filtered.rb"]
t.test_files = FileList["test_*.rb"]
t.test_files = FileList["test_filtered.rb"]
# t.test_files = FileList["test_*.rb"]
t.options = '--verbose'
end
Loading

0 comments on commit 7530402

Please sign in to comment.