Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): implement file-based remote function register #829

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Config:
def __init__(self):
load_dotenv(override=True)
self.wren_engine_endpoint = os.getenv("WREN_ENGINE_ENDPOINT")
self.remote_function_list_path = os.getenv("REMOTE_FUNCTION_LIST_PATH")
self.validate_wren_engine_endpoint(self.wren_engine_endpoint)
self.diagnose = False
self.init_logger()
Expand Down Expand Up @@ -57,6 +58,9 @@ def update(self, diagnose: bool):
else:
self.init_logger()

def set_remote_function_list_path(self, path: str):
self.remote_function_list_path = path


config = Config()

Expand Down
12 changes: 8 additions & 4 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(
self.manifest_str = manifest_str
self.data_source = data_source
if experiment:
self._rewriter = EmbeddedEngineRewriter(manifest_str)
config = get_config()
function_path = config.remote_function_list_path
self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path)
else:
self._rewriter = ExternalEngineRewriter(manifest_str)

Expand Down Expand Up @@ -68,14 +70,16 @@ def rewrite(self, sql: str) -> str:


class EmbeddedEngineRewriter:
def __init__(self, manifest_str: str):
def __init__(self, manifest_str: str, function_path: str):
self.manifest_str = manifest_str
self.function_path = function_path

def rewrite(self, sql: str) -> str:
from wren_core import transform_sql
from wren_core import read_remote_function_list, transform_sql

try:
return transform_sql(self.manifest_str, sql)
functions = read_remote_function_list(self.function_path)
return transform_sql(self.manifest_str, functions, sql)
except Exception as e:
raise RewriteError(str(e))

Expand Down
5 changes: 5 additions & 0 deletions ibis-server/tests/resource/functions.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
function_type,name,return_type,description
scalar,add_two,int,"Adds two numbers together."
aggregate,median,int,"Returns the median value of a numeric column."
window,max_if,int,"If the condition is true, returns the maximum value in the window."
scalar,unistr,varchar,"Postgres: Evaluate escaped Unicode characters in the argument".
27 changes: 27 additions & 0 deletions ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.testclient import TestClient
from testcontainers.postgres import PostgresContainer

from app.config import get_config
from app.main import app
from app.model.validator import rules
from tests.confest import file_path
Expand Down Expand Up @@ -362,6 +363,32 @@ def test_dry_plan(manifest_str):
assert response.status_code == 200
assert response.text is not None

def test_query_with_remote_function(manifest_str, postgres: PostgresContainer):
config = get_config()
config.set_remote_function_list_path(file_path("resource/functions.csv"))

connection_info = _to_connection_info(postgres)
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT unistr(o_orderstatus) FROM wren.public.orders LIMIT 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 1
assert len(result["data"]) == 1
assert result["data"][0] == [
"O",
]
assert result["dtypes"] == {
"unistr": "object",
}

config.set_remote_function_list_path(None)

def _to_connection_info(pg: PostgresContainer):
return {
"host": pg.get_container_host_ip(),
Expand Down
1 change: 1 addition & 0 deletions ibis-server/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_config():
assert response.status_code == 200
assert response.json() == {
"wren_engine_endpoint": "http://localhost:8080",
"remote_function_list_path": None,
"diagnose": False,
}

Expand Down
96 changes: 96 additions & 0 deletions wren-modeling-py/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions wren-modeling-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ wren-core = { path = "../wren-modeling-rs/core" }
base64 = "0.22.1"
serde_json = "1.0.117"
thiserror = "1.0"
csv = "1.3.0"
serde = { version = "1.0.210", features = ["derive"] }
env_logger = "0.11.5"
log = "0.4.22"

[build-dependencies]
pyo3-build-config = "0.21.2"
58 changes: 50 additions & 8 deletions wren-modeling-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,64 @@ use std::sync::Arc;
use base64::prelude::*;
use pyo3::prelude::*;

use crate::errors::CoreError;
use crate::remote_functions::RemoteFunction;
use log::debug;
use wren_core::mdl;
use wren_core::mdl::manifest::Manifest;
use wren_core::mdl::AnalyzedWrenMDL;

use crate::errors::CoreError;

mod errors;
mod remote_functions;

#[pyfunction]
fn transform_sql(mdl_base64: &str, sql: &str) -> Result<String, CoreError> {
fn transform_sql(
mdl_base64: &str,
remote_functions: Vec<RemoteFunction>,
sql: &str,
) -> Result<String, CoreError> {
let mdl_json_bytes = BASE64_STANDARD
.decode(mdl_base64)
.map_err(CoreError::from)?;
let mdl_json = String::from_utf8(mdl_json_bytes).map_err(CoreError::from)?;
let manifest = serde_json::from_str::<Manifest>(&mdl_json)?;
let remote_functions: Vec<mdl::function::RemoteFunction> = remote_functions
.into_iter()
.map(|f| f.into())
.collect::<Vec<_>>();

let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else {
return Err(CoreError::new("Failed to analyze manifest"));
};
match mdl::transform_sql(Arc::new(analyzed_mdl), sql) {
match mdl::transform_sql(Arc::new(analyzed_mdl), &remote_functions, sql) {
Ok(transformed_sql) => Ok(transformed_sql),
Err(e) => Err(CoreError::new(&e.to_string())),
}
}

#[pyfunction]
fn read_remote_function_list(path: Option<&str>) -> Vec<RemoteFunction> {
debug!(
"Reading remote function list from {}",
path.unwrap_or("path is not provided")
);
if let Some(path) = path {
csv::Reader::from_path(path)
.unwrap()
.into_deserialize::<RemoteFunction>()
.filter_map(Result::ok)
.collect::<Vec<_>>()
} else {
vec![]
}
}

#[pymodule]
#[pyo3(name = "wren_core")]
fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
env_logger::init();
m.add_function(wrap_pyfunction!(transform_sql, m)?)?;
m.add_function(wrap_pyfunction!(read_remote_function_list, m)?)?;
Ok(())
}

Expand All @@ -41,7 +70,7 @@ mod tests {
use base64::Engine;
use serde_json::Value;

use crate::transform_sql;
use crate::{read_remote_function_list, transform_sql};

#[test]
fn test_transform_sql() {
Expand All @@ -66,13 +95,26 @@ mod tests {
}"#;
let v: Value = serde_json::from_str(data).unwrap();
let mdl_base64: String = BASE64_STANDARD.encode(v.to_string().as_bytes());
let transformed_sql =
transform_sql(&mdl_base64, "SELECT * FROM my_catalog.my_schema.customer")
.unwrap();
let transformed_sql = transform_sql(
&mdl_base64,
vec![],
"SELECT * FROM my_catalog.my_schema.customer",
)
.unwrap();
assert_eq!(
transformed_sql,
"SELECT customer.c_custkey, customer.c_name FROM \
(SELECT main.customer.c_custkey AS c_custkey, main.customer.c_name AS c_name FROM main.customer) AS customer"
);
}

#[test]
fn test_read_remote_function_list() {
let path = "tests/functions.csv";
let remote_functions = read_remote_function_list(Some(path));
assert_eq!(remote_functions.len(), 3);

let remote_function = read_remote_function_list(None);
assert_eq!(remote_function.len(), 0);
}
}
Loading
Loading