Skip to content

Commit

Permalink
feat: perform topic permission checks for KSQL service principal
Browse files Browse the repository at this point in the history
  • Loading branch information
spena committed Aug 20, 2019
1 parent 1228cb0 commit 5463a8e
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.parser.tree.CreateAsSelect;
import io.confluent.ksql.parser.tree.CreateSource;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.Statement;
Expand All @@ -44,15 +45,17 @@ public void validate(
final Statement statement
) {
if (statement instanceof Query) {
validateQueryTopicSources(serviceContext, metaStore, (Query)statement);
validateQuery(serviceContext, metaStore, (Query)statement);
} else if (statement instanceof InsertInto) {
validateInsertInto(serviceContext, metaStore, (InsertInto)statement);
} else if (statement instanceof CreateAsSelect) {
validateCreateAsSelect(serviceContext, metaStore, (CreateAsSelect)statement);
} else if (statement instanceof CreateSource) {
validateCreateSource(serviceContext, (CreateSource)statement);
}
}

private void validateQueryTopicSources(
private void validateQuery(
final ServiceContext serviceContext,
final MetaStore metaStore,
final Query query
Expand All @@ -78,11 +81,20 @@ private void validateCreateAsSelect(
* the target topic using the same ServiceContext used for validation.
*/

validateQueryTopicSources(serviceContext, metaStore, createAsSelect.getQuery());
validateQuery(serviceContext, metaStore, createAsSelect.getQuery());

// At this point, the topic should have been created by the TopicCreateInjector
final String kafkaTopic = getCreateAsSelectSinkTopic(metaStore, createAsSelect);
checkAccess(serviceContext, kafkaTopic, AclOperation.WRITE);

}

private void validateCreateSource(
final ServiceContext serviceContext,
final CreateSource createSource
) {
final String sourceTopic = createSource.getProperties().getKafkaTopic();
checkAccess(serviceContext, sourceTopic, AclOperation.READ);
}

private void validateInsertInto(
Expand All @@ -96,7 +108,7 @@ private void validateInsertInto(
* Validates Write on the target topic, and Read on the query sources topics.
*/

validateQueryTopicSources(serviceContext, metaStore, insertInto.getQuery());
validateQuery(serviceContext, metaStore, insertInto.getQuery());

final String kafkaTopic = getSourceTopicName(metaStore, insertInto.getTarget().getSuffix());
checkAccess(serviceContext, kafkaTopic, AclOperation.WRITE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
Expand Down Expand Up @@ -60,10 +61,14 @@
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.common.acl.AclOperation;
import org.apache.kafka.common.errors.TopicAuthorizationException;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;

// CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling
public class InsertValuesExecutor {
// CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling

private static final Duration MAX_SEND_TIMEOUT = Duration.ofSeconds(5);

Expand Down Expand Up @@ -163,6 +168,17 @@ public void execute(
);

producer.sendRecord(record, serviceContext, config.getProducerClientConfigProps());
} catch (final TopicAuthorizationException e) {
// TopicAuthorizationException does not give much detailed information about why it failed,
// except which topics are denied. Here we just add the ACL to make the error message
// consistent with other authorization error messages.
final Exception rootCause = new KsqlTopicAuthorizationException(
AclOperation.WRITE,
e.unauthorizedTopics()
);

throw new KsqlException("Failed to insert values into stream/table: "
+ insertValues.getTarget().getSuffix(), rootCause);
} catch (final Exception e) {
throw new KsqlException("Failed to insert values into stream/table: "
+ insertValues.getTarget().getSuffix(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -63,6 +64,7 @@
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
Expand All @@ -74,6 +76,7 @@
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.errors.SerializationException;
import org.apache.kafka.common.errors.TopicAuthorizationException;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.connect.data.Schema;
Expand Down Expand Up @@ -515,6 +518,30 @@ public void shouldThrowOnSerializingValueError() {
executor.execute(statement, engine, serviceContext);
}

@Test
public void shouldThrowOnTopicAuthorizationException() {
// Given:
final ConfiguredStatement<InsertValues> statement = givenInsertValues(
allFieldNames(SCHEMA),
ImmutableList.of(
new LongLiteral(1L),
new StringLiteral("str"),
new StringLiteral("str"),
new LongLiteral(2L))
);
doThrow(new TopicAuthorizationException(Collections.singleton("t1")))
.when(producer).send(any());

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectCause(hasMessage(
containsString("Authorization denied to Write on topic(s): [t1]"))
);

// When:
executor.execute(statement, engine, serviceContext);
}

@Test
public void shouldThrowIfRowKeyAndKeyDoNotMatch() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package io.confluent.ksql.rest.server.computation;

import static java.util.Objects.requireNonNull;

import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.engine.TopicAccessValidator;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatusEntity;
Expand All @@ -25,7 +29,6 @@
import io.confluent.ksql.statement.Injector;
import io.confluent.ksql.util.KsqlServerException;
import java.time.Duration;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;

Expand All @@ -40,15 +43,19 @@ public class DistributingExecutor implements StatementExecutor<Statement> {
private final CommandQueue commandQueue;
private final Duration distributedCmdResponseTimeout;
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final TopicAccessValidator topicAccessValidator;

public DistributingExecutor(
final CommandQueue commandQueue,
final Duration distributedCmdResponseTimeout,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory) {
this.commandQueue = Objects.requireNonNull(commandQueue, "commandQueue");
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final TopicAccessValidator topicAccessValidator
) {
this.commandQueue = requireNonNull(commandQueue, "commandQueue");
this.distributedCmdResponseTimeout =
Objects.requireNonNull(distributedCmdResponseTimeout, "distributedCmdResponseTimeout");
this.injectorFactory = Objects.requireNonNull(injectorFactory, "injectorFactory");
requireNonNull(distributedCmdResponseTimeout, "distributedCmdResponseTimeout");
this.injectorFactory = requireNonNull(injectorFactory, "injectorFactory");
this.topicAccessValidator = requireNonNull(topicAccessValidator, "topicAccessValidator");
}

@Override
Expand All @@ -61,6 +68,8 @@ public Optional<KsqlEntity> execute(
.apply(executionContext, serviceContext)
.inject(statement);

checkExecutionPermissions(serviceContext, executionContext, injected.getStatement());

try {
final QueuedCommandStatus queuedCommandStatus = commandQueue.enqueueCommand(injected);
final CommandStatus commandStatus = queuedCommandStatus
Expand All @@ -78,4 +87,40 @@ public Optional<KsqlEntity> execute(
statement.getStatementText()), e);
}
}

/**
* Performs permissions checks on the statement resources.
* </p>
* Before persisting the statement in the KSQL command topic, this check verifies the KSQL
* server as well as the User in a authenticated environment have the right ACLs permissions
* to access the statement resources.
*
* @param userContext The context of the user executing this command.
* The KSQL context is used of no authentication service and impersonation
* is configured in the system.
* @param executionContext The execution context which contains the KSQL service context
* and the KSQL metastore.
* @param statement The statement that needs to be checked.
*/
private void checkExecutionPermissions(
final ServiceContext userContext,
final KsqlExecutionContext executionContext,
final Statement statement
) {
final ServiceContext serverContext = executionContext.getServiceContext();
final MetaStore metaStore = executionContext.getMetaStore();

topicAccessValidator.validate(userContext, metaStore, statement);

// If these service contexts are different, then KSQL is running in a secured environment
// with authentication and impersonation enabled.
if (userContext != serverContext) {
try {
// Perform a permission check for the KSQL server
topicAccessValidator.validate(serverContext, metaStore, statement);
} catch (final Exception e) {
throw new KsqlServerException("The KSQL server cannot execute the given command.", e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ public KsqlResource(
CustomValidators.VALIDATOR_MAP,
injectorFactory,
ksqlEngine::createSandbox,
ksqlConfig,
topicAccessValidator);
ksqlConfig);
this.handler = new RequestHandler(
CustomExecutors.EXECUTOR_MAP,
new DistributingExecutor(
commandQueue,
distributedCmdResponseTimeout,
injectorFactory),
injectorFactory,
topicAccessValidator
),
ksqlEngine,
ksqlConfig,
new DefaultCommandQueueSync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.engine.KsqlEngine;
import io.confluent.ksql.engine.TopicAccessValidator;
import io.confluent.ksql.parser.KsqlParser.ParsedStatement;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.tree.CreateAsSelect;
Expand Down Expand Up @@ -55,7 +54,6 @@ public class RequestValidator {
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
private final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier;
private final KsqlConfig ksqlConfig;
private final TopicAccessValidator topicAccessValidator;

/**
* @param customValidators a map describing how to validate each statement of type
Expand All @@ -69,14 +67,12 @@ public RequestValidator(
final Map<Class<? extends Statement>, StatementValidator<?>> customValidators,
final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory,
final Function<ServiceContext, KsqlExecutionContext> snapshotSupplier,
final KsqlConfig ksqlConfig,
final TopicAccessValidator topicAccessValidator
final KsqlConfig ksqlConfig
) {
this.customValidators = requireNonNull(customValidators, "customValidators");
this.injectorFactory = requireNonNull(injectorFactory, "injectorFactory");
this.snapshotSupplier = requireNonNull(snapshotSupplier, "snapshotSupplier");
this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig");
this.topicAccessValidator = topicAccessValidator;
}

/**
Expand Down Expand Up @@ -144,12 +140,6 @@ private <T extends Statement> int validate(
} else if (KsqlEngine.isExecutableStatement(configured.getStatement())) {
final ConfiguredStatement<?> statementInjected = injector.inject(configured);

topicAccessValidator.validate(
serviceContext,
executionContext.getMetaStore(),
statementInjected.getStatement()
);

executionContext.execute(serviceContext, statementInjected);
} else {
throw new KsqlStatementException(
Expand Down
Loading

0 comments on commit 5463a8e

Please sign in to comment.