diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index 688b797e85..6b58b8f90e 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -46,6 +46,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; @@ -114,6 +115,8 @@ import org.opensearch.plugins.ExtensionAwarePlugin; import org.opensearch.plugins.IdentityPlugin; import org.opensearch.plugins.MapperPlugin; +import org.opensearch.plugins.SecureSettingsFactory; +import org.opensearch.plugins.SecureTransportSettingsProvider; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; @@ -167,11 +170,11 @@ import org.opensearch.security.securityconf.DynamicConfigFactory; import org.opensearch.security.setting.OpensearchDynamicSetting; import org.opensearch.security.setting.TransportPassiveAuthSetting; +import org.opensearch.security.ssl.OpenSearchSecureSettingsFactory; import org.opensearch.security.ssl.OpenSearchSecuritySSLPlugin; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.http.netty.ValidatingDispatcher; import org.opensearch.security.ssl.transport.DefaultPrincipalExtractor; -import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport; import org.opensearch.security.ssl.util.SSLConfigConstants; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.GuardedSearchOperationWrapper; @@ -199,6 +202,7 @@ import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; +import org.opensearch.transport.netty4.ssl.SecureNetty4Transport; import org.opensearch.watcher.ResourceWatcherService; import static org.opensearch.security.dlic.rest.api.RestApiAdminPrivilegesEvaluator.ENDPOINTS_WITH_PERMISSIONS; @@ -858,25 +862,27 @@ public void sendRequest( } @Override - public Map> getTransports( + public Map> getSecureTransports( Settings settings, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService, + SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer ) { Map> transports = new HashMap>(); if (SSLConfig.isSslOnlyMode()) { - return super.getTransports( + return super.getSecureTransports( settings, threadPool, pageCacheRecycler, circuitBreakerService, namedWriteableRegistry, networkService, + secureTransportSettingsProvider, tracer ); } @@ -884,18 +890,16 @@ public Map> getTransports( if (transportSSLEnabled) { transports.put( "org.opensearch.security.ssl.http.netty.SecuritySSLNettyTransport", - () -> new SecuritySSLNettyTransport( - settings, + () -> new SecureNetty4Transport( + migrateSettings(settings), Version.CURRENT, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, - sks, - evaluateSslExceptionHandler(), sharedGroupFactory, - SSLConfig, + secureTransportSettingsProvider, tracer ) ); @@ -904,7 +908,7 @@ public Map> getTransports( } @Override - public Map> getHttpTransports( + public Map> getSecureHttpTransports( Settings settings, ThreadPool threadPool, BigArrays bigArrays, @@ -914,11 +918,12 @@ public Map> getHttpTransports( NetworkService networkService, Dispatcher dispatcher, ClusterSettings clusterSettings, + SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer ) { if (SSLConfig.isSslOnlyMode()) { - return super.getHttpTransports( + return super.getSecureHttpTransports( settings, threadPool, bigArrays, @@ -928,6 +933,7 @@ public Map> getHttpTransports( networkService, dispatcher, clusterSettings, + secureTransportSettingsProvider, tracer ); } @@ -944,16 +950,15 @@ public Map> getHttpTransports( ); // TODO close odshst final SecurityHttpServerTransport odshst = new SecurityHttpServerTransport( - settings, + migrateSettings(settings), networkService, bigArrays, threadPool, - sks, - evaluateSslExceptionHandler(), xContentRegistry, validatingDispatcher, clusterSettings, sharedGroupFactory, + secureTransportSettingsProvider, tracer, securityRestHandler ); @@ -963,7 +968,7 @@ public Map> getHttpTransports( return Collections.singletonMap( "org.opensearch.security.http.SecurityHttpServerTransport", () -> new SecurityNonSslHttpServerTransport( - settings, + migrateSettings(settings), networkService, bigArrays, threadPool, @@ -971,6 +976,7 @@ public Map> getHttpTransports( dispatcher, clusterSettings, sharedGroupFactory, + secureTransportSettingsProvider, tracer, securityRestHandler ) @@ -2005,6 +2011,11 @@ public SecurityTokenManager getTokenManager() { return tokenManager; } + @Override + public Optional getSecureSettingFactory(Settings settings) { + return Optional.of(new OpenSearchSecureSettingsFactory(settings, sks, sslExceptionHandler)); + } + public static class GuiceHolder implements LifecycleComponent { private static RepositoriesService repositoriesService; diff --git a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java index eb75f898f4..a4a85f9e6c 100644 --- a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java @@ -34,19 +34,21 @@ import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.http.netty4.ssl.SecureNetty4HttpServerTransport; +import org.opensearch.plugins.SecureTransportSettingsProvider; import org.opensearch.security.filter.SecurityResponse; import org.opensearch.security.filter.SecurityRestFilter; -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.SslExceptionHandler; -import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; +import org.opensearch.security.ssl.http.netty.Netty4ConditionalDecompressor; +import org.opensearch.security.ssl.http.netty.Netty4HttpRequestHeaderVerifier; import org.opensearch.security.ssl.http.netty.ValidatingDispatcher; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SharedGroupFactory; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.util.AttributeKey; -public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport { +public class SecurityHttpServerTransport extends SecureNetty4HttpServerTransport { public static final AttributeKey EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); public static final AttributeKey> UNCONSUMED_PARAMS = AttributeKey.newInstance("opensearch-http-request-consumed-params"); @@ -56,17 +58,18 @@ public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTrans public static final AttributeKey SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress"); public static final AttributeKey IS_AUTHENTICATED = AttributeKey.newInstance("opensearch-http-is-authenticated"); + private final ChannelInboundHandlerAdapter headerVerifier; + public SecurityHttpServerTransport( final Settings settings, final NetworkService networkService, final BigArrays bigArrays, final ThreadPool threadPool, - final SecurityKeyStore odsks, - final SslExceptionHandler sslExceptionHandler, final NamedXContentRegistry namedXContentRegistry, final ValidatingDispatcher dispatcher, final ClusterSettings clusterSettings, SharedGroupFactory sharedGroupFactory, + final SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer, SecurityRestFilter restFilter ) { @@ -75,14 +78,24 @@ public SecurityHttpServerTransport( networkService, bigArrays, threadPool, - odsks, namedXContentRegistry, dispatcher, - sslExceptionHandler, clusterSettings, sharedGroupFactory, - tracer, - restFilter + secureTransportSettingsProvider, + tracer ); + + headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); + } + + @Override + protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return headerVerifier; + } + + @Override + protected ChannelInboundHandlerAdapter createDecompressor() { + return new Netty4ConditionalDecompressor(); } } diff --git a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java index f37ebb48e8..4842c35bcf 100644 --- a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java @@ -33,6 +33,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.http.HttpHandlingSettings; import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.http.netty4.ssl.SecureNetty4HttpServerTransport; +import org.opensearch.plugins.SecureTransportSettingsProvider; import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.http.netty.Netty4ConditionalDecompressor; import org.opensearch.security.ssl.http.netty.Netty4HttpRequestHeaderVerifier; @@ -44,7 +46,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelInboundHandlerAdapter; -public class SecurityNonSslHttpServerTransport extends Netty4HttpServerTransport { +public class SecurityNonSslHttpServerTransport extends SecureNetty4HttpServerTransport { private final ChannelInboundHandlerAdapter headerVerifier; @@ -57,6 +59,7 @@ public SecurityNonSslHttpServerTransport( final Dispatcher dispatcher, final ClusterSettings clusterSettings, final SharedGroupFactory sharedGroupFactory, + final SecureTransportSettingsProvider secureTransportSettingsProvider, final Tracer tracer, final SecurityRestFilter restFilter ) { @@ -69,6 +72,7 @@ public SecurityNonSslHttpServerTransport( dispatcher, clusterSettings, sharedGroupFactory, + secureTransportSettingsProvider, tracer ); headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); diff --git a/src/main/java/org/opensearch/security/ssl/OpenSearchSecureSettingsFactory.java b/src/main/java/org/opensearch/security/ssl/OpenSearchSecureSettingsFactory.java new file mode 100644 index 0000000000..d85f490d0c --- /dev/null +++ b/src/main/java/org/opensearch/security/ssl/OpenSearchSecureSettingsFactory.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.ssl; + +import java.util.Optional; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +import org.opensearch.common.settings.Settings; +import org.opensearch.http.HttpServerTransport; +import org.opensearch.plugins.SecureSettingsFactory; +import org.opensearch.plugins.SecureTransportSettingsProvider; +import org.opensearch.transport.TcpTransport; + +public class OpenSearchSecureSettingsFactory implements SecureSettingsFactory { + private final Settings settings; + private final SecurityKeyStore sks; + private final SslExceptionHandler sslExceptionHandler; + + public OpenSearchSecureSettingsFactory(Settings settings, SecurityKeyStore sks, SslExceptionHandler sslExceptionHandler) { + this.settings = settings; + this.sks = sks; + this.sslExceptionHandler = sslExceptionHandler; + } + + @Override + public Optional getSecureTransportSettingsProvider(Settings settings) { + return Optional.of(new SecureTransportSettingsProvider() { + @Override + public Optional buildHttpServerExceptionHandler(Settings settings, HttpServerTransport transport) { + return Optional.of(new ServerExceptionHandler() { + @Override + public void onError(Throwable t) { + sslExceptionHandler.logError(t, true); + } + }); + } + + @Override + public Optional buildServerTransportExceptionHandler(Settings settings, TcpTransport transport) { + return Optional.of(new ServerExceptionHandler() { + @Override + public void onError(Throwable t) { + sslExceptionHandler.logError(t, false); + } + }); + } + + @Override + public Optional buildSecureHttpServerEngine(Settings settings, HttpServerTransport transport) throws SSLException { + return Optional.of(sks.createHTTPSSLEngine()); + } + + @Override + public Optional buildSecureServerTransportEngine(Settings settings, TcpTransport transport) throws SSLException { + return Optional.of(sks.createServerTransportSSLEngine()); + } + + @Override + public Optional buildSecureClientTransportEngine(Settings settings, String hostname, int port) throws SSLException { + return Optional.of(sks.createClientTransportSSLEngine(hostname, port)); + } + }); + } +} diff --git a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java index e6e4e85b33..f9c2ac64c0 100644 --- a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java +++ b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; @@ -62,6 +63,8 @@ import org.opensearch.http.HttpServerTransport.Dispatcher; import org.opensearch.plugins.NetworkPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SecureSettingsFactory; +import org.opensearch.plugins.SecureTransportSettingsProvider; import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; @@ -70,20 +73,21 @@ import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.NonValidatingObjectMapper; import org.opensearch.security.filter.SecurityRestFilter; -import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; +import org.opensearch.security.http.SecurityHttpServerTransport; import org.opensearch.security.ssl.http.netty.ValidatingDispatcher; import org.opensearch.security.ssl.rest.SecuritySSLInfoAction; import org.opensearch.security.ssl.transport.DefaultPrincipalExtractor; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.transport.SSLConfig; -import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport; import org.opensearch.security.ssl.transport.SecuritySSLTransportInterceptor; import org.opensearch.security.ssl.util.SSLConfigConstants; +import org.opensearch.security.support.SecuritySettings; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SharedGroupFactory; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.netty4.ssl.SecureNetty4Transport; import org.opensearch.watcher.ResourceWatcherService; import io.netty.handler.ssl.OpenSsl; @@ -91,6 +95,19 @@ //For ES5 this class has only effect when SSL only plugin is installed public class OpenSearchSecuritySSLPlugin extends Plugin implements SystemIndexPlugin, NetworkPlugin { + private static final Setting SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION = Setting.boolSetting( + SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION, + true, + Property.NodeScope, + Property.Filtered + ); + + private static final Setting SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME = Setting.boolSetting( + SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME, + true, + Property.NodeScope, + Property.Filtered + ); private static boolean USE_NETTY_DEFAULT_ALLOCATOR = Booleans.parseBoolean( System.getProperty("opensearch.unsafe.use_netty_default_allocator"), @@ -237,7 +254,7 @@ public Object run() { } @Override - public Map> getHttpTransports( + public Map> getSecureHttpTransports( Settings settings, ThreadPool threadPool, BigArrays bigArrays, @@ -247,6 +264,7 @@ public Map> getHttpTransports( NetworkService networkService, Dispatcher dispatcher, ClusterSettings clusterSettings, + SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer ) { @@ -259,17 +277,16 @@ public Map> getHttpTransports( configPath, NOOP_SSL_EXCEPTION_HANDLER ); - final SecuritySSLNettyHttpServerTransport sgsnht = new SecuritySSLNettyHttpServerTransport( - settings, + final SecurityHttpServerTransport sgsnht = new SecurityHttpServerTransport( + migrateSettings(settings), networkService, bigArrays, threadPool, - sks, xContentRegistry, validatingDispatcher, - NOOP_SSL_EXCEPTION_HANDLER, clusterSettings, sharedGroupFactory, + secureTransportSettingsProvider, tracer, securityRestHandler ); @@ -313,13 +330,14 @@ public List getTransportInterceptors(NamedWriteableRegistr } @Override - public Map> getTransports( + public Map> getSecureTransports( Settings settings, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService, + SecureTransportSettingsProvider secureTransportSettingsProvider, Tracer tracer ) { @@ -327,18 +345,16 @@ public Map> getTransports( if (transportSSLEnabled) { transports.put( "org.opensearch.security.ssl.http.netty.SecuritySSLNettyTransport", - () -> new SecuritySSLNettyTransport( - settings, + () -> new SecureNetty4Transport( + migrateSettings(settings), Version.CURRENT, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, - sks, - NOOP_SSL_EXCEPTION_HANDLER, sharedGroupFactory, - SSLConfig, + secureTransportSettingsProvider, tracer ) ); @@ -436,22 +452,8 @@ public List> getSettings() { Property.Filtered ) ); - settings.add( - Setting.boolSetting( - SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION, - true, - Property.NodeScope, - Property.Filtered - ) - ); - settings.add( - Setting.boolSetting( - SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME, - true, - Property.NodeScope, - Property.Filtered - ) - ); + settings.add(SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION); + settings.add(SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME); settings.add( Setting.simpleString(SSLConfigConstants.SECURITY_SSL_TRANSPORT_KEYSTORE_FILEPATH, Property.NodeScope, Property.Filtered) ); @@ -664,4 +666,29 @@ public List getSettingsFilter() { settingsFilter.add("plugins.security.*"); return settingsFilter; } + + @Override + public Optional getSecureSettingFactory(Settings settings) { + return Optional.of(new OpenSearchSecureSettingsFactory(settings, sks, NOOP_SSL_EXCEPTION_HANDLER)); + } + + protected Settings migrateSettings(Settings settings) { + final Settings.Builder builder = Settings.builder().put(settings); + + builder.remove(NetworkModule.TRANSPORT_SSL_DUAL_MODE_ENABLED_KEY); + builder.remove(NetworkModule.TRANSPORT_SSL_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME_KEY); + builder.remove(NetworkModule.TRANSPORT_SSL_ENFORCE_HOSTNAME_VERIFICATION_KEY); + + builder.put(NetworkModule.TRANSPORT_SSL_DUAL_MODE_ENABLED_KEY, SecuritySettings.SSL_DUAL_MODE_SETTING.get(settings)); + builder.put( + NetworkModule.TRANSPORT_SSL_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME_KEY, + SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME.get(settings) + ); + builder.put( + NetworkModule.TRANSPORT_SSL_ENFORCE_HOSTNAME_VERIFICATION_KEY, + SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION.get(settings) + ); + + return builder.build(); + } } diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java deleted file mode 100644 index fc2f31b2b0..0000000000 --- a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright 2015-2017 floragunn GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package org.opensearch.security.ssl.http.netty; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import org.opensearch.common.network.NetworkService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.BigArrays; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.http.HttpChannel; -import org.opensearch.http.HttpHandlingSettings; -import org.opensearch.http.netty4.Netty4HttpChannel; -import org.opensearch.http.netty4.Netty4HttpServerTransport; -import org.opensearch.security.filter.SecurityRestFilter; -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.SslExceptionHandler; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.SharedGroupFactory; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.DecoderException; -import io.netty.handler.ssl.ApplicationProtocolNames; -import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; -import io.netty.handler.ssl.SslHandler; - -public class SecuritySSLNettyHttpServerTransport extends Netty4HttpServerTransport { - private static final Logger logger = LogManager.getLogger(SecuritySSLNettyHttpServerTransport.class); - private final SecurityKeyStore sks; - private final SslExceptionHandler errorHandler; - private final ChannelInboundHandlerAdapter headerVerifier; - - public SecuritySSLNettyHttpServerTransport( - final Settings settings, - final NetworkService networkService, - final BigArrays bigArrays, - final ThreadPool threadPool, - final SecurityKeyStore sks, - final NamedXContentRegistry namedXContentRegistry, - final ValidatingDispatcher dispatcher, - final SslExceptionHandler errorHandler, - ClusterSettings clusterSettings, - SharedGroupFactory sharedGroupFactory, - Tracer tracer, - SecurityRestFilter restFilter - ) { - super( - settings, - networkService, - bigArrays, - threadPool, - namedXContentRegistry, - dispatcher, - clusterSettings, - sharedGroupFactory, - tracer - ); - this.sks = sks; - this.errorHandler = errorHandler; - headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); - } - - @Override - public ChannelHandler configureServerChannelHandler() { - return new SSLHttpChannelHandler(this, handlingSettings, sks); - } - - @Override - public void onException(HttpChannel channel, Exception cause0) { - Throwable cause = cause0; - - if (cause0 instanceof DecoderException && cause0 != null) { - cause = cause0.getCause(); - } - - errorHandler.logError(cause, true); - logger.error("Exception during establishing a SSL connection: " + cause, cause); - - super.onException(channel, cause0); - } - - protected class SSLHttpChannelHandler extends Netty4HttpServerTransport.HttpChannelHandler { - /** - * Application negotiation handler to select either HTTP 1.1 or HTTP 2 protocol, based - * on client/server ALPN negotiations. - */ - private class Http2OrHttpHandler extends ApplicationProtocolNegotiationHandler { - protected Http2OrHttpHandler() { - super(ApplicationProtocolNames.HTTP_1_1); - } - - @Override - protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { - if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { - configureDefaultHttp2Pipeline(ctx.pipeline()); - } else if (ApplicationProtocolNames.HTTP_1_1.equals(protocol)) { - configureDefaultHttpPipeline(ctx.pipeline()); - } else { - throw new IllegalStateException("Unknown application protocol: " + protocol); - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - super.exceptionCaught(ctx, cause); - Netty4HttpChannel channel = ctx.channel().attr(HTTP_CHANNEL_KEY).get(); - if (channel != null) { - if (cause instanceof Error) { - onException(channel, new Exception(cause)); - } else { - onException(channel, (Exception) cause); - } - } - } - } - - protected SSLHttpChannelHandler( - Netty4HttpServerTransport transport, - final HttpHandlingSettings handlingSettings, - final SecurityKeyStore odsks - ) { - super(transport, handlingSettings); - } - - @Override - protected void initChannel(Channel ch) throws Exception { - super.initChannel(ch); - final SslHandler sslHandler = new SslHandler(SecuritySSLNettyHttpServerTransport.this.sks.createHTTPSSLEngine()); - ch.pipeline().addFirst("ssl_http", sslHandler); - } - - @Override - protected void configurePipeline(Channel ch) { - ch.pipeline().addLast(new Http2OrHttpHandler()); - } - } - - @Override - protected ChannelInboundHandlerAdapter createHeaderVerifier() { - return headerVerifier; - } - - @Override - protected ChannelInboundHandlerAdapter createDecompressor() { - return new Netty4ConditionalDecompressor(); - } -} diff --git a/src/main/java/org/opensearch/security/ssl/transport/DualModeSSLHandler.java b/src/main/java/org/opensearch/security/ssl/transport/DualModeSSLHandler.java deleted file mode 100644 index a7961f864b..0000000000 --- a/src/main/java/org/opensearch/security/ssl/transport/DualModeSSLHandler.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ -package org.opensearch.security.ssl.transport; - -import java.nio.charset.StandardCharsets; -import java.util.List; -import javax.net.ssl.SSLException; - -import com.google.common.annotations.VisibleForTesting; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.util.SSLConnectionTestUtil; -import org.opensearch.security.ssl.util.TLSUtil; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPipeline; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.ssl.SslHandler; - -/** - * Modifies the current pipeline dynamically to enable TLS - */ -public class DualModeSSLHandler extends ByteToMessageDecoder { - - private static final Logger logger = LogManager.getLogger(DualModeSSLHandler.class); - private final SecurityKeyStore securityKeyStore; - - private final SslHandler providedSSLHandler; - - public DualModeSSLHandler(SecurityKeyStore securityKeyStore) { - this(securityKeyStore, null); - } - - @VisibleForTesting - protected DualModeSSLHandler(SecurityKeyStore securityKeyStore, SslHandler providedSSLHandler) { - this.securityKeyStore = securityKeyStore; - this.providedSSLHandler = providedSSLHandler; - } - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - // Will use the first six bytes to detect a protocol. - if (in.readableBytes() < 6) { - return; - } - int offset = in.readerIndex(); - if (in.getCharSequence(offset, 6, StandardCharsets.UTF_8).equals(SSLConnectionTestUtil.DUAL_MODE_CLIENT_HELLO_MSG)) { - logger.debug("Received DualSSL Client Hello message"); - ByteBuf responseBuffer = Unpooled.buffer(6); - responseBuffer.writeCharSequence(SSLConnectionTestUtil.DUAL_MODE_SERVER_HELLO_MSG, StandardCharsets.UTF_8); - ctx.writeAndFlush(responseBuffer).addListener(ChannelFutureListener.CLOSE); - return; - } - - if (TLSUtil.isTLS(in)) { - logger.debug("Identified request as SSL request"); - enableSsl(ctx); - } else { - logger.debug("Identified request as non SSL request, running in HTTP mode as dual mode is enabled"); - ctx.pipeline().remove(this); - } - } - - private void enableSsl(ChannelHandlerContext ctx) throws SSLException { - SslHandler sslHandler; - if (providedSSLHandler != null) { - sslHandler = providedSSLHandler; - } else { - sslHandler = new SslHandler(securityKeyStore.createServerTransportSSLEngine()); - } - ChannelPipeline p = ctx.pipeline(); - p.addAfter("port_unification_handler", "ssl_server", sslHandler); - p.remove(this); - logger.debug("Removed port unification handler and added SSL handler as incoming request is SSL"); - } -} diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java deleted file mode 100644 index 5be3424528..0000000000 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransport.java +++ /dev/null @@ -1,308 +0,0 @@ -/* - * Copyright 2015-2017 floragunn GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.security.ssl.transport; - -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.security.AccessController; -import java.security.PrivilegedAction; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLException; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import org.opensearch.ExceptionsHelper; -import org.opensearch.OpenSearchSecurityException; -import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.network.NetworkService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.PageCacheRecycler; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.SslExceptionHandler; -import org.opensearch.security.ssl.util.SSLConfigConstants; -import org.opensearch.security.ssl.util.SSLConnectionTestResult; -import org.opensearch.security.ssl.util.SSLConnectionTestUtil; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.SharedGroupFactory; -import org.opensearch.transport.TcpChannel; -import org.opensearch.transport.netty4.Netty4Transport; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.DecoderException; -import io.netty.handler.ssl.SslHandler; - -public class SecuritySSLNettyTransport extends Netty4Transport { - - private static final Logger logger = LogManager.getLogger(SecuritySSLNettyTransport.class); - private final SecurityKeyStore ossks; - private final SslExceptionHandler errorHandler; - private final SSLConfig SSLConfig; - - public SecuritySSLNettyTransport( - final Settings settings, - final Version version, - final ThreadPool threadPool, - final NetworkService networkService, - final PageCacheRecycler pageCacheRecycler, - final NamedWriteableRegistry namedWriteableRegistry, - final CircuitBreakerService circuitBreakerService, - final SecurityKeyStore ossks, - final SslExceptionHandler errorHandler, - SharedGroupFactory sharedGroupFactory, - final SSLConfig SSLConfig, - final Tracer tracer - ) { - super( - settings, - version, - threadPool, - networkService, - pageCacheRecycler, - namedWriteableRegistry, - circuitBreakerService, - sharedGroupFactory, - tracer - ); - - this.ossks = ossks; - this.errorHandler = errorHandler; - this.SSLConfig = SSLConfig; - } - - // This allows for testing log messages - Logger getLogger() { - return logger; - } - - @Override - public void onException(TcpChannel channel, Exception e) { - - Throwable cause = e; - - if (e instanceof DecoderException && e != null) { - cause = e.getCause(); - } - - errorHandler.logError(cause, false); - getLogger().error("Exception during establishing a SSL connection: " + cause, cause); - - if (channel == null || !channel.isOpen()) { - throw new OpenSearchSecurityException("The provided TCP channel is invalid.", e); - } - super.onException(channel, e); - } - - @Override - protected ChannelHandler getServerChannelInitializer(String name) { - return new SSLServerChannelInitializer(name); - } - - @Override - protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) { - return new SSLClientChannelInitializer(node); - } - - protected class SSLServerChannelInitializer extends Netty4Transport.ServerChannelInitializer { - - public SSLServerChannelInitializer(String name) { - super(name); - } - - @Override - protected void initChannel(Channel ch) throws Exception { - super.initChannel(ch); - - boolean dualModeEnabled = SSLConfig.isDualModeEnabled(); - if (dualModeEnabled) { - logger.info("SSL Dual mode enabled, using port unification handler"); - final ChannelHandler portUnificationHandler = new DualModeSSLHandler(ossks); - ch.pipeline().addFirst("port_unification_handler", portUnificationHandler); - } else { - final SslHandler sslHandler = new SslHandler(ossks.createServerTransportSSLEngine()); - ch.pipeline().addFirst("ssl_server", sslHandler); - } - } - - @Override - public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof DecoderException && cause != null) { - cause = cause.getCause(); - } - - errorHandler.logError(cause, false); - getLogger().error("Exception during establishing a SSL connection: " + cause, cause); - - super.exceptionCaught(ctx, cause); - } - } - - protected static class ClientSSLHandler extends ChannelOutboundHandlerAdapter { - private final Logger log = LogManager.getLogger(this.getClass()); - private final SecurityKeyStore sks; - private final boolean hostnameVerificationEnabled; - private final boolean hostnameVerificationResovleHostName; - private final SslExceptionHandler errorHandler; - - private ClientSSLHandler( - final SecurityKeyStore sks, - final boolean hostnameVerificationEnabled, - final boolean hostnameVerificationResovleHostName, - final SslExceptionHandler errorHandler - ) { - this.sks = sks; - this.hostnameVerificationEnabled = hostnameVerificationEnabled; - this.hostnameVerificationResovleHostName = hostnameVerificationResovleHostName; - this.errorHandler = errorHandler; - } - - @Override - public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof DecoderException && cause != null) { - cause = cause.getCause(); - } - - errorHandler.logError(cause, false); - logger.error("Exception during establishing a SSL connection: " + cause, cause); - - super.exceptionCaught(ctx, cause); - } - - @Override - public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) - throws Exception { - SSLEngine engine = null; - try { - if (hostnameVerificationEnabled) { - final InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress; - String hostname = null; - if (hostnameVerificationResovleHostName) { - hostname = inetSocketAddress.getHostName(); - } else { - hostname = inetSocketAddress.getHostString(); - } - - if (log.isDebugEnabled()) { - log.debug( - "Hostname of peer is {} ({}/{}) with hostnameVerificationResovleHostName: {}", - hostname, - inetSocketAddress.getHostName(), - inetSocketAddress.getHostString(), - hostnameVerificationResovleHostName - ); - } - - engine = sks.createClientTransportSSLEngine(hostname, inetSocketAddress.getPort()); - } else { - engine = sks.createClientTransportSSLEngine(null, -1); - } - } catch (final SSLException e) { - throw ExceptionsHelper.convertToOpenSearchException(e); - } - final SslHandler sslHandler = new SslHandler(engine); - ctx.pipeline().replace(this, "ssl_client", sslHandler); - super.connect(ctx, remoteAddress, localAddress, promise); - } - } - - protected class SSLClientChannelInitializer extends Netty4Transport.ClientChannelInitializer { - private final boolean hostnameVerificationEnabled; - private final boolean hostnameVerificationResovleHostName; - private final DiscoveryNode node; - private SSLConnectionTestResult connectionTestResult; - - @SuppressWarnings("removal") - public SSLClientChannelInitializer(DiscoveryNode node) { - this.node = node; - hostnameVerificationEnabled = settings.getAsBoolean( - SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION, - true - ); - hostnameVerificationResovleHostName = settings.getAsBoolean( - SSLConfigConstants.SECURITY_SSL_TRANSPORT_ENFORCE_HOSTNAME_VERIFICATION_RESOLVE_HOST_NAME, - true - ); - - connectionTestResult = SSLConnectionTestResult.SSL_AVAILABLE; - if (SSLConfig.isDualModeEnabled()) { - SSLConnectionTestUtil sslConnectionTestUtil = new SSLConnectionTestUtil( - node.getAddress().getAddress(), - node.getAddress().getPort() - ); - connectionTestResult = AccessController.doPrivileged( - (PrivilegedAction) sslConnectionTestUtil::testConnection - ); - } - } - - @Override - protected void initChannel(Channel ch) throws Exception { - super.initChannel(ch); - - if (connectionTestResult == SSLConnectionTestResult.OPENSEARCH_PING_FAILED) { - logger.error( - "SSL dual mode is enabled but dual mode handshake and OpenSearch ping has failed during client connection setup, closing channel" - ); - ch.close(); - return; - } - - if (connectionTestResult == SSLConnectionTestResult.SSL_AVAILABLE) { - logger.debug("Connection to {} needs to be ssl, adding ssl handler to the client channel ", node.getHostName()); - ch.pipeline() - .addFirst( - "client_ssl_handler", - new ClientSSLHandler(ossks, hostnameVerificationEnabled, hostnameVerificationResovleHostName, errorHandler) - ); - } else { - logger.debug("Connection to {} needs to be non ssl", node.getHostName()); - } - } - - @Override - public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof DecoderException && cause != null) { - cause = cause.getCause(); - } - - errorHandler.logError(cause, false); - getLogger().error("Exception during establishing a SSL connection: " + cause, cause); - - super.exceptionCaught(ctx, cause); - } - } -} diff --git a/src/test/java/org/opensearch/security/ssl/transport/DualModeSSLHandlerTests.java b/src/test/java/org/opensearch/security/ssl/transport/DualModeSSLHandlerTests.java deleted file mode 100644 index e71e77d414..0000000000 --- a/src/test/java/org/opensearch/security/ssl/transport/DualModeSSLHandlerTests.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ -package org.opensearch.security.ssl.transport; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.util.SSLConnectionTestUtil; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPipeline; -import io.netty.handler.ssl.SslHandler; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import static org.opensearch.transport.NettyAllocator.getAllocator; - -public class DualModeSSLHandlerTests { - - public static final int TLS_MAJOR_VERSION = 3; - public static final int TLS_MINOR_VERSION = 0; - private static final ByteBufAllocator ALLOCATOR = getAllocator(); - - private SecurityKeyStore securityKeyStore; - private ChannelPipeline pipeline; - private ChannelHandlerContext ctx; - private SslHandler sslHandler; - - @Before - public void setup() { - pipeline = Mockito.mock(ChannelPipeline.class); - ctx = Mockito.mock(ChannelHandlerContext.class); - Mockito.when(ctx.pipeline()).thenReturn(pipeline); - - securityKeyStore = Mockito.mock(SecurityKeyStore.class); - sslHandler = Mockito.mock(SslHandler.class); - } - - @Test - public void testInvalidMessage() throws Exception { - DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore); - - handler.decode(ctx, ALLOCATOR.buffer(4), null); - // ensure pipeline is not fetched and manipulated - Mockito.verify(ctx, Mockito.times(0)).pipeline(); - } - - @Test - public void testValidTLSMessage() throws Exception { - DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler); - - ByteBuf buffer = ALLOCATOR.buffer(6); - buffer.writeByte(20); - buffer.writeByte(TLS_MAJOR_VERSION); - buffer.writeByte(TLS_MINOR_VERSION); - buffer.writeByte(100); - buffer.writeByte(0); - buffer.writeByte(0); - - handler.decode(ctx, buffer, null); - // ensure ssl handler is added - Mockito.verify(ctx, Mockito.times(1)).pipeline(); - Mockito.verify(pipeline, Mockito.times(1)).addAfter("port_unification_handler", "ssl_server", sslHandler); - Mockito.verify(pipeline, Mockito.times(1)).remove(handler); - } - - @Test - public void testNonTLSMessage() throws Exception { - DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler); - - ByteBuf buffer = ALLOCATOR.buffer(6); - - for (int i = 0; i < 6; i++) { - buffer.writeByte(1); - } - - handler.decode(ctx, buffer, null); - // ensure ssl handler is added - Mockito.verify(ctx, Mockito.times(1)).pipeline(); - Mockito.verify(pipeline, Mockito.times(0)).addAfter("port_unification_handler", "ssl_server", sslHandler); - Mockito.verify(pipeline, Mockito.times(1)).remove(handler); - } - - @Test - public void testDualModeClientHelloMessage() throws Exception { - ChannelFuture channelFuture = Mockito.mock(ChannelFuture.class); - Mockito.when(ctx.writeAndFlush(Mockito.any())).thenReturn(channelFuture); - Mockito.when(channelFuture.addListener(Mockito.any())).thenReturn(channelFuture); - - ByteBuf buffer = ALLOCATOR.buffer(6); - buffer.writeCharSequence(SSLConnectionTestUtil.DUAL_MODE_CLIENT_HELLO_MSG, StandardCharsets.UTF_8); - - DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler); - List decodedObjs = new ArrayList<>(); - handler.decode(ctx, buffer, decodedObjs); - - ArgumentCaptor serverHelloReplyBuffer = ArgumentCaptor.forClass(ByteBuf.class); - Mockito.verify(ctx, Mockito.times(1)).writeAndFlush(serverHelloReplyBuffer.capture()); - - String actualReply = serverHelloReplyBuffer.getValue().getCharSequence(0, 6, StandardCharsets.UTF_8).toString(); - Assert.assertEquals(SSLConnectionTestUtil.DUAL_MODE_SERVER_HELLO_MSG, actualReply); - } -} diff --git a/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java b/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java deleted file mode 100644 index 32e0f48fac..0000000000 --- a/src/test/java/org/opensearch/security/ssl/transport/SecuritySSLNettyTransportTests.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.security.ssl.transport; - -import java.util.Collections; - -import org.apache.logging.log4j.Logger; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; - -import org.opensearch.OpenSearchSecurityException; -import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.network.NetworkService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.PageCacheRecycler; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.security.ssl.SecurityKeyStore; -import org.opensearch.security.ssl.SslExceptionHandler; -import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport.SSLClientChannelInitializer; -import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport.SSLServerChannelInitializer; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.FakeTcpChannel; -import org.opensearch.transport.SharedGroupFactory; -import org.opensearch.transport.TcpChannel; - -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.DecoderException; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class SecuritySSLNettyTransportTests { - - @Mock - private Version version; - @Mock - private ThreadPool threadPool; - @Mock - private PageCacheRecycler pageCacheRecycler; - @Mock - private NamedWriteableRegistry namedWriteableRegistry; - @Mock - private CircuitBreakerService circuitBreakerService; - @Mock - private Tracer trace; - @Mock - private SecurityKeyStore ossks; - @Mock - private SslExceptionHandler sslExceptionHandler; - @Mock - private DiscoveryNode discoveryNode; - - // This initializes all the above mocks - @Rule - public MockitoRule rule = MockitoJUnit.rule(); - - private NetworkService networkService; - private SharedGroupFactory sharedGroupFactory; - private Logger mockLogger; - private SSLConfig sslConfig; - private SecuritySSLNettyTransport securitySSLNettyTransport; - Throwable testCause = new Throwable("Test Cause"); - - @Before - public void setup() { - - networkService = new NetworkService(Collections.emptyList()); - sharedGroupFactory = new SharedGroupFactory(Settings.EMPTY); - - sslConfig = new SSLConfig(Settings.EMPTY); - mockLogger = mock(Logger.class); - - securitySSLNettyTransport = spy( - new SecuritySSLNettyTransport( - Settings.EMPTY, - version, - threadPool, - networkService, - pageCacheRecycler, - namedWriteableRegistry, - circuitBreakerService, - ossks, - sslExceptionHandler, - sharedGroupFactory, - sslConfig, - trace - ) - ); - } - - @Test - public void OnException_withNullChannelShouldThrowException() { - - OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid"); - assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(null, exception)); - } - - @Test - public void OnException_withClosedChannelShouldThrowException() { - - TcpChannel channel = new FakeTcpChannel(); - channel.close(); - OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid"); - assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(channel, exception)); - } - - @Test - public void OnException_withNullExceptionShouldSucceed() { - - TcpChannel channel = new FakeTcpChannel(); - securitySSLNettyTransport.onException(channel, null); - verify(securitySSLNettyTransport, times(1)).onException(channel, null); - channel.close(); - } - - @Test - public void OnException_withDecoderExceptionShouldGetCause() { - - when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); - DecoderException exception = new DecoderException("Test Exception", testCause); - TcpChannel channel = new FakeTcpChannel(); - securitySSLNettyTransport.onException(channel, exception); - verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); - } - - @Test - public void getServerChannelInitializer_shouldReturnValidServerChannel() { - - ChannelHandler channelHandler = securitySSLNettyTransport.getServerChannelInitializer("test-server-channel"); - assertThat(channelHandler, is(notNullValue())); - assertThat(channelHandler, is(instanceOf(SSLServerChannelInitializer.class))); - } - - @Test - public void getClientChannelInitializer_shouldReturnValidClientChannel() { - ChannelHandler channelHandler = securitySSLNettyTransport.getClientChannelInitializer(discoveryNode); - assertThat(channelHandler, is(notNullValue())); - assertThat(channelHandler, is(instanceOf(SSLClientChannelInitializer.class))); - } - - @Test - public void exceptionWithServerChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception { - when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); - Throwable exception = new DecoderException("Test Exception", testCause); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception); - verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); - } - - @Test - public void exceptionWithServerChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception { - when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); - Throwable exception = new OpenSearchSecurityException("Test Exception", testCause); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception); - verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception); - } - - @Test - public void exceptionWithClientChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception { - when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); - Throwable exception = new DecoderException("Test Exception", testCause); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception); - verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause()); - } - - @Test - public void exceptionWithClientChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception { - when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger); - Throwable exception = new OpenSearchSecurityException("Test Exception", testCause); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception); - verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception); - } -}