-
Notifications
You must be signed in to change notification settings - Fork 468
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transform: Turn a Join with a constant input into a FlatMap or Map
- Loading branch information
Showing
7 changed files
with
465 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
// Copyright Materialize, Inc. and contributors. All rights reserved. | ||
// | ||
// Use of this software is governed by the Business Source License | ||
// included in the LICENSE file. | ||
// | ||
// As of the Change Date specified in that file, in accordance with | ||
// the Business Source License, use of this software will be governed | ||
// by the Apache License, Version 2.0. | ||
|
||
//! todo: comment | ||
//! | ||
|
||
use itertools::Itertools; | ||
use mz_expr::visit::Visit; | ||
use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc, RECURSION_LIMIT}; | ||
use mz_ore::stack::{CheckedRecursion, RecursionGuard}; | ||
|
||
use crate::TransformCtx; | ||
|
||
/// todo: comment | ||
#[derive(Debug)] | ||
pub struct JoinToFlatMap { | ||
recursion_guard: RecursionGuard, //////// todo: do we need this, or we'll do visit_mut_post instead? | ||
} | ||
|
||
impl Default for JoinToFlatMap { | ||
fn default() -> JoinToFlatMap { | ||
JoinToFlatMap { | ||
recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT), | ||
} | ||
} | ||
} | ||
|
||
impl CheckedRecursion for JoinToFlatMap { | ||
fn recursion_guard(&self) -> &RecursionGuard { | ||
&self.recursion_guard | ||
} | ||
} | ||
|
||
impl crate::Transform for JoinToFlatMap { | ||
#[mz_ore::instrument( | ||
target = "optimizer", | ||
level = "debug", | ||
fields(path.segment = "join_to_flat_map") | ||
)] | ||
fn transform( | ||
&self, | ||
relation: &mut MirRelationExpr, | ||
_: &mut TransformCtx, | ||
) -> Result<(), crate::TransformError> { | ||
let result = self.action(relation); | ||
mz_repr::explain::trace_plan(&*relation); | ||
result | ||
} | ||
} | ||
|
||
impl JoinToFlatMap { | ||
/// todo: comment | ||
pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), crate::TransformError> { | ||
relation.visit_mut_post(&mut |e| { | ||
crate::fusion::join::eliminate_trivial_join(e); | ||
match e { | ||
MirRelationExpr::Join { | ||
inputs, | ||
equivalences, | ||
implementation, | ||
} => { | ||
assert!( | ||
!implementation.is_implemented(), | ||
"JoinToFlatMap is meant to be used on Unimplemented joins" | ||
); | ||
match &mut &mut **inputs { | ||
&mut [in1, in2] => { | ||
/////// todo: comment about these variables | ||
let const_rows; | ||
let const_typ; | ||
let other_in; | ||
let permute; | ||
// Only match non-error constants. | ||
if let MirRelationExpr::Constant { | ||
rows: Ok(rows), | ||
typ, | ||
} = in1 | ||
{ | ||
const_rows = rows; | ||
const_typ = typ; | ||
other_in = in2; | ||
permute = true; | ||
} else { | ||
if let MirRelationExpr::Constant { | ||
rows: Ok(rows), | ||
typ, | ||
} = in2 | ||
{ | ||
const_rows = rows; | ||
const_typ = typ; | ||
other_in = in1; | ||
permute = false; | ||
} else { | ||
return; | ||
} | ||
} | ||
let other_arity = other_in.arity(); | ||
let const_arity = const_typ.arity(); | ||
// At this point, we know that one of the inputs is a constant. (It | ||
// might be that actually both are constants, but we do our | ||
// transformation anyway, and a later FoldConstants will eliminate all | ||
// this.) | ||
if let [(const_row, 1)] = &**const_rows { | ||
// We can turn the join into a `Map`. | ||
let scalar_exprs = const_row | ||
.into_iter() | ||
.zip_eq(const_typ.column_types.iter()) | ||
.map(|(datum, typ)| { | ||
MirScalarExpr::literal_ok(datum, typ.scalar_type.clone()) | ||
}) | ||
.collect(); | ||
*e = other_in | ||
.take_dangerous() | ||
.map(scalar_exprs) | ||
.filter(crate::fusion::join::unpack_equivalences(equivalences)); | ||
} else { | ||
// We can turn the join into a `FlatMap`. | ||
*e = other_in | ||
.take_dangerous() | ||
.flat_map( | ||
TableFunc::JoinWithConstant { | ||
rows: const_rows.clone(), | ||
typ: const_typ.clone(), | ||
}, | ||
vec![], | ||
) | ||
.filter(crate::fusion::join::unpack_equivalences(equivalences)); | ||
} | ||
if permute { | ||
*e = e.take_dangerous().project( | ||
(other_arity..(other_arity + const_arity)) | ||
.chain(0..other_arity) | ||
.collect(), | ||
); | ||
} | ||
} | ||
_ => { | ||
// Join with more than 2 inputs. | ||
// (0- and 1-input joins are eliminated by `eliminate_trivial_join`.) | ||
// We don't transform these for now. | ||
// Note that most joins have 2 inputs at this point in the optimizer | ||
// pipeline. For example, something like | ||
// ``` | ||
// SELECT * | ||
// FROM t1, t2, t3, t4 | ||
// ``` | ||
// is still a series of binary joins at this point. | ||
///////////// todo: check some examples where this branch happens | ||
} | ||
} | ||
} | ||
_ => {} | ||
} | ||
})?; | ||
|
||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.