Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Sep 5, 2024
1 parent f638df3 commit 0b0659f
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 5 deletions.
30 changes: 25 additions & 5 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,32 @@ pub async fn from_substrait_plan(
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(p.expr, p.input.schema(), &renamed_schema)?,
p.input
)?
))
},
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
Ok(LogicalPlan::Aggregate(
Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?
))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(
plan.schema().columns().iter().map(|c| col(c.to_owned())),
plan.schema(),
&renamed_schema
)?,
Arc::new(plan)
)?
)),
}
}
},
Expand Down Expand Up @@ -363,7 +381,6 @@ fn make_renamed_schema(
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
Expand All @@ -390,7 +407,6 @@ fn make_renamed_schema(
name_idx,
dfs_names.len());
}

DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
Expand All @@ -412,6 +428,10 @@ pub async fn from_substrait_rel(
);
let mut names: HashSet<String> = HashSet::new();
let mut exprs: Vec<Expr> = vec![];
input.schema().iter().for_each(|(qualifier, field)| {
exprs.push(col(Column::from((qualifier, field))))
});

for e in &p.expressions {
let x =
from_substrait_rex(ctx, e, input.clone().schema(), extensions)
Expand Down
54 changes: 54 additions & 0 deletions datafusion/substrait/tests/cases/bugs_converage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for bugs in substrait

#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use substrait::proto::Plan;
#[tokio::test]
async fn extra_projection_with_input() -> Result<()> {
let ctx = SessionContext::new();
let schema = Schema::new(vec![
Field::new("user_id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("paid_for_service", DataType::Boolean, false),
]);
let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap();
ctx.register_table("users", Arc::new(memory_table))?;
let path = "tests/testdata/extra_projection_with_input.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{}", plan);
assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\
\n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: users projection=[user_id, name, paid_for_service]");
Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

mod bugs_converage;
mod consumer_integration;
mod function_test;
mod logical_plans;
Expand Down
113 changes: 113 additions & 0 deletions datafusion/substrait/tests/testdata/extra_projection_with_input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "row_number"
}
}
],
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"direct": {}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"user_id",
"name",
"paid_for_service"
],
"struct": {
"types": [
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"bool": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"users"
]
}
}
},
"expressions": [
{
"windowFunction": {
"functionReference": 1,
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
],
"upperBound": {
"unbounded": {}
},
"lowerBound": {
"unbounded": {}
},
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL"
}
}
]
}
},
"names": [
"user_id",
"name",
"paid_for_service",
"row_number"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}

0 comments on commit 0b0659f

Please sign in to comment.