Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: enforce WITH KEY column type matches ROWKEY type #4147

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ private static void handleExplicitKeyField(
if (keyValue == null) {
values.put(key.name(), rowKeyValue);
} else {
values.put(SchemaUtil.ROWKEY_NAME, keyValue.toString());
values.put(SchemaUtil.ROWKEY_NAME, keyValue);
}
} else if (keyValue != null && !Objects.equals(keyValue.toString(), rowKeyValue)) {
} else if (keyValue != null && !Objects.equals(keyValue, rowKeyValue)) {
throw new KsqlException(String.format(
"Expected ROWKEY and %s to match but got %s and %s respectively.",
key.toString(FormatOptions.noEscape()), rowKeyValue, keyValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public DataSource<?> getDataSource() {
return dataSource;
}

SourceName getAlias() {
public SourceName getAlias() {
return alias;
}

Expand Down
132 changes: 83 additions & 49 deletions ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
import io.confluent.ksql.execution.streams.JoinParamsFactory;
import io.confluent.ksql.metastore.model.DataSource.DataSourceType;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.tree.WithinExpression;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.Column.Namespace;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SqlBaseType;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKStream;
Expand Down Expand Up @@ -93,11 +92,7 @@ public JoinNode(
? left.getKeyField()
: KeyField.of(leftKeyCol.ref());

this.schema = JoinParamsFactory.createSchema(left.getSchema(), right.getSchema());

if (schema.key().get(0).type().baseType() != SqlBaseType.STRING) {
throw new KsqlException("JOIN is not supported with non-STRING keys");
}
this.schema = buildJoinSchema(left, leftJoinFieldName, right, rightJoinFieldName);
}

@Override
Expand Down Expand Up @@ -237,11 +232,7 @@ SchemaKStream<K> buildStream(
}

@SuppressWarnings("unchecked")
SchemaKTable<K> buildTable(
final PlanNode node,
final ColumnRef joinFieldName,
final SourceName tableName
) {
SchemaKTable<K> buildTable(final PlanNode node) {
final SchemaKStream<?> schemaKStream = node.buildStream(
builder.withKsqlConfig(builder.getKsqlConfig()
.cloneWithPropertyOverwrite(Collections.singletonMap(
Expand All @@ -252,37 +243,7 @@ SchemaKTable<K> buildTable(
throw new RuntimeException("Expected to find a Table, found a stream instead.");
}

final Optional<Column> keyColumn = schemaKStream
.getKeyField()
.resolve(schemaKStream.getSchema());

final ColumnRef rowKey = ColumnRef.of(
tableName,
SchemaUtil.ROWKEY_NAME
);

final boolean namesMatch = keyColumn
.map(field -> field.ref().equals(joinFieldName))
.orElse(false);

if (namesMatch || joinFieldName.equals(rowKey)) {
return (SchemaKTable) schemaKStream;
}

if (!keyColumn.isPresent()) {
throw new KsqlException(
"Source table (" + tableName.name() + ") has no key column defined. "
+ "Only 'ROWKEY' is supported in the join criteria."
);
}

throw new KsqlException(
"Source table (" + tableName.toString(FormatOptions.noEscape()) + ") key column ("
+ keyColumn.get().ref().toString(FormatOptions.noEscape()) + ") "
+ "is not the column used in the join criteria ("
+ joinFieldName.toString(FormatOptions.noEscape()) + "). "
+ "Only the table's key column or 'ROWKEY' is supported in the join criteria."
);
return (SchemaKTable<K>) schemaKStream;
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -378,8 +339,7 @@ public SchemaKStream<K> join() {
+ " the WITHIN clause) and try to execute your join again.");
}

final SchemaKTable<K> rightTable = buildTable(
joinNode.getRight(), joinNode.rightJoinFieldName, joinNode.right.getAlias());
final SchemaKTable<K> rightTable = buildTable(joinNode.getRight());

final SchemaKStream<K> leftStream = buildStream(
joinNode.getLeft(), joinNode.leftJoinFieldName);
Expand Down Expand Up @@ -428,10 +388,8 @@ public SchemaKTable<K> join() {
+ "join again.");
}

final SchemaKTable<K> leftTable = buildTable(
joinNode.getLeft(), joinNode.leftJoinFieldName, joinNode.left.getAlias());
final SchemaKTable<K> rightTable = buildTable(
joinNode.getRight(), joinNode.rightJoinFieldName, joinNode.right.getAlias());
final SchemaKTable<K> leftTable = buildTable(joinNode.getLeft());
final SchemaKTable<K> rightTable = buildTable(joinNode.getRight());

switch (joinNode.joinType) {
case LEFT:
Expand Down Expand Up @@ -465,4 +423,80 @@ private static DataSourceType calculateSinkType(
? DataSourceType.KTABLE
: DataSourceType.KSTREAM;
}

private static LogicalSchema buildJoinSchema(
final DataSourceNode left,
final ColumnRef leftJoinFieldName,
final DataSourceNode right,
final ColumnRef rightJoinFieldName
) {
final LogicalSchema leftSchema = selectKey(left, leftJoinFieldName);
final LogicalSchema rightSchema = selectKey(right, rightJoinFieldName);

return JoinParamsFactory.createSchema(leftSchema, rightSchema);
}

/**
* Adjust the schema to take into account any change in key columns.
*
* @param source the source node
* @param joinColumnRef the join column
* @return the true source schema after any change of key columns.
*/
private static LogicalSchema selectKey(
final DataSourceNode source,
final ColumnRef joinColumnRef
) {
final LogicalSchema sourceSchema = source.getSchema();

final Column joinCol = sourceSchema.findColumn(joinColumnRef)
.orElseThrow(() -> new KsqlException("Unknown join column: " + joinColumnRef));

if (sourceSchema.key().size() != 1) {
throw new UnsupportedOperationException("Only single key columns supported");
}

if (joinCol.namespace() == Namespace.KEY) {
// Join column is only key column, so no change of key columns required:
return sourceSchema;
}

final Optional<Column> keyColumn = source
.getKeyField()
.resolve(sourceSchema);

if (keyColumn.isPresent() && keyColumn.get().equals(joinCol)) {
// Join column is KEY field, which is an alias for the only key column, so no change of key
// columns required:
return sourceSchema;
}

// Change of key columns required

if (source.getDataSourceType() == DataSourceType.KTABLE) {
// Tables do not support rekey:
final String sourceName = source.getDataSource().getName().toString(FormatOptions.noEscape());

if (!keyColumn.isPresent()) {
throw new KsqlException(
"Invalid join criteria: Source table (" + sourceName + ") has no key column "
+ "defined. Only 'ROWKEY' is supported in the join criteria."
);
}

throw new KsqlException(
"Invalid join criteria: Source table "
+ "(" + sourceName + ") key column "
+ "(" + keyColumn.get().ref().toString(FormatOptions.noEscape()) + ") "
+ "is not the column used in the join criteria ("
+ joinCol.ref().toString(FormatOptions.noEscape()) + "). "
+ "Only the table's key column or 'ROWKEY' is supported in the join criteria."
);
}

return LogicalSchema.builder()
.keyColumn(source.getAlias(), SchemaUtil.ROWKEY_NAME, joinCol.type())
.valueColumns(sourceSchema.value())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.Format;
import io.confluent.ksql.serde.FormatInfo;
Expand All @@ -30,6 +29,7 @@
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.util.MetaStoreFixture;
import io.confluent.ksql.util.SchemaUtil;
import java.util.Optional;
import java.util.Set;
import org.hamcrest.MatcherAssert;
Expand All @@ -47,8 +47,9 @@ public class DdlCommandExecTest {
private static final SourceName TABLE_NAME = SourceName.of("t1");
private static final String TOPIC_NAME = "topic";
private static final LogicalSchema SCHEMA = new LogicalSchema.Builder()
.valueColumn(ColumnName.of("F1"), SqlPrimitiveType.of("INTEGER"))
.valueColumn(ColumnName.of("F2"), SqlPrimitiveType.of("VARCHAR"))
.keyColumn(SchemaUtil.ROWKEY_NAME, SqlTypes.BIGINT)
.valueColumn(ColumnName.of("F1"), SqlTypes.BIGINT)
.valueColumn(ColumnName.of("F2"), SqlTypes.STRING)
.build();
private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON));
private static final KeyFormat KEY_FORMAT = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,15 @@ public void shouldThrowOnInsertIntoWithKeyMismatch() {
expectedException.expect(rawMessage(containsString(
"Incompatible key fields for sink and results. "
+ "Sink key field is ORDERTIME (type: BIGINT) "
+ "while result key field is ITEMID (type: STRING)")));
+ "while result key field is ORDERID (type: BIGINT)")));
expectedException.expect(statementText(
is("insert into bar select * from orders partition by itemid;")));
is("insert into bar select * from orders partition by orderid;")));

// When:
KsqlEngineTestUtil.execute(
serviceContext,
ksqlEngine,
"insert into bar select * from orders partition by itemid;",
"insert into bar select * from orders partition by orderid;",
KSQL_CONFIG,
Collections.emptyMap()
);
Expand Down Expand Up @@ -767,7 +767,7 @@ public void shouldHandleMultipleStatements() {
+ "CREATE STREAM S0 (a INT, b VARCHAR) "
+ " WITH (kafka_topic='s0_topic', value_format='DELIMITED');\n"
+ "\n"
+ "CREATE TABLE T1 (f0 BIGINT, f1 DOUBLE) "
+ "CREATE TABLE T1 (ROWKEY BIGINT KEY, f0 BIGINT, f1 DOUBLE) "
+ " WITH (kafka_topic='t1_topic', value_format='JSON', key = 'f0');\n"
+ "\n"
+ "CREATE STREAM S1 AS SELECT * FROM S0;\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ private void produceInitData() throws Exception {
}

private void execInitCreateStreamQueries() {
final String ordersStreamStr = String.format("CREATE STREAM %s (ORDERTIME bigint, ORDERID varchar, "
final String ordersStreamStr = String.format("CREATE STREAM %s ("
+ "ROWKEY BIGINT KEY, ORDERTIME bigint, ORDERID varchar, "
+ "ITEMID varchar, ORDERUNITS double, PRICEARRAY array<double>, KEYVALUEMAP "
+ "map<varchar, double>) WITH (value_format = 'json', "
+ "kafka_topic='%s' , "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ private static Map<String, Object> getKsqlConfig(final Credentials user) {
return configs;
}

private void produceInitData() throws Exception {
private void produceInitData() {
if (topicClient.isTopicExists(INPUT_TOPIC)) {
return;
}
Expand All @@ -340,11 +340,11 @@ private void awaitAsyncInputTopicCreation() {
}

private void execInitCreateStreamQueries() {
final String ordersStreamStr = String.format("CREATE STREAM %s (ORDERTIME bigint, ORDERID varchar, "
+ "ITEMID varchar, ORDERUNITS double, PRICEARRAY array<double>, KEYVALUEMAP "
+ "map<varchar, double>) WITH (value_format = 'json', "
+ "kafka_topic='%s' , "
+ "key='ordertime');", INPUT_STREAM, INPUT_TOPIC);
final String ordersStreamStr =
"CREATE STREAM " + INPUT_STREAM + " (ORDERTIME bigint, ORDERID varchar, "
+ "ITEMID varchar, ORDERUNITS double, PRICEARRAY array<double>, KEYVALUEMAP "
+ "map<varchar, double>) WITH (value_format = 'json', "
+ "kafka_topic='" + INPUT_TOPIC + "');";

KsqlEngineTestUtil.execute(
serviceContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,13 @@ private void createOrdersStream() {
+ " KEYVALUEMAP map<varchar, double>";

ksqlContext.sql("CREATE STREAM " + JSON_STREAM_NAME + " (" + columns + ") WITH "
+ "(kafka_topic='" + jsonTopicName + "', value_format='JSON', key='ordertime');");
+ "(kafka_topic='" + jsonTopicName + "', value_format='JSON');");

ksqlContext.sql("CREATE STREAM " + AVRO_STREAM_NAME + " (" + columns + ") WITH "
+ "(kafka_topic='" + avroTopicName + "', value_format='AVRO', key='ordertime');");
+ "(kafka_topic='" + avroTopicName + "', value_format='AVRO');");

ksqlContext.sql("CREATE STREAM " + AVRO_TIMESTAMP_STREAM_NAME + " (" + columns + ") WITH "
+ "(kafka_topic='" + avroTopicName + "', value_format='AVRO', key='ordertime', "
+ "(kafka_topic='" + avroTopicName + "', value_format='AVRO', "
+ "timestamp='timestamp', timestamp_format='yyyy-MM-dd');");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ private Set<String> getTopicNames() {
return names;
}

private static Deserializer getKeyDeserializerFor(final Object key) {
private static Deserializer<?> getKeyDeserializerFor(final Object key) {
if (key instanceof Windowed) {
if (((Windowed) key).window() instanceof SessionWindow) {
if (((Windowed<?>) key).window() instanceof SessionWindow) {
return SESSION_WINDOWED_DESERIALIZER;
}
return TIME_WINDOWED_DESERIALIZER;
Expand All @@ -288,6 +288,6 @@ private void createOrdersStream() {
+ "ORDERUNITS double, "
+ "PRICEARRAY array<double>, "
+ "KEYVALUEMAP map<varchar, double>) "
+ "WITH (kafka_topic='" + sourceTopicName + "', value_format='JSON', key='ordertime');");
+ "WITH (kafka_topic='" + sourceTopicName + "', value_format='JSON');");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ public class PhysicalPlanBuilderTest {
+ "WITH (KAFKA_TOPIC = 'test1', VALUE_FORMAT = 'JSON');";

private static final String CREATE_STREAM_TEST2 = "CREATE STREAM TEST2 "
+ "(ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ "(ROWKEY BIGINT KEY, ID BIGINT, COL0 VARCHAR, COL1 BIGINT) "
+ " WITH (KAFKA_TOPIC = 'test2', VALUE_FORMAT = 'JSON', KEY='ID');";

private static final String CREATE_STREAM_TEST3 = "CREATE STREAM TEST3 "
+ "(ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ "(ROWKEY BIGINT KEY, ID BIGINT, COL0 BIGINT, COL1 DOUBLE) "
+ " WITH (KAFKA_TOPIC = 'test3', VALUE_FORMAT = 'JSON', KEY='ID');";

private static final String CREATE_TABLE_TEST4 = "CREATE TABLE TEST4 "
+ "(ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ "(ROWKEY BIGINT KEY, ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ " WITH (KAFKA_TOPIC = 'test4', VALUE_FORMAT = 'JSON', KEY='ID');";

private static final String CREATE_TABLE_TEST5 = "CREATE TABLE TEST5 "
+ "(ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ "(ROWKEY BIGINT KEY, ID BIGINT, COL0 VARCHAR, COL1 DOUBLE) "
+ " WITH (KAFKA_TOPIC = 'test5', VALUE_FORMAT = 'JSON', KEY='ID');";

private static final String CREATE_STREAM_TEST6 = "CREATE STREAM TEST6 "
Expand Down Expand Up @@ -316,7 +316,7 @@ public void shouldRepartitionLeftStreamIfNotCorrectKey() {
.get(0);

// Then:
assertThat(result.getExecutionPlan(), containsString("[ REKEY ] | Schema: [ROWKEY DOUBLE KEY, TEST2."));
assertThat(result.getExecutionPlan(), containsString("[ REKEY ] | Schema: [ROWKEY BIGINT KEY, TEST2."));
}

@Test
Expand All @@ -332,7 +332,7 @@ public void shouldRepartitionRightStreamIfNotCorrectKey() {
.get(0);

// Then:
assertThat(result.getExecutionPlan(), containsString("[ REKEY ] | Schema: [ROWKEY STRING KEY, TEST3."));
assertThat(result.getExecutionPlan(), containsString("[ REKEY ] | Schema: [ROWKEY BIGINT KEY, TEST3."));
}

@Test
Expand Down
Loading