Skip to content

Commit

Permalink
Merge branch 'canary' into sam/undo-fuckery
Browse files Browse the repository at this point in the history
  • Loading branch information
sxlijin authored Jul 24, 2024
2 parents 964e494 + 7cd8794 commit 9e91b2c
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 19 deletions.
1 change: 1 addition & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 32 additions & 1 deletion engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use baml_types::FieldType;
use either::Either;

use indexmap::IndexMap;
use internal_baml_diagnostics::Span;
use internal_baml_parser_database::{
walkers::{
ClassWalker, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, FieldWalker,
Expand Down Expand Up @@ -436,7 +437,6 @@ impl WithRepr<Expression> for ast::Expression {
type TemplateStringId = String;

#[derive(serde::Serialize, Debug)]

pub struct TemplateString {
pub name: TemplateStringId,
pub params: Vec<Field>,
Expand Down Expand Up @@ -1094,12 +1094,39 @@ impl WithRepr<RetryPolicy> for ConfigurationWalker<'_> {
}
}

#[derive(serde::Serialize, Debug)]
pub struct TestCaseFunction(String);

impl TestCaseFunction {
pub fn name(&self) -> &str {
&self.0
}
}

#[derive(serde::Serialize, Debug)]
pub struct TestCase {
pub name: String,
pub functions: Vec<Node<TestCaseFunction>>,
pub args: IndexMap<String, Expression>,
}

impl WithRepr<TestCaseFunction> for (&ConfigurationWalker<'_>, usize) {
fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes {
let span = self.0.test_case().functions[self.1].1.clone();
NodeAttributes {
meta: IndexMap::new(),
overrides: IndexMap::new(),
span: Some(span),
}
}

fn repr(&self, db: &ParserDatabase) -> Result<TestCaseFunction> {
Ok(TestCaseFunction(
self.0.test_case().functions[self.1].0.clone(),
))
}
}

impl WithRepr<TestCase> for ConfigurationWalker<'_> {
fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes {
NodeAttributes {
Expand All @@ -1110,6 +1137,9 @@ impl WithRepr<TestCase> for ConfigurationWalker<'_> {
}

fn repr(&self, db: &ParserDatabase) -> Result<TestCase> {
let functions = (0..self.test_case().functions.len())
.map(|i| (self, i).node(db))
.collect::<Result<Vec<_>>>()?;
Ok(TestCase {
name: self.name().to_string(),
args: self
Expand All @@ -1118,6 +1148,7 @@ impl WithRepr<TestCase> for ConfigurationWalker<'_> {
.iter()
.map(|(k, (_, v))| Ok((k.clone(), v.repr(db)?)))
.collect::<Result<IndexMap<_, _>>>()?,
functions,
})
}
}
Expand Down
1 change: 1 addition & 0 deletions engine/baml-schema-wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ console_log = { version = "1", features = ["color"] }
getrandom = { version = "0.2.15", features = ["js"] }
indexmap.workspace = true
internal-baml-codegen.workspace = true
internal-baml-core.workspace = true
js-sys = "=0.3.69"
log.workspace = true
serde.workspace = true
Expand Down
191 changes: 187 additions & 4 deletions engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use baml_runtime::{
internal::llm_client::LLMResponse, BamlRuntime, DiagnosticsError, IRHelper, RenderedPrompt,
};
use baml_types::{BamlMap, BamlValue};

use internal_baml_codegen::version_check::GeneratorType;
use internal_baml_codegen::version_check::{check_version, VersionCheckMode};
use internal_baml_core::ir::Expression;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;

use wasm_bindgen::prelude::*;

//Run: wasm-pack test --firefox --headless --features internal,wasm
Expand Down Expand Up @@ -253,6 +252,7 @@ pub struct WasmRuntime {
}

#[wasm_bindgen(getter_with_clone, inspectable)]
#[derive(Clone)]
pub struct WasmFunction {
#[wasm_bindgen(readonly)]
pub name: String,
Expand Down Expand Up @@ -317,6 +317,17 @@ impl Default for WasmSpan {
}
}

#[wasm_bindgen(getter_with_clone, inspectable)]
#[derive(Clone)]
pub struct WasmParentFunction {
#[wasm_bindgen(readonly)]
pub start: usize,
#[wasm_bindgen(readonly)]
pub end: usize,
#[wasm_bindgen(readonly)]
pub name: String,
}

#[wasm_bindgen(getter_with_clone, inspectable)]
#[derive(Clone)]
pub struct WasmTestCase {
Expand All @@ -328,6 +339,8 @@ pub struct WasmTestCase {
pub error: Option<String>,
#[wasm_bindgen(readonly)]
pub span: WasmSpan,
#[wasm_bindgen(readonly)]
pub parent_functions: Vec<WasmParentFunction>,
}

#[wasm_bindgen(getter_with_clone, inspectable)]
Expand Down Expand Up @@ -879,6 +892,23 @@ impl WasmRuntime {
inputs: params,
error,
span: wasm_span,
parent_functions: tc
.test_case()
.functions
.iter()
.map(|f| {
let (start, end) = f
.attributes
.span
.as_ref()
.map_or((0, 0), |f| (f.start, f.end));
WasmParentFunction {
start,
end,
name: f.elem.name().to_string(),
}
})
.collect(),
}
})
.collect(),
Expand Down Expand Up @@ -1054,20 +1084,173 @@ impl WasmRuntime {
pub fn get_function_at_position(
&self,
file_name: &str,
selected_func: &str,
cursor_idx: usize,
) -> Option<WasmFunction> {
let functions = self.list_functions();

for function in functions {
for function in functions.clone() {
let span = function.span.clone(); // Clone the span

if span.file_path.as_str().contains(file_name)
if span.file_path.as_str().ends_with(file_name)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
return Some(function);
}
}

let testcases = self.list_testcases();

for tc in testcases {
let span = tc.span;
if span.file_path.as_str().ends_with(file_name)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
if let Some(parent_function) =
tc.parent_functions.iter().find(|f| f.name == selected_func)
{
return functions.into_iter().find(|f| f.name == selected_func);
} else if let Some(first_function) = tc.parent_functions.get(0) {
return functions
.into_iter()
.find(|f| f.name == first_function.name);
}
}
}

None
}

#[wasm_bindgen]
pub fn get_function_of_testcase(
&self,
file_name: &str,
cursor_idx: usize,
) -> Option<WasmParentFunction> {
let testcases = self.list_testcases();

for tc in testcases {
let span = tc.span;
if span.file_path.as_str().ends_with(file_name)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
let first_function = tc
.parent_functions
.iter()
.find(|f| f.start <= cursor_idx && cursor_idx <= f.end)
.cloned();

return first_function;
}
}
None
}

#[wasm_bindgen]
pub fn list_testcases(&self) -> Vec<WasmTestCase> {
self.runtime
.internal()
.ir()
.walk_tests()
.map(|tc| {
let params = match tc.test_case_params(&self.runtime.env_vars()) {
Ok(params) => Ok(params
.iter()
.map(|(k, v)| {
let as_str = match v {
Ok(v) => match serde_json::to_string(v) {
Ok(s) => Ok(s),
Err(e) => Err(e.to_string()),
},
Err(e) => Err(e.to_string()),
};

let (value, error) = match as_str {
Ok(s) => (Some(s), None),
Err(e) => (None, Some(e)),
};

WasmParam {
name: k.to_string(),
value,
error,
}
})
.collect()),
Err(e) => Err(e.to_string()),
};

let (mut params, error) = match params {
Ok(p) => (p, None),
Err(e) => (Vec::new(), Some(e)),
};

// Any missing params should be set to an error
let _ = tc.function().inputs().right().map(|func_params| {
for (param_name, t) in func_params {
if !params.iter().any(|p| p.name.cmp(param_name).is_eq())
&& !t.is_optional()
{
params.insert(
0,
WasmParam {
name: param_name.to_string(),
value: None,
error: Some("Missing parameter".to_string()),
},
);
}
}
});

let wasm_span = match tc.span() {
Some(span) => span.into(),
None => WasmSpan::default(),
};

WasmTestCase {
name: tc.test_case().name.clone(),
inputs: params,
error,
span: wasm_span,
parent_functions: tc
.test_case()
.functions
.iter()
.map(|f| {
let (start, end) = f
.attributes
.span
.as_ref()
.map_or((0, 0), |f| (f.start, f.end));
WasmParentFunction {
start,
end,
name: f.elem.name().to_string(),
}
})
.collect(),
}
})
.collect()
}

#[wasm_bindgen]
pub fn get_testcase_from_position(
&self,
parent_function: WasmFunction,
cursor_idx: usize,
) -> Option<WasmTestCase> {
let testcases = parent_function.test_cases;
for testcase in testcases {
let span = testcase.clone().span;

if span.file_path.as_str() == (parent_function.span.file_path)
&& ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx)
{
return Some(testcase);
}
}
None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ test TestFnNamedArgsSingleMapStringToMap {
}
}
}
}
}
3 changes: 2 additions & 1 deletion integ-tests/baml_src/test-files/providers/providers.baml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ function TestAws(input: string) -> string {


test TestProvider {
functions [TestAnthropic, TestVertex, TestOpenAI, TestAzure, TestOllama, TestGemini, TestAws]
functions [TestAnthropic, TestVertex, PromptTestOpenAI, TestAzure, TestOllama, TestGemini, TestAws]
args {
input "Donkey kong and peanut butter"
}
}



Loading

0 comments on commit 9e91b2c

Please sign in to comment.