diff --git a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt index 4d162fc9e..ad4d229f5 100644 --- a/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/security/tls/TLSSecureChannel.kt @@ -71,6 +71,13 @@ class TlsSecureChannel(private val localKey: PrivKey, private val muxers: List): TlsSecureChannel { + return TlsSecureChannel(localKey, muxerIds, "ECDSA") } } @@ -132,8 +139,15 @@ fun buildTlsHandler( cause = cause.cause handshakeComplete.completeExceptionally(cause) } else { - val negotiatedProtocols = sslContext.applicationProtocolNegotiator().protocols() - val selectedMuxer = muxers.findBestMatch(negotiatedProtocols) + val nextProtocol = handler.applicationProtocol() + val selectedMuxer = muxers + .filter { mux -> + mux.protocolDescriptor.protocolMatcher.matches(nextProtocol) + } + .map { mux -> + NegotiatedStreamMuxer(mux, nextProtocol) + } + .firstOrNull() handshakeComplete.complete( SecureChannel.Session( PeerId.fromPubKey(localKey.publicKey()), @@ -151,15 +165,6 @@ fun buildTlsHandler( private val > List.allProtocols: List get() = this.flatMap { it.protocolDescriptor.announceProtocols } -private fun List.findBestMatch(remoteProtocols: List): NegotiatedStreamMuxer? = - this.firstNotNullOfOrNull { muxer -> - remoteProtocols.firstOrNull { remoteProtocol -> - muxer.protocolDescriptor.protocolMatcher.matches(remoteProtocol) - }?.let { negotiatedProtocol -> - NegotiatedStreamMuxer(muxer, negotiatedProtocol) - } - } - private class ChannelSetup( private val localKey: PrivKey, private val muxers: List, diff --git a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java index c1284aa01..f8f92ef86 100644 --- a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java +++ b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java @@ -37,7 +37,7 @@ void ping() throws Exception { Host clientHost = new HostBuilder() .transport(TcpTransport::new) - .secureChannel(TlsSecureChannel::new) + .secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA")) .muxer(StreamMuxerProtocol::getYamux) .build();