diff --git a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt index 0c961c0cb..879cd60cd 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt @@ -3,6 +3,7 @@ package io.libp2p.core.mux import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.multistream.ProtocolBinding import io.libp2p.mux.mplex.MplexStreamMuxer +import io.libp2p.mux.yamux.YamuxStreamMuxer fun interface StreamMuxerProtocol { @@ -18,5 +19,15 @@ fun interface StreamMuxerProtocol { multistreamProtocol ) } + + @JvmStatic + val Yamux = StreamMuxerProtocol { multistreamProtocol, protocols -> + YamuxStreamMuxer( + multistreamProtocol.createMultistream( + protocols + ).toStreamHandler(), + multistreamProtocol + ) + } } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxFrame.kt deleted file mode 100644 index 2d14c3f2e..000000000 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxFrame.kt +++ /dev/null @@ -1,23 +0,0 @@ -package io.libp2p.mux - -import io.libp2p.etc.types.toByteArray -import io.libp2p.etc.types.toHex -import io.libp2p.etc.util.netty.mux.MuxId -import io.netty.buffer.ByteBuf -import io.netty.buffer.DefaultByteBufHolder -import io.netty.buffer.Unpooled - -open class MuxFrame(val id: MuxId, val flag: Flag, val data: ByteBuf? = null) : - DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) { - - enum class Flag { - OPEN, - DATA, - CLOSE, - RESET - } - - override fun toString(): String { - return "MuxFrame(id=$id, flag=$flag, data=${data?.toByteArray()?.toHex()})" - } -} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt index ce10cd67e..71a56ed6a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt @@ -9,22 +9,18 @@ import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.CONNECTION import io.libp2p.etc.STREAM import io.libp2p.etc.types.forward -import io.libp2p.etc.types.sliceMaxSize import io.libp2p.etc.util.netty.mux.AbstractMuxHandler import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxChannelInitializer -import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.transport.implementation.StreamOverNetty import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture -import java.util.concurrent.atomic.AtomicLong abstract class MuxHandler( private val ready: CompletableFuture?, inboundStreamHandler: StreamHandler<*> ) : AbstractMuxHandler(), StreamMuxer.Session { - private val idGenerator = AtomicLong(0xF) protected abstract val multistreamProtocol: MultistreamProtocol protected abstract val maxFrameDataLength: Int @@ -38,45 +34,6 @@ abstract class MuxHandler( ready?.complete(this) } - override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { - msg as MuxFrame - when (msg.flag) { - MuxFrame.Flag.OPEN -> onRemoteOpen(msg.id) - MuxFrame.Flag.CLOSE -> onRemoteDisconnect(msg.id) - MuxFrame.Flag.RESET -> onRemoteClose(msg.id) - MuxFrame.Flag.DATA -> childRead(msg.id, msg.data!!) - } - } - - override fun onChildWrite(child: MuxChannel, data: ByteBuf) { - val ctx = getChannelHandlerContext() - data.sliceMaxSize(maxFrameDataLength) - .map { frameSliceBuf -> - MuxFrame(child.id, MuxFrame.Flag.DATA, frameSliceBuf) - }.forEach { muxFrame -> - ctx.write(muxFrame) - } - ctx.flush() - } - - override fun onLocalOpen(child: MuxChannel) { - getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.OPEN)) - } - - override fun onLocalDisconnect(child: MuxChannel) { - getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.CLOSE)) - } - - override fun onLocalClose(child: MuxChannel) { - getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.RESET)) - } - - override fun onRemoteCreated(child: MuxChannel) { - } - - override fun generateNextId() = - MuxId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true) - private fun createStream(channel: MuxChannel): Stream { val connection = ctx!!.channel().attr(CONNECTION).get() val stream = StreamOverNetty(channel, connection, channel.initiator) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlag.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlag.kt new file mode 100644 index 000000000..6cc4c685b --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlag.kt @@ -0,0 +1,62 @@ +/* + * Copyright 2019 BLK Technologies Limited (web3labs.com). + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package io.libp2p.mux.mplex + +import io.libp2p.mux.mplex.MplexFlag.Type.* + +/** + * Contains all the permissible values for flags in the mplex protocol. + */ +enum class MplexFlag( + val value: Int, + val type: Type +) { + NewStream(0, OPEN), + MessageReceiver(1, DATA), + MessageInitiator(2, DATA), + CloseReceiver(3, CLOSE), + CloseInitiator(4, CLOSE), + ResetReceiver(5, RESET), + ResetInitiator(6, RESET); + + enum class Type { + OPEN, + DATA, + CLOSE, + RESET + } + + val isInitiator get() = value % 2 == 0 + + private val initiatorString get() = when (isInitiator) { + true -> "init" + false -> "resp" + } + + override fun toString(): String = "$type($initiatorString)" + + companion object { + private val valueToFlag = MplexFlag.values().associateBy { it.value } + + fun getByValue(flagValue: Int): MplexFlag = + valueToFlag[flagValue] ?: throw IllegalArgumentException("Invalid Mplex stream tag: $flagValue") + + fun getByType(type: Type, initiator: Boolean): MplexFlag = + when (type) { + OPEN -> NewStream + DATA -> if (initiator) MessageInitiator else MessageReceiver + CLOSE -> if (initiator) CloseInitiator else CloseReceiver + RESET -> if (initiator) ResetInitiator else ResetReceiver + } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlags.kt deleted file mode 100644 index b42431260..000000000 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlags.kt +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2019 BLK Technologies Limited (web3labs.com). - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ -package io.libp2p.mux.mplex - -import io.libp2p.core.Libp2pException -import io.libp2p.mux.MuxFrame -import io.libp2p.mux.MuxFrame.Flag.CLOSE -import io.libp2p.mux.MuxFrame.Flag.DATA -import io.libp2p.mux.MuxFrame.Flag.OPEN -import io.libp2p.mux.MuxFrame.Flag.RESET - -/** - * Contains all the permissible values for flags in the mplex protocol. - */ -object MplexFlags { - const val NewStream = 0 - const val MessageReceiver = 1 - const val MessageInitiator = 2 - const val CloseReceiver = 3 - const val CloseInitiator = 4 - const val ResetReceiver = 5 - const val ResetInitiator = 6 - - fun isInitiator(mplexFlag: Int) = mplexFlag % 2 == 0 - - fun toAbstractFlag(mplexFlag: Int): MuxFrame.Flag = - when (mplexFlag) { - NewStream -> OPEN - MessageReceiver, MessageInitiator -> DATA - CloseReceiver, CloseInitiator -> CLOSE - ResetReceiver, ResetInitiator -> RESET - else -> throw Libp2pException("Unknown mplex flag: $mplexFlag") - } - - fun toMplexFlag(abstractFlag: MuxFrame.Flag, initiator: Boolean): Int = - when (abstractFlag) { - OPEN -> NewStream - DATA -> if (initiator) MessageInitiator else MessageReceiver - CLOSE -> if (initiator) CloseInitiator else CloseReceiver - RESET -> if (initiator) ResetInitiator else ResetReceiver - } -} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt index b38b52f0c..13868402d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt @@ -12,11 +12,9 @@ */ package io.libp2p.mux.mplex -import io.libp2p.etc.types.toByteArray -import io.libp2p.etc.types.toHex import io.libp2p.etc.util.netty.mux.MuxId -import io.libp2p.mux.MuxFrame import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled /** * Contains the fields that comprise an mplex frame. @@ -26,11 +24,16 @@ import io.netty.buffer.ByteBuf * @param data the data segment. * @see [mplex documentation](https://github.com/libp2p/specs/tree/master/mplex#opening-a-new-stream) */ -class MplexFrame(channelId: MuxId, val mplexFlag: Int, data: ByteBuf? = null) : - MuxFrame(channelId, MplexFlags.toAbstractFlag(mplexFlag), data) { +data class MplexFrame(val id: MuxId, val flag: MplexFlag, val data: ByteBuf) { - override fun toString(): String { - val init = if (MplexFlags.isInitiator(mplexFlag)) "init" else "resp" - return "MplexFrame(id=$id, flag=$flag ($init), data=${data?.toByteArray()?.toHex()})" + companion object { + fun createDataFrame(id: MuxId, data: ByteBuf) = + MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.DATA, id.initiator), data) + fun createOpenFrame(id: MuxId) = + MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.OPEN, id.initiator), Unpooled.EMPTY_BUFFER) + fun createCloseFrame(id: MuxId) = + MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.CLOSE, id.initiator), Unpooled.EMPTY_BUFFER) + fun createResetFrame(id: MuxId) = + MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.RESET, id.initiator), Unpooled.EMPTY_BUFFER) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt index a31658e38..9abe21ed8 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt @@ -16,9 +16,7 @@ import io.libp2p.core.ProtocolViolationException import io.libp2p.etc.types.readUvarint import io.libp2p.etc.types.writeUvarint import io.libp2p.etc.util.netty.mux.MuxId -import io.libp2p.mux.MuxFrame import io.netty.buffer.ByteBuf -import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.ByteToMessageCodec @@ -29,7 +27,7 @@ const val DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH = 1 shl 20 */ class MplexFrameCodec( val maxFrameDataLength: Int = DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH -) : ByteToMessageCodec() { +) : ByteToMessageCodec() { /** * Encodes the given mplex frame into bytes and writes them into the output list. @@ -38,10 +36,10 @@ class MplexFrameCodec( * @param msg the mplex frame. * @param out the list to write the bytes to. */ - override fun encode(ctx: ChannelHandlerContext, msg: MuxFrame, out: ByteBuf) { - out.writeUvarint(msg.id.id.shl(3).or(MplexFlags.toMplexFlag(msg.flag, msg.id.initiator).toLong())) - out.writeUvarint(msg.data?.readableBytes() ?: 0) - out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER) + override fun encode(ctx: ChannelHandlerContext, msg: MplexFrame, out: ByteBuf) { + out.writeUvarint(msg.id.id.shl(3).or(msg.flag.value.toLong())) + out.writeUvarint(msg.data.readableBytes()) + out.writeBytes(msg.data) } /** @@ -76,8 +74,8 @@ class MplexFrameCodec( val streamId = header.shr(3) val data = msg.readSlice(lenData.toInt()) data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed - val initiator = if (streamTag == MplexFlags.NewStream) false else !MplexFlags.isInitiator(streamTag) - val mplexFrame = MplexFrame(MuxId(ctx.channel().id(), streamId, initiator), streamTag, data) + val flag = MplexFlag.getByValue(streamTag) + val mplexFrame = MplexFrame(MuxId(ctx.channel().id(), streamId, !flag.isInitiator), flag, data) out.add(mplexFrame) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt index 0cff7ffee..f886b3247 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt @@ -3,12 +3,60 @@ package io.libp2p.mux.mplex import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer +import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.util.netty.mux.MuxChannel +import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicLong open class MplexHandler( override val multistreamProtocol: MultistreamProtocol, override val maxFrameDataLength: Int, ready: CompletableFuture?, inboundStreamHandler: StreamHandler<*> -) : MuxHandler(ready, inboundStreamHandler) +) : MuxHandler(ready, inboundStreamHandler) { + + private val idGenerator = AtomicLong(0xF) + + override fun generateNextId() = + MuxId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true) + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as MplexFrame + when (msg.flag.type) { + MplexFlag.Type.OPEN -> onRemoteOpen(msg.id) + MplexFlag.Type.CLOSE -> onRemoteDisconnect(msg.id) + MplexFlag.Type.RESET -> onRemoteClose(msg.id) + MplexFlag.Type.DATA -> childRead(msg.id, msg.data) + } + } + + override fun onChildWrite(child: MuxChannel, data: ByteBuf) { + val ctx = getChannelHandlerContext() + data.sliceMaxSize(maxFrameDataLength) + .map { frameSliceBuf -> + MplexFrame.createDataFrame(child.id, frameSliceBuf) + }.forEach { muxFrame -> + ctx.write(muxFrame) + } + ctx.flush() + } + + override fun onLocalOpen(child: MuxChannel) { + getChannelHandlerContext().writeAndFlush(MplexFrame.createOpenFrame(child.id)) + } + + override fun onLocalDisconnect(child: MuxChannel) { + getChannelHandlerContext().writeAndFlush(MplexFrame.createCloseFrame(child.id)) + } + + override fun onLocalClose(child: MuxChannel) { + getChannelHandlerContext().writeAndFlush(MplexFrame.createResetFrame(child.id)) + } + + override fun onRemoteCreated(child: MuxChannel) { + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt new file mode 100644 index 000000000..85499d0dd --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt @@ -0,0 +1,11 @@ +package io.libp2p.mux.yamux + +/** + * Contains all the permissible values for flags in the yamux protocol. + */ +object YamuxFlags { + const val SYN = 1 + const val ACK = 2 + const val FIN = 4 + const val RST = 8 +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt new file mode 100644 index 000000000..fefdf1aee --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt @@ -0,0 +1,23 @@ +package io.libp2p.mux.yamux + +import io.libp2p.etc.types.toByteArray +import io.libp2p.etc.util.netty.mux.MuxId +import io.netty.buffer.ByteBuf +import io.netty.buffer.DefaultByteBufHolder +import io.netty.buffer.Unpooled + +/** + * Contains the fields that comprise a yamux frame. + * @param streamId the ID of the stream. + * @param flag the flag value for this frame. + * @param data the data segment. + */ +class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val lenData: Long, val data: ByteBuf? = null) : + DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) { + + override fun toString(): String { + if (data == null) + return "YamuxFrame(id=$id, type=$type, flag=$flags)" + return "YamuxFrame(id=$id, type=$type, flag=$flags, data=${String(data.toByteArray())})" + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt new file mode 100644 index 000000000..d21fb2d4f --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt @@ -0,0 +1,81 @@ +package io.libp2p.mux.yamux + +import io.libp2p.core.ProtocolViolationException +import io.libp2p.etc.util.netty.mux.MuxId +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.ByteToMessageCodec + +const val DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH = 1 shl 20 + +/** + * A Netty codec implementation that converts [YamuxFrame] instances to [ByteBuf] and vice-versa. + */ +class YamuxFrameCodec( + val isInitiator: Boolean, + val maxFrameDataLength: Int = DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH +) : ByteToMessageCodec() { + + /** + * Encodes the given yamux frame into bytes and writes them into the output list. + * @see [https://github.com/hashicorp/yamux/blob/master/spec.md] + * @param ctx the context. + * @param msg the yamux frame. + * @param out the list to write the bytes to. + */ + override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) { + out.writeByte(0) // version + out.writeByte(msg.type) + out.writeShort(msg.flags) + out.writeInt(msg.id.id.toInt()) + out.writeInt(msg.data?.readableBytes() ?: msg.lenData.toInt()) + out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER) + } + + /** + * Decodes the bytes in the given byte buffer and constructs a [YamuxFrame] that is written into + * the output list. + * @param ctx the context. + * @param msg the byte buffer. + * @param out the list to write the extracted frame to. + */ + override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList) { + while (msg.isReadable) { + if (msg.readableBytes() < 12) + return + val readerIndex = msg.readerIndex() + msg.readByte(); // version always 0 + val type = msg.readUnsignedByte() + val flags = msg.readUnsignedShort() + val streamId = msg.readUnsignedInt() + val lenData = msg.readUnsignedInt() + if (type.toInt() != YamuxType.DATA) { + val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData) + out.add(yamuxFrame) + continue + } + if (lenData > maxFrameDataLength) { + msg.skipBytes(msg.readableBytes()) + throw ProtocolViolationException("Yamux frame is too large: $lenData") + } + if (msg.readableBytes() < lenData) { + // not enough data to read the frame content + // will wait for more ... + msg.readerIndex(readerIndex) + return + } + val data = msg.readSlice(lenData.toInt()) + data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed + val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData, data) + out.add(yamuxFrame) + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + // notify higher level handlers on the error + ctx.fireExceptionCaught(cause) + // exceptions in [decode] are very likely unrecoverable so just close the connection + ctx.close() + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt new file mode 100644 index 000000000..645c54c78 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -0,0 +1,174 @@ +package io.libp2p.mux.yamux + +import io.libp2p.core.Libp2pException +import io.libp2p.core.StreamHandler +import io.libp2p.core.multistream.MultistreamProtocol +import io.libp2p.core.mux.StreamMuxer +import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.util.netty.mux.MuxChannel +import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.MuxHandler +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +const val INITIAL_WINDOW_SIZE = 256 * 1024 +const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024 + +open class YamuxHandler( + override val multistreamProtocol: MultistreamProtocol, + override val maxFrameDataLength: Int, + ready: CompletableFuture?, + inboundStreamHandler: StreamHandler<*>, + initiator: Boolean +) : MuxHandler(ready, inboundStreamHandler) { + private val idGenerator = AtomicInteger(if (initiator) 1 else 2) // 0 is reserved + private val receiveWindows = ConcurrentHashMap() + private val sendWindows = ConcurrentHashMap() + private val sendBuffers = ConcurrentHashMap() + private val totalBufferedWrites = AtomicInteger() + + inner class SendBuffer(val ctx: ChannelHandlerContext) { + private val buffered = ArrayDeque() + + fun add(data: ByteBuf) { + buffered.add(data) + } + + fun flush(sendWindow: AtomicInteger, id: MuxId): Int { + var written = 0 + while (! buffered.isEmpty()) { + val buf = buffered.first() + if (buf.readableBytes() + written < sendWindow.get()) { + buffered.removeFirst() + sendBlocks(ctx, buf, sendWindow, id) + written += buf.readableBytes() + } else + break + } + return written + } + } + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as YamuxFrame + when (msg.type) { + YamuxType.DATA -> handleDataRead(msg) + YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) + YamuxType.PING -> handlePing(msg) + YamuxType.GO_AWAY -> onRemoteClose(msg.id) + } + } + + fun handlePing(msg: YamuxFrame) { + val ctx = getChannelHandlerContext() + when (msg.flags) { + YamuxFlags.SYN -> ctx.writeAndFlush(YamuxFrame(MuxId(msg.id.parentId, 0, msg.id.initiator), YamuxType.PING, YamuxFlags.ACK, msg.lenData)) + YamuxFlags.ACK -> {} + } + } + + fun handleFlags(msg: YamuxFrame) { + val ctx = getChannelHandlerContext() + if (msg.flags == YamuxFlags.SYN) { + // ACK the new stream + onRemoteOpen(msg.id) + ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + } + if (msg.flags == YamuxFlags.FIN) + onRemoteDisconnect(msg.id) + } + + fun handleDataRead(msg: YamuxFrame) { + val ctx = getChannelHandlerContext() + val size = msg.lenData + handleFlags(msg) + if (size.toInt() == 0) + return + val recWindow = receiveWindows.get(msg.id) + if (recWindow == null) + throw Libp2pException("No receive window for " + msg.id) + val newWindow = recWindow.addAndGet(-size.toInt()) + if (newWindow < INITIAL_WINDOW_SIZE / 2) { + val delta = INITIAL_WINDOW_SIZE / 2 + recWindow.addAndGet(delta) + ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) + ctx.flush() + } + childRead(msg.id, msg.data!!) + } + + fun handleWindowUpdate(msg: YamuxFrame) { + handleFlags(msg) + val size = msg.lenData.toInt() + val sendWindow = sendWindows.get(msg.id) + if (sendWindow == null) + throw Libp2pException("No send window for " + msg.id) + sendWindow.addAndGet(size) + val buffer = sendBuffers.get(msg.id) + if (buffer != null) { + val writtenBytes = buffer.flush(sendWindow, msg.id) + totalBufferedWrites.addAndGet(-writtenBytes) + } + } + + override fun onChildWrite(child: MuxChannel, data: ByteBuf) { + val ctx = getChannelHandlerContext() + + val sendWindow = sendWindows.get(child.id) + if (sendWindow == null) + throw Libp2pException("No send window for " + child.id) + if (sendWindow.get() <= 0) { + // wait until the window is increased to send more data + val buffer = sendBuffers.getOrPut(child.id, { SendBuffer(ctx) }) + buffer.add(data) + if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES) + throw Libp2pException("Overflowed send buffer for connection") + return + } + sendBlocks(ctx, data, sendWindow, child.id) + } + + fun sendBlocks(ctx: ChannelHandlerContext, data: ByteBuf, sendWindow: AtomicInteger, id: MuxId) { + data.sliceMaxSize(minOf(maxFrameDataLength, sendWindow.get())) + .map { frameSliceBuf -> + sendWindow.addAndGet(-frameSliceBuf.readableBytes()) + YamuxFrame(id, YamuxType.DATA, 0, frameSliceBuf.readableBytes().toLong(), frameSliceBuf) + }.forEach { muxFrame -> + ctx.write(muxFrame) + } + ctx.flush() + } + + override fun onLocalOpen(child: MuxChannel) { + getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0)) + receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + } + + override fun onLocalDisconnect(child: MuxChannel) { + sendWindows.remove(child.id) + receiveWindows.remove(child.id) + sendBuffers.remove(child.id) + getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0)) + } + + override fun onLocalClose(child: MuxChannel) { + getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) + val sendWindow = sendWindows.remove(child.id) + val buffered = sendBuffers.remove(child.id) + if (buffered != null && sendWindow != null) { + buffered.flush(sendWindow, child.id) + } + } + + override fun onRemoteCreated(child: MuxChannel) { + receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + } + + override fun generateNextId() = + MuxId(getChannelHandlerContext().channel().id(), idGenerator.addAndGet(2).toLong(), true) +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt new file mode 100644 index 000000000..4b43a0597 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt @@ -0,0 +1,39 @@ +package io.libp2p.mux.yamux + +import io.libp2p.core.ChannelVisitor +import io.libp2p.core.Connection +import io.libp2p.core.P2PChannel +import io.libp2p.core.StreamHandler +import io.libp2p.core.multistream.MultistreamProtocol +import io.libp2p.core.multistream.ProtocolDescriptor +import io.libp2p.core.mux.StreamMuxer +import io.libp2p.core.mux.StreamMuxerDebug +import java.util.concurrent.CompletableFuture + +class YamuxStreamMuxer( + val inboundStreamHandler: StreamHandler<*>, + private val multistreamProtocol: MultistreamProtocol +) : StreamMuxer, StreamMuxerDebug { + + override val protocolDescriptor = ProtocolDescriptor("/yamux/1.0.0") + override var muxFramesDebugHandler: ChannelVisitor? = null + + override fun initChannel(ch: P2PChannel, selectedProtocol: String): CompletableFuture { + val muxSessionReady = CompletableFuture() + + val yamuxFrameCodec = YamuxFrameCodec(ch.isInitiator) + ch.pushHandler(yamuxFrameCodec) + muxFramesDebugHandler?.also { it.visit(ch as Connection) } + ch.pushHandler( + YamuxHandler( + multistreamProtocol, + yamuxFrameCodec.maxFrameDataLength, + muxSessionReady, + inboundStreamHandler, + ch.isInitiator + ) + ) + + return muxSessionReady + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt new file mode 100644 index 000000000..cf66f4b8b --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt @@ -0,0 +1,11 @@ +package io.libp2p.mux.yamux + +/** + * Contains all the permissible values for flags in the yamux protocol. + */ +object YamuxType { + const val DATA = 0 + const val WINDOW_UPDATE = 1 + const val PING = 2 + const val GO_AWAY = 3 +} diff --git a/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt b/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt index f0ad484d9..eb4513f99 100644 --- a/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/core/HostTest.kt @@ -4,7 +4,7 @@ import io.libp2p.core.multistream.ProtocolMatcher import io.libp2p.etc.PROTOCOL import io.libp2p.etc.types.seconds import io.libp2p.etc.types.toByteArray -import io.libp2p.mux.MuxFrame +import io.libp2p.mux.mplex.MplexFrame import io.libp2p.protocol.Ping import io.libp2p.protocol.PingBinding import io.libp2p.protocol.PingProtocol @@ -131,7 +131,7 @@ class HostTest { val afterSecureTestHandler1 = TestByteBufChannelHandler("1-afterSecure") val preStreamTestHandler1 = TestByteBufChannelHandler("1-preStream") val streamTestHandler1 = TestByteBufChannelHandler("1-stream") - val muxFrameTestHandler1 = TestChannelHandler("1-mux") + val muxFrameTestHandler1 = TestChannelHandler("1-mux") hostFactory.hostBuilderModifier = { debug { @@ -148,7 +148,7 @@ class HostTest { val afterSecureTestHandler2 = TestByteBufChannelHandler("2-afterSecure") val preStreamTestHandler2 = TestByteBufChannelHandler("2-preStream") val streamTestHandler2 = TestByteBufChannelHandler("2-stream") - val muxFrameTestHandler2 = TestChannelHandler("2-mux") + val muxFrameTestHandler2 = TestChannelHandler("2-mux") hostFactory.hostBuilderModifier = { debug { diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MultiplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt similarity index 87% rename from libp2p/src/test/kotlin/io/libp2p/mux/MultiplexHandlerTest.kt rename to libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index 5a8e18013..b4ff22a37 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MultiplexHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -4,19 +4,10 @@ import io.libp2p.core.ConnectionClosedException import io.libp2p.core.Libp2pException import io.libp2p.core.Stream import io.libp2p.core.StreamHandler -import io.libp2p.core.multistream.MultistreamProtocolV1 -import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.getX import io.libp2p.etc.types.toByteArray -import io.libp2p.etc.types.toByteBuf import io.libp2p.etc.types.toHex -import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.etc.util.netty.nettyInitializer -import io.libp2p.mux.MuxFrame.Flag.DATA -import io.libp2p.mux.MuxFrame.Flag.OPEN -import io.libp2p.mux.MuxFrame.Flag.RESET -import io.libp2p.mux.mplex.DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH -import io.libp2p.mux.mplex.MplexHandler import io.libp2p.tools.TestChannel import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandler @@ -37,12 +28,14 @@ import java.util.concurrent.CompletableFuture /** * Created by Anton Nashatyrev on 09.07.2019. */ -class MultiplexHandlerTest { +abstract class MuxHandlerAbstractTest { val dummyParentChannelId = DefaultChannelId.newInstance() val childHandlers = mutableListOf() lateinit var multistreamHandler: MuxHandler lateinit var ech: TestChannel + abstract fun createMuxHandler(streamHandler: StreamHandler): MuxHandler + @BeforeEach fun startMultiplexor() { childHandlers.clear() @@ -54,19 +47,29 @@ class MultiplexHandlerTest { childHandlers += handler } ) - multistreamHandler = object : MplexHandler( - MultistreamProtocolV1, DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH, null, streamHandler - ) { - // MuxHandler consumes the exception. Override this behaviour for testing - @Deprecated("Deprecated in Java") - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { - ctx.fireExceptionCaught(cause) - } - } + multistreamHandler = createMuxHandler(streamHandler) ech = TestChannel("test", true, LoggingHandler(LogLevel.ERROR), multistreamHandler) } + abstract fun openStream(id: Long): Boolean + abstract fun writeStream(id: Long, msg: String): Boolean + abstract fun resetStream(id: Long): Boolean + + fun createStreamHandler(channelInitializer: ChannelHandler) = object : StreamHandler { + override fun handleStream(stream: Stream): CompletableFuture { + stream.pushHandler(channelInitializer) + return CompletableFuture.completedFuture(Unit) + } + } + + fun assertHandlerCount(count: Int) = assertEquals(count, childHandlers.size) + fun assertLastMessage(handler: Int, msgCount: Int, msg: String) { + val messages = childHandlers[handler].inboundMessages + assertEquals(msgCount, messages.size) + assertEquals(msg, messages.last()) + } + @Test fun singleStream() { openStream(12) @@ -238,26 +241,6 @@ class MultiplexHandlerTest { assertThrows(ConnectionClosedException::class.java) { staleStream.stream.getX(3.0) } } - fun assertHandlerCount(count: Int) = assertEquals(count, childHandlers.size) - fun assertLastMessage(handler: Int, msgCount: Int, msg: String) { - val messages = childHandlers[handler].inboundMessages - assertEquals(msgCount, messages.size) - assertEquals(msg, messages.last()) - } - - fun openStream(id: Long) = writeFrame(id, OPEN) - fun writeStream(id: Long, msg: String) = writeFrame(id, DATA, msg.fromHex().toByteBuf()) - fun resetStream(id: Long) = writeFrame(id, RESET) - fun writeFrame(id: Long, flag: MuxFrame.Flag, data: ByteBuf? = null) = - ech.writeInbound(MuxFrame(MuxId(dummyParentChannelId, id, true), flag, data)) - - fun createStreamHandler(channelInitializer: ChannelHandler) = object : StreamHandler { - override fun handleStream(stream: Stream): CompletableFuture { - stream.pushHandler(channelInitializer) - return CompletableFuture.completedFuture(Unit) - } - } - class TestHandler : ChannelInboundHandlerAdapter() { val inboundMessages = mutableListOf() var ctx: ChannelHandlerContext? = null diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt index 031350c55..8139f61df 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexFrameCodecTest.kt @@ -36,7 +36,7 @@ class MplexFrameCodecTest { val channelLarge = EmbeddedChannel(MplexFrameCodec(maxFrameDataLength = 1024)) val mplexFrame = MplexFrame( - MuxId(dummyId, 777, true), MplexFlags.MessageInitiator, + MuxId(dummyId, 777, true), MplexFlag.MessageInitiator, ByteArray(1024).toByteBuf() ) @@ -61,9 +61,9 @@ class MplexFrameCodecTest { val channel = EmbeddedChannel(MplexFrameCodec()) val mplexFrames = arrayOf( - MplexFrame(MuxId(dummyId, 777, true), MplexFlags.MessageInitiator, "Hello-1".toByteArray().toByteBuf()), - MplexFrame(MuxId(dummyId, 888, true), MplexFlags.MessageInitiator, "Hello-2".toByteArray().toByteBuf()), - MplexFrame(MuxId(dummyId, 999, true), MplexFlags.MessageInitiator, "Hello-3".toByteArray().toByteBuf()) + MplexFrame(MuxId(dummyId, 777, true), MplexFlag.MessageInitiator, "Hello-1".toByteArray().toByteBuf()), + MplexFrame(MuxId(dummyId, 888, true), MplexFlag.MessageInitiator, "Hello-2".toByteArray().toByteBuf()), + MplexFrame(MuxId(dummyId, 999, true), MplexFlag.MessageInitiator, "Hello-3".toByteArray().toByteBuf()) ) assertTrue( channel.writeOutbound(*mplexFrames) @@ -86,8 +86,36 @@ class MplexFrameCodecTest { assertEquals(777, resultFrames[0].id.id) assertEquals(888, resultFrames[1].id.id) assertEquals(999, resultFrames[2].id.id) - assertEquals("Hello-1", resultFrames[0].data!!.toByteArray().toString(UTF_8)) - assertEquals("Hello-2", resultFrames[1].data!!.toByteArray().toString(UTF_8)) - assertEquals("Hello-3", resultFrames[2].data!!.toByteArray().toString(UTF_8)) + assertEquals("Hello-1", resultFrames[0].data.toByteArray().toString(UTF_8)) + assertEquals("Hello-2", resultFrames[1].data.toByteArray().toString(UTF_8)) + assertEquals("Hello-3", resultFrames[2].data.toByteArray().toString(UTF_8)) + } + + @Test + fun `test id initiator is inverted on decoding`() { + val channel = EmbeddedChannel(MplexFrameCodec()) + + val mplexFrames = arrayOf( + MplexFrame.createOpenFrame(MuxId(dummyId, 1, true)), + MplexFrame.createDataFrame(MuxId(dummyId, 2, true), "Hello-2".toByteArray().toByteBuf()), + MplexFrame.createDataFrame(MuxId(dummyId, 3, false), "Hello-3".toByteArray().toByteBuf()), + MplexFrame.createCloseFrame(MuxId(dummyId, 4, true)), + MplexFrame.createCloseFrame(MuxId(dummyId, 5, false)), + MplexFrame.createResetFrame(MuxId(dummyId, 6, true)), + MplexFrame.createResetFrame(MuxId(dummyId, 7, false)), + ) + assertTrue( + channel.writeOutbound(*mplexFrames) + ) + + repeat(mplexFrames.size) { idx -> + val wireBytes = channel.readOutbound() + channel.writeInbound(wireBytes) + val resFrame = channel.readInbound() + + assertEquals(mplexFrames[idx].id.id, resFrame.id.id) + assertEquals(!mplexFrames[idx].id.initiator, resFrame.id.initiator) + assertEquals(mplexFrames[idx].flag, resFrame.flag) + } } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt new file mode 100644 index 000000000..e64115a57 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt @@ -0,0 +1,32 @@ +package io.libp2p.mux.mplex + +import io.libp2p.core.StreamHandler +import io.libp2p.core.multistream.MultistreamProtocolV1 +import io.libp2p.etc.types.fromHex +import io.libp2p.etc.types.toByteBuf +import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.MuxHandler +import io.libp2p.mux.MuxHandlerAbstractTest +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.channel.ChannelHandlerContext + +class MplexHandlerTest : MuxHandlerAbstractTest() { + + override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = + object : MplexHandler( + MultistreamProtocolV1, DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH, null, streamHandler + ) { + // MuxHandler consumes the exception. Override this behaviour for testing + @Deprecated("Deprecated in Java") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + ctx.fireExceptionCaught(cause) + } + } + + override fun openStream(id: Long) = writeFrame(id, MplexFlag.Type.OPEN) + override fun writeStream(id: Long, msg: String) = writeFrame(id, MplexFlag.Type.DATA, msg.fromHex().toByteBuf()) + override fun resetStream(id: Long) = writeFrame(id, MplexFlag.Type.RESET) + fun writeFrame(id: Long, flagType: MplexFlag.Type, data: ByteBuf = Unpooled.EMPTY_BUFFER) = + ech.writeInbound(MplexFrame(MuxId(dummyParentChannelId, id, true), MplexFlag.getByType(flagType, true), data)) +} diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt new file mode 100644 index 000000000..69016f56b --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -0,0 +1,41 @@ +package io.libp2p.mux.yamux + +import io.libp2p.core.StreamHandler +import io.libp2p.core.multistream.MultistreamProtocolV1 +import io.libp2p.etc.types.fromHex +import io.libp2p.etc.types.toByteBuf +import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.MuxHandler +import io.libp2p.mux.MuxHandlerAbstractTest +import io.netty.channel.ChannelHandlerContext + +class YamuxHandlerTest : MuxHandlerAbstractTest() { + + override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = + object : YamuxHandler( + MultistreamProtocolV1, DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH, null, streamHandler, true + ) { + // MuxHandler consumes the exception. Override this behaviour for testing + @Deprecated("Deprecated in Java") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + ctx.fireExceptionCaught(cause) + } + } + + override fun openStream(id: Long) = + ech.writeInbound(YamuxFrame(MuxId(dummyParentChannelId, id, true), YamuxType.DATA, YamuxFlags.SYN, 0)) + + override fun writeStream(id: Long, msg: String) = + ech.writeInbound( + YamuxFrame( + MuxId(dummyParentChannelId, id, true), + YamuxType.DATA, + 0, + msg.fromHex().size.toLong(), + msg.fromHex().toByteBuf() + ) + ) + + override fun resetStream(id: Long) = + ech.writeInbound(YamuxFrame(MuxId(dummyParentChannelId, id, true), YamuxType.GO_AWAY, 0, 0)) +}