Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring to use schema type for comparison. #614

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ protected Pair<String, Schema> visitComparisonExpression(ComparisonExpression no
Boolean unmangleNames) {
Pair<String, Schema> left = process(node.getLeft(), unmangleNames);
Pair<String, Schema> right = process(node.getRight(), unmangleNames);
if ((left.getRight() == Schema.STRING_SCHEMA) || (right.getRight() == Schema.STRING_SCHEMA)) {
if ((left.getRight().type() == Schema.Type.STRING) || (right.getRight().type() == Schema.Type.STRING)) {
if ("=".equals(node.getType().getValue())) {
return new Pair<>("(" + left.getLeft() + ".equals(" + right.getLeft() + "))",
Schema.BOOLEAN_SCHEMA);
Expand Down Expand Up @@ -261,19 +261,20 @@ protected Pair<String, Schema> visitCast(Cast node, Boolean context) {

case "INTEGER": {
Schema rightSchema = expr.getRight();
String exprStr = getCastToIntegerString(rightSchema, expr.getLeft());
String exprStr = getCastString(rightSchema, expr.getLeft(), "intValue", "Integer.parseInt");
return new Pair<>(exprStr, returnType);
}

case "BIGINT": {
Schema rightSchema = expr.getRight();
String exprStr = getCastToLongString(rightSchema, expr.getLeft());
String exprStr = getCastString(rightSchema, expr.getLeft(), "longValue", "Long"
+ ".parseLong");
return new Pair<>(exprStr, returnType);
}

case "DOUBLE": {
Schema rightSchema = expr.getRight();
String exprStr = getCastToDoubleString(rightSchema, expr.getLeft());
String exprStr = getCastString(rightSchema, expr.getLeft(), "doubleValue", "Double.parseDouble");
return new Pair<>(exprStr, returnType);
}
default:
Expand Down Expand Up @@ -403,58 +404,45 @@ private String joinExpressions(List<Expression> expressions, boolean unmangleNam
}

private String getCastToBooleanString(Schema schema, String exprStr) {
if (schema == Schema.BOOLEAN_SCHEMA) {
if (schema.type() == Schema.Type.BOOLEAN) {
return exprStr;
} else if (schema == Schema.STRING_SCHEMA) {
} else if (schema.type() == Schema.Type.STRING) {
return "Boolean.parseBoolean(" + exprStr + ")";
} else {
throw new KsqlFunctionException(
"Invalid cast operation: Cannot cast " + exprStr + " to boolean.");
}
}

private String getCastToIntegerString(Schema schema, String exprStr) {
if (schema == Schema.STRING_SCHEMA) {
return "Integer.parseInt(" + exprStr + ")";
} else if (schema == Schema.INT32_SCHEMA) {
return exprStr;
} else if (schema == Schema.INT64_SCHEMA) {
return "(new Long(" + exprStr + ").intValue())";
} else if (schema == Schema.FLOAT64_SCHEMA) {
return "(new Double(" + exprStr + ").intValue())";
} else {
throw new KsqlFunctionException(
"Invalid cast operation: Cannot cast " + exprStr + " to Integer.");
}
}
private String getCastString(Schema schema,
String exprStr,
String javaTypeMethod,
String javaStringParserMethod) {
switch (schema.type()) {
case INT32:
if (javaTypeMethod.equals("intValue")) {
return exprStr;
} else {
return "(new Integer(" + exprStr + ")." + javaTypeMethod + "())";
}
case INT64:
if (javaTypeMethod.equals("longValue")) {
return exprStr;
} else {
return "(new Long(" + exprStr + ")." + javaTypeMethod + "())";
}
case FLOAT64:
if (javaTypeMethod.equals("doubleValue")) {
return exprStr;
} else {
return "(new Double(" + exprStr + ")." + javaTypeMethod + "())";
}
case STRING:
return javaStringParserMethod + "(" + exprStr + ")";

private String getCastToLongString(Schema schema, String exprStr) {
if (schema == Schema.STRING_SCHEMA) {
return "Long.parseLong(" + exprStr + ")";
} else if (schema == Schema.INT32_SCHEMA) {
return "(new Integer(" + exprStr + ").longValue())";
} else if (schema == Schema.INT64_SCHEMA) {
return exprStr;
} else if (schema == Schema.FLOAT64_SCHEMA) {
return "(new Double(" + exprStr + ").longValue())";
} else {
throw new KsqlFunctionException("Invalid cast operation: Cannot cast "
+ exprStr + " to Long.");
}
}

private String getCastToDoubleString(Schema schema, String exprStr) {
if (schema == Schema.STRING_SCHEMA) {
return "Double.parseDouble(" + exprStr + ")";
} else if (schema == Schema.INT32_SCHEMA) {
return "(new Integer(" + exprStr + ").doubleValue())";
} else if (schema == Schema.INT64_SCHEMA) {
return "(new Long(" + exprStr + ").doubleValue())";
} else if (schema == Schema.FLOAT64_SCHEMA) {
return exprStr;
} else {
throw new KsqlFunctionException("Invalid cast operation: Cannot cast "
+ exprStr + " to Double.");
default:
throw new KsqlFunctionException("Invalid cast operation: Cannot cast "
+ exprStr + " to " + schema.type() + ".");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,19 @@ protected Expression visitFunctionCall(final FunctionCall node,

private Schema resolveArithmaticType(final Schema leftSchema,
final Schema rightSchema) {
if (leftSchema == rightSchema) {
Schema.Type leftType = leftSchema.type();
Schema.Type rightType = rightSchema.type();

if (leftType == rightType) {
return leftSchema;
} else if ((leftSchema == Schema.STRING_SCHEMA) || (rightSchema == Schema.STRING_SCHEMA)) {
throw new PlanException("Incompatible types.");
} else if ((leftSchema == Schema.BOOLEAN_SCHEMA) || (rightSchema == Schema.BOOLEAN_SCHEMA)) {
} else if (((leftType == Schema.Type.STRING) || (rightType == Schema.Type.STRING))
|| ((leftType == Schema.Type.BOOLEAN) || (rightType == Schema.Type.BOOLEAN))) {
throw new PlanException("Incompatible types.");
} else if ((leftSchema == Schema.FLOAT64_SCHEMA) || (rightSchema == Schema.FLOAT64_SCHEMA)) {
} else if ((leftType == Schema.Type.FLOAT64) || (rightType == Schema.Type.FLOAT64)) {
return Schema.FLOAT64_SCHEMA;
} else if ((leftSchema == Schema.INT64_SCHEMA) || (rightSchema == Schema.INT64_SCHEMA)) {
} else if ((leftType == Schema.Type.INT64) || (rightType == Schema.Type.INT64)) {
return Schema.INT64_SCHEMA;
} else if ((leftSchema == Schema.INT32_SCHEMA) || (rightSchema == Schema.INT32_SCHEMA)) {
} else if ((leftType == Schema.Type.INT32) || (rightType == Schema.Type.INT32)) {
return Schema.INT32_SCHEMA;
}
throw new PlanException("Unsupported types.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,24 @@ public Object enforceFieldType(final int index, final Object value) {
}

private Object enforceFieldType(Schema schema, final Object value) {
if (schema == Schema.FLOAT64_SCHEMA) {
return enforceDouble(value);
} else if (schema == Schema.INT64_SCHEMA) {
return enforceLong(value);
} else if (schema == Schema.INT32_SCHEMA) {
return enforceInteger(value);
} else if (schema == Schema.STRING_SCHEMA) {
return enforceString(value);
} else if (schema == Schema.BOOLEAN_SCHEMA) {
return enforceBoolean(value);
} else if (schema.type() == Schema.Type.ARRAY) {
return value;
} else if (schema.type() == Schema.Type.MAP) {
return value;
} else {
throw new KsqlException("Type is not supported: " + schema);

switch (schema.type()) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could also be push down do an enum on Schema.INT32.enforceInteger() etc - then you dont need this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no, we are using Schema class from connect so we cannot change it!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hjafarpour - I still dont see more OO modelling in the latest commit...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I forgot to submit my comments yesterday. @bluemonk3y as I mentioned above we use Schema class from Connect package so we don't have control on the class implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can create our own class to do this - i.e., one that wraps the Schema

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - the code with have in here needs a serious refactor

case INT32:
return enforceInteger(value);
case INT64:
return enforceLong(value);
case FLOAT64:
return enforceDouble(value);
case STRING:
return enforceString(value);
case BOOLEAN:
return enforceBoolean(value);
case ARRAY:
return value;
case MAP:
return value;
default:
throw new KsqlException("Type is not supported: " + schema);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import java.util.List;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;

public class SqlToJavaVisitorTest {

private static final KsqlParser KSQL_PARSER = new KsqlParser();
Expand Down Expand Up @@ -56,8 +60,30 @@ public void processBasicJavaMath() throws Exception {
String javaExpression = new SqlToJavaVisitor(schema, functionRegistry)
.process(analysis.getSelectExpressions().get(0));

Assert.assertEquals("(TEST1_COL0 + TEST1_COL3)", javaExpression);
assertThat(javaExpression, equalTo("(TEST1_COL0 + TEST1_COL3)"));

}

@Test
public void shouldCreateCorrectCastJavaExpression() throws Exception {


String simpleQuery = "SELECT cast(col0 AS INTEGER), cast(col3 as BIGINT), cast(col3 as "
+ "varchar) FROM "
+ "test1 WHERE "
+ "col0 > 100;";
Analysis analysis = analyzeQuery(simpleQuery);

String javaExpression0 = new SqlToJavaVisitor(schema, functionRegistry)
.process(analysis.getSelectExpressions().get(0));
String javaExpression1 = new SqlToJavaVisitor(schema, functionRegistry)
.process(analysis.getSelectExpressions().get(1));
String javaExpression2 = new SqlToJavaVisitor(schema, functionRegistry)
.process(analysis.getSelectExpressions().get(2));

assertThat(javaExpression0, equalTo("(new Long(TEST1_COL0).intValue())"));
assertThat(javaExpression1, equalTo("(new Double(TEST1_COL3).longValue())"));
assertThat(javaExpression2, equalTo("String.valueOf(TEST1_COL3)"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.List;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.MatcherAssert.assertThat;

public class LogicalPlannerTest {
Expand Down Expand Up @@ -97,30 +98,29 @@ public void testSimpleQueryLogicalPlan() throws Exception {
PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

// Assert.assertTrue(logicalPlan instanceof OutputKafkaTopicNode);
Assert.assertTrue(logicalPlan.getSources().get(0) instanceof ProjectNode);
Assert.assertTrue(logicalPlan.getSources().get(0).getSources().get(0) instanceof FilterNode);
Assert.assertTrue(logicalPlan.getSources().get(0).getSources().get(0).getSources()
.get(0) instanceof StructuredDataSourceNode);

Assert.assertTrue(logicalPlan.getSchema().fields().size() == 3);
Assert.assertNotNull(
((FilterNode) logicalPlan.getSources().get(0).getSources().get(0)).getPredicate());
assertThat(logicalPlan.getSources().get(0), instanceOf(ProjectNode.class));
assertThat(logicalPlan.getSources().get(0).getSources().get(0), instanceOf(FilterNode.class));
assertThat(logicalPlan.getSources().get(0).getSources().get(0).getSources().get(0),
instanceOf(StructuredDataSourceNode.class));

assertThat(logicalPlan.getSchema().fields().size(), equalTo( 3));
Assert.assertNotNull(((FilterNode) logicalPlan.getSources().get(0).getSources().get(0)).getPredicate());
}

@Test
public void testSimpleLeftJoinLogicalPlan() throws Exception {
String simpleQuery = "SELECT t1.col1, t2.col1, t1.col4, t2.col2 FROM test1 t1 LEFT JOIN test2 t2 ON t1.col1 = t2.col1;";
PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

// Assert.assertTrue(logicalPlan instanceof OutputKafkaTopicNode);
Assert.assertTrue(logicalPlan.getSources().get(0) instanceof ProjectNode);
Assert.assertTrue(logicalPlan.getSources().get(0).getSources().get(0) instanceof JoinNode);
Assert.assertTrue(logicalPlan.getSources().get(0).getSources().get(0).getSources()
.get(0) instanceof StructuredDataSourceNode);
Assert.assertTrue(logicalPlan.getSources().get(0).getSources().get(0).getSources()
.get(1) instanceof StructuredDataSourceNode);
// assertThat(logicalPlan instanceof OutputKafkaTopicNode);
assertThat(logicalPlan.getSources().get(0), instanceOf(ProjectNode.class));
assertThat(logicalPlan.getSources().get(0).getSources().get(0), instanceOf(JoinNode.class));
assertThat(logicalPlan.getSources().get(0).getSources().get(0).getSources()
.get(0), instanceOf(StructuredDataSourceNode.class));
assertThat(logicalPlan.getSources().get(0).getSources().get(0).getSources()
.get(1), instanceOf(StructuredDataSourceNode.class));

Assert.assertTrue(logicalPlan.getSchema().fields().size() == 4);
assertThat(logicalPlan.getSchema().fields().size(), equalTo(4));

}

Expand All @@ -132,22 +132,20 @@ public void testSimpleLeftJoinFilterLogicalPlan() throws Exception {
+ "t1.col1 = t2.col1 WHERE t1.col1 > 10 AND t2.col4 = 10.8;";
PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

// Assert.assertTrue(logicalPlan instanceof OutputKafkaTopicNode);
Assert.assertTrue(logicalPlan.getSources().get(0) instanceof ProjectNode);
assertThat(logicalPlan.getSources().get(0), instanceOf(ProjectNode.class));
ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0);

Assert.assertTrue(projectNode.getKeyField().name().equalsIgnoreCase("t1.col1"));
Assert.assertTrue(projectNode.getSchema().fields().size() == 5);
assertThat(projectNode.getKeyField().name(), equalTo("t1.col1".toUpperCase()));
assertThat(projectNode.getSchema().fields().size(), equalTo(5));

Assert.assertTrue(projectNode.getSources().get(0) instanceof FilterNode);
assertThat(projectNode.getSources().get(0), instanceOf(FilterNode.class));
FilterNode filterNode = (FilterNode) projectNode.getSources().get(0);
Assert.assertTrue(filterNode.getPredicate().toString()
.equalsIgnoreCase("((T1.COL1 > 10) AND (T2.COL4 = 10.8))"));
assertThat(filterNode.getPredicate().toString(), equalTo("((T1.COL1 > 10) AND (T2.COL4 = 10.8))"));

Assert.assertTrue(filterNode.getSources().get(0) instanceof JoinNode);
assertThat(filterNode.getSources().get(0), instanceOf(JoinNode.class));
JoinNode joinNode = (JoinNode) filterNode.getSources().get(0);
Assert.assertTrue(joinNode.getSources().get(0) instanceof StructuredDataSourceNode);
Assert.assertTrue(joinNode.getSources().get(1) instanceof StructuredDataSourceNode);
assertThat(joinNode.getSources().get(0), instanceOf(StructuredDataSourceNode.class));
assertThat(joinNode.getSources().get(1), instanceOf(StructuredDataSourceNode.class));

}

Expand All @@ -159,18 +157,17 @@ public void testSimpleAggregateLogicalPlan() throws Exception {

PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

Assert.assertTrue(logicalPlan.getSources().get(0) instanceof AggregateNode);
assertThat(logicalPlan.getSources().get(0), instanceOf(AggregateNode.class));
AggregateNode aggregateNode = (AggregateNode) logicalPlan.getSources().get(0);
Assert.assertTrue(aggregateNode.getFunctionList().size() == 2);
Assert.assertTrue(aggregateNode.getFunctionList().get(0).getName().getSuffix()
.equalsIgnoreCase("sum"));
Assert.assertTrue(aggregateNode.getWindowExpression().getKsqlWindowExpression().toString().equalsIgnoreCase(" TUMBLING ( SIZE 2 SECONDS ) "));
Assert.assertTrue(aggregateNode.getGroupByExpressions().size() == 1);
Assert.assertTrue(aggregateNode.getGroupByExpressions().get(0).toString().equalsIgnoreCase("TEST1.COL0"));
Assert.assertTrue(aggregateNode.getRequiredColumnList().size() == 2);
Assert.assertTrue(aggregateNode.getSchema().fields().get(1).schema() == Schema.FLOAT64_SCHEMA);
Assert.assertTrue(aggregateNode.getSchema().fields().get(2).schema() == Schema.INT64_SCHEMA);
Assert.assertTrue(logicalPlan.getSources().get(0).getSchema().fields().size() == 3);
assertThat(aggregateNode.getFunctionList().size(), equalTo(2));
assertThat(aggregateNode.getFunctionList().get(0).getName().getSuffix(), equalTo("SUM"));
assertThat(aggregateNode.getWindowExpression().getKsqlWindowExpression().toString(), equalTo(" TUMBLING ( SIZE 2 SECONDS ) "));
assertThat(aggregateNode.getGroupByExpressions().size(), equalTo(1));
assertThat(aggregateNode.getGroupByExpressions().get(0).toString(), equalTo("TEST1.COL0"));
assertThat(aggregateNode.getRequiredColumnList().size(), equalTo(2));
assertThat(aggregateNode.getSchema().fields().get(1).schema().type(), equalTo(Schema.Type.FLOAT64));
assertThat(aggregateNode.getSchema().fields().get(2).schema().type(), equalTo(Schema.Type.INT64));
assertThat(logicalPlan.getSources().get(0).getSchema().fields().size(), equalTo(3));

}

Expand All @@ -182,17 +179,16 @@ public void testComplexAggregateLogicalPlan() throws Exception {

PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

Assert.assertTrue(logicalPlan.getSources().get(0) instanceof AggregateNode);
assertThat(logicalPlan.getSources().get(0), instanceOf(AggregateNode.class));
AggregateNode aggregateNode = (AggregateNode) logicalPlan.getSources().get(0);
Assert.assertTrue(aggregateNode.getFunctionList().size() == 2);
Assert.assertTrue(aggregateNode.getFunctionList().get(0).getName().getSuffix()
.equalsIgnoreCase("sum"));
Assert.assertTrue(aggregateNode.getWindowExpression().getKsqlWindowExpression().toString().equalsIgnoreCase(" HOPPING ( SIZE 2 SECONDS , ADVANCE BY 1 SECONDS ) "));
Assert.assertTrue(aggregateNode.getGroupByExpressions().size() == 1);
Assert.assertTrue(aggregateNode.getGroupByExpressions().get(0).toString().equalsIgnoreCase("TEST1.COL0"));
Assert.assertTrue(aggregateNode.getRequiredColumnList().size() == 2);
Assert.assertTrue(aggregateNode.getSchema().fields().get(1).schema() == Schema.FLOAT64_SCHEMA);
Assert.assertTrue(logicalPlan.getSources().get(0).getSchema().fields().size() == 2);
assertThat(aggregateNode.getFunctionList().size(), equalTo(2));
assertThat(aggregateNode.getFunctionList().get(0).getName().getSuffix(), equalTo("SUM"));
assertThat(aggregateNode.getWindowExpression().getKsqlWindowExpression().toString(), equalTo(" HOPPING ( SIZE 2 SECONDS , ADVANCE BY 1 SECONDS ) "));
assertThat(aggregateNode.getGroupByExpressions().size(), equalTo(1));
assertThat(aggregateNode.getGroupByExpressions().get(0).toString(), equalTo("TEST1.COL0"));
assertThat(aggregateNode.getRequiredColumnList().size(), equalTo(2));
assertThat(aggregateNode.getSchema().fields().get(1).schema().type(), equalTo(Schema.Type.FLOAT64));
assertThat(logicalPlan.getSources().get(0).getSchema().fields().size(), equalTo(2));

}
}
Loading