Skip to content

Commit

Permalink
fix: changes required for compatibility with KIP-479
Browse files Browse the repository at this point in the history
  • Loading branch information
bbejeck committed Oct 4, 2019
1 parent 47313ff commit 530ae49
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.streams.StreamJoinedFactory;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.name.SourceName;
Expand Down Expand Up @@ -120,6 +121,8 @@
import org.apache.kafka.streams.kstream.KTable;
import org.apache.kafka.streams.kstream.KeyValueMapper;
import org.apache.kafka.streams.kstream.Predicate;
import org.apache.kafka.streams.kstream.StreamJoined;
import org.apache.kafka.streams.kstream.ValueMapper;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -144,6 +147,8 @@ public class SchemaKStreamTest {
"group", Serdes.String(), Serdes.String());
private final Joined joined = Joined.with(
Serdes.String(), Serdes.String(), Serdes.String(), "join");
private final StreamJoined streamJoined = StreamJoined.with(
Serdes.String(), Serdes.String(), Serdes.String()).withName("join");
private final KeyField validJoinKeyField = KeyField.of(
Optional.of(ColumnName.of("left.COL1")),
metaStore.getSource(SourceName.of("TEST1"))
Expand Down Expand Up @@ -180,6 +185,8 @@ public class SchemaKStreamTest {
@Mock
private MaterializedFactory mockMaterializedFactory;
@Mock
private StreamJoinedFactory mockStreamJoinedFactory;
@Mock
private KStream mockKStream;
@Mock
private KeySerde keySerde;
Expand Down Expand Up @@ -207,6 +214,7 @@ public void init() {
when(mockGroupedFactory.create(anyString(), any(Serde.class), any(Serde.class)))
.thenReturn(grouped);
when(mockJoinedFactory.create(any(), any(), any(), anyString())).thenReturn(joined);
when(mockStreamJoinedFactory.create(any(), any(), any(), anyString(), anyString())).thenReturn(streamJoined);

final KsqlStream secondKsqlStream = (KsqlStream) metaStore.getSource(SourceName.of("ORDERS"));
secondKStream = builder
Expand Down Expand Up @@ -871,7 +879,7 @@ public void shouldPerformStreamToStreamLeftJoin() {
any(KStream.class),
any(KsqlValueJoiner.class),
any(JoinWindows.class),
any(Joined.class))
any(StreamJoined.class))
).thenReturn(mockKStream);
when(queryBuilder.buildValueSerde(any(), any(), any()))
.thenReturn(leftSerde)
Expand All @@ -892,12 +900,12 @@ public void shouldPerformStreamToStreamLeftJoin() {

// Then:
joinedKStream.getSourceStep().build(planBuilder);
verifyCreateJoined(rightSerde);
verifyCreateStreamJoined(rightSerde);
verify(mockKStream).leftJoin(
eq(secondKStream),
any(KsqlValueJoiner.class),
eq(joinWindow),
same(joined)
same(streamJoined)
);
assertThat(joinedKStream, instanceOf(SchemaKStream.class));
assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type);
Expand Down Expand Up @@ -975,7 +983,7 @@ public void shouldPerformStreamToStreamInnerJoin() {
any(KStream.class),
any(KsqlValueJoiner.class),
any(JoinWindows.class),
any(Joined.class))
any(StreamJoined.class))
).thenReturn(mockKStream);
when(queryBuilder.buildValueSerde(any(), any(), any()))
.thenReturn(leftSerde)
Expand All @@ -996,12 +1004,12 @@ public void shouldPerformStreamToStreamInnerJoin() {

// Then:
joinedKStream.getSourceStep().build(planBuilder);
verifyCreateJoined(rightSerde);
verifyCreateStreamJoined(rightSerde);
verify(mockKStream).join(
eq(secondKStream),
any(KsqlValueJoiner.class),
eq(joinWindow),
same(joined)
same(streamJoined)
);

assertThat(joinedKStream, instanceOf(SchemaKStream.class));
Expand All @@ -1023,7 +1031,7 @@ public void shouldPerformStreamToStreamOuterJoin() {
any(KStream.class),
any(KsqlValueJoiner.class),
any(JoinWindows.class),
any(Joined.class))
any(StreamJoined.class))
).thenReturn(mockKStream);
when(queryBuilder.buildValueSerde(any(), any(), any()))
.thenReturn(leftSerde)
Expand All @@ -1044,12 +1052,12 @@ public void shouldPerformStreamToStreamOuterJoin() {

// Then:
joinedKStream.getSourceStep().build(planBuilder);
verifyCreateJoined(rightSerde);
verifyCreateStreamJoined(rightSerde);
verify(mockKStream).outerJoin(
eq(secondKStream),
any(KsqlValueJoiner.class),
eq(joinWindow),
same(joined)
same(streamJoined)
);
assertThat(joinedKStream, instanceOf(SchemaKStream.class));
assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type);
Expand Down Expand Up @@ -1298,6 +1306,16 @@ private void givenSourcePropertiesWithSchema(
);
}

private void verifyCreateStreamJoined(final Serde<GenericRow> rightSerde) {
verify(mockStreamJoinedFactory).create(
same(keySerde),
same(leftSerde),
same(rightSerde),
eq(StreamsUtil.buildOpName(childContextStacker.getQueryContext())),
eq(StreamsUtil.buildOpName(childContextStacker.getQueryContext()))
);
}

private SchemaKStream buildSchemaKStream(
final LogicalSchema schema,
final KeyField keyField,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2018 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.execution.streams;

import io.confluent.ksql.util.KsqlConfig;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.streams.kstream.StreamJoined;

public interface StreamJoinedFactory {
<K, V, V0> StreamJoined<K, V, V0> create(
Serde<K> keySerde,
Serde<V> leftSerde,
Serde<V0> rightSerde,
String name,
String storeName);


static StreamJoinedFactory create(final KsqlConfig ksqlConfig) {
if (StreamsUtil.useProvidedName(ksqlConfig)) {
return new StreamJoinedFactory() {
@Override
public <K, V, V0> StreamJoined<K, V, V0> create(
final Serde<K> keySerde,
final Serde<V> leftSerde,
final Serde<V0> rightSerde,
final String name,
final String storeName) {
return StreamJoined.with(keySerde, leftSerde, rightSerde)
.withName(name).withStoreName(storeName);
}
};
}
return new StreamJoinedFactory() {
@Override
public <K, V, V0> StreamJoined<K, V, V0> create(
final Serde<K> keySerde,
final Serde<V> leftSerde,
final Serde<V0> rightSerde,
final String name,
final String storeName) {
return StreamJoined.with(keySerde, leftSerde, rightSerde);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import io.confluent.ksql.serde.KeySerde;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.Joined;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.StreamJoined;

public final class StreamStreamJoinBuilder {
private static final String LEFT_SERDE_CTX = "left";
Expand All @@ -41,7 +41,7 @@ public static <K> KStreamHolder<K> build(
final KStreamHolder<K> right,
final StreamStreamJoin<K> join,
final KsqlQueryBuilder queryBuilder,
final JoinedFactory joinedFactory) {
final StreamJoinedFactory streamJoinedFactory) {
final Formats leftFormats = join.getLeftFormats();
final QueryContext queryContext = join.getProperties().getQueryContext();
final QueryContext.Stacker stacker = QueryContext.Stacker.of(queryContext);
Expand Down Expand Up @@ -71,10 +71,11 @@ public static <K> KStreamHolder<K> build(
leftPhysicalSchema,
queryContext
);
final Joined<K, GenericRow, GenericRow> joined = joinedFactory.create(
final StreamJoined<K, GenericRow, GenericRow> joined = streamJoinedFactory.create(
keySerde,
leftSerde,
rightSerde,
StreamsUtil.buildOpName(queryContext),
StreamsUtil.buildOpName(queryContext)
);
final KsqlValueJoiner joiner = new KsqlValueJoiner(leftSchema, rightSchema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,27 @@ public class StreamsFactories {
private final GroupedFactory groupedFactory;
private final JoinedFactory joinedFactory;
private final MaterializedFactory materializedFactory;
private final StreamJoinedFactory streamJoinedFactory;

public static StreamsFactories create(final KsqlConfig ksqlConfig) {
Objects.requireNonNull(ksqlConfig);
return new StreamsFactories(
GroupedFactory.create(ksqlConfig),
JoinedFactory.create(ksqlConfig),
MaterializedFactory.create(ksqlConfig)
MaterializedFactory.create(ksqlConfig),
StreamJoinedFactory.create(ksqlConfig)
);
}

public StreamsFactories(
final GroupedFactory groupedFactory,
final JoinedFactory joinedFactory,
final MaterializedFactory materializedFactory) {
final MaterializedFactory materializedFactory,
final StreamJoinedFactory streamJoinedFactory) {
this.groupedFactory = Objects.requireNonNull(groupedFactory);
this.joinedFactory = Objects.requireNonNull(joinedFactory);
this.materializedFactory = Objects.requireNonNull(materializedFactory);
this.streamJoinedFactory = Objects.requireNonNull(streamJoinedFactory);
}

public GroupedFactory getGroupedFactory() {
Expand All @@ -52,4 +56,8 @@ public JoinedFactory getJoinedFactory() {
public MaterializedFactory getMaterializedFactory() {
return materializedFactory;
}

public StreamJoinedFactory getStreamJoinedFactory() {
return streamJoinedFactory;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.Joined;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.StreamJoined;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -103,9 +103,9 @@ public class StreamStreamJoinBuilderTest {
@Mock
private ExecutionStep<KStreamHolder<Struct>> right;
@Mock
private Joined<Struct, GenericRow, GenericRow> joined;
private StreamJoined<Struct, GenericRow, GenericRow> joined;
@Mock
private JoinedFactory joinedFactory;
private StreamJoinedFactory streamJoinedFactory;
@Mock
private KsqlQueryBuilder queryBuilder;
@Mock
Expand All @@ -130,7 +130,7 @@ public void init() {
.thenReturn(leftSerde);
when(queryBuilder.buildValueSerde(eq(FormatInfo.of(Format.AVRO)), any(), any()))
.thenReturn(rightSerde);
when(joinedFactory.create(any(Serde.class), any(), any(), any())).thenReturn(joined);
when(streamJoinedFactory.create(any(Serde.class), any(), any(), any())).thenReturn(joined);
when(left.build(any())).thenReturn(
new KStreamHolder<>(leftKStream, keySerdeFactory));
when(right.build(any())).thenReturn(
Expand All @@ -141,7 +141,7 @@ public void init() {
mock(AggregateParams.Factory.class),
new StreamsFactories(
mock(GroupedFactory.class),
joinedFactory,
streamJoinedFactory,
mock(MaterializedFactory.class)
)
);
Expand All @@ -164,7 +164,7 @@ private void givenLeftJoin() {

@SuppressWarnings("unchecked")
private void givenOuterJoin() {
when(leftKStream.outerJoin(any(KStream.class), any(), any(), any())).thenReturn(resultKStream);
when(leftKStream.outerJoin(any(KStream.class), any(), any(), any(StreamJoined.class))).thenReturn(resultKStream);
join = new StreamStreamJoin<>(
new DefaultExecutionStepProperties(SCHEMA, CTX),
JoinType.OUTER,
Expand All @@ -179,7 +179,7 @@ private void givenOuterJoin() {

@SuppressWarnings("unchecked")
private void givenInnerJoin() {
when(leftKStream.join(any(KStream.class), any(), any(), any())).thenReturn(resultKStream);
when(leftKStream.join(any(KStream.class), any(), any(), any(StreamJoined.class))).thenReturn(resultKStream);
join = new StreamStreamJoin<>(
new DefaultExecutionStepProperties(SCHEMA, CTX),
JoinType.INNER,
Expand Down Expand Up @@ -261,7 +261,7 @@ public void shouldBuildJoinedCorrectly() {
join.build(planBuilder);

// Then:
verify(joinedFactory).create(keySerde, leftSerde, rightSerde, "jo-in");
verify(streamJoinedFactory).create(keySerde, leftSerde, rightSerde, "jo-in", "jo-in");
}

@Test
Expand Down

0 comments on commit 530ae49

Please sign in to comment.