Skip to content

Commit

Permalink
feat: Allows lots of table scans cases where keys cannot easily be ex…
Browse files Browse the repository at this point in the history
…tracted (#7155)

* feat: Allows lots of table scans cases where keys cannot easily be extracted
  • Loading branch information
AlanConfluent authored Mar 9, 2021
1 parent 782cf2b commit 71becea
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@

import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.CompiledExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.execution.interpreter.InterpretedExpressionFactory;
import io.confluent.ksql.execution.transform.ExpressionEvaluator;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogger;
import io.confluent.ksql.logging.processing.RecordProcessingError;
Expand Down Expand Up @@ -63,19 +64,22 @@ public class GenericExpressionResolver {
private final FunctionRegistry functionRegistry;
private final KsqlConfig config;
private final String operation;
private final boolean shouldUseInterpreter;

public GenericExpressionResolver(
final SqlType fieldType,
final ColumnName fieldName,
final FunctionRegistry functionRegistry,
final KsqlConfig config,
final String operation
final String operation,
final boolean shouldUseInterpreter
) {
this.fieldType = Objects.requireNonNull(fieldType, "fieldType");
this.fieldName = Objects.requireNonNull(fieldName, "fieldName");
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.config = Objects.requireNonNull(config, "config");
this.operation = Objects.requireNonNull(operation, "operation");
this.shouldUseInterpreter = shouldUseInterpreter;
}

public Object resolve(final Expression expression) {
Expand All @@ -87,17 +91,17 @@ private class Visitor extends VisitParentExpressionVisitor<Object, Void> {
@Override
protected Object visitExpression(final Expression expression, final Void context) {
new EnsureNoColReferences(expression).process(expression, context);
final CompiledExpression metadata =
CodeGenRunner.compileExpression(
final ExpressionEvaluator evaluator = shouldUseInterpreter
? InterpretedExpressionFactory.create(expression, NO_COLUMNS, functionRegistry, config)
: CodeGenRunner.compileExpression(
expression,
operation,
NO_COLUMNS,
config,
functionRegistry
);
functionRegistry);

// we expect no column references, so we can pass in an empty generic row
final Object value = metadata.evaluate(new GenericRow(), null, THROWING_LOGGER, IGNORED_MSG);
final Object value = evaluator.evaluate(new GenericRow(), null, THROWING_LOGGER, IGNORED_MSG);

return sqlValueCoercer.coerce(value, fieldType)
.orElseThrow(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ private static Map<ColumnName, Object> resolveValues(
column,
functionRegistry,
config,
"insert value").resolve(valueExp);
"insert value",
false).resolve(valueExp);

values.put(column, value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.planner.plan;

import com.google.common.base.Preconditions;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -44,6 +45,7 @@
import io.confluent.ksql.schema.ksql.DefaultSqlValueCoercer;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SystemColumns;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.schema.utils.FormatOptions;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.util.KsqlConfig;
Expand Down Expand Up @@ -114,7 +116,7 @@ public PullFilterNode(
this.isWindowed = isWindowed;

// Basic validation of WHERE clause
this.requiresTableScan = validateWhereClause();
this.requiresTableScan = validateWhereClauseAndCheckTableScan();

// Extraction of key and system columns
extractKeysAndSystemCols();
Expand Down Expand Up @@ -165,15 +167,16 @@ public LogicalSchema getIntermediateSchema() {
return intermediateSchema;
}

private boolean validateWhereClause() {
boolean requiresTableScan = false;
private boolean validateWhereClauseAndCheckTableScan() {
for (Expression disjunct : disjuncts) {
final Validator validator = new Validator();
validator.process(disjunct, null);
requiresTableScan = requiresTableScan || validator.requiresTableScan;
if (validator.requiresTableScan) {
return true;
}
if (!validator.isKeyedQuery) {
if (pullPlannerOptions.getTableScansEnabled()) {
requiresTableScan = true;
return true;
} else {
throw invalidWhereClauseException("WHERE clause missing key column for disjunct: "
+ disjunct.toString(), isWindowed);
Expand All @@ -182,22 +185,22 @@ private boolean validateWhereClause() {

if (!validator.seenKeys.isEmpty()
&& validator.seenKeys.cardinality() != schema.key().size()) {
final List<ColumnName> seenKeyNames = validator.seenKeys
.stream()
.boxed()
.map(i -> schema.key().get(i))
.map(Column::name)
.collect(Collectors.toList());
if (pullPlannerOptions.getTableScansEnabled()) {
requiresTableScan = true;
return true;
} else {
final List<ColumnName> seenKeyNames = validator.seenKeys
.stream()
.boxed()
.map(i -> schema.key().get(i))
.map(Column::name)
.collect(Collectors.toList());
throw invalidWhereClauseException(
"Multi-column sources must specify every key in the WHERE clause. Specified: "
+ seenKeyNames + " Expected: " + schema.key(), isWindowed);
}
}
}
return requiresTableScan;
return false;
}

private void extractKeysAndSystemCols() {
Expand Down Expand Up @@ -295,7 +298,25 @@ public Void visitComparisonExpression(
final ComparisonExpression node,
final Object context
) {
final UnqualifiedColumnReferenceExp column = getColumnRefSide(node);
// First see if we can find a direct column reference
final UnqualifiedColumnReferenceExp column = getColumnRefSideOrNull(node);
if (column != null) {
final Expression other = getNonColumnRefSide(node);
final HasColumnRef hasColumnRef = new HasColumnRef();
hasColumnRef.process(other, null);

if (hasColumnRef.hasColumnRef()) {
setTableScanOrElseThrow(() ->
invalidWhereClauseException("A comparison must be between a key column and a "
+ "resolvable expression", isWindowed));
return null;
}
} else {
setTableScanOrElseThrow(() ->
invalidWhereClauseException("A comparison must directly reference a key column",
isWindowed));
return null;
}

final ColumnName columnName = column.getColumnName();
if (columnName.equals(SystemColumns.WINDOWSTART_NAME)
Expand Down Expand Up @@ -345,10 +366,33 @@ private void setTableScanOrElseThrow(final Supplier<KsqlException> exceptionSupp
}
}

private UnqualifiedColumnReferenceExp getColumnRefSide(final ComparisonExpression comp) {
private static final class HasColumnRef extends TraversalExpressionVisitor<Object> {

private boolean hasColumnRef;

HasColumnRef() {
hasColumnRef = false;
}

@Override
public Void visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp node,
final Object context
) {
hasColumnRef = true;
return null;
}

public boolean hasColumnRef() {
return hasColumnRef;
}
}

private UnqualifiedColumnReferenceExp getColumnRefSideOrNull(final ComparisonExpression comp) {
return (UnqualifiedColumnReferenceExp)
(comp.getRight() instanceof UnqualifiedColumnReferenceExp
? comp.getRight() : comp.getLeft());
? comp.getRight()
: (comp.getLeft() instanceof UnqualifiedColumnReferenceExp ? comp.getLeft() : null));
}

private Expression getNonColumnRefSide(final ComparisonExpression comparison) {
Expand Down Expand Up @@ -391,8 +435,9 @@ private final class KeyValueExtractor extends TraversalExpressionVisitor<Object>
@Override
public Void visitComparisonExpression(
final ComparisonExpression node, final Object context) {
final UnqualifiedColumnReferenceExp column = getColumnRefSide(node);
final UnqualifiedColumnReferenceExp column = getColumnRefSideOrNull(node);
final Expression other = getNonColumnRefSide(node);
Preconditions.checkNotNull(column, "UnqualifiedColumnReferenceExp should be found");
final ColumnName columnName = column.getColumnName();

final Optional<Column> col = schema.findColumn(columnName);
Expand Down Expand Up @@ -424,7 +469,8 @@ private Object resolveKey(
keyColumn.name(),
metaStore,
config,
"pull query"
"pull query",
ksqlConfig.getBoolean(KsqlConfig.KSQL_QUERY_PULL_INTERPRETER_ENABLED)
).resolve(exp);
}

Expand All @@ -449,7 +495,7 @@ private Object resolveKey(
* 1. An equality bound cannot be combined with other bounds.
* 2. No duplicate bounds are allowed, such as multiple greater than bounds.
*/
private static final class WindowBoundsExtractor
private final class WindowBoundsExtractor
extends TraversalExpressionVisitor<WindowBounds> {

@Override
Expand All @@ -472,17 +518,18 @@ public Void visitComparisonExpression(
}
boolean result = false;
if (node.getType().equals(Type.EQUAL)) {
final Range<Instant> instant = Range.singleton(asInstant(getNonColumnRefSide(node)));
final Range<Instant> instant = Range.singleton(asInstant(getNonColumnRefSide(node),
column.getColumnName()));
result = windowBounds.setEquality(column, instant);
}
final Type type = getSimplifiedBoundType(node);

if (type.equals(Type.LESS_THAN)) {
final Instant upper = asInstant(getNonColumnRefSide(node));
final Instant upper = asInstant(getNonColumnRefSide(node), column.getColumnName());
final BoundType upperType = getRangeBoundType(node);
result = windowBounds.setUpper(column, Range.upTo(upper, upperType));
} else if (type.equals(Type.GREATER_THAN)) {
final Instant lower = asInstant(getNonColumnRefSide(node));
final Instant lower = asInstant(getNonColumnRefSide(node), column.getColumnName());
final BoundType lowerType = getRangeBoundType(node);
result = windowBounds.setLower(column, Range.downTo(lower, lowerType));
}
Expand Down Expand Up @@ -543,7 +590,7 @@ private Expression getNonColumnRefSide(final ComparisonExpression comparison) {
: comparison.getRight();
}

private Instant asInstant(final Expression other) {
private Instant asInstant(final Expression other, final ColumnName name) {
if (other instanceof IntegerLiteral) {
return Instant.ofEpochMilli(((IntegerLiteral) other).getValue());
}
Expand All @@ -564,10 +611,23 @@ private Instant asInstant(final Expression other) {
}
}

throw invalidWhereClauseException(
"Window bounds must be an INT, BIGINT or STRING containing a datetime.",
true
);
try {
final Long value = (Long) new GenericExpressionResolver(
SqlTypes.BIGINT,
name,
metaStore,
ksqlConfig,
"pull query window bounds extractor",
ksqlConfig.getBoolean(KsqlConfig.KSQL_QUERY_PULL_INTERPRETER_ENABLED)
).resolve(other);

return Instant.ofEpochMilli(value);
} catch (final KsqlException e) {
throw invalidWhereClauseException(
"Window bounds must resolve to an INT, BIGINT, or STRING containing a datetime.",
true
);
}
}

private BoundType getRangeBoundType(final ComparisonExpression lowerComparison) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public void shouldResolveArbitraryExpressions() {
));

// When:
final Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value").resolve(exp);
final Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config,
"insert value", false).resolve(exp);

// Then:
assertThat(o, is(new Struct(
Expand All @@ -71,7 +72,8 @@ public void shouldResolveNullLiteral() {
final Expression exp = new NullLiteral();

// When:
final Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value").resolve(exp);
final Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config,
"insert value", false).resolve(exp);

// Then:
assertThat(o, Matchers.nullValue());
Expand All @@ -86,7 +88,8 @@ public void shouldThrowIfCannotCoerce() {
// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value").resolve(exp));
() -> new GenericExpressionResolver(type, FIELD_NAME, registry, config,
"insert value", false).resolve(exp));

// Then:
assertThat(e.getMessage(), containsString("Expected type ARRAY<INTEGER> for field `FOO` but got INTEGER(1)"));
Expand All @@ -101,7 +104,8 @@ public void shouldThrowIfCannotParseTimestamp() {
// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value").resolve(exp));
() -> new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value",
false).resolve(exp));

// Then:
assertThat(e.getMessage(), containsString("Timestamp format must be yyyy-mm-ddThh:mm:ss[.S]"));
Expand All @@ -114,7 +118,8 @@ public void shouldParseTimestamp() {
final Expression exp = new StringLiteral("2021-01-09T04:40:02");

// When:
Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value").resolve(exp);
Object o = new GenericExpressionResolver(type, FIELD_NAME, registry, config, "insert value",
false).resolve(exp);

// Then:
assertTrue(o instanceof Timestamp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,8 @@ public void shouldThrowOnInvalidTimestampType() {
));

// Then:
assertThat(e.getMessage(), containsString("Window bounds must be an INT, BIGINT or STRING containing a datetime."));
assertThat(e.getMessage(), containsString("Window bounds must resolve to an INT, BIGINT, or "
+ "STRING containing a datetime."));
}


Expand Down
Loading

0 comments on commit 71becea

Please sign in to comment.