Skip to content

Commit

Permalink
Make HashGenerationOptimizer insensitive to hash symbols order
Browse files Browse the repository at this point in the history
This improvement reduces number of hash computations in a plan
when hash symbols for partitions or joins are the same but
with different order (e.g: Tpch Q9 query).
  • Loading branch information
sopel39 authored and martint committed Oct 21, 2016
1 parent 2fd1fa2 commit 15d32ee
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode.EquiJoinClause;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
Expand All @@ -53,9 +52,10 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;

import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -141,7 +141,7 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation
{
Optional<HashComputation> groupByHash = Optional.empty();
if (!canSkipHashGeneration(node.getGroupingKeys())) {
groupByHash = computeHash(node.getGroupingKeys());
groupByHash = computeHash(node.getGroupingKeys(), parentPreference);
}

// aggregation does not pass through preferred hash symbols
Expand Down Expand Up @@ -185,7 +185,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, HashComputa
return planSimpleNodeWithProperties(node, parentPreference);
}

Optional<HashComputation> hashComputation = computeHash(node.getDistinctSymbols());
Optional<HashComputation> hashComputation = computeHash(node.getDistinctSymbols(), parentPreference);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
Expand All @@ -206,7 +206,7 @@ public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, HashComputati
return planSimpleNodeWithProperties(node, parentPreference);
}

Optional<HashComputation> hashComputation = computeHash(node.getDistinctSymbols());
Optional<HashComputation> hashComputation = computeHash(node.getDistinctSymbols(), parentPreference);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
Expand All @@ -226,7 +226,7 @@ public PlanWithProperties visitRowNumber(RowNumberNode node, HashComputationSet
return planSimpleNodeWithProperties(node, parentPreference);
}

Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy());
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), parentPreference);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
Expand All @@ -252,7 +252,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa
return planSimpleNodeWithProperties(node, parentPreference);
}

Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy());
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), parentPreference);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
Expand Down Expand Up @@ -300,11 +300,15 @@ public PlanWithProperties visitJoin(JoinNode node, HashComputationSet parentPref

// join does not pass through preferred hash symbols since they take more memory and since
// the join node filters, may take more compute
Optional<HashComputation> leftHashComputation = computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft));
Optional<HashComputation> leftHashComputation = computeHash(
clauses.stream()
.map(JoinNode.EquiJoinClause::getLeft)
.collect(toImmutableList()),
parentPreference);
PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(leftHashComputation), true, new HashComputationSet(leftHashComputation));
Symbol leftHashSymbol = left.getRequiredHashSymbol(leftHashComputation.get());

Optional<HashComputation> rightHashComputation = computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getRight));
Optional<HashComputation> rightHashComputation = computeHash(getEquiJoinClauseRightSymbols(node.getCriteria(), leftHashComputation.get().getFields()));
// drop undesired hash symbols from build to save memory
PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(rightHashComputation), true, new HashComputationSet(rightHashComputation));
Symbol rightHashSymbol = right.getRequiredHashSymbol(rightHashComputation.get());
Expand Down Expand Up @@ -368,15 +372,19 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet

// join does not pass through preferred hash symbols since they take more memory and since
// the join node filters, may take more compute
Optional<HashComputation> probeHashComputation = computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe));
Optional<HashComputation> probeHashComputation = computeHash(
clauses.stream()
.map(IndexJoinNode.EquiJoinClause::getProbe)
.collect(toImmutableList()),
parentPreference);
PlanWithProperties probe = planAndEnforce(
node.getProbeSource(),
new HashComputationSet(probeHashComputation),
true,
new HashComputationSet(probeHashComputation));
Symbol probeHashSymbol = probe.getRequiredHashSymbol(probeHashComputation.get());

Optional<HashComputation> indexHashComputation = computeHash(Lists.transform(clauses, EquiJoinClause::getIndex));
Optional<HashComputation> indexHashComputation = computeHash(getEquiJoinClauseIndexSymbols(node.getCriteria(), probeHashComputation.get().getFields()));
HashComputationSet requiredHashes = new HashComputationSet(indexHashComputation);
PlanWithProperties index = planAndEnforce(node.getIndexSource(), requiredHashes, true, requiredHashes);
Symbol indexHashSymbol = index.getRequiredHashSymbol(indexHashComputation.get());
Expand Down Expand Up @@ -407,7 +415,7 @@ public PlanWithProperties visitWindow(WindowNode node, HashComputationSet parent
return planSimpleNodeWithProperties(node, parentPreference, true);
}

Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy());
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), parentPreference);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
Expand Down Expand Up @@ -441,8 +449,9 @@ public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet pa
partitioningScheme.getPartitioning().getArguments().stream().allMatch(ArgumentBinding::isVariable)) {
// add precomputed hash for exchange
partitionSymbols = computeHash(partitioningScheme.getPartitioning().getArguments().stream()
.map(ArgumentBinding::getColumn)
.collect(toImmutableList()));
.map(ArgumentBinding::getColumn)
.collect(toImmutableList()),
preference);
preference = preference.withHashComputation(partitionSymbols);
}

Expand Down Expand Up @@ -601,6 +610,26 @@ public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parent
hashSymbols);
}

private List<Symbol> getEquiJoinClauseRightSymbols(List<JoinNode.EquiJoinClause> clauses, List<Symbol> leftSymbols)
{
return leftSymbols.stream()
.map(leftSymbol -> clauses.stream()
.filter(clause -> clause.getLeft().equals(leftSymbol))
.findFirst().get()
.getRight())
.collect(toImmutableList());
}

private List<Symbol> getEquiJoinClauseIndexSymbols(List<IndexJoinNode.EquiJoinClause> clauses, List<Symbol> probeSymbols)
{
return probeSymbols.stream()
.map(probeSymbol -> clauses.stream()
.filter(clause -> clause.getProbe().equals(probeSymbol))
.findFirst().get()
.getIndex())
.collect(toImmutableList());
}

private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashComputationSet preferredHashes)
{
return planSimpleNodeWithProperties(node, preferredHashes, true);
Expand Down Expand Up @@ -761,12 +790,29 @@ public HashComputationSet withHashComputation(Optional<HashComputation> hashComp
}

public static Optional<HashComputation> computeHash(Iterable<Symbol> fields)
{
return computeHash(fields, new HashComputationSet());
}

public static Optional<HashComputation> computeHash(Iterable<Symbol> fields, HashComputationSet preferredHashes)
{
requireNonNull(fields, "fields is null");
requireNonNull(preferredHashes, "preferredHashes is null");
List<Symbol> symbols = ImmutableList.copyOf(fields);
if (symbols.isEmpty()) {
return Optional.empty();
}

// try to use one of preferred hash computation if unique symbols match
Multiset<Symbol> unorderedFields = ImmutableMultiset.copyOf(fields);
Optional<HashComputation> preferredHash = preferredHashes.getHashes().stream()
.filter(hash -> ImmutableMultiset.copyOf(hash.getFields()).equals(unorderedFields))
.findAny();

if (preferredHash.isPresent()) {
return preferredHash;
}

return Optional.of(new HashComputation(fields));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

@Immutable
Expand Down Expand Up @@ -65,6 +66,8 @@ public JoinNode(@JsonProperty("id") PlanNodeId id,
this.filter = filter;
this.leftHashSymbol = leftHashSymbol;
this.rightHashSymbol = rightHashSymbol;

checkState(leftHashSymbol.isPresent() == rightHashSymbol.isPresent(), "Either none or both hash symbols should be provided");
}

public enum Type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,39 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
))))))));
}

@Test
public void testGeneratesOneLeftHashForTwoJoinsWithShuffledSymbols()
{
assertPlan(
"SELECT * " +
" FROM " +
" (SELECT " +
" l.suppkey," +
" l.partkey" +
" FROM" +
" lineitem l" +
" JOIN" +
" partsupp ps" +
" ON" +
" ps.suppkey = l.suppkey" +
" AND ps.partkey = l.partkey) l" +
" JOIN" +
" partsupp ps" +
" ON" +
" ps.partkey = l.partkey" +
" AND ps.suppkey = l.suppkey",
anyTree(node(JoinNode.class,
project(
node(JoinNode.class,
project(
anyTree()).withSymbol("hash", "H"),
anyTree())
).withSymbol("suppkey", "S")
.withSymbol("partkey", "P")
.withSymbol("hash", "H"),
anyTree())));
}

private void assertPlan(String sql, PlanMatchPattern pattern)
{
assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern);
Expand Down

0 comments on commit 15d32ee

Please sign in to comment.