Skip to content

Commit

Permalink
transform: Turn a Join with a constant input into a FlatMap or Map
Browse files Browse the repository at this point in the history
  • Loading branch information
ggevay committed Sep 20, 2024
1 parent 2b128b7 commit d1d1e4c
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 45 deletions.
12 changes: 12 additions & 0 deletions src/expr/src/relation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import "expr/src/scalar.proto";
import "expr/src/relation/func.proto";

import "repr/src/relation_and_scalar.proto";
import "repr/src/row.proto";

package mz_expr.relation;

Expand Down Expand Up @@ -178,6 +179,16 @@ message ProtoTableFunc {
mz_repr.relation_and_scalar.ProtoRelationType relation = 2;
}

message ProtoRowAndDiff {
mz_repr.row.ProtoRow row = 1;
int64 diff = 2;
}

message ProtoJoinWithConstant {
repeated ProtoRowAndDiff rows = 1;
mz_repr.relation_and_scalar.ProtoRelationType typ = 2;
}

oneof kind {
bool jsonb_each = 1;
google.protobuf.Empty jsonb_object_keys = 2;
Expand All @@ -197,5 +208,6 @@ message ProtoTableFunc {
google.protobuf.Empty acl_explode = 16;
google.protobuf.Empty mz_acl_explode = 17;
mz_repr.relation_and_scalar.ProtoScalarType unnest_map = 18;
ProtoJoinWithConstant join_with_constant = 19;
}
}
48 changes: 47 additions & 1 deletion src/expr/src/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ use serde::{Deserialize, Serialize};

use crate::explain::{HumanizedExpr, HumanizerMode};
use crate::relation::proto_aggregate_func::{self, ProtoColumnOrders, ProtoFusedValueWindowFunc};
use crate::relation::proto_table_func::ProtoTabletizedScalar;
use crate::relation::proto_table_func::{
ProtoJoinWithConstant, ProtoRowAndDiff, ProtoTabletizedScalar,
};
use crate::relation::{
compare_columns, proto_table_func, ColumnOrder, ProtoAggregateFunc, ProtoTableFunc,
WindowFrame, WindowFrameBound, WindowFrameUnits,
Expand Down Expand Up @@ -3437,6 +3439,12 @@ pub enum TableFunc {
name: String,
relation: RelationType,
},
/// ////// todo: comment
JoinWithConstant {
#[mzreflect(ignore)]
rows: Vec<(Row, Diff)>,
typ: RelationType,
},
}

impl RustType<ProtoTableFunc> for TableFunc {
Expand Down Expand Up @@ -3471,6 +3479,12 @@ impl RustType<ProtoTableFunc> for TableFunc {
relation: Some(relation.into_proto()),
})
}
TableFunc::JoinWithConstant { rows, typ } => {
Kind::JoinWithConstant(ProtoJoinWithConstant {
rows: rows.iter().map(|r| r.into_proto()).collect(),
typ: Some(typ.into_proto()),
})
}
}),
}
}
Expand Down Expand Up @@ -3515,10 +3529,34 @@ impl RustType<ProtoTableFunc> for TableFunc {
.relation
.into_rust_if_some("ProtoTabletizedScalar::relation")?,
},
Kind::JoinWithConstant(v) => TableFunc::JoinWithConstant {
rows: v
.rows
.into_iter()
.map(|r| r.into_rust())
.collect::<Result<Vec<_>, _>>()?,
typ: v.typ.into_rust_if_some("ProtoJoinWithConstant::typ")?,
},
})
}
}

impl RustType<ProtoRowAndDiff> for (Row, Diff) {
fn into_proto(&self) -> ProtoRowAndDiff {
ProtoRowAndDiff {
row: Some(self.0.into_proto()),
diff: self.1,
}
}

fn from_proto(proto: ProtoRowAndDiff) -> Result<Self, TryFromProtoError> {
Ok((
proto.row.into_rust_if_some("ProtoRowAndDiff::row")?,
proto.diff,
))
}
}

impl TableFunc {
pub fn eval<'a>(
&'a self,
Expand Down Expand Up @@ -3594,6 +3632,7 @@ impl TableFunc {
let r = Row::pack_slice(datums);
Ok(Box::new(std::iter::once((r, 1))))
}
TableFunc::JoinWithConstant { rows, typ: _ } => Ok(Box::new(rows.clone().into_iter())),
}
}

Expand Down Expand Up @@ -3722,6 +3761,9 @@ impl TableFunc {
TableFunc::TabletizedScalar { relation, .. } => {
return relation.clone();
}
TableFunc::JoinWithConstant { rows: _, typ } => {
return typ.clone();
}
};

if !keys.is_empty() {
Expand Down Expand Up @@ -3751,6 +3793,7 @@ impl TableFunc {
TableFunc::UnnestMap { .. } => 2,
TableFunc::Wrap { width, .. } => *width,
TableFunc::TabletizedScalar { relation, .. } => relation.column_types.len(),
TableFunc::JoinWithConstant { rows: _, typ } => typ.column_types.len(),
}
}

Expand All @@ -3774,6 +3817,7 @@ impl TableFunc {
| TableFunc::UnnestMap { .. } => true,
TableFunc::Wrap { .. } => false,
TableFunc::TabletizedScalar { .. } => false,
TableFunc::JoinWithConstant { .. } => false,
}
}

Expand All @@ -3800,6 +3844,7 @@ impl TableFunc {
TableFunc::UnnestMap { .. } => true,
TableFunc::Wrap { .. } => true,
TableFunc::TabletizedScalar { .. } => true,
TableFunc::JoinWithConstant { .. } => true,
}
}
}
Expand All @@ -3825,6 +3870,7 @@ impl fmt::Display for TableFunc {
TableFunc::UnnestMap { .. } => f.write_str("unnest_map"),
TableFunc::Wrap { width, .. } => write!(f, "wrap{}", width),
TableFunc::TabletizedScalar { name, .. } => f.write_str(name),
TableFunc::JoinWithConstant { rows, typ: _ } => f.write_str("join_with_constant"), //////// todo: rows
}
}
}
Expand Down
50 changes: 34 additions & 16 deletions src/transform/src/fusion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,19 @@ impl Join {
/// Return Ok(true) iff the action manipulated the tree after detecting the
/// most general pattern.
pub fn action(relation: &mut MirRelationExpr) -> Result<bool, TransformError> {
if eliminate_trivial_join(relation) {
return Ok(false);
}
if let MirRelationExpr::Join {
inputs,
equivalences,
..
implementation,
} = relation
{
// Local non-fusion tidying.
inputs.retain(|e| !e.is_constant_singleton());
if inputs.len() == 0 {
*relation = MirRelationExpr::constant(vec![vec![]], mz_repr::RelationType::empty())
.filter(unpack_equivalences(equivalences));
return Ok(false);
}
if inputs.len() == 1 {
*relation = inputs
.pop()
.unwrap()
.filter(unpack_equivalences(equivalences));
return Ok(false);
}
assert!(
!implementation.is_implemented(),
"fusion::Join is meant to be used on Unimplemented joins"
);

// Bail early if no children are MFPs around a Join
if inputs.iter().any(|mut expr| {
Expand Down Expand Up @@ -245,12 +238,37 @@ impl Join {
}
}

/// /////// todo: comment
pub fn eliminate_trivial_join(relation: &mut MirRelationExpr) -> bool {
if let MirRelationExpr::Join {
inputs,
equivalences,
implementation: _,
} = relation
{
inputs.retain(|e| !e.is_constant_singleton());
if inputs.len() == 0 {
*relation = MirRelationExpr::constant(vec![vec![]], mz_repr::RelationType::empty())
.filter(unpack_equivalences(equivalences));
return true;
}
if inputs.len() == 1 {
*relation = inputs
.pop()
.unwrap()
.filter(unpack_equivalences(equivalences));
return true;
}
}
false
}

/// Unpacks multiple equivalence classes into conjuncts that should all be true, essentially
/// turning join equivalences into a Filter.
///
/// Note that a join equivalence treats null equal to null, while an `=` in a Filter does not.
/// This function is mindful of this.
fn unpack_equivalences(equivalences: &Vec<Vec<MirScalarExpr>>) -> Vec<MirScalarExpr> {
pub fn unpack_equivalences(equivalences: &Vec<Vec<MirScalarExpr>>) -> Vec<MirScalarExpr> {
let mut result = Vec::new();
for mut class in equivalences.iter().cloned() {
// Let's put the simplest expression at the beginning of `class`, because all the
Expand Down
164 changes: 164 additions & 0 deletions src/transform/src/join_to_flat_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! todo: comment
//!

use itertools::Itertools;
use mz_expr::visit::Visit;
use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc, RECURSION_LIMIT};
use mz_ore::stack::{CheckedRecursion, RecursionGuard};

use crate::TransformCtx;

/// todo: comment
#[derive(Debug)]
pub struct JoinToFlatMap {
recursion_guard: RecursionGuard, //////// todo: do we need this, or we'll do visit_mut_post instead?
}

impl Default for JoinToFlatMap {
fn default() -> JoinToFlatMap {
JoinToFlatMap {
recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
}
}
}

impl CheckedRecursion for JoinToFlatMap {
fn recursion_guard(&self) -> &RecursionGuard {
&self.recursion_guard
}
}

impl crate::Transform for JoinToFlatMap {
#[mz_ore::instrument(
target = "optimizer",
level = "debug",
fields(path.segment = "join_to_flat_map")
)]
fn transform(
&self,
relation: &mut MirRelationExpr,
_: &mut TransformCtx,
) -> Result<(), crate::TransformError> {
let result = self.action(relation);
mz_repr::explain::trace_plan(&*relation);
result
}
}

impl JoinToFlatMap {
/// todo: comment
pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), crate::TransformError> {
relation.visit_mut_post(&mut |e| {
crate::fusion::join::eliminate_trivial_join(e);
match e {
MirRelationExpr::Join {
inputs,
equivalences,
implementation,
} => {
assert!(
!implementation.is_implemented(),
"JoinToFlatMap is meant to be used on Unimplemented joins"
);
match &mut &mut **inputs {
&mut [in1, in2] => {
/////// todo: comment about these variables
let const_rows;
let const_typ;
let other_in;
let permute;
// Only match non-error constants.
if let MirRelationExpr::Constant {
rows: Ok(rows),
typ,
} = in1
{
const_rows = rows;
const_typ = typ;
other_in = in2;
permute = true;
} else {
if let MirRelationExpr::Constant {
rows: Ok(rows),
typ,
} = in2
{
const_rows = rows;
const_typ = typ;
other_in = in1;
permute = false;
} else {
return;
}
}
let other_arity = other_in.arity();
let const_arity = const_typ.arity();
// At this point, we know that one of the inputs is a constant. (It
// might be that actually both are constants, but we do our
// transformation anyway, and a later FoldConstants will eliminate all
// this.)
if let [(const_row, 1)] = &**const_rows {
// We can turn the join into a `Map`.
let scalar_exprs = const_row
.into_iter()
.zip_eq(const_typ.column_types.iter())
.map(|(datum, typ)| {
MirScalarExpr::literal_ok(datum, typ.scalar_type.clone())
})
.collect();
*e = other_in
.take_dangerous()
.map(scalar_exprs)
.filter(crate::fusion::join::unpack_equivalences(equivalences));
} else {
// We can turn the join into a `FlatMap`.
*e = other_in
.take_dangerous()
.flat_map(
TableFunc::JoinWithConstant {
rows: const_rows.clone(),
typ: const_typ.clone(),
},
vec![],
)
.filter(crate::fusion::join::unpack_equivalences(equivalences));
}
if permute {
*e = e.take_dangerous().project(
(other_arity..(other_arity + const_arity))
.chain(0..other_arity)
.collect(),
);
}
}
_ => {
// Join with more than 2 inputs.
// (0- and 1-input joins are eliminated by `eliminate_trivial_join`.)
// We don't transform these for now.
// Note that most joins have 2 inputs at this point in the optimizer
// pipeline. For example, something like
// ```
// SELECT *
// FROM t1, t2, t3, t4
// ```
// is still a series of binary joins at this point.
///////////// todo: check some examples where this branch happens
}
}
}
_ => {}
}
})?;

Ok(())
}
}
Loading

0 comments on commit d1d1e4c

Please sign in to comment.