diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java b/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java index ae267ed4bde8..ee547841ec51 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/connect/ConnectRequestHeadersExtension.java @@ -15,6 +15,7 @@ package io.confluent.ksql.connect; +import io.confluent.ksql.security.KsqlPrincipal; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,7 +40,7 @@ public interface ConnectRequestHeadersExtension { * *

Set this to {@code false} in order to use this * {@code ConnectRequestHeadersExtension} for additional custom headers - * (via {@link ConnectRequestHeadersExtension#getHeaders()}) only, without + * (via {@link ConnectRequestHeadersExtension#getHeaders(Optional)}) only, without * impacting ksqlDB's default behavior for the auth header. * * @return whether to use the custom auth header returned from @@ -68,9 +69,12 @@ default boolean shouldUseCustomAuthHeader() { * such as those required for authenticating with Connect. The custom headers are added * to the request in addition to, and after, ksqlDB's default headers. * + * @param userPrincipal principal associated with the user who submitted the connector + * request to ksqlDB, if present (i.e., if user authentication + * is enabled) * @return additional headers to be included with connector requests made by ksqlDB */ - default Map getHeaders() { + default Map getHeaders(Optional userPrincipal) { return Collections.emptyMap(); } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java b/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java index 58f7596fb806..6205510cbb3e 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/security/DefaultKsqlPrincipal.java @@ -28,9 +28,15 @@ public class DefaultKsqlPrincipal implements KsqlPrincipal { private final Principal principal; + private final String ipAddress; public DefaultKsqlPrincipal(final Principal principal) { + this(principal, ""); + } + + protected DefaultKsqlPrincipal(final Principal principal, final String ipAddress) { this.principal = Objects.requireNonNull(principal, "principal"); + this.ipAddress = Objects.requireNonNull(ipAddress, "ipAddress"); } @Override @@ -54,4 +60,17 @@ public Map getUserProperties() { public Principal getOriginalPrincipal() { return principal; } + + @Override + public String getIpAddress() { + return ipAddress; + } + + /** + * IP address is populated from the request context, and subsequently passed + * throughout the engine. + */ + public DefaultKsqlPrincipal withIpAddress(final String ipAddress) { + return new DefaultKsqlPrincipal(principal, ipAddress); + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java b/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java index dff227e583f1..bd68ac7edc07 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/security/KsqlPrincipal.java @@ -31,4 +31,21 @@ default Map getUserProperties() { return Collections.emptyMap(); } + /** + * Returns the user's IP address, as set by the ksqlDB server's request context. + * + *

This method never returns {@code null}. An empty string may be returned in + * certain situations (those where an IP address is not available, e.g., domain + * socket requests). + * + *

Overriding the implementation of this method from custom {@code KsqlPrincipal} + * implementations has no effect on the return value of this method, when called from + * custom extensions, because incoming {@code KsqlPrincipal} instances are wrapped + * inside ksqlDB's own {@link DefaultKsqlPrincipal} before being passed throughout + * the ksqlDB engine, and {@code DefaultKsqlPrincipal} has its own implementation + * for tracking and returning the IP address from this method. + */ + default String getIpAddress() { + return ""; + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java index 307985c3f34d..402e3097d310 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ConnectClientFactory.java @@ -15,6 +15,7 @@ package io.confluent.ksql.services; +import io.confluent.ksql.security.KsqlPrincipal; import java.util.List; import java.util.Map.Entry; import java.util.Optional; @@ -23,7 +24,8 @@ public interface ConnectClientFactory { ConnectClient get( Optional authHeader, - List> incomingRequestHeaders + List> incomingRequestHeaders, + Optional userPrincipal ); default void close() {} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java index cf54fbbd37fc..08656bb24dfc 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/DefaultConnectClientFactory.java @@ -17,6 +17,7 @@ import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.connect.ConnectRequestHeadersExtension; +import io.confluent.ksql.security.KsqlPrincipal; import io.confluent.ksql.util.FileWatcher; import io.confluent.ksql.util.KsqlConfig; import java.io.FileInputStream; @@ -78,7 +79,8 @@ public DefaultConnectClientFactory( @Override public synchronized DefaultConnectClient get( final Optional ksqlAuthHeader, - final List> incomingRequestHeaders + final List> incomingRequestHeaders, + final Optional userPrincipal ) { if (defaultConnectAuthHeader == null) { defaultConnectAuthHeader = buildDefaultAuthHeader(); @@ -91,7 +93,7 @@ public synchronized DefaultConnectClient get( ksqlConfig.getString(KsqlConfig.CONNECT_URL_PROPERTY), buildAuthHeader(ksqlAuthHeader, incomingRequestHeaders), requestHeadersExtension - .map(ConnectRequestHeadersExtension::getHeaders) + .map(extension -> extension.getHeaders(userPrincipal)) .orElse(Collections.emptyMap()), Optional.ofNullable(newSslContext(configWithPrefixOverrides)), shouldVerifySslHostname(configWithPrefixOverrides) diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java index 006f67f90018..2aacd858e226 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/services/ServiceContextFactory.java @@ -42,7 +42,8 @@ public static ServiceContext create( Collections.emptyMap())::get, () -> new DefaultConnectClientFactory(ksqlConfig).get( Optional.empty(), - Collections.emptyList()), + Collections.emptyList(), + Optional.empty()), ksqlClientSupplier ); } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java b/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java index 63ef97a2a46b..b2610b61a3b6 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/embedded/KsqlContextTestUtil.java @@ -61,7 +61,8 @@ public static KsqlContext create( adminClient, kafkaTopicClient, () -> schemaRegistryClient, - new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList()) + new DefaultConnectClientFactory(ksqlConfig) + .get(Optional.empty(), Collections.emptyList(), Optional.empty()) ); final String metricsPrefix = "instance-" + COUNTER.getAndIncrement() + "-"; diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java index 36f0f1f79605..4a6461a2e24c 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/services/DefaultConnectClientFactoryTest.java @@ -27,6 +27,7 @@ import com.google.common.collect.ImmutableMap; import io.confluent.ksql.connect.ConnectRequestHeadersExtension; +import io.confluent.ksql.security.KsqlPrincipal; import io.confluent.ksql.test.util.KsqlTestFolder; import io.confluent.ksql.util.KsqlConfig; import io.vertx.core.http.HttpHeaders; @@ -96,6 +97,8 @@ public class DefaultConnectClientFactoryTest { private ConnectRequestHeadersExtension requestHeadersExtension; @Mock private List> incomingRequestHeaders; + @Mock + private KsqlPrincipal userPrincipal; private String credentialsPath; @@ -107,7 +110,8 @@ public void setUp() { when(config.getString(KsqlConfig.CONNECT_URL_PROPERTY)).thenReturn("http://localhost:8034"); when(config.getString(KsqlConfig.CONNECT_BASIC_AUTH_CREDENTIALS_SOURCE_PROPERTY)).thenReturn("NONE"); - when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX)).thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE); + when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX)) + .thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE); connectClientFactory = new DefaultConnectClientFactory(config); } @@ -115,7 +119,8 @@ public void setUp() { @Test public void shouldBuildWithoutAuthHeader() { // When: - final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList()); + final DefaultConnectClient connectClient = + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); // Then: assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS)); @@ -128,7 +133,8 @@ public void shouldBuildAuthHeader() throws Exception { givenValidCredentialsFile(); // When: - final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList()); + final DefaultConnectClient connectClient = + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); // Then: assertThat(connectClient.getRequestHeaders(), @@ -142,8 +148,8 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception { givenValidCredentialsFile(); // When: get() is called twice - connectClientFactory.get(Optional.empty(), Collections.emptyList()); - connectClientFactory.get(Optional.empty(), Collections.emptyList()); + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); // Then: only loaded the credentials once -- ideally we'd check the number of times the file // was read but this is an acceptable proxy for this unit test @@ -154,7 +160,7 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception { public void shouldUseKsqlAuthHeaderIfNoAuthHeaderPresent() { // When: final DefaultConnectClient connectClient = - connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList()); + connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList(), Optional.empty()); // Then: assertThat(connectClient.getRequestHeaders(), @@ -168,7 +174,8 @@ public void shouldNotFailOnUnreadableCredentials() throws Exception { givenInvalidCredentialsFiles(); // When: - final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList()); + final DefaultConnectClient connectClient = + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); // Then: assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS)); @@ -181,7 +188,8 @@ public void shouldNotFailOnMissingCredentials() { // no credentials file present // When: - final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList()); + final DefaultConnectClient connectClient = + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()); // Then: assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS)); @@ -195,7 +203,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception { // no credentials file present // verify that no auth header is present - assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(), + assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(), is(EMPTY_HEADERS)); // When: credentials file is created @@ -205,7 +213,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception { // Then: auth header is present assertThatEventually( "Should load newly created credentials", - () -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(), + () -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(), arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER)), TimeUnit.SECONDS.toMillis(1), TimeUnit.SECONDS.toMillis(1) @@ -220,7 +228,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception { givenValidCredentialsFile(); // verify auth header is present - assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(), + assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(), arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER))); // When: credentials file is modified @@ -230,7 +238,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception { // Then: new auth header is present assertThatEventually( "Should load updated credentials", - () -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(), + () -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(), arrayContaining(header(AUTH_HEADER_NAME, OTHER_EXPECTED_HEADER)), TimeUnit.SECONDS.toMillis(1), TimeUnit.SECONDS.toMillis(1) @@ -247,12 +255,12 @@ public void shouldPassCustomRequestHeaders() { // re-initialize client factory since request headers extension is configured in constructor connectClientFactory = new DefaultConnectClientFactory(config); - when(requestHeadersExtension.getHeaders()) + when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal))) .thenReturn(ImmutableMap.of("header", "value")); // When: final DefaultConnectClient connectClient = - connectClientFactory.get(Optional.empty(), Collections.emptyList()); + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal)); // Then: assertThat(connectClient.getRequestHeaders(), arrayContaining(header("header", "value"))); @@ -271,13 +279,13 @@ public void shouldPassCustomRequestHeadersInAdditionToDefaultBasicAuthHeader() t // re-initialize client factory since request headers extension is configured in constructor connectClientFactory = new DefaultConnectClientFactory(config); - when(requestHeadersExtension.getHeaders()) + when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal))) .thenReturn(ImmutableMap.of("header", "value")); when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(false); // When: final DefaultConnectClient connectClient = - connectClientFactory.get(Optional.empty(), Collections.emptyList()); + connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal)); // Then: assertThat(connectClient.getRequestHeaders(), @@ -299,14 +307,14 @@ public void shouldFavorCustomAuthHeaderOverBasicAuthHeader() throws Exception { // re-initialize client factory since request headers extension is configured in constructor connectClientFactory = new DefaultConnectClientFactory(config); - when(requestHeadersExtension.getHeaders()) + when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal))) .thenReturn(ImmutableMap.of("header", "value")); when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(true); when(requestHeadersExtension.getAuthHeader(incomingRequestHeaders)).thenReturn(Optional.of("some custom auth")); // When: final DefaultConnectClient connectClient = - connectClientFactory.get(Optional.empty(), incomingRequestHeaders); + connectClientFactory.get(Optional.empty(), incomingRequestHeaders, Optional.of(userPrincipal)); // Then: assertThat(connectClient.getRequestHeaders(), diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java b/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java index 3486b81cdfe1..57de7b0281fc 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/services/TestServiceContext.java @@ -114,7 +114,8 @@ public static ServiceContext create( adminClient, new KafkaTopicClientImpl(() -> adminClient), srClientFactory, - new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList()), + new DefaultConnectClientFactory(ksqlConfig) + .get(Optional.empty(), Collections.emptyList(), Optional.empty()), new KafkaConsumerGroupClientImpl(() -> adminClient) ); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java index b8ec31d2ae10..8a64254fb45c 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/ApiUser.java @@ -15,10 +15,10 @@ package io.confluent.ksql.api.auth; -import io.confluent.ksql.security.KsqlPrincipal; +import io.confluent.ksql.security.DefaultKsqlPrincipal; import io.vertx.ext.auth.User; public interface ApiUser extends User { - KsqlPrincipal getPrincipal(); + DefaultKsqlPrincipal getPrincipal(); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java index abac37fcd4fe..934253fb058e 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/AuthenticationPluginHandler.java @@ -26,7 +26,6 @@ import io.confluent.ksql.api.server.Server; import io.confluent.ksql.rest.server.KsqlRestConfig; import io.confluent.ksql.security.DefaultKsqlPrincipal; -import io.confluent.ksql.security.KsqlPrincipal; import io.vertx.core.AsyncResult; import io.vertx.core.Handler; import io.vertx.core.json.JsonObject; @@ -101,7 +100,7 @@ public void handle(final RoutingContext routingContext) { private static class AuthPluginUser implements ApiUser { - private final KsqlPrincipal principal; + private final DefaultKsqlPrincipal principal; AuthPluginUser(final Principal principal) { Objects.requireNonNull(principal); @@ -132,7 +131,7 @@ public void setAuthProvider(final AuthProvider authProvider) { } @Override - public KsqlPrincipal getPrincipal() { + public DefaultKsqlPrincipal getPrincipal() { return principal; } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java index 1441e6f96ab9..47e5f9428315 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/DefaultApiSecurityContext.java @@ -37,8 +37,11 @@ public static DefaultApiSecurityContext create(final RoutingContext routingConte final ApiUser apiUser = (ApiUser) user; final String authToken = routingContext.request().getHeader("Authorization"); final List> requestHeaders = routingContext.request().headers().entries(); + final String ipAddress = routingContext.request().remoteAddress().host(); return new DefaultApiSecurityContext( - apiUser != null ? apiUser.getPrincipal() : null, + apiUser != null + ? apiUser.getPrincipal().withIpAddress(ipAddress == null ? "" : ipAddress) + : null, authToken, requestHeaders); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java index 0995c7f0e5fd..4617de9fc125 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasAuthProvider.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.api.server.Server; import io.confluent.ksql.rest.server.KsqlRestConfig; +import io.confluent.ksql.security.DefaultKsqlPrincipal; import io.confluent.ksql.security.KsqlPrincipal; import io.vertx.core.AsyncResult; import io.vertx.core.Future; @@ -131,7 +132,7 @@ private void getUser( final boolean authorized = validateRoles(lc, allowedRoles); // if the subject from the login context is already a KsqlPrincipal, use the subject - // directly rather than creating a new one + // (wrapped inside another DefaultKsqlPrincipal) rather than creating a new one final Optional ksqlPrincipal = lc.getSubject().getPrincipals().stream() .filter(p -> p instanceof KsqlPrincipal) .map(p -> (KsqlPrincipal)p) @@ -157,7 +158,7 @@ private static boolean validateRoles(final LoginContext lc, final List a @SuppressWarnings("deprecation") static class JaasUser extends io.vertx.ext.auth.AbstractUser implements ApiUser { - private final KsqlPrincipal principal; + private final DefaultKsqlPrincipal principal; private final boolean authorized; JaasUser( @@ -165,18 +166,21 @@ static class JaasUser extends io.vertx.ext.auth.AbstractUser implements ApiUser final String password, final boolean authorized ) { - this( - new JaasPrincipal( - Objects.requireNonNull(username, "username"), - Objects.requireNonNull(password, "password")), - authorized); + this.principal = new JaasPrincipal( + Objects.requireNonNull(username, "username"), + Objects.requireNonNull(password, "password") + ); + this.authorized = authorized; } JaasUser( final KsqlPrincipal principal, final boolean authorized ) { - this.principal = Objects.requireNonNull(principal, "principal"); + Objects.requireNonNull(principal, "principal"); + this.principal = principal instanceof DefaultKsqlPrincipal + ? (DefaultKsqlPrincipal) principal + : new DefaultKsqlPrincipal(principal); this.authorized = authorized; } @@ -198,7 +202,7 @@ public void setAuthProvider(final AuthProvider authProvider) { } @Override - public KsqlPrincipal getPrincipal() { + public DefaultKsqlPrincipal getPrincipal() { return principal; } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java index f92ae7fcdbbe..76c549f2b1b5 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/JaasPrincipal.java @@ -15,8 +15,9 @@ package io.confluent.ksql.api.auth; -import io.confluent.ksql.security.KsqlPrincipal; +import io.confluent.ksql.security.DefaultKsqlPrincipal; import java.nio.charset.StandardCharsets; +import java.security.Principal; import java.util.Base64; import java.util.Collections; import java.util.Map; @@ -29,14 +30,22 @@ * extensions. *

*/ -public class JaasPrincipal implements KsqlPrincipal { +public class JaasPrincipal extends DefaultKsqlPrincipal { private final String name; + private final String password; private final String token; public JaasPrincipal(final String name, final String password) { + this(name, password, ""); + } + + private JaasPrincipal(final String name, final String password, final String ipAddress) { + super(new BasicJaasPrincipal(name), ipAddress); + this.name = Objects.requireNonNull(name, "name"); - this.token = createToken(name, Objects.requireNonNull(password)); + this.password = Objects.requireNonNull(password, "password"); + this.token = createToken(name, password); } @Override @@ -49,13 +58,36 @@ public Map getUserProperties() { return Collections.emptyMap(); } - private String createToken(final String name, final String secret) { + public String getToken() { + return token; + } + + /** + * Preserve token functionality by returning another JaasPrincipal when the + * IP address is set from the routing context. + */ + @Override + public DefaultKsqlPrincipal withIpAddress(final String ipAddress) { + return new JaasPrincipal(name, password, ipAddress); + } + + private static String createToken(final String name, final String secret) { return Base64.getEncoder().encodeToString((name + ":" + secret) .getBytes(StandardCharsets.ISO_8859_1)); } - public String getToken() { - return token; + static class BasicJaasPrincipal implements Principal { + + private final String name; + + BasicJaasPrincipal(final String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java index c0997186b84c..351eff231ad3 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/auth/SystemAuthenticationHandler.java @@ -16,7 +16,6 @@ package io.confluent.ksql.api.auth; import io.confluent.ksql.security.DefaultKsqlPrincipal; -import io.confluent.ksql.security.KsqlPrincipal; import io.vertx.core.AsyncResult; import io.vertx.core.Handler; import io.vertx.core.http.HttpConnection; @@ -61,7 +60,7 @@ public static boolean isAuthenticatedAsSystemUser(final RoutingContext routingCo private static class SystemUser implements ApiUser { - private final KsqlPrincipal principal; + private final DefaultKsqlPrincipal principal; SystemUser(final Principal principal) { Objects.requireNonNull(principal); @@ -94,7 +93,7 @@ public void setAuthProvider(final AuthProvider authProvider) { } @Override - public KsqlPrincipal getPrincipal() { + public DefaultKsqlPrincipal getPrincipal() { return principal; } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java index 7a46ba28bf44..b8cb449ad41c 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProvider.java @@ -86,7 +86,8 @@ public KsqlSecurityContext provide(final ApiSecurityContext apiSecurityContext) schemaRegistryClientFactory, connectClientFactory, sharedClient, - requestHeaders) + requestHeaders, + principal) ); } @@ -100,7 +101,8 @@ public KsqlSecurityContext provide(final ApiSecurityContext apiSecurityContext) provider.getSchemaRegistryClientFactory(principal.get()), connectClientFactory, sharedClient, - requestHeaders))) + requestHeaders, + principal))) .get(); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java index f9d89a079642..b1d93f21596b 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java @@ -643,7 +643,7 @@ public static KsqlRestApplication buildApplication( final ServiceContext tempServiceContext = new LazyServiceContext(() -> RestServiceContextFactory.create(ksqlConfig, Optional.empty(), schemaRegistryClientFactory, connectClientFactory, sharedClient, - Collections.emptyList())); + Collections.emptyList(), Optional.empty())); final String kafkaClusterId = KafkaClusterUtil.getKafkaClusterId(tempServiceContext); final String ksqlServerId = ksqlConfig.getString(KsqlConfig.KSQL_SERVICE_ID_CONFIG); updatedRestProps.putAll( @@ -657,7 +657,8 @@ public static KsqlRestApplication buildApplication( schemaRegistryClientFactory, connectClientFactory, sharedClient, - Collections.emptyList())); + Collections.emptyList(), + Optional.empty())); return buildApplication( "", diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java index c7a4510b0e1f..831e54311250 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java @@ -17,6 +17,7 @@ import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.rest.client.KsqlClient; +import io.confluent.ksql.security.KsqlPrincipal; import io.confluent.ksql.services.ConnectClientFactory; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.services.ServiceContextFactory; @@ -41,7 +42,8 @@ ServiceContext create( Supplier srClientFactory, ConnectClientFactory connectClientFactory, KsqlClient sharedClient, - List> requestHeaders + List> requestHeaders, + Optional userPrincipal ); } @@ -54,7 +56,8 @@ ServiceContext create( Supplier srClientFactory, ConnectClientFactory connectClientFactory, KsqlClient sharedClient, - List> requestHeaders + List> requestHeaders, + Optional userPrincipal ); } @@ -64,7 +67,8 @@ public static ServiceContext create( final Supplier schemaRegistryClientFactory, final ConnectClientFactory connectClientFactory, final KsqlClient sharedClient, - final List> requestHeaders + final List> requestHeaders, + final Optional userPrincipal ) { return create( ksqlConfig, @@ -73,7 +77,8 @@ public static ServiceContext create( schemaRegistryClientFactory, connectClientFactory, sharedClient, - requestHeaders + requestHeaders, + userPrincipal ); } @@ -84,13 +89,14 @@ public static ServiceContext create( final Supplier srClientFactory, final ConnectClientFactory connectClientFactory, final KsqlClient sharedClient, - final List> requestHeaders + final List> requestHeaders, + final Optional userPrincipal ) { return ServiceContextFactory.create( ksqlConfig, kafkaClientSupplier, srClientFactory, - () -> connectClientFactory.get(authHeader, requestHeaders), + () -> connectClientFactory.get(authHeader, requestHeaders, userPrincipal), () -> new DefaultKsqlClient(authHeader, sharedClient) ); } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java index 4079a29061c7..83d4348850e2 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/impl/DefaultKsqlSecurityContextProviderTest.java @@ -81,7 +81,7 @@ public void setup() { userServiceContextFactory, ksqlConfig, () -> schemaRegistryClientFactory, - (authHeader, userPrincipal) -> connectClient, + (authHeader, requestHeaders, userPrincipal) -> connectClient, ksqlClient ); @@ -89,9 +89,9 @@ public void setup() { when(apiSecurityContext.getAuthToken()).thenReturn(Optional.empty()); when(apiSecurityContext.getRequestHeaders()).thenReturn(incomingRequestHeaders); - when(defaultServiceContextFactory.create(any(), any(), any(), any(), any(), any())) + when(defaultServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any())) .thenReturn(defaultServiceContext); - when(userServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any())) + when(userServiceContextFactory.create(any(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(userServiceContext); } @@ -135,7 +135,7 @@ public void shouldCreateUserServiceContextIfUserContextProviderIsEnabled() { // Then: verify(userServiceContextFactory) - .create(eq(ksqlConfig), eq(Optional.empty()), any(), any(), any(), any(), any()); + .create(eq(ksqlConfig), eq(Optional.empty()), any(), any(), any(), any(), any(), any()); assertThat(ksqlSecurityContext.getUserPrincipal(), is(Optional.of(user1))); assertThat(ksqlSecurityContext.getServiceContext(), is(userServiceContext)); } @@ -150,7 +150,7 @@ public void shouldPassAuthHeaderToDefaultFactory() { ksqlSecurityContextProvider.provide(apiSecurityContext); // Then: - verify(defaultServiceContextFactory).create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any()); + verify(defaultServiceContextFactory).create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any()); } @Test @@ -164,7 +164,7 @@ public void shouldPassAuthHeaderToUserFactory() { // Then: verify(userServiceContextFactory) - .create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any()); + .create(any(), eq(Optional.of("some-auth")), any(), any(), any(), any(), any(), any()); } @Test @@ -176,7 +176,7 @@ public void shouldPassRequestHeadersToDefaultFactory() { ksqlSecurityContextProvider.provide(apiSecurityContext); // Then: - verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), eq(incomingRequestHeaders)); + verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), eq(incomingRequestHeaders), any()); } @Test @@ -189,6 +189,31 @@ public void shouldPassRequestHeadersToUserFactory() { // Then: verify(userServiceContextFactory) - .create(any(), any(), any(), any(), any(), any(), eq(incomingRequestHeaders)); + .create(any(), any(), any(), any(), any(), any(), eq(incomingRequestHeaders), any()); + } + + @Test + public void shouldPassUserPrincipalToDefaultFactory() { + // Given: + when(securityExtension.getUserContextProvider()).thenReturn(Optional.empty()); + + // When: + ksqlSecurityContextProvider.provide(apiSecurityContext); + + // Then: + verify(defaultServiceContextFactory).create(any(), any(), any(), any(), any(), any(), eq(Optional.of(user1))); + } + + @Test + public void shouldPassUserPrincipalToUserFactory() { + // Given: + when(securityExtension.getUserContextProvider()).thenReturn(Optional.of(userContextProvider)); + + // When: + ksqlSecurityContextProvider.provide(apiSecurityContext); + + // Then: + verify(userServiceContextFactory) + .create(any(), any(), any(), any(), any(), any(), any(), eq(Optional.of(user1))); } } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java index e899b5161d86..0ea43fa19cc0 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java @@ -337,7 +337,7 @@ protected void initialize() { 3, serviceContext.get(), () -> serviceContext.get().getSchemaRegistryClient(), - (authHeader, userPrincipal) -> serviceContext.get().getConnectClient(), + (authHeader, requestHeaders, userPrincipal) -> serviceContext.get().getConnectClient(), vertx, InternalKsqlClientFactory.createInternalClient( PropertiesUtil.toMapStrings(ksqlRestConfig.originals()), diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java index a1b9d43912f9..d9f532e1b55f 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/services/TestRestServiceContextFactory.java @@ -22,8 +22,8 @@ public interface InternalSimpleKsqlClientFactory { public static DefaultServiceContextFactory createDefault( final InternalSimpleKsqlClientFactory ksqlClientFactory ) { - return (ksqlConfig, authHeader, srClientFactory, - connectClientFactory, sharedClient, userPrincipal) -> { + return (ksqlConfig, authHeader, srClientFactory, connectClientFactory, + sharedClient, requestHeaders, userPrincipal) -> { return createUser(ksqlClientFactory).create( ksqlConfig, authHeader, @@ -31,6 +31,7 @@ public static DefaultServiceContextFactory createDefault( srClientFactory, connectClientFactory, sharedClient, + requestHeaders, userPrincipal ); }; @@ -39,8 +40,8 @@ public static DefaultServiceContextFactory createDefault( public static UserServiceContextFactory createUser( final InternalSimpleKsqlClientFactory ksqlClientFactory ) { - return (ksqlConfig, authHeader, kafkaClientSupplier, - srClientFactory, connectClientFactory, sharedClient, userPrincipal) -> { + return (ksqlConfig, authHeader, kafkaClientSupplier, srClientFactory, + connectClientFactory, sharedClient, requestHeaders, userPrincipal) -> { return ServiceContextFactory.create( ksqlConfig, kafkaClientSupplier,