Skip to content

Commit

Permalink
refactor: change rowtime rewrite conditions and add error handling (#…
Browse files Browse the repository at this point in the history
…3209)

* refactor: rowtime rewrite replaces dates to long instead of wrapping

* refactor: downscale conditions for rewrite

* chore: fix build issues

* fix: dont use precomputed timestamps in tests
  • Loading branch information
Zara Lim authored Aug 16, 2019
1 parent a79adb4 commit f41c246
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.parser.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.timestamp.StringToTimestampParser;

import java.util.ArrayList;
import java.util.List;
import java.time.ZoneId;
import java.util.Objects;
import java.util.Optional;

Expand All @@ -46,118 +46,23 @@ public Expression rewriteForRowtime() {
new OperatorPlugin()::process).rewrite(expression, null);
}

private static final class TimestampPlugin
extends VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {
private static final String DATE_PATTERN = "yyyy-MM-dd";
private static final String TIME_PATTERN = "HH:mm:ss.SSS";
private static final String PATTERN = DATE_PATTERN + "'T'" + TIME_PATTERN;

private TimestampPlugin() {
super(Optional.empty());
}

@Override
public Optional<Expression> visitFunctionCall(
final FunctionCall node,
final Context<Void> context) {
return Optional.of(node);
}

@Override
public Optional<Expression> visitStringLiteral(
final StringLiteral node,
final Context<Void> context) {
if (!node.getValue().equals("ROWTIME")) {
return Optional.of(
new FunctionCall(
QualifiedName.of("STRINGTOTIMESTAMP"),
getFunctionArgs(node.getValue())
)
);
}
return Optional.of(node);
}

private List<Expression> getFunctionArgs(final String datestring) {
final List<Expression> args = new ArrayList<>();
final String date;
final String time;
final String timezone;
if (datestring.contains("T")) {
date = datestring.substring(0, datestring.indexOf('T'));
final String withTimezone = completeTime(datestring.substring(datestring.indexOf('T') + 1));
timezone = getTimezone(withTimezone);
time = completeTime(withTimezone.substring(0, timezone.length()));
} else {
date = completeDate(datestring);
time = completeTime("");
timezone = "";
}

if (timezone.length() > 0) {
args.add(new StringLiteral(date + "T" + time));
args.add(new StringLiteral(PATTERN));
args.add(new StringLiteral(timezone));
} else {
args.add(new StringLiteral(date + "T" + time));
args.add(new StringLiteral(PATTERN));
}
return args;
}

private String getTimezone(final String time) {
if (time.contains("+")) {
return time.substring(time.indexOf('+'));
} else if (time.contains("-")) {
return time.substring(time.indexOf('-'));
} else {
return "";
}
}

private String completeDate(final String date) {
final String[] parts = date.split("-");
if (parts.length == 1) {
return date + "-01-01";
} else if (parts.length == 2) {
return date + "-01";
} else {
// It is either a complete date or an incorrectly formatted one.
// In the latter case, we can pass the incorrectly formed string
// to STRINGTITIMESTAMP which will deal with the error handling.
return date;
}
}

private String completeTime(final String time) {
if (time.length() >= TIME_PATTERN.length()) {
return time;
}
return time + TIME_PATTERN.substring(time.length()).replaceAll("[a-zA-Z]", "0");
}
}

private static final class OperatorPlugin
extends VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {
private OperatorPlugin() {
super(Optional.empty());
}

private Expression rewriteTimestamp(final Expression expression) {
return new ExpressionTreeRewriter<>(new TimestampPlugin()::process).rewrite(expression, null);
}

@Override
public Optional<Expression> visitBetweenPredicate(
final BetweenPredicate node,
final Context<Void> context) {
if (StatementRewriteForRowtime.requiresRewrite(node)) {
if (requiresRewrite(node.getValue())) {
return Optional.of(
new BetweenPredicate(
node.getLocation(),
rewriteTimestamp(node.getValue()),
rewriteTimestamp(node.getMin()),
rewriteTimestamp(node.getMax())
node.getValue(),
rewriteTimestamp(((StringLiteral) node.getMin()).getValue()),
rewriteTimestamp(((StringLiteral) node.getMax()).getValue())
)
);
}
Expand All @@ -168,13 +73,22 @@ public Optional<Expression> visitBetweenPredicate(
public Optional<Expression> visitComparisonExpression(
final ComparisonExpression node,
final Context<Void> context) {
if (expressionIsRowtime(node.getLeft()) || expressionIsRowtime(node.getRight())) {
if (expressionIsRowtime(node.getLeft()) && node.getRight() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
rewriteTimestamp(node.getLeft()),
rewriteTimestamp(node.getRight())
node.getLeft(),
rewriteTimestamp(((StringLiteral) node.getRight()).getValue())
)
);
} else if (expressionIsRowtime(node.getRight()) && node.getLeft() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
rewriteTimestamp(((StringLiteral) node.getLeft()).getValue()),
node.getRight()
)
);
}
Expand All @@ -186,4 +100,69 @@ private static boolean expressionIsRowtime(final Expression node) {
return (node instanceof DereferenceExpression)
&& ((DereferenceExpression) node).getFieldName().equals("ROWTIME");
}

private static LongLiteral rewriteTimestamp(final String timestamp) {
final String timePattern = "HH:mm:ss.SSS";
final StringToTimestampParser parser = new StringToTimestampParser(
"yyyy-MM-dd'T'" + timePattern);

final String date;
final String time;
final String timezone;

if (timestamp.contains("T")) {
date = timestamp.substring(0, timestamp.indexOf('T'));
final String withTimezone = completeTime(
timestamp.substring(timestamp.indexOf('T') + 1),
timePattern);
timezone = getTimezone(withTimezone);
time = completeTime(withTimezone.substring(0, timezone.length()), timePattern);
} else {
date = completeDate(timestamp);
time = completeTime("", timePattern);
timezone = "";
}

try {
if (timezone.length() > 0) {
return new LongLiteral(parser.parse(date + "T" + time, ZoneId.of(timezone)));
} else {
return new LongLiteral(parser.parse(date + "T" + time));
}
} catch (final RuntimeException e) {
throw new KsqlException("Failed to parse timestamp '"
+ timestamp + "': " + e.getMessage(), e);
}
}

private static String getTimezone(final String time) {
if (time.contains("+")) {
return time.substring(time.indexOf('+'));
} else if (time.contains("-")) {
return time.substring(time.indexOf('-'));
} else {
return "";
}
}

private static String completeDate(final String date) {
final String[] parts = date.split("-");
if (parts.length == 1) {
return date + "-01-01";
} else if (parts.length == 2) {
return date + "-01";
} else {
// It is either a complete date or an incorrectly formatted one.
// In the latter case, we can pass the incorrectly formed string
// to the timestamp parser which will deal with the error handling.
return date;
}
}

private static String completeTime(final String time, final String timePattern) {
if (time.length() >= timePattern.length()) {
return time;
}
return time + timePattern.substring(time.length()).replaceAll("[a-zA-Z]", "0");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,40 @@
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.KsqlParserTestUtil;
import io.confluent.ksql.parser.tree.*;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.MetaStoreFixture;
import io.confluent.ksql.util.timestamp.StringToTimestampParser;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.time.ZoneId;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;

public class StatementRewriteForRowtimeTest {
@Rule
public final ExpectedException expectedException = ExpectedException.none();
private MetaStore metaStore;
final StringToTimestampParser parser = new StringToTimestampParser("yyyy-MM-dd'T'HH:mm:ss.SSS");

@Before
public void init() {
metaStore = MetaStoreFixture.getNewMetaStore(mock(FunctionRegistry.class));
}

@Test
public void shouldWrapDatestring() {
public void shouldReplaceDatestring() {
final String query = "SELECT * FROM orders where ROWTIME > '2017-01-01T00:00:00.000';";
final Query statement = (Query) KsqlParserTestUtil.buildSingleAst(query, metaStore).getStatement();
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), equalTo("(ORDERS.ROWTIME > STRINGTOTIMESTAMP('2017-01-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS'))"));
assertThat(rewritten.toString(), equalTo(String.format("(ORDERS.ROWTIME > %d)", parser.parse("2017-01-01T00:00:00.000"))));
}

@Test
Expand All @@ -54,7 +63,7 @@ public void shouldHandleInexactTimestamp() {
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), equalTo("(ORDERS.ROWTIME = STRINGTOTIMESTAMP('2017-01-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS'))"));
assertThat(rewritten.toString(), equalTo(String.format("(ORDERS.ROWTIME = %d)", parser.parse("2017-01-01T00:00:00.000"))));
}

@Test
Expand All @@ -64,9 +73,10 @@ public void shouldHandleBetweenExpression() {
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), equalTo("(ORDERS.ROWTIME BETWEEN"
+ " STRINGTOTIMESTAMP('2017-01-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS') AND"
+ " STRINGTOTIMESTAMP('2017-02-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS'))"));
assertThat(rewritten.toString(), equalTo(String.format(
"(ORDERS.ROWTIME BETWEEN %d AND %d)",
parser.parse("2017-01-01T00:00:00.000"),
parser.parse("2017-02-01T00:00:00.000"))));
}

@Test
Expand All @@ -86,7 +96,7 @@ public void shouldIgnoreNonRowtimeStrings() {
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), equalTo("((ORDERS.ROWTIME > STRINGTOTIMESTAMP('2017-01-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS')) AND (ORDERS.ROWKEY = '2017-01-01'))"));
assertThat(rewritten.toString(), equalTo(String.format("((ORDERS.ROWTIME > %d) AND (ORDERS.ROWKEY = '2017-01-01'))", parser.parse("2017-01-01T00:00:00.000"))));
}

@Test
Expand All @@ -96,6 +106,55 @@ public void shouldHandleTimezones() {
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), containsString("(ORDERS.ROWTIME = STRINGTOTIMESTAMP('2017-01-01T00:00:00.000', 'yyyy-MM-dd''T''HH:mm:ss.SSS', '+0100'))"));
assertThat(rewritten.toString(), equalTo(String.format("(ORDERS.ROWTIME = %d)", parser.parse("2017-01-01T00:00:00.000", ZoneId.of("+0100")))));
}

@Test
public void shouldNotProcessWhenRowtimeInFunction() {
final String simpleQuery = "SELECT * FROM orders where foo(ROWTIME) = '2017-01-01';";
final Query statement = (Query) KsqlParserTestUtil.buildSingleAst(simpleQuery, metaStore).getStatement();
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), containsString("(FOO(ORDERS.ROWTIME) = '2017-01-01')"));
}

@Test
public void shouldNotProcessArithmetic() {
final String simpleQuery = "SELECT * FROM orders where '2017-01-01' + 10000 > ROWTIME;";
final Query statement = (Query) KsqlParserTestUtil.buildSingleAst(simpleQuery, metaStore).getStatement();
final Expression predicate = statement.getWhere().get();
final Expression rewritten = new StatementRewriteForRowtime(predicate).rewriteForRowtime();

assertThat(rewritten.toString(), containsString("(('2017-01-01' + 10000) > ORDERS.ROWTIME)"));
}

@Test
public void shouldThrowParseError() {
// Given:
final String simpleQuery = "SELECT * FROM orders where ROWTIME = '2oo17-01-01';";
final Query statement = (Query) KsqlParserTestUtil.buildSingleAst(simpleQuery, metaStore).getStatement();
final Expression predicate = statement.getWhere().get();

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Failed to parse timestamp '2oo17-01-01'");

// When:
new StatementRewriteForRowtime(predicate).rewriteForRowtime();
}

@Test
public void shouldThrowTimezoneParseError() {
final String simpleQuery = "SELECT * FROM orders where ROWTIME = '2017-01-01T00:00:00.000+foo';";
final Query statement = (Query) KsqlParserTestUtil.buildSingleAst(simpleQuery, metaStore).getStatement();
final Expression predicate = statement.getWhere().get();

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Failed to parse timestamp '2017-01-01T00:00:00.000+foo'");

// When:
new StatementRewriteForRowtime(predicate).rewriteForRowtime();
}
}

0 comments on commit f41c246

Please sign in to comment.