Skip to content

Commit

Permalink
Feat: apply join synopses for approx. query processing
Browse files Browse the repository at this point in the history
  • Loading branch information
taewhi committed Sep 25, 2024
1 parent 74cf27d commit dee5be6
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ MTable createJoinTable(List<String> schemaNames, List<String> tableNames,

void dropJoinTable(String schemaName, String joinTableName) throws CatalogException;

Collection<MSynopsis> getJoinSynopses(
List<Long> baseTableIds, Map<Long, List<String>> columnNames, String joinCondition)
throws CatalogException;

/* Common */
void close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.jdo.PersistenceManager;
import javax.jdo.Query;
import javax.jdo.Transaction;
Expand Down Expand Up @@ -228,6 +229,54 @@ public void dropJoinTable(String schemaName, String joinTableName) throws Catalo
}
}

@Override
public Collection<MSynopsis> getJoinSynopses(
List<Long> baseTableIds, Map<Long, List<String>> columnNames, String joinCondition)
throws CatalogException {
try {
List<Long> joinTableIds = null;
for (Long tid : baseTableIds) {
Query query = pm.newQuery(MJoin.class);
setFilterPatterns(query, ImmutableMap.of("src_table_id", tid));
List<MJoin> mJoins = (List<MJoin>) query.execute();
List<Long> ids = mJoins.stream()
.filter(obj -> obj.containsColumnNames(columnNames.get(tid)))
.map(MJoin::getJoinTableId).collect(Collectors.toList());
if (joinTableIds == null) {
joinTableIds = ids;
} else {
joinTableIds.retainAll(ids);
}
}
if (joinTableIds == null) {
return null;
}

List<MSynopsis> joinSynopses = new ArrayList<>();
for (Long joinTableId : joinTableIds) {
Collection<MTable> joinTable = getTables(ImmutableMap.of("id", joinTableId));
for (MTable jt : joinTable) {
Collection<MTableExt> tableExts = jt.getTableExts();
if (tableExts == null || tableExts.isEmpty()) {
continue;
}
for (MTableExt tableExt : tableExts) {
if (tableExt.getExternalTableUri().contains(joinCondition)) {
Collection<MSynopsis> synopses =
getAllSynopses(jt.getSchema().getSchemaName(), jt.getTableName());
joinSynopses.addAll(synopses);
break;
}
}
}
}
return joinSynopses;
} catch (RuntimeException e) {
throw new CatalogException("failed to get synopses", e);
}

}

@Override
public MModel trainModel(
String modeltypeName, String modelName, String schemaName, String tableName,
Expand Down
8 changes: 8 additions & 0 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MJoin.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ public MJoin(long joinTableId, long srcTableId, List<String> columns) {
this.columns = columns;
}

public long getJoinTableId() {
return join_table_id;
}

public List<String> getColumnNames() {
return columns;
}

public boolean containsColumnNames(List<String> columnNames) {
return this.columns.containsAll(columnNames);
}

}
35 changes: 35 additions & 0 deletions traindb-core/src/main/java/traindb/planner/TrainDBPlanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.calcite.adapter.file.CsvEnumerator;
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.plan.Context;
Expand Down Expand Up @@ -66,6 +68,7 @@
import traindb.catalog.CatalogException;
import traindb.catalog.pm.MModel;
import traindb.catalog.pm.MSynopsis;
import traindb.catalog.pm.MTable;
import traindb.common.TrainDBConfiguration;
import traindb.planner.caqp.CaqpExecutionTimePolicy;
import traindb.planner.caqp.CaqpExecutionTimePolicyType;
Expand Down Expand Up @@ -155,6 +158,38 @@ public CatalogContext getCatalogContext() {
return catalogContext;
}

public Collection<MSynopsis> getAvailableJoinSynopses(
List<TableScan> scans, Map<TableScan, List<String>> requiredScanColumnMap, String condition) {

List<Long> scanTableIds = new ArrayList<>();
Map<Long, List<String>> scanTableColumns = new HashMap<>();
for (TableScan scan : scans) {
String baseSchema = scan.getTable().getQualifiedName().get(1);
String baseTable = scan.getTable().getQualifiedName().get(2);
MTable mTable = catalogContext.getTable(baseSchema, baseTable);
if (mTable == null) {
return null;
}
Long tid = mTable.getId();
scanTableIds.add(tid);
scanTableColumns.put(tid, requiredScanColumnMap.get(scan));
}
try {
Collection<MSynopsis> synopses =
catalogContext.getJoinSynopses(scanTableIds, scanTableColumns, condition);
List<MSynopsis> availableSynopses = new ArrayList<>();
for (MSynopsis synopsis : synopses) {
if (!synopsis.isEnabled()) {
continue;
}
availableSynopses.add(synopsis);
}
return availableSynopses;
} catch (CatalogException e) {
}
return null;
}

public Collection<MSynopsis> getAvailableSynopses(List<String> qualifiedBaseTableName,
List<String> requiredColumnNames) {
String baseSchema = qualifiedBaseTableName.get(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
Expand All @@ -37,6 +39,7 @@
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.mapping.Mappings;
import org.immutables.value.Value;
Expand Down Expand Up @@ -133,6 +136,40 @@ private List<Integer> getRequiredColumnIndex(RelNode node, int start, int end) {
return requiredColumnIndex;
}

private String getConditionString(RelNode node, List<TableScan> scans) {
String condStr = "";
if (node instanceof Join) {
Join join = (Join) node;
RexNode joinCondition = join.getCondition();
if (joinCondition instanceof RexCall) {
List<RexNode> operands = ((RexCall) joinCondition).getOperands();
List<Integer> inputRefIndex = getRexInputRefIndex(operands);
SqlOperator operator = ((RexCall) joinCondition).getOperator();
int li;
int ri;
if (inputRefIndex.get(0) < inputRefIndex.get(1)) {
li = inputRefIndex.get(0);
ri = inputRefIndex.get(1) - join.getLeft().getRowType().getFieldCount();
} else {
li = inputRefIndex.get(1);
ri = inputRefIndex.get(0) - join.getLeft().getRowType().getFieldCount();
}

StringBuilder sb = new StringBuilder();
sb.append(scans.get(0).getTable().getQualifiedName().get(2));
sb.append(".");
sb.append(join.getLeft().getRowType().getFieldNames().get(li));
sb.append(" ").append(operator).append(" ");
sb.append(scans.get(1).getTable().getQualifiedName().get(2));
sb.append(".");
sb.append(join.getRight().getRowType().getFieldNames().get(ri));
condStr = sb.toString();
}
}

return condStr;
}

private Mappings.TargetMapping createMapping(List<String> fromColumns, List<String> toColumns) {
List<Integer> targets = new ArrayList<>();
for (int i = 0; i < fromColumns.size(); i++) {
Expand All @@ -152,40 +189,176 @@ public void onMatch(RelOptRuleCall call) {

final Aggregate aggregate = call.rel(0);
List<TableScan> tableScans = ApproxAggregateUtil.findAllTableScans(aggregate);
Map<Join, List<TableScan>> joinScanMap = new HashMap<>();
Map<Join, List<Filter>> joinFilterMap = new HashMap<>();
Map<TableScan, List<String>> requiredScanColumnMap = new HashMap<>();
for (TableScan scan : tableScans) {
if (!isApplicable(aggregate, scan)) {
continue;
}

// build required column information for each table scan
Set<Integer> requiredColumnIndex = new HashSet<>();
List<Filter> filterList = new ArrayList<>();
int start = 0;
int end = scan.getRowType().getFieldCount();
RelNode node, parent;
boolean projected = false;
RelNode node;
RelNode parent;
for (node = scan, parent = ApproxAggregateUtil.getParent(aggregate, scan);
node != aggregate;
node = parent, parent = ApproxAggregateUtil.getParent(aggregate, node)) {
if (parent instanceof Join) {
RelNode left = ((Join) parent).getLeft();
Join parentJoin = (Join) parent;
List<TableScan> joinScans = joinScanMap.get(parentJoin);
if (joinScans == null) {
joinScans = new ArrayList<>();
joinScans.add(scan);
joinScanMap.put(parentJoin, joinScans);
} else {
joinScans.add(scan);
}
if (!filterList.isEmpty()) {
List<Filter> joinFilters = joinFilterMap.get(parentJoin);
if (joinFilters == null) {
joinFilterMap.put(parentJoin, filterList);
} else {
joinFilters.addAll(filterList);
}
}
if (projected) {
continue;
}

RelNode left = parentJoin.getLeft();
if (left instanceof RelSubset) {
left = ((RelSubset) left).getBestOrOriginal();
}
RelNode right = ((Join) parent).getRight();
RelNode right = parentJoin.getRight();
if (right instanceof RelSubset) {
right = ((RelSubset) right).getBestOrOriginal();
}
if (node == right) {
start = left.getRowType().getFieldCount();
end = left.getRowType().getFieldCount() + right.getRowType().getFieldCount();
}
} else if (parent instanceof Filter) {
filterList.add((Filter) parent);
}

if (projected) {
continue;
}

requiredColumnIndex.addAll(getRequiredColumnIndex(parent, start, end));
if (parent instanceof Project) {
break;
projected = true;
}
}

List<String> inputColumns = scan.getRowType().getFieldNames();
List<String> requiredColumnNames =
ApproxAggregateUtil.getSublistByIndex(inputColumns, new ArrayList(requiredColumnIndex));
requiredScanColumnMap.put(scan, requiredColumnNames);
}

// try join synopses first
for (Map.Entry<Join, List<TableScan>> entry : joinScanMap.entrySet()) {
Join join = entry.getKey();
List<TableScan> joinScans = entry.getValue();
String condStr = getConditionString(join, joinScans);
Collection<MSynopsis> candidateJoinSynopses =
planner.getAvailableJoinSynopses(joinScans, requiredScanColumnMap, condStr);
if (candidateJoinSynopses == null || candidateJoinSynopses.isEmpty()) {
continue;
}
List<String> requiredJoinColumnNames = new ArrayList<>();
for (TableScan joinScan : joinScans) {
requiredJoinColumnNames.addAll(requiredScanColumnMap.get(joinScan));
}

TableScan baseScan = joinScans.get(0);
MSynopsis bestSynopsis = planner.getBestSynopsis(
candidateJoinSynopses, baseScan, aggregate.getHints(), requiredJoinColumnNames);

RelOptTableImpl synopsisTable =
(RelOptTableImpl) planner.getSynopsisTable(bestSynopsis, baseScan.getTable());
if (synopsisTable == null) {
return;
}
TableScan newScan = planner.createSynopsisTableScan(bestSynopsis, synopsisTable, baseScan);
relBuilder.push(newScan);

List<String> synopsisColumns = bestSynopsis.getColumnNames();
List<String> inputColumns = new ArrayList<>();
inputColumns.addAll(join.getLeft().getRowType().getFieldNames());
inputColumns.addAll(join.getRight().getRowType().getFieldNames());
final Mappings.TargetMapping mapping = createMapping(inputColumns, synopsisColumns);
List<Filter> joinFilters = joinFilterMap.get(join);
if (joinFilters != null) {
for (Filter f : joinFilters) {
relBuilder.filter(f.getCondition());
}
}
relBuilder.project(relBuilder.fields(mapping), join.getRowType().getFieldNames(), true);

RelNode child;
RelNode node;
for (child = join, node = ApproxAggregateUtil.getParent(aggregate, join);
node != aggregate; child = node, node = ApproxAggregateUtil.getParent(aggregate, node)) {
if (node instanceof Filter) {
Filter filter = (Filter) node;
relBuilder.filter(filter.getCondition());
} else if (node instanceof Join) {
Join oj = (Join) node;
RexNode newCondition;
newCondition = oj.getCondition();
RelNode left = oj.getLeft();
RelNode right = oj.getRight();
if (left instanceof RelSubset
&& ((RelSubset) left).getBestOrOriginal() == child) {
final Join newJoin =
oj.copy(oj.getTraitSet(), newCondition, relBuilder.peek(), oj.getRight(),
oj.getJoinType(), oj.isSemiJoinDone());
relBuilder.clear();
relBuilder.push(newJoin);
} else if (right instanceof RelSubset
&& ((RelSubset) right).getBestOrOriginal() == child) {
final Join newJoin =
oj.copy(oj.getTraitSet(), newCondition, oj.getLeft(), relBuilder.peek(),
oj.getJoinType(), oj.isSemiJoinDone());
relBuilder.clear();
relBuilder.push(newJoin);
} else {
return;
}
} else if (node instanceof Project) {
Project project = (Project) node;
relBuilder.project(project.getProjects(), project.getRowType().getFieldNames());
} else {
break; /* cannot apply this rule */
}
}

if (node != aggregate) {
continue;
}
relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet()),
aggregate.getAggCallList());

double scaleFactor = 1.0 / bestSynopsis.getRatio();
List<RexNode> aggProjects = ApproxAggregateUtil.makeAggregateProjects(aggregate, scaleFactor);
relBuilder.project(aggProjects, aggregate.getRowType().getFieldNames());

call.transformTo(relBuilder.build());
return;
}

// apply synopses on single tables
for (Map.Entry<TableScan, List<String>> entry : requiredScanColumnMap.entrySet()) {
TableScan scan = entry.getKey();
List<String> requiredColumnNames = entry.getValue();
List<String> inputColumns = scan.getRowType().getFieldNames();

List<String> qualifiedTableName = scan.getTable().getQualifiedName();
Collection<MSynopsis> candidateSynopses =
Expand All @@ -212,6 +385,7 @@ public void onMatch(RelOptRuleCall call) {

boolean projected = false;
RelNode child;
RelNode node;
for (child = scan, node = ApproxAggregateUtil.getParent(aggregate, scan);
node != aggregate; child = node, node = ApproxAggregateUtil.getParent(aggregate, node)) {
if (node instanceof Filter) {
Expand Down

0 comments on commit dee5be6

Please sign in to comment.