Skip to content

Commit

Permalink
move rate limiting to distributing executor
Browse files Browse the repository at this point in the history
  • Loading branch information
lct45 committed Feb 28, 2022
1 parent 6d168be commit e2465ba
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.entity.CommandId;
import io.confluent.ksql.rest.server.CommandTopic;
import io.confluent.ksql.rest.server.CommandTopicBackup;
import io.confluent.ksql.rest.server.CommandTopicBackupImpl;
import io.confluent.ksql.rest.server.CommandTopicBackupNoOp;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.rest.util.CommandTopicBackupUtil;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.util.KsqlConfig;
Expand Down Expand Up @@ -82,7 +79,6 @@ public class CommandStore implements CommandQueue, Closeable {
private final Serializer<Command> commandSerializer;
private final Deserializer<CommandId> commandIdDeserializer;
private final CommandTopicBackup commandTopicBackup;
private final RateLimiter rateLimiter;


public static final class Factory {
Expand Down Expand Up @@ -122,8 +118,6 @@ public static CommandStore create(
internalTopicClient
);
}
final double rateLimit =
ksqlConfig.getDouble(KsqlConfig.KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG);

return new CommandStore(
commandTopicName,
Expand All @@ -139,13 +133,11 @@ public static CommandStore create(
InternalTopicSerdes.serializer(),
InternalTopicSerdes.serializer(),
InternalTopicSerdes.deserializer(CommandId.class),
commandTopicBackup,
RateLimiter.create(rateLimit)
commandTopicBackup
);
}
}

// CHECKSTYLE_RULES.OFF: ParameterNumberCheck
CommandStore(
final String commandTopicName,
final CommandTopic commandTopic,
Expand All @@ -156,10 +148,8 @@ public static CommandStore create(
final Serializer<CommandId> commandIdSerializer,
final Serializer<Command> commandSerializer,
final Deserializer<CommandId> commandIdDeserializer,
final CommandTopicBackup commandTopicBackup,
final RateLimiter rateLimiter
final CommandTopicBackup commandTopicBackup
) {
// CHECKSTYLE_RULES.ON: ParameterNumberCheck
this.commandTopic = Objects.requireNonNull(commandTopic, "commandTopic");
this.commandStatusMap = Maps.newConcurrentMap();
this.sequenceNumberFutureStore =
Expand All @@ -179,8 +169,6 @@ public static CommandStore create(
Objects.requireNonNull(commandIdDeserializer, "commandIdDeserializer");
this.commandTopicBackup =
Objects.requireNonNull(commandTopicBackup, "commandTopicBackup");
this.rateLimiter =
Objects.requireNonNull(rateLimiter, "rateLimiter");
}

@Override
Expand Down Expand Up @@ -210,12 +198,6 @@ public QueuedCommandStatus enqueueCommand(
final Command command,
final Producer<CommandId, Command> transactionalProducer
) {
if (!rateLimiter.tryAcquire()) {
throw new KsqlRestException(
Errors.tooManyRequests(
"Too many writes to the command topic within a 1 second timeframe"
));
}
final CommandStatusFuture statusFuture = commandStatusMap.compute(
commandId,
(k, v) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package io.confluent.ksql.rest.server.computation;

import com.google.common.util.concurrent.RateLimiter;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.metastore.MetaStore;
Expand All @@ -32,6 +33,7 @@
import io.confluent.ksql.rest.entity.KsqlWarning;
import io.confluent.ksql.rest.entity.WarningEntity;
import io.confluent.ksql.rest.server.execution.StatementExecutorResponse;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.security.KsqlSecurityContext;
import io.confluent.ksql.services.ServiceContext;
Expand Down Expand Up @@ -70,6 +72,7 @@ public class DistributingExecutor {
private final ReservedInternalTopics internalTopics;
private final Errors errorHandler;
private final Supplier<String> commandRunnerWarning;
private final RateLimiter rateLimiter;

@SuppressFBWarnings(value = "EI_EXPOSE_REP2")
public DistributingExecutor(
Expand Down Expand Up @@ -98,6 +101,8 @@ public DistributingExecutor(
this.errorHandler = Objects.requireNonNull(errorHandler, "errorHandler");
this.commandRunnerWarning =
Objects.requireNonNull(commandRunnerWarning, "commandRunnerWarning");
this.rateLimiter =
RateLimiter.create(ksqlConfig.getDouble(KsqlConfig.KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG));
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
Expand Down Expand Up @@ -196,6 +201,13 @@ public StatementExecutorResponse execute(
statement.getStatementText()), e);
}

if (!rateLimiter.tryAcquire()) {
throw new KsqlRestException(
Errors.tooManyRequests(
"Too many writes to the command topic within a 1 second timeframe"
));
}

CommandId commandId = null;
try {
transactionalProducer.beginTransaction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -34,14 +32,10 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.rest.entity.CommandId;
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.server.CommandTopic;
import io.confluent.ksql.rest.server.CommandTopicBackup;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.time.Duration;
import java.util.ArrayList;
Expand Down Expand Up @@ -158,8 +152,7 @@ public void setUp() {
commandIdSerializer,
commandSerializer,
commandIdDeserializer,
commandTopicBackup,
RateLimiter.create(KsqlConfig.KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG_DEFAULT)
commandTopicBackup
);
}

Expand Down Expand Up @@ -386,36 +379,6 @@ public void shouldSuccessfullyAbortAndRetry() {
commandStore.enqueueCommand(commandId, command, transactionalProducer);
}

@Test
public void shouldFailEnqueueIfRateLimitHit() {
// Given:
final CommandStore lowRateLimitStore = new CommandStore(
COMMAND_TOPIC_NAME,
commandTopic,
sequenceNumberFutureStore,
Collections.emptyMap(),
Collections.emptyMap(),
TIMEOUT,
commandIdSerializer,
commandSerializer,
commandIdDeserializer,
commandTopicBackup,
RateLimiter.create(1.0)
);

// When:
lowRateLimitStore.enqueueCommand(commandId, command, transactionalProducer);

// Then:
final KsqlRestException e = assertThrows(
KsqlRestException.class,
() -> lowRateLimitStore.enqueueCommand(commandId, command, transactionalProducer)
);
assertEquals(e.getResponse().getStatus(), 429);
final KsqlErrorMessage errorMessage = (KsqlErrorMessage) e.getResponse().getEntity();
assertTrue(errorMessage.getMessage().contains("Too many writes to the command topic within a 1 second timeframe"));
}

private static ConsumerRecords<byte[], byte[]> buildRecords(final Object... args) {
assertThat(args.length % 2, equalTo(0));
final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
Expand Down Expand Up @@ -59,8 +61,10 @@
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatus.Status;
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.entity.WarningEntity;
import io.confluent.ksql.rest.server.execution.StatementExecutorResponse;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.security.KsqlSecurityContext;
Expand Down Expand Up @@ -496,4 +500,33 @@ CommonCreateConfigs.VALUE_FORMAT_PROPERTY, new StringLiteral("json")
assertThat("Should be present", response.getEntity().isPresent());
assertThat(((WarningEntity) response.getEntity().get()).getMessage(), containsString(""));
}

@Test
public void shouldThrowIfRateLimitHit() {
// Given:
final DistributingExecutor rateLimitedDistributor = new DistributingExecutor(
new KsqlConfig(ImmutableMap.of("ksql.command.topic.rate.limit", 1.0)),
queue,
DURATION_10_MS,
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector),
Optional.of(authorizationValidator),
validatedCommandFactory,
errorHandler,
commandRunnerWarning
);

// When:
distributor.execute(CONFIGURED_STATEMENT, executionContext, securityContext);


// Then:
final KsqlRestException e = assertThrows(
KsqlRestException.class,
() -> distributor.execute(CONFIGURED_STATEMENT, executionContext, securityContext)
);

assertEquals(e.getResponse().getStatus(), 429);
final KsqlErrorMessage errorMessage = (KsqlErrorMessage) e.getResponse().getEntity();
assertTrue(errorMessage.getMessage().contains("Too many writes to the command topic within a 1 second timeframe"));
}
}

0 comments on commit e2465ba

Please sign in to comment.