Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Sep 3, 2024
1 parent 6ffb1f6 commit 4a8a158
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 10 deletions.
7 changes: 7 additions & 0 deletions datafusion/substrait/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct Extensions {
pub functions: HashMap<u32, String>, // anchor -> function name
pub types: HashMap<u32, String>, // anchor -> type name
pub type_variations: HashMap<u32, String>, // anchor -> type variation name
pub names: Option<Vec<String>>,
}

impl Extensions {
Expand Down Expand Up @@ -75,6 +76,11 @@ impl Extensions {
}
}
}
/// with the predefined names
pub fn with_projection_names(mut self, names: Vec<String>) -> Self {
self.names = Some(names);
self
}
}

impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
Expand Down Expand Up @@ -107,6 +113,7 @@ impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
functions,
types,
type_variations,
names: None,
})
}
}
Expand Down
60 changes: 50 additions & 10 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub async fn from_substrait_plan(
plan: &Plan,
) -> Result<LogicalPlan> {
// Register function extension
let extensions = Extensions::try_from(&plan.extensions)?;
let mut extensions = Extensions::try_from(&plan.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}
Expand All @@ -214,6 +214,9 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel, &extensions).await?)
},
plan_rel::RelType::Root(root) => {
if !root.names.is_empty() {
extensions = extensions.with_projection_names(root.names.clone());
}
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
Expand All @@ -228,14 +231,32 @@ pub async fn from_substrait_plan(
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(p.expr, p.input.schema(), &renamed_schema)?,
p.input
)?
))
},
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
Ok(LogicalPlan::Aggregate(
Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?
))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(
plan.schema().columns().iter().map(|c| col(c.to_owned())),
plan.schema(),
&renamed_schema
)?,
Arc::new(plan)
)?
)),
}
}
},
Expand Down Expand Up @@ -363,7 +384,6 @@ fn make_renamed_schema(
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
Expand All @@ -390,7 +410,6 @@ fn make_renamed_schema(
name_idx,
dfs_names.len());
}

DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
Expand All @@ -410,19 +429,19 @@ pub async fn from_substrait_rel(
let mut input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
println!("{:?}", p);
let mut names: HashSet<String> = HashSet::new();
let mut exprs: Vec<Expr> = vec![];
let mut exprs: Vec<Expr> = Vec::new();
for e in &p.expressions {
let x =
from_substrait_rex(ctx, e, input.clone().schema(), extensions)
.await?;

// if the expression is WindowFunction, wrap in a Window relation
if let Expr::WindowFunction(_) = &x {
// Adding the same expression here and in the project below
// works because the project's builder uses columnize_expr(..)
// to transform it into a column reference
input = input.window(vec![x.clone()])?
}

// Ensure the expression has a unique display name, so that project's
// validate_unique_names doesn't fail
let name = x.schema_name().to_string();
Expand All @@ -439,6 +458,27 @@ pub async fn from_substrait_rel(
}
names.insert(new_name);
}
let schema = input.schema();
if let (Some(extensions_names), true) =
(extensions.names.as_ref(), p.common.is_some())
{
extensions_names.iter().for_each(|name| {
if let Ok(field) =
schema.qualified_field_with_unqualified_name(name)
{
let expr = Expr::from(Column::from(field));
let schema_name = expr.schema_name().to_string();

if names.insert(schema_name.clone()) {
let position = extensions_names
.iter()
.position(|n| n == name)
.unwrap_or(exprs.len());
exprs.insert(position, expr);
}
}
});
}
input.project(exprs)?.build()
} else {
not_impl_err!("Projection without an input is not supported")
Expand Down
2 changes: 2 additions & 0 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,7 @@ mod test {
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
names: None,
}
);

Expand Down Expand Up @@ -2423,6 +2424,7 @@ mod test {
INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()
)]),
type_variations: HashMap::new(),
names: None,
}
);

Expand Down
54 changes: 54 additions & 0 deletions datafusion/substrait/tests/cases/bugs_converage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for bugs in substrait

#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use substrait::proto::Plan;
#[tokio::test]
async fn extra_projection_with_input() -> Result<()> {
let ctx = SessionContext::new();
let schema = Schema::new(vec![
Field::new("user_id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("paid_for_service", DataType::Boolean, false),
]);
let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap();
ctx.register_table("users", Arc::new(memory_table))?;
let path = "tests/testdata/extra_projection_with_input.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{}", plan);
assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\
\n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: users projection=[user_id, name, paid_for_service]");
Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

mod bugs_converage;
mod consumer_integration;
mod function_test;
mod logical_plans;
Expand Down
113 changes: 113 additions & 0 deletions datafusion/substrait/tests/testdata/extra_projection_with_input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "row_number"
}
}
],
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"direct": {}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"user_id",
"name",
"paid_for_service"
],
"struct": {
"types": [
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"bool": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"users"
]
}
}
},
"expressions": [
{
"windowFunction": {
"functionReference": 1,
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
],
"upperBound": {
"unbounded": {}
},
"lowerBound": {
"unbounded": {}
},
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL"
}
}
]
}
},
"names": [
"user_id",
"name",
"paid_for_service",
"row_number"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}

0 comments on commit 4a8a158

Please sign in to comment.