Skip to content

Commit

Permalink
chore: pipe user IP via KsqlPrincipal to connect headers extension (c…
Browse files Browse the repository at this point in the history
  • Loading branch information
vcrfxia authored Feb 3, 2022
1 parent ba572e0 commit 8329dc4
Show file tree
Hide file tree
Showing 21 changed files with 200 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.connect;

import io.confluent.ksql.security.KsqlPrincipal;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -39,7 +40,7 @@ public interface ConnectRequestHeadersExtension {
*
* <p>Set this to {@code false} in order to use this
* {@code ConnectRequestHeadersExtension} for additional custom headers
* (via {@link ConnectRequestHeadersExtension#getHeaders()}) only, without
* (via {@link ConnectRequestHeadersExtension#getHeaders(Optional)}) only, without
* impacting ksqlDB's default behavior for the auth header.
*
* @return whether to use the custom auth header returned from
Expand Down Expand Up @@ -68,9 +69,12 @@ default boolean shouldUseCustomAuthHeader() {
* such as those required for authenticating with Connect. The custom headers are added
* to the request in addition to, and after, ksqlDB's default headers.
*
* @param userPrincipal principal associated with the user who submitted the connector
* request to ksqlDB, if present (i.e., if user authentication
* is enabled)
* @return additional headers to be included with connector requests made by ksqlDB
*/
default Map<String, String> getHeaders() {
default Map<String, String> getHeaders(Optional<KsqlPrincipal> userPrincipal) {
return Collections.emptyMap();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@
public class DefaultKsqlPrincipal implements KsqlPrincipal {

private final Principal principal;
private final String ipAddress;

public DefaultKsqlPrincipal(final Principal principal) {
this(principal, "");
}

protected DefaultKsqlPrincipal(final Principal principal, final String ipAddress) {
this.principal = Objects.requireNonNull(principal, "principal");
this.ipAddress = Objects.requireNonNull(ipAddress, "ipAddress");
}

@Override
Expand All @@ -54,4 +60,17 @@ public Map<String, Object> getUserProperties() {
public Principal getOriginalPrincipal() {
return principal;
}

@Override
public String getIpAddress() {
return ipAddress;
}

/**
* IP address is populated from the request context, and subsequently passed
* throughout the engine.
*/
public DefaultKsqlPrincipal withIpAddress(final String ipAddress) {
return new DefaultKsqlPrincipal(principal, ipAddress);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,21 @@ default Map<String, Object> getUserProperties() {
return Collections.emptyMap();
}

/**
* Returns the user's IP address, as set by the ksqlDB server's request context.
*
* <p>This method never returns {@code null}. An empty string may be returned in
* certain situations (those where an IP address is not available, e.g., domain
* socket requests).
*
* <p>Overriding the implementation of this method from custom {@code KsqlPrincipal}
* implementations has no effect on the return value of this method, when called from
* custom extensions, because incoming {@code KsqlPrincipal} instances are wrapped
* inside ksqlDB's own {@link DefaultKsqlPrincipal} before being passed throughout
* the ksqlDB engine, and {@code DefaultKsqlPrincipal} has its own implementation
* for tracking and returning the IP address from this method.
*/
default String getIpAddress() {
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.services;

import io.confluent.ksql.security.KsqlPrincipal;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -23,7 +24,8 @@ public interface ConnectClientFactory {

ConnectClient get(
Optional<String> authHeader,
List<Entry<String, String>> incomingRequestHeaders
List<Entry<String, String>> incomingRequestHeaders,
Optional<KsqlPrincipal> userPrincipal
);

default void close() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.google.common.annotations.VisibleForTesting;
import io.confluent.ksql.connect.ConnectRequestHeadersExtension;
import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.util.FileWatcher;
import io.confluent.ksql.util.KsqlConfig;
import java.io.FileInputStream;
Expand Down Expand Up @@ -78,7 +79,8 @@ public DefaultConnectClientFactory(
@Override
public synchronized DefaultConnectClient get(
final Optional<String> ksqlAuthHeader,
final List<Entry<String, String>> incomingRequestHeaders
final List<Entry<String, String>> incomingRequestHeaders,
final Optional<KsqlPrincipal> userPrincipal
) {
if (defaultConnectAuthHeader == null) {
defaultConnectAuthHeader = buildDefaultAuthHeader();
Expand All @@ -91,7 +93,7 @@ public synchronized DefaultConnectClient get(
ksqlConfig.getString(KsqlConfig.CONNECT_URL_PROPERTY),
buildAuthHeader(ksqlAuthHeader, incomingRequestHeaders),
requestHeadersExtension
.map(ConnectRequestHeadersExtension::getHeaders)
.map(extension -> extension.getHeaders(userPrincipal))
.orElse(Collections.emptyMap()),
Optional.ofNullable(newSslContext(configWithPrefixOverrides)),
shouldVerifySslHostname(configWithPrefixOverrides)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public static ServiceContext create(
Collections.emptyMap())::get,
() -> new DefaultConnectClientFactory(ksqlConfig).get(
Optional.empty(),
Collections.emptyList()),
Collections.emptyList(),
Optional.empty()),
ksqlClientSupplier
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public static KsqlContext create(
adminClient,
kafkaTopicClient,
() -> schemaRegistryClient,
new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList())
new DefaultConnectClientFactory(ksqlConfig)
.get(Optional.empty(), Collections.emptyList(), Optional.empty())
);

final String metricsPrefix = "instance-" + COUNTER.getAndIncrement() + "-";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.connect.ConnectRequestHeadersExtension;
import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.test.util.KsqlTestFolder;
import io.confluent.ksql.util.KsqlConfig;
import io.vertx.core.http.HttpHeaders;
Expand Down Expand Up @@ -96,6 +97,8 @@ public class DefaultConnectClientFactoryTest {
private ConnectRequestHeadersExtension requestHeadersExtension;
@Mock
private List<Entry<String, String>> incomingRequestHeaders;
@Mock
private KsqlPrincipal userPrincipal;

private String credentialsPath;

Expand All @@ -107,15 +110,17 @@ public void setUp() {

when(config.getString(KsqlConfig.CONNECT_URL_PROPERTY)).thenReturn("http://localhost:8034");
when(config.getString(KsqlConfig.CONNECT_BASIC_AUTH_CREDENTIALS_SOURCE_PROPERTY)).thenReturn("NONE");
when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX)).thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE);
when(config.valuesWithPrefixOverride(KsqlConfig.KSQL_CONNECT_PREFIX))
.thenReturn(DEFAULT_CONFIGS_WITH_PREFIX_OVERRIDE);

connectClientFactory = new DefaultConnectClientFactory(config);
}

@Test
public void shouldBuildWithoutAuthHeader() {
// When:
final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());

// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
Expand All @@ -128,7 +133,8 @@ public void shouldBuildAuthHeader() throws Exception {
givenValidCredentialsFile();

// When:
final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());

// Then:
assertThat(connectClient.getRequestHeaders(),
Expand All @@ -142,8 +148,8 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception {
givenValidCredentialsFile();

// When: get() is called twice
connectClientFactory.get(Optional.empty(), Collections.emptyList());
connectClientFactory.get(Optional.empty(), Collections.emptyList());
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());

// Then: only loaded the credentials once -- ideally we'd check the number of times the file
// was read but this is an acceptable proxy for this unit test
Expand All @@ -154,7 +160,7 @@ public void shouldBuildAuthHeaderOnlyOnce() throws Exception {
public void shouldUseKsqlAuthHeaderIfNoAuthHeaderPresent() {
// When:
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList());
connectClientFactory.get(Optional.of("some ksql request header"), Collections.emptyList(), Optional.empty());

// Then:
assertThat(connectClient.getRequestHeaders(),
Expand All @@ -168,7 +174,8 @@ public void shouldNotFailOnUnreadableCredentials() throws Exception {
givenInvalidCredentialsFiles();

// When:
final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());

// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
Expand All @@ -181,7 +188,8 @@ public void shouldNotFailOnMissingCredentials() {
// no credentials file present

// When:
final DefaultConnectClient connectClient = connectClientFactory.get(Optional.empty(), Collections.emptyList());
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty());

// Then:
assertThat(connectClient.getRequestHeaders(), is(EMPTY_HEADERS));
Expand All @@ -195,7 +203,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception {
// no credentials file present

// verify that no auth header is present
assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
is(EMPTY_HEADERS));

// When: credentials file is created
Expand All @@ -205,7 +213,7 @@ public void shouldReloadCredentialsOnFileCreation() throws Exception {
// Then: auth header is present
assertThatEventually(
"Should load newly created credentials",
() -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
() -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER)),
TimeUnit.SECONDS.toMillis(1),
TimeUnit.SECONDS.toMillis(1)
Expand All @@ -220,7 +228,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception {
givenValidCredentialsFile();

// verify auth header is present
assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
assertThat(connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, EXPECTED_HEADER)));

// When: credentials file is modified
Expand All @@ -230,7 +238,7 @@ public void shouldReloadCredentialsOnFileChange() throws Exception {
// Then: new auth header is present
assertThatEventually(
"Should load updated credentials",
() -> connectClientFactory.get(Optional.empty(), Collections.emptyList()).getRequestHeaders(),
() -> connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.empty()).getRequestHeaders(),
arrayContaining(header(AUTH_HEADER_NAME, OTHER_EXPECTED_HEADER)),
TimeUnit.SECONDS.toMillis(1),
TimeUnit.SECONDS.toMillis(1)
Expand All @@ -247,12 +255,12 @@ public void shouldPassCustomRequestHeaders() {
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);

when(requestHeadersExtension.getHeaders())
when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));

// When:
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList());
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal));

// Then:
assertThat(connectClient.getRequestHeaders(), arrayContaining(header("header", "value")));
Expand All @@ -271,13 +279,13 @@ public void shouldPassCustomRequestHeadersInAdditionToDefaultBasicAuthHeader() t
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);

when(requestHeadersExtension.getHeaders())
when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));
when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(false);

// When:
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), Collections.emptyList());
connectClientFactory.get(Optional.empty(), Collections.emptyList(), Optional.of(userPrincipal));

// Then:
assertThat(connectClient.getRequestHeaders(),
Expand All @@ -299,14 +307,14 @@ public void shouldFavorCustomAuthHeaderOverBasicAuthHeader() throws Exception {
// re-initialize client factory since request headers extension is configured in constructor
connectClientFactory = new DefaultConnectClientFactory(config);

when(requestHeadersExtension.getHeaders())
when(requestHeadersExtension.getHeaders(Optional.of(userPrincipal)))
.thenReturn(ImmutableMap.of("header", "value"));
when(requestHeadersExtension.shouldUseCustomAuthHeader()).thenReturn(true);
when(requestHeadersExtension.getAuthHeader(incomingRequestHeaders)).thenReturn(Optional.of("some custom auth"));

// When:
final DefaultConnectClient connectClient =
connectClientFactory.get(Optional.empty(), incomingRequestHeaders);
connectClientFactory.get(Optional.empty(), incomingRequestHeaders, Optional.of(userPrincipal));

// Then:
assertThat(connectClient.getRequestHeaders(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ public static ServiceContext create(
adminClient,
new KafkaTopicClientImpl(() -> adminClient),
srClientFactory,
new DefaultConnectClientFactory(ksqlConfig).get(Optional.empty(), Collections.emptyList()),
new DefaultConnectClientFactory(ksqlConfig)
.get(Optional.empty(), Collections.emptyList(), Optional.empty()),
new KafkaConsumerGroupClientImpl(() -> adminClient)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.api.auth;

import io.confluent.ksql.security.KsqlPrincipal;
import io.confluent.ksql.security.DefaultKsqlPrincipal;
import io.vertx.ext.auth.User;

public interface ApiUser extends User {

KsqlPrincipal getPrincipal();
DefaultKsqlPrincipal getPrincipal();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.confluent.ksql.api.server.Server;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import io.confluent.ksql.security.DefaultKsqlPrincipal;
import io.confluent.ksql.security.KsqlPrincipal;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.json.JsonObject;
Expand Down Expand Up @@ -101,7 +100,7 @@ public void handle(final RoutingContext routingContext) {

private static class AuthPluginUser implements ApiUser {

private final KsqlPrincipal principal;
private final DefaultKsqlPrincipal principal;

AuthPluginUser(final Principal principal) {
Objects.requireNonNull(principal);
Expand Down Expand Up @@ -132,7 +131,7 @@ public void setAuthProvider(final AuthProvider authProvider) {
}

@Override
public KsqlPrincipal getPrincipal() {
public DefaultKsqlPrincipal getPrincipal() {
return principal;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ public static DefaultApiSecurityContext create(final RoutingContext routingConte
final ApiUser apiUser = (ApiUser) user;
final String authToken = routingContext.request().getHeader("Authorization");
final List<Entry<String, String>> requestHeaders = routingContext.request().headers().entries();
final String ipAddress = routingContext.request().remoteAddress().host();
return new DefaultApiSecurityContext(
apiUser != null ? apiUser.getPrincipal() : null,
apiUser != null
? apiUser.getPrincipal().withIpAddress(ipAddress == null ? "" : ipAddress)
: null,
authToken,
requestHeaders);
}
Expand Down
Loading

0 comments on commit 8329dc4

Please sign in to comment.