diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java b/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java
index ae267ed4bde8..ee547841ec51 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java
@@ -15,6 +15,7 @@
package io.confluent.ksql.connect;
+import io.confluent.ksql.security.KsqlPrincipal;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -39,7 +40,7 @@ public interface ConnectRequestHeadersExtension {
*
*
Set this to {@code false} in order to use this
* {@code ConnectRequestHeadersExtension} for additional custom headers
- * (via {@link ConnectRequestHeadersExtension#getHeaders()}) only, without
+ * (via {@link ConnectRequestHeadersExtension#getHeaders(Optional)}) only, without
* impacting ksqlDB's default behavior for the auth header.
*
* @return whether to use the custom auth header returned from
@@ -68,9 +69,12 @@ default boolean shouldUseCustomAuthHeader() {
* such as those required for authenticating with Connect. The custom headers are added
* to the request in addition to, and after, ksqlDB's default headers.
*
+ * @param userPrincipal principal associated with the user who submitted the connector
+ * request to ksqlDB, if present (i.e., if user authentication
+ * is enabled)
* @return additional headers to be included with connector requests made by ksqlDB
*/
- default Map getHeaders() {
+ default Map getHeaders(Optional userPrincipal) {
return Collections.emptyMap();
}
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java b/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java
index 58f7596fb806..6205510cbb3e 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java
@@ -28,9 +28,15 @@
public class DefaultKsqlPrincipal implements KsqlPrincipal {
private final Principal principal;
+ private final String ipAddress;
public DefaultKsqlPrincipal(final Principal principal) {
+ this(principal, "");
+ }
+
+ protected DefaultKsqlPrincipal(final Principal principal, final String ipAddress) {
this.principal = Objects.requireNonNull(principal, "principal");
+ this.ipAddress = Objects.requireNonNull(ipAddress, "ipAddress");
}
@Override
@@ -54,4 +60,17 @@ public Map getUserProperties() {
public Principal getOriginalPrincipal() {
return principal;
}
+
+ @Override
+ public String getIpAddress() {
+ return ipAddress;
+ }
+
+ /**
+ * IP address is populated from the request context, and subsequently passed
+ * throughout the engine.
+ */
+ public DefaultKsqlPrincipal withIpAddress(final String ipAddress) {
+ return new DefaultKsqlPrincipal(principal, ipAddress);
+ }
}
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java b/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java
index dff227e583f1..bd68ac7edc07 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java
@@ -31,4 +31,21 @@ default Map getUserProperties() {
return Collections.emptyMap();
}
+ /**
+ * Returns the user's IP address, as set by the ksqlDB server's request context.
+ *
+ * This method never returns {@code null}. An empty string may be returned in
+ * certain situations (those where an IP address is not available, e.g., domain
+ * socket requests).
+ *
+ *
Overriding the implementation of this method from custom {@code KsqlPrincipal}
+ * implementations has no effect on the return value of this method, when called from
+ * custom extensions, because incoming {@code KsqlPrincipal} instances are wrapped
+ * inside ksqlDB's own {@link DefaultKsqlPrincipal} before being passed throughout
+ * the ksqlDB engine, and {@code DefaultKsqlPrincipal} has its own implementation
+ * for tracking and returning the IP address from this method.
+ */
+ default String getIpAddress() {
+ return "";
+ }
}
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java
index 307985c3f34d..402e3097d310 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java
@@ -15,6 +15,7 @@
package io.confluent.ksql.services;
+import io.confluent.ksql.security.KsqlPrincipal;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
@@ -23,7 +24,8 @@ public interface ConnectClientFactory {
ConnectClient get(
Optional authHeader,
- List> incomingRequestHeaders
+ List> incomingRequestHeaders,
+ Optional userPrincipal
);
default void close() {}
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java
index cf54fbbd37fc..08656bb24dfc 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java
@@ -17,6 +17,7 @@
import com.google.common.annotations.VisibleForTesting;
import io.confluent.ksql.connect.ConnectRequestHeadersExtension;
+import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.util.FileWatcher;
import io.confluent.ksql.util.KsqlConfig;
import java.io.FileInputStream;
@@ -78,7 +79,8 @@ public DefaultConnectClientFactory(
@Override
public synchronized DefaultConnectClient get(
final Optional ksqlAuthHeader,
- final List> incomingRequestHeaders
+ final List> incomingRequestHeaders,
+ final Optional userPrincipal
) {
if (defaultConnectAuthHeader == null) {
defaultConnectAuthHeader = buildDefaultAuthHeader();
@@ -91,7 +93,7 @@ public synchronized DefaultConnectClient get(
ksqlConfig.getString(KsqlConfig.CONNECT_URL_PROPERTY),
buildAuthHeader(ksqlAuthHeader, incomingRequestHeaders),
requestHeadersExtension
- .map(ConnectRequestHeadersExtension::getHeaders)
+ .map(extension -> extension.getHeaders(userPrincipal))
.orElse(Collections.emptyMap()),
Optional.ofNullable(newSslContext(configWithPrefixOverrides)),
shouldVerifySslHostname(configWithPrefixOverrides)
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java
index 006f67f90018..2aacd858e226 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java
@@ -42,7 +42,8 @@ public static ServiceContext create(
Collections.emptyMap())::get,
() -> new DefaultConnectClientFactory(ksqlConfig).get(
Optional.empty(),
- Collections.emptyList()),
+ Collections.emptyList(),
+ Optional.empty()),
ksqlClientSupplier
);
}
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java b/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java
index 63ef97a2a46b..b2610b61a3b6 100644
--- a/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java
@@ -61,7 +61,8 @@ public static KsqlContext create(
adminClient,
kafkaTopicClient,
() -> schemaRegistryClient,
- new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList())
+ new DefaultConnectClientFactory(ksqlConfig)
+ .get(Optional.empty(), Collections.emptyList(), Optional.empty())
);
final String metricsPrefix = "instance-" + COUNTER.getAndIncrement() + "-";
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java
index 36f0f1f79605..4a6461a2e24c 100644
--- a/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java
@@ -27,6 +27,7 @@
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.connect.ConnectRequestHeadersExtension;
+import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.test.util.KsqlTestFolder;
import io.confluent.ksql.util.KsqlConfig;
import io.vertx.core.http.HttpHeaders;
@@ -96,6 +97,8 @@ public class DefaultConnectClientFactoryTest {
private ConnectRequestHeadersExtension requestHeadersExtension;
@Mock
private List> incomingRequestHeaders;
+ @Mock
+ private KsqlPrincipal userPrincipal;
private String credentialsPath;
@@ -107,7 +110,8 @@ public void setUp() {
when(config.getString(KsqlConfig.CONNECT_URL_PROPERTY)).thenReturn("http://localhost:8034");
when(config.getString(KsqlConfig.CONNECT_BASIC_AUTH_CREDENTIALS_SOURCE_PROPERTY)).thenReturn("NONE");
- when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX)).thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE);
+ when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX))
+ .thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE);
connectClientFactory = new DefaultConnectClientFactory(config);
}
@@ -115,7 +119,8 @@ public void setUp() {
@Test
public void shouldBuildWithoutAuthHeader() {
// When:
- final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ final DefaultConnectClient connectClient =
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
@@ -128,7 +133,8 @@ public void shouldBuildAuthHeader() throws Exception {
givenValidCredentialsFile();
// When:
- final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ final DefaultConnectClient connectClient =
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
// Then:
assertThat(connectClient.getRequestHeaders(),
@@ -142,8 +148,8 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception {
givenValidCredentialsFile();
// When: get() is called twice
- connectClientFactory.get(Optional.empty(), Collections.emptyList());
- connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
// Then: only loaded the credentials once -- ideally we'd check the number of times the file
// was read but this is an acceptable proxy for this unit test
@@ -154,7 +160,7 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception {
public void shouldUseKsqlAuthHeaderIfNoAuthHeaderPresent() {
// When:
final DefaultConnectClient connectClient =
- connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList());
+ connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList(), Optional.empty());
// Then:
assertThat(connectClient.getRequestHeaders(),
@@ -168,7 +174,8 @@ public void shouldNotFailOnUnreadableCredentials() throws Exception {
givenInvalidCredentialsFiles();
// When:
- final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ final DefaultConnectClient connectClient =
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
@@ -181,7 +188,8 @@ public void shouldNotFailOnMissingCredentials() {
// no credentials file present
// When:
- final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ final DefaultConnectClient connectClient =
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
@@ -195,7 +203,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception {
// no credentials file present
// verify that no auth header is present
- assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
+ assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
is(EMPTY_HEADERS));
// When: credentials file is created
@@ -205,7 +213,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception {
// Then: auth header is present
assertThatEventually(
"Should load newly created credentials",
- () -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
+ () -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER)),
TimeUnit.SECONDS.toMillis(1),
TimeUnit.SECONDS.toMillis(1)
@@ -220,7 +228,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception {
givenValidCredentialsFile();
// verify auth header is present
- assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
+ assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER)));
// When: credentials file is modified
@@ -230,7 +238,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception {
// Then: new auth header is present
assertThatEventually(
"Should load updated credentials",
- () -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
+ () -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, OTHER_EXPECTED_HEADER)),
TimeUnit.SECONDS.toMillis(1),
TimeUnit.SECONDS.toMillis(1)
@@ -247,12 +255,12 @@ public void shouldPassCustomRequestHeaders() {
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);
- when(requestHeadersExtension.getHeaders())
+ when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));
// When:
final DefaultConnectClient connectClient =
- connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal));
// Then:
assertThat(connectClient.getRequestHeaders(), arrayContaining(header("header", "value")));
@@ -271,13 +279,13 @@ public void shouldPassCustomRequestHeadersInAdditionToDefaultBasicAuthHeader() t
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);
- when(requestHeadersExtension.getHeaders())
+ when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));
when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(false);
// When:
final DefaultConnectClient connectClient =
- connectClientFactory.get(Optional.empty(), Collections.emptyList());
+ connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal));
// Then:
assertThat(connectClient.getRequestHeaders(),
@@ -299,14 +307,14 @@ public void shouldFavorCustomAuthHeaderOverBasicAuthHeader() throws Exception {
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);
- when(requestHeadersExtension.getHeaders())
+ when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));
when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(true);
when(requestHeadersExtension.getAuthHeader(incomingRequestHeaders)).thenReturn(Optional.of("some custom auth"));
// When:
final DefaultConnectClient connectClient =
- connectClientFactory.get(Optional.empty(), incomingRequestHeaders);
+ connectClientFactory.get(Optional.empty(), incomingRequestHeaders, Optional.of(userPrincipal));
// Then:
assertThat(connectClient.getRequestHeaders(),
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java b/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java
index 3486b81cdfe1..57de7b0281fc 100644
--- a/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java
@@ -114,7 +114,8 @@ public static ServiceContext create(
adminClient,
new KafkaTopicClientImpl(() -> adminClient),
srClientFactory,
- new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList()),
+ new DefaultConnectClientFactory(ksqlConfig)
+ .get(Optional.empty(), Collections.emptyList(), Optional.empty()),
new KafkaConsumerGroupClientImpl(() -> adminClient)
);
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java
index b8ec31d2ae10..8a64254fb45c 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java
@@ -15,10 +15,10 @@
package io.confluent.ksql.api.auth;
-import io.confluent.ksql.security.KsqlPrincipal;
+import io.confluent.ksql.security.DefaultKsqlPrincipal;
import io.vertx.ext.auth.User;
public interface ApiUser extends User {
- KsqlPrincipal getPrincipal();
+ DefaultKsqlPrincipal getPrincipal();
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java
index abac37fcd4fe..934253fb058e 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java
@@ -26,7 +26,6 @@
import io.confluent.ksql.api.server.Server;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import io.confluent.ksql.security.DefaultKsqlPrincipal;
-import io.confluent.ksql.security.KsqlPrincipal;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.json.JsonObject;
@@ -101,7 +100,7 @@ public void handle(final RoutingContext routingContext) {
private static class AuthPluginUser implements ApiUser {
- private final KsqlPrincipal principal;
+ private final DefaultKsqlPrincipal principal;
AuthPluginUser(final Principal principal) {
Objects.requireNonNull(principal);
@@ -132,7 +131,7 @@ public void setAuthProvider(final AuthProvider authProvider) {
}
@Override
- public KsqlPrincipal getPrincipal() {
+ public DefaultKsqlPrincipal getPrincipal() {
return principal;
}
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java
index 1441e6f96ab9..47e5f9428315 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java
@@ -37,8 +37,11 @@ public static DefaultApiSecurityContext create(final RoutingContext routingConte
final ApiUser apiUser = (ApiUser) user;
final String authToken = routingContext.request().getHeader("Authorization");
final List> requestHeaders = routingContext.request().headers().entries();
+ final String ipAddress = routingContext.request().remoteAddress().host();
return new DefaultApiSecurityContext(
- apiUser != null ? apiUser.getPrincipal() : null,
+ apiUser != null
+ ? apiUser.getPrincipal().withIpAddress(ipAddress == null ? "" : ipAddress)
+ : null,
authToken,
requestHeaders);
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java
index 0995c7f0e5fd..4617de9fc125 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java
@@ -18,6 +18,7 @@
import com.google.common.annotations.VisibleForTesting;
import io.confluent.ksql.api.server.Server;
import io.confluent.ksql.rest.server.KsqlRestConfig;
+import io.confluent.ksql.security.DefaultKsqlPrincipal;
import io.confluent.ksql.security.KsqlPrincipal;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
@@ -131,7 +132,7 @@ private void getUser(
final boolean authorized = validateRoles(lc, allowedRoles);
// if the subject from the login context is already a KsqlPrincipal, use the subject
- // directly rather than creating a new one
+ // (wrapped inside another DefaultKsqlPrincipal) rather than creating a new one
final Optional ksqlPrincipal = lc.getSubject().getPrincipals().stream()
.filter(p -> p instanceof KsqlPrincipal)
.map(p -> (KsqlPrincipal)p)
@@ -157,7 +158,7 @@ private static boolean validateRoles(final LoginContext lc, final List a
@SuppressWarnings("deprecation")
static class JaasUser extends io.vertx.ext.auth.AbstractUser implements ApiUser {
- private final KsqlPrincipal principal;
+ private final DefaultKsqlPrincipal principal;
private final boolean authorized;
JaasUser(
@@ -165,18 +166,21 @@ static class JaasUser extends io.vertx.ext.auth.AbstractUser implements ApiUser
final String password,
final boolean authorized
) {
- this(
- new JaasPrincipal(
- Objects.requireNonNull(username, "username"),
- Objects.requireNonNull(password, "password")),
- authorized);
+ this.principal = new JaasPrincipal(
+ Objects.requireNonNull(username, "username"),
+ Objects.requireNonNull(password, "password")
+ );
+ this.authorized = authorized;
}
JaasUser(
final KsqlPrincipal principal,
final boolean authorized
) {
- this.principal = Objects.requireNonNull(principal, "principal");
+ Objects.requireNonNull(principal, "principal");
+ this.principal = principal instanceof DefaultKsqlPrincipal
+ ? (DefaultKsqlPrincipal) principal
+ : new DefaultKsqlPrincipal(principal);
this.authorized = authorized;
}
@@ -198,7 +202,7 @@ public void setAuthProvider(final AuthProvider authProvider) {
}
@Override
- public KsqlPrincipal getPrincipal() {
+ public DefaultKsqlPrincipal getPrincipal() {
return principal;
}
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java
index f92ae7fcdbbe..76c549f2b1b5 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java
@@ -15,8 +15,9 @@
package io.confluent.ksql.api.auth;
-import io.confluent.ksql.security.KsqlPrincipal;
+import io.confluent.ksql.security.DefaultKsqlPrincipal;
import java.nio.charset.StandardCharsets;
+import java.security.Principal;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
@@ -29,14 +30,22 @@
* extensions.
*
*/
-public class JaasPrincipal implements KsqlPrincipal {
+public class JaasPrincipal extends DefaultKsqlPrincipal {
private final String name;
+ private final String password;
private final String token;
public JaasPrincipal(final String name, final String password) {
+ this(name, password, "");
+ }
+
+ private JaasPrincipal(final String name, final String password, final String ipAddress) {
+ super(new BasicJaasPrincipal(name), ipAddress);
+
this.name = Objects.requireNonNull(name, "name");
- this.token = createToken(name, Objects.requireNonNull(password));
+ this.password = Objects.requireNonNull(password, "password");
+ this.token = createToken(name, password);
}
@Override
@@ -49,13 +58,36 @@ public Map getUserProperties() {
return Collections.emptyMap();
}
- private String createToken(final String name, final String secret) {
+ public String getToken() {
+ return token;
+ }
+
+ /**
+ * Preserve token functionality by returning another JaasPrincipal when the
+ * IP address is set from the routing context.
+ */
+ @Override
+ public DefaultKsqlPrincipal withIpAddress(final String ipAddress) {
+ return new JaasPrincipal(name, password, ipAddress);
+ }
+
+ private static String createToken(final String name, final String secret) {
return Base64.getEncoder().encodeToString((name + ":" + secret)
.getBytes(StandardCharsets.ISO_8859_1));
}
- public String getToken() {
- return token;
+ static class BasicJaasPrincipal implements Principal {
+
+ private final String name;
+
+ BasicJaasPrincipal(final String name) {
+ this.name = name;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
}
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java
index c0997186b84c..351eff231ad3 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java
@@ -16,7 +16,6 @@
package io.confluent.ksql.api.auth;
import io.confluent.ksql.security.DefaultKsqlPrincipal;
-import io.confluent.ksql.security.KsqlPrincipal;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.http.HttpConnection;
@@ -61,7 +60,7 @@ public static boolean isAuthenticatedAsSystemUser(final RoutingContext routingCo
private static class SystemUser implements ApiUser {
- private final KsqlPrincipal principal;
+ private final DefaultKsqlPrincipal principal;
SystemUser(final Principal principal) {
Objects.requireNonNull(principal);
@@ -94,7 +93,7 @@ public void setAuthProvider(final AuthProvider authProvider) {
}
@Override
- public KsqlPrincipal getPrincipal() {
+ public DefaultKsqlPrincipal getPrincipal() {
return principal;
}
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java
index 7a46ba28bf44..b8cb449ad41c 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java
@@ -86,7 +86,8 @@ public KsqlSecurityContext provide(final ApiSecurityContext apiSecurityContext)
schemaRegistryClientFactory,
connectClientFactory,
sharedClient,
- requestHeaders)
+ requestHeaders,
+ principal)
);
}
@@ -100,7 +101,8 @@ public KsqlSecurityContext provide(final ApiSecurityContext apiSecurityContext)
provider.getSchemaRegistryClientFactory(principal.get()),
connectClientFactory,
sharedClient,
- requestHeaders)))
+ requestHeaders,
+ principal)))
.get();
}
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java
index f9d89a079642..b1d93f21596b 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java
@@ -643,7 +643,7 @@ public static KsqlRestApplication buildApplication(
final ServiceContext tempServiceContext = new LazyServiceContext(() ->
RestServiceContextFactory.create(ksqlConfig, Optional.empty(),
schemaRegistryClientFactory, connectClientFactory, sharedClient,
- Collections.emptyList()));
+ Collections.emptyList(), Optional.empty()));
final String kafkaClusterId = KafkaClusterUtil.getKafkaClusterId(tempServiceContext);
final String ksqlServerId = ksqlConfig.getString(KsqlConfig.KSQL_SERVICE_ID_CONFIG);
updatedRestProps.putAll(
@@ -657,7 +657,8 @@ public static KsqlRestApplication buildApplication(
schemaRegistryClientFactory,
connectClientFactory,
sharedClient,
- Collections.emptyList()));
+ Collections.emptyList(),
+ Optional.empty()));
return buildApplication(
"",
diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java
index c7a4510b0e1f..831e54311250 100644
--- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java
+++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java
@@ -17,6 +17,7 @@
import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient;
import io.confluent.ksql.rest.client.KsqlClient;
+import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.services.ConnectClientFactory;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.services.ServiceContextFactory;
@@ -41,7 +42,8 @@ ServiceContext create(
Supplier srClientFactory,
ConnectClientFactory connectClientFactory,
KsqlClient sharedClient,
- List> requestHeaders
+ List> requestHeaders,
+ Optional userPrincipal
);
}
@@ -54,7 +56,8 @@ ServiceContext create(
Supplier srClientFactory,
ConnectClientFactory connectClientFactory,
KsqlClient sharedClient,
- List> requestHeaders
+ List> requestHeaders,
+ Optional userPrincipal
);
}
@@ -64,7 +67,8 @@ public static ServiceContext create(
final Supplier schemaRegistryClientFactory,
final ConnectClientFactory connectClientFactory,
final KsqlClient sharedClient,
- final List> requestHeaders
+ final List> requestHeaders,
+ final Optional userPrincipal
) {
return create(
ksqlConfig,
@@ -73,7 +77,8 @@ public static ServiceContext create(
schemaRegistryClientFactory,
connectClientFactory,
sharedClient,
- requestHeaders
+ requestHeaders,
+ userPrincipal
);
}
@@ -84,13 +89,14 @@ public static ServiceContext create(
final Supplier srClientFactory,
final ConnectClientFactory connectClientFactory,
final KsqlClient sharedClient,
- final List> requestHeaders
+ final List> requestHeaders,
+ final Optional userPrincipal
) {
return ServiceContextFactory.create(
ksqlConfig,
kafkaClientSupplier,
srClientFactory,
- () -> connectClientFactory.get(authHeader, requestHeaders),
+ () -> connectClientFactory.get(authHeader, requestHeaders, userPrincipal),
() -> new DefaultKsqlClient(authHeader, sharedClient)
);
}
diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java
index 4079a29061c7..83d4348850e2 100644
--- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java
+++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java
@@ -81,7 +81,7 @@ public void setup() {
userServiceContextFactory,
ksqlConfig,
() -> schemaRegistryClientFactory,
- (authHeader, userPrincipal) -> connectClient,
+ (authHeader, requestHeaders, userPrincipal) -> connectClient,
ksqlClient
);
@@ -89,9 +89,9 @@ public void setup() {
when(apiSecurityContext.getAuthToken()).thenReturn(Optional.empty());
when(apiSecurityContext.getRequestHeaders()).thenReturn(incomingRequestHeaders);
- when(defaultServiceContextFactory.create(any(), any(), any(), any(), any(), any()))
+ when(defaultServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any()))
.thenReturn(defaultServiceContext);
- when(userServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any()))
+ when(userServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(userServiceContext);
}
@@ -135,7 +135,7 @@ public void shouldCreateUserServiceContextIfUserContextProviderIsEnabled() {
// Then:
verify(userServiceContextFactory)
- .create(eq(ksqlConfig), eq(Optional.empty()), any(), any(), any(), any(), any());
+ .create(eq(ksqlConfig), eq(Optional.empty()), any(), any(), any(), any(), any(), any());
assertThat(ksqlSecurityContext.getUserPrincipal(), is(Optional.of(user1)));
assertThat(ksqlSecurityContext.getServiceContext(), is(userServiceContext));
}
@@ -150,7 +150,7 @@ public void shouldPassAuthHeaderToDefaultFactory() {
ksqlSecurityContextProvider.provide(apiSecurityContext);
// Then:
- verify(defaultServiceContextFactory).create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any());
+ verify(defaultServiceContextFactory).create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any());
}
@Test
@@ -164,7 +164,7 @@ public void shouldPassAuthHeaderToUserFactory() {
// Then:
verify(userServiceContextFactory)
- .create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any());
+ .create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any(), any());
}
@Test
@@ -176,7 +176,7 @@ public void shouldPassRequestHeadersToDefaultFactory() {
ksqlSecurityContextProvider.provide(apiSecurityContext);
// Then:
- verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), eq(incomingRequestHeaders));
+ verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), eq(incomingRequestHeaders), any());
}
@Test
@@ -189,6 +189,31 @@ public void shouldPassRequestHeadersToUserFactory() {
// Then:
verify(userServiceContextFactory)
- .create(any(), any(), any(), any(), any(), any(), eq(incomingRequestHeaders));
+ .create(any(), any(), any(), any(), any(), any(), eq(incomingRequestHeaders), any());
+ }
+
+ @Test
+ public void shouldPassUserPrincipalToDefaultFactory() {
+ // Given:
+ when(securityExtension.getUserContextProvider()).thenReturn(Optional.empty());
+
+ // When:
+ ksqlSecurityContextProvider.provide(apiSecurityContext);
+
+ // Then:
+ verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), any(), eq(Optional.of(user1)));
+ }
+
+ @Test
+ public void shouldPassUserPrincipalToUserFactory() {
+ // Given:
+ when(securityExtension.getUserContextProvider()).thenReturn(Optional.of(userContextProvider));
+
+ // When:
+ ksqlSecurityContextProvider.provide(apiSecurityContext);
+
+ // Then:
+ verify(userServiceContextFactory)
+ .create(any(), any(), any(), any(), any(), any(), any(), eq(Optional.of(user1)));
}
}
diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java
index e899b5161d86..0ea43fa19cc0 100644
--- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java
+++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java
@@ -337,7 +337,7 @@ protected void initialize() {
3,
serviceContext.get(),
() -> serviceContext.get().getSchemaRegistryClient(),
- (authHeader, userPrincipal) -> serviceContext.get().getConnectClient(),
+ (authHeader, requestHeaders, userPrincipal) -> serviceContext.get().getConnectClient(),
vertx,
InternalKsqlClientFactory.createInternalClient(
PropertiesUtil.toMapStrings(ksqlRestConfig.originals()),
diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java
index a1b9d43912f9..d9f532e1b55f 100644
--- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java
+++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java
@@ -22,8 +22,8 @@ public interface InternalSimpleKsqlClientFactory {
public static DefaultServiceContextFactory createDefault(
final InternalSimpleKsqlClientFactory ksqlClientFactory
) {
- return (ksqlConfig, authHeader, srClientFactory,
- connectClientFactory, sharedClient, userPrincipal) -> {
+ return (ksqlConfig, authHeader, srClientFactory, connectClientFactory,
+ sharedClient, requestHeaders, userPrincipal) -> {
return createUser(ksqlClientFactory).create(
ksqlConfig,
authHeader,
@@ -31,6 +31,7 @@ public static DefaultServiceContextFactory createDefault(
srClientFactory,
connectClientFactory,
sharedClient,
+ requestHeaders,
userPrincipal
);
};
@@ -39,8 +40,8 @@ public static DefaultServiceContextFactory createDefault(
public static UserServiceContextFactory createUser(
final InternalSimpleKsqlClientFactory ksqlClientFactory
) {
- return (ksqlConfig, authHeader, kafkaClientSupplier,
- srClientFactory, connectClientFactory, sharedClient, userPrincipal) -> {
+ return (ksqlConfig, authHeader, kafkaClientSupplier, srClientFactory,
+ connectClientFactory, sharedClient, requestHeaders, userPrincipal) -> {
return ServiceContextFactory.create(
ksqlConfig,
kafkaClientSupplier,