Skip to content

Commit

Permalink
feat: block inserting into sources with headers (#8417)
Browse files Browse the repository at this point in the history
* feat: block inserting into sources with headers

* more detailed error messages

* unused import
  • Loading branch information
Zara Lim authored Dec 2, 2021
1 parent 8aa55b1 commit 0239a95
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ public boolean isKeyColumn(final ColumnName columnName) {
.isPresent();
}

/**
* @param columnName the column name to check
* @return {@code true} if the column matches the name of any key column.
*/
public boolean isHeaderColumn(final ColumnName columnName) {
return findColumnMatching(withNamespace(HEADERS).and(withName(columnName)))
.isPresent();
}

/**
* Returns True if this schema is compatible with {@code other} schema.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,14 @@ public void shouldRemoveAllButKeyCols() {
));
}

@Test
public void shouldMatchHeaderColumnName() {
assertThat(SOME_SCHEMA.isHeaderColumn(H0), is(true));
assertThat(SOME_SCHEMA.isHeaderColumn(ROWPARTITION_NAME), is(false));
assertThat(SOME_SCHEMA.isHeaderColumn(K0), is(false));
assertThat(SOME_SCHEMA.isHeaderColumn(F0), is(false));
}

@Test
public void shouldMatchMetaColumnName() {
assertThat(SystemColumns.isPseudoColumn(ROWTIME_NAME, ROWTIME_PSEUDOCOLUMN_VERSION), is(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public StatementExecutorResponse execute(
.inject(statement);

if (injected.getStatement() instanceof InsertInto) {
throwIfInsertOnReadOnlyTopic(
validateInsertIntoQueries(
executionContext.getMetaStore(),
(InsertInto) injected.getStatement()
);
Expand Down Expand Up @@ -266,7 +266,7 @@ private void checkAuthorization(
}
}

private void throwIfInsertOnReadOnlyTopic(
private void validateInsertIntoQueries(
final MetaStore metaStore,
final InsertInto insertInto
) {
Expand All @@ -280,5 +280,10 @@ private void throwIfInsertOnReadOnlyTopic(
throw new KsqlException("Cannot insert into read-only topic: "
+ dataSource.getKafkaTopicName());
}

if (!dataSource.getSchema().headers().isEmpty()) {
throw new KsqlException("Cannot insert into " + insertInto.getTarget().text()
+ " because it has header columns");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.confluent.ksql.logging.processing.NoopProcessingLogContext;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.InsertValues;
import io.confluent.ksql.rest.SessionProperties;
import io.confluent.ksql.schema.ksql.PersistenceSchema;
Expand All @@ -50,12 +51,14 @@
import io.confluent.ksql.util.KsqlStatementException;
import io.confluent.ksql.util.ReservedInternalTopics;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.hc.core5.http.HttpStatus;
Expand Down Expand Up @@ -144,6 +147,8 @@ public void execute(

final DataSource dataSource = getDataSource(config, metaStore, insertValues);

validateInsert(insertValues.getColumns(), dataSource);

final ProducerRecord<byte[], byte[]> record =
buildRecord(statement, metaStore, dataSource, serviceContext);

Expand All @@ -164,6 +169,25 @@ public void execute(
}
}

private void validateInsert(final List<ColumnName> columns, final DataSource dataSource) {
final List<String> headerColumns;
if (columns.isEmpty()) {
headerColumns = dataSource.getSchema().headers()
.stream()
.map(column -> column.name().text())
.collect(Collectors.toList());
} else {
headerColumns = columns.stream()
.filter(columnName -> dataSource.getSchema().isHeaderColumn(columnName))
.map(ColumnName::text)
.collect(Collectors.toList());
}
if (!headerColumns.isEmpty()) {
throw new KsqlException("Cannot insert into HEADER columns: "
+ String.join(", ", headerColumns));
}
}

private static DataSource getDataSource(
final KsqlConfig ksqlConfig,
final MetaStore metaStore,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.config.SessionConfig;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.properties.with.CreateSourceProperties;
Expand All @@ -59,6 +61,7 @@
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.WarningEntity;
import io.confluent.ksql.rest.server.execution.StatementExecutorResponse;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.security.KsqlSecurityContext;
import io.confluent.ksql.services.SandboxedServiceContext;
Expand Down Expand Up @@ -408,6 +411,31 @@ public void shouldThrowExceptionWhenInsertIntoProcessingLogTopic() {
+ "default_ksql_processing_log"));
}

@Test
public void shouldThrowExceptionWhenInsertIntoSourceWithHeaders() {
// Given
final PreparedStatement<Statement> preparedStatement =
PreparedStatement.of("", new InsertInto(SourceName.of("s1"), mock(Query.class)));
final ConfiguredStatement<Statement> configured =
ConfiguredStatement.of(preparedStatement, SessionConfig.of(KSQL_CONFIG, ImmutableMap.of())
);
final DataSource dataSource = mock(DataSource.class);
final LogicalSchema schema = mock(LogicalSchema.class);
doReturn(dataSource).when(metaStore).getSource(SourceName.of("s1"));
doReturn(schema).when(dataSource).getSchema();
doReturn(ImmutableList.of(ColumnName.of("a"))).when(schema).headers();
when(dataSource.getKafkaTopicName()).thenReturn("topic");

// When:
final Exception e = assertThrows(
KsqlException.class,
() -> distributor.execute(configured, executionContext, mock(KsqlSecurityContext.class))
);

// Then:
assertThat(e.getMessage(), is("Cannot insert into s1 because it has header columns"));
}

@Test
public void shouldAbortOnError_ProducerFencedException() {
// When:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static io.confluent.ksql.GenericRow.genericRow;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -47,6 +48,7 @@
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.function.TestFunctionRegistry;
import io.confluent.ksql.logging.processing.NoopProcessingLogContext;
Expand Down Expand Up @@ -109,6 +111,8 @@ public class InsertValuesExecutorTest {
private static final ColumnName COL0 = ColumnName.of("COL0");
private static final ColumnName COL1 = ColumnName.of("COL1");
private static final ColumnName INT_COL = ColumnName.of("INT");
private static final ColumnName HEAD0 = ColumnName.of("HEAD0");
private static final ColumnName HEAD1 = ColumnName.of("HEAD1");

private static final LogicalSchema SINGLE_VALUE_COLUMN_SCHEMA = LogicalSchema.builder()
.keyColumn(K0, SqlTypes.STRING)
Expand All @@ -121,6 +125,21 @@ public class InsertValuesExecutorTest {
.valueColumn(COL1, SqlTypes.BIGINT)
.build();

private static final LogicalSchema SCHEMA_WITH_HEADERS = LogicalSchema.builder()
.keyColumn(K0, SqlTypes.STRING)
.valueColumn(COL0, SqlTypes.STRING)
.valueColumn(COL1, SqlTypes.BIGINT)
.headerColumn(HEAD0, Optional.empty())
.build();

private static final LogicalSchema SCHEMA_WITH_KEY_HEADERS = LogicalSchema.builder()
.keyColumn(K0, SqlTypes.STRING)
.valueColumn(COL0, SqlTypes.STRING)
.valueColumn(COL1, SqlTypes.BIGINT)
.headerColumn(HEAD0, Optional.of("a"))
.headerColumn(HEAD1, Optional.of("b"))
.build();

private static final LogicalSchema BIG_SCHEMA = LogicalSchema.builder()
.keyColumn(K0, SqlTypes.STRING)
.valueColumn(COL0, SqlTypes.STRING) // named COL0 for auto-ROWKEY
Expand Down Expand Up @@ -1116,6 +1135,99 @@ public void shouldThrowWhenNotAuthorizedToWriteValSchemaToSR() throws Exception
+ KsqlConstants.getSRSubject(TOPIC_NAME, false)));
}

@Test
public void shouldThrowOnInsertHeaders() {
// Given:
givenSourceStreamWithSchema(SCHEMA_WITH_HEADERS, SerdeFeatures.of(), SerdeFeatures.of());
final ConfiguredStatement<InsertValues> statement = givenInsertValues(
allColumnNames(SCHEMA_WITH_HEADERS),
ImmutableList.of(
new StringLiteral("key"),
new StringLiteral("str"),
new LongLiteral(2L),
new NullLiteral()
)
);

// When:
final Exception e = assertThrows(
KsqlException.class,
() -> executor.execute(statement, mock(SessionProperties.class), engine, serviceContext)
);

// Then:
assertThat(e.getMessage(), is("Cannot insert into HEADER columns: HEAD0"));
}

@Test
public void shouldThrowOnInsertKeyHeaders() {
// Given:
givenSourceStreamWithSchema(SCHEMA_WITH_KEY_HEADERS, SerdeFeatures.of(), SerdeFeatures.of());
final ConfiguredStatement<InsertValues> statement = givenInsertValues(
allColumnNames(SCHEMA_WITH_KEY_HEADERS),
ImmutableList.of(
new StringLiteral("key"),
new StringLiteral("str"),
new LongLiteral(2L),
new NullLiteral(),
new NullLiteral()
)
);

// When:
final Exception e = assertThrows(
KsqlException.class,
() -> executor.execute(statement, mock(SessionProperties.class), engine, serviceContext)
);

// Then:
assertThat(e.getMessage(), is("Cannot insert into HEADER columns: HEAD0, HEAD1"));
}

@Test
public void shouldInsertValuesIntoHeaderSchemaValueColumns() {
// Given:
givenSourceStreamWithSchema(SCHEMA_WITH_HEADERS, SerdeFeatures.of(), SerdeFeatures.of());
final ConfiguredStatement<InsertValues> statement = givenInsertValues(
ImmutableList.of(K0, COL0, COL1),
ImmutableList.of(
new StringLiteral("key"),
new StringLiteral("str"),
new LongLiteral(2L)
)
);

// When:
executor.execute(statement, mock(SessionProperties.class), engine, serviceContext);

// Then:
verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE));
}

@Test
public void shouldThrowOnInsertAllWithHeaders() {
// Given:
givenSourceStreamWithSchema(SCHEMA_WITH_HEADERS, SerdeFeatures.of(), SerdeFeatures.of());
final ConfiguredStatement<InsertValues> statement = givenInsertValues(
ImmutableList.of(),
ImmutableList.of(
new StringLiteral("key"),
new StringLiteral("str"),
new LongLiteral(2L),
new NullLiteral()
)
);

// When:
final Exception e = assertThrows(
KsqlException.class,
() -> executor.execute(statement, mock(SessionProperties.class), engine, serviceContext)
);

// Then:
assertThat(e.getMessage(), is("Cannot insert into HEADER columns: HEAD0"));
}

private static ConfiguredStatement<InsertValues> givenInsertValues(
final List<ColumnName> columns,
final List<Expression> values
Expand Down

0 comments on commit 0239a95

Please sign in to comment.