Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support conversion of STRING to BIGINT for window bounds #4500

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,16 @@ public Expression process(final Expression expression) {

private final RewritingVisitor<C> rewriter;

@SuppressWarnings("unchecked")
public static <C, T extends Expression> T rewriteWith(
final BiFunction<Expression, Context<C>, Optional<Expression>> plugin, final T expression) {
return rewriteWith(plugin, expression, null);
}

@SuppressWarnings("unchecked")
public static <C, T extends Expression> T rewriteWith(
final BiFunction<Expression, Context<C>, Optional<Expression>> plugin,
final T expression,
final C context) {
return new ExpressionTreeRewriter<C>(plugin).rewrite(expression, context);
return new ExpressionTreeRewriter<>(plugin).rewrite(expression, context);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -182,7 +180,7 @@ public Expression visitSubscriptExpression(
final SubscriptExpression node,
final C context) {
final Optional<Expression> result
= plugin.apply(node, new Context<C>(context, this));
= plugin.apply(node, new Context<>(context, this));
if (result.isPresent()) {
return result.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.confluent.ksql.engine.rewrite;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
Expand All @@ -24,28 +25,33 @@
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.util.SchemaUtil;
import io.confluent.ksql.util.timestamp.PartialStringToTimestampParser;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class StatementRewriteForMagicPseudoTimestamp {
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved

private static final Set<ColumnName> SUPPORTED_COLUMNS = ImmutableSet.<ColumnName>builder()
.addAll(SchemaUtil.windowBoundsColumnNames())
.add(SchemaUtil.ROWTIME_NAME)
.build();

public class StatementRewriteForRowtime {

private final PartialStringToTimestampParser parser;

public StatementRewriteForRowtime() {
public StatementRewriteForMagicPseudoTimestamp() {
this(new PartialStringToTimestampParser());
}

@VisibleForTesting
StatementRewriteForRowtime(final PartialStringToTimestampParser parser) {
StatementRewriteForMagicPseudoTimestamp(final PartialStringToTimestampParser parser) {
this.parser = Objects.requireNonNull(parser, "parser");
}

public Expression rewriteForRowtime(final Expression expression) {
if (noRewriteRequired(expression)) {
return expression;
}
public Expression rewrite(final Expression expression) {
return new ExpressionTreeRewriter<>(new OperatorPlugin()::process)
.rewrite(expression, null);
}
Expand All @@ -66,16 +72,22 @@ public Optional<Expression> visitBetweenPredicate(
final BetweenPredicate node,
final Context<Void> context
) {
if (noRewriteRequired(node.getValue())) {
if (!supportedColumnRef(node.getValue())) {
return Optional.empty();
}

final Optional<Expression> min = maybeRewriteTimestamp(node.getMin());
final Optional<Expression> max = maybeRewriteTimestamp(node.getMax());
if (!min.isPresent() && !max.isPresent()) {
return Optional.empty();
}

return Optional.of(
new BetweenPredicate(
node.getLocation(),
node.getValue(),
rewriteTimestamp(((StringLiteral) node.getMin()).getValue()),
rewriteTimestamp(((StringLiteral) node.getMax()).getValue())
min.orElse(node.getMin()),
max.orElse(node.getMax())
)
);
}
Expand All @@ -85,42 +97,45 @@ public Optional<Expression> visitComparisonExpression(
final ComparisonExpression node,
final Context<Void> context
) {
if (expressionIsRowtime(node.getLeft()) && node.getRight() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
node.getLeft(),
rewriteTimestamp(((StringLiteral) node.getRight()).getValue())
)
);
if (supportedColumnRef(node.getLeft())) {
final Optional<Expression> right = maybeRewriteTimestamp(node.getRight());
return right.map(r -> new ComparisonExpression(
node.getLocation(),
node.getType(),
node.getLeft(),
r
));
}

if (expressionIsRowtime(node.getRight()) && node.getLeft() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
rewriteTimestamp(((StringLiteral) node.getLeft()).getValue()),
node.getRight()
)
);
if (supportedColumnRef(node.getRight())) {
final Optional<Expression> left = maybeRewriteTimestamp(node.getLeft());
return left.map(l -> new ComparisonExpression(
node.getLocation(),
node.getType(),
l,
node.getRight()
));
}

return Optional.empty();
}
}

private static boolean refIsRowtime(final ColumnReferenceExp node) {
return node.getReference().equals(SchemaUtil.ROWTIME_NAME);
}
private Optional<Expression> maybeRewriteTimestamp(final Expression maybeTimestamp) {
if (!(maybeTimestamp instanceof StringLiteral)) {
return Optional.empty();
}

private static boolean expressionIsRowtime(final Expression node) {
return (node instanceof ColumnReferenceExp)
&& refIsRowtime((ColumnReferenceExp) node);
final String text = ((StringLiteral) maybeTimestamp).getValue();

return Optional.of(new LongLiteral(parser.parse(text)));
}

private LongLiteral rewriteTimestamp(final String timestamp) {
return new LongLiteral(parser.parse(timestamp));
private static boolean supportedColumnRef(final Expression maybeColumnRef) {
if (!(maybeColumnRef instanceof ColumnReferenceExp)) {
return false;
}

return SUPPORTED_COLUMNS.contains(((ColumnReferenceExp) maybeColumnRef).getReference());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import static java.util.Objects.requireNonNull;

import io.confluent.ksql.engine.rewrite.StatementRewriteForRowtime;
import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
Expand Down Expand Up @@ -136,8 +136,7 @@ public SchemaKStream<K> filter(
}

static Expression rewriteTimeComparisonForFilter(final Expression expression) {
return new StatementRewriteForRowtime()
.rewriteForRowtime(expression);
return new StatementRewriteForMagicPseudoTimestamp().rewrite(expression);
}

public SchemaKStream<K> select(
Expand Down Expand Up @@ -325,7 +324,7 @@ public SchemaKStream<Struct> selectKey(
final Expression keyExpression,
final QueryContext.Stacker contextStacker
) {
if (!needsRepartition(keyExpression)) {
if (repartitionNotNeeded(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

Expand Down Expand Up @@ -360,9 +359,9 @@ private KeyField getNewKeyField(final Expression expression) {
return getSchema().isMetaColumn(columnName) ? KeyField.none() : newKeyField;
}

protected boolean needsRepartition(final Expression expression) {
protected boolean repartitionNotNeeded(final Expression expression) {
if (!(expression instanceof UnqualifiedColumnReferenceExp)) {
return true;
return false;
}

final ColumnName columnName = ((UnqualifiedColumnReferenceExp) expression).getReference();
Expand All @@ -379,7 +378,7 @@ protected boolean needsRepartition(final Expression expression) {
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

return !namesMatch && !isRowKey(columnName);
return namesMatch || isRowKey(columnName);
}

private boolean isRowKey(final ColumnName fieldName) {
Expand Down Expand Up @@ -453,7 +452,7 @@ public SchemaKGroupedStream groupBy(
);
}

@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "rawtypes"})
private SchemaKGroupedStream groupByKey(
final KeyFormat rekeyedKeyFormat,
final ValueFormat valueFormat,
Expand Down Expand Up @@ -534,7 +533,7 @@ LogicalSchema resolveSchema(final ExecutionStep<?> step) {
return new StepSchemaResolver(ksqlConfig, functionRegistry).resolve(step, schema);
}

LogicalSchema resolveSchema(final ExecutionStep<?> step, final SchemaKStream right) {
LogicalSchema resolveSchema(final ExecutionStep<?> step, final SchemaKStream<?> right) {
return new StepSchemaResolver(ksqlConfig, functionRegistry).resolve(
step,
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public SchemaKTable<K> select(
@Override
public SchemaKStream<Struct> selectKey(final Expression keyExpression,
final Stacker contextStacker) {
if (!needsRepartition(keyExpression)) {
if (repartitionNotNeeded(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

Expand Down
Loading