Skip to content

Commit

Permalink
Yamux implementation (#281)
Browse files Browse the repository at this point in the history
* Initial yamux implementation
* Implement per stream write buffers and per connection max write buffer size limit.
* Make MuxHandler abstract. Move Mplex specific members to MplexHandler. Derive YamuxHandler from it
* Refactor MuxHandler tests
* Refactor MplexFrame, remove obsolete MuxFrame
* Flush buffered writes in yamux on local disconnect
---------
Co-authored-by: Anton Nashatyrev <[email protected]>
  • Loading branch information
ianopolous committed May 24, 2023
1 parent f42740f commit 8971b31
Show file tree
Hide file tree
Showing 19 changed files with 612 additions and 185 deletions.
11 changes: 11 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.libp2p.core.mux
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.multistream.ProtocolBinding
import io.libp2p.mux.mplex.MplexStreamMuxer
import io.libp2p.mux.yamux.YamuxStreamMuxer

fun interface StreamMuxerProtocol {

Expand All @@ -18,5 +19,15 @@ fun interface StreamMuxerProtocol {
multistreamProtocol
)
}

@JvmStatic
val Yamux = StreamMuxerProtocol { multistreamProtocol, protocols ->
YamuxStreamMuxer(
multistreamProtocol.createMultistream(
protocols
).toStreamHandler(),
multistreamProtocol
)
}
}
}
23 changes: 0 additions & 23 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxFrame.kt

This file was deleted.

43 changes: 0 additions & 43 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,18 @@ import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.CONNECTION
import io.libp2p.etc.STREAM
import io.libp2p.etc.types.forward
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.util.netty.mux.AbstractMuxHandler
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxChannelInitializer
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.transport.implementation.StreamOverNetty
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicLong

abstract class MuxHandler(
private val ready: CompletableFuture<StreamMuxer.Session>?,
inboundStreamHandler: StreamHandler<*>
) : AbstractMuxHandler<ByteBuf>(), StreamMuxer.Session {
private val idGenerator = AtomicLong(0xF)

protected abstract val multistreamProtocol: MultistreamProtocol
protected abstract val maxFrameDataLength: Int
Expand All @@ -38,45 +34,6 @@ abstract class MuxHandler(
ready?.complete(this)
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
msg as MuxFrame
when (msg.flag) {
MuxFrame.Flag.OPEN -> onRemoteOpen(msg.id)
MuxFrame.Flag.CLOSE -> onRemoteDisconnect(msg.id)
MuxFrame.Flag.RESET -> onRemoteClose(msg.id)
MuxFrame.Flag.DATA -> childRead(msg.id, msg.data!!)
}
}

override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()
data.sliceMaxSize(maxFrameDataLength)
.map { frameSliceBuf ->
MuxFrame(child.id, MuxFrame.Flag.DATA, frameSliceBuf)
}.forEach { muxFrame ->
ctx.write(muxFrame)
}
ctx.flush()
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.OPEN))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.CLOSE))
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MuxFrame(child.id, MuxFrame.Flag.RESET))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
}

override fun generateNextId() =
MuxId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true)

private fun createStream(channel: MuxChannel<ByteBuf>): Stream {
val connection = ctx!!.channel().attr(CONNECTION).get()
val stream = StreamOverNetty(channel, connection, channel.initiator)
Expand Down
62 changes: 62 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlag.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2019 BLK Technologies Limited (web3labs.com).
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package io.libp2p.mux.mplex

import io.libp2p.mux.mplex.MplexFlag.Type.*

/**
* Contains all the permissible values for flags in the <code>mplex</code> protocol.
*/
enum class MplexFlag(
val value: Int,
val type: Type
) {
NewStream(0, OPEN),
MessageReceiver(1, DATA),
MessageInitiator(2, DATA),
CloseReceiver(3, CLOSE),
CloseInitiator(4, CLOSE),
ResetReceiver(5, RESET),
ResetInitiator(6, RESET);

enum class Type {
OPEN,
DATA,
CLOSE,
RESET
}

val isInitiator get() = value % 2 == 0

private val initiatorString get() = when (isInitiator) {
true -> "init"
false -> "resp"
}

override fun toString(): String = "$type($initiatorString)"

companion object {
private val valueToFlag = MplexFlag.values().associateBy { it.value }

fun getByValue(flagValue: Int): MplexFlag =
valueToFlag[flagValue] ?: throw IllegalArgumentException("Invalid Mplex stream tag: $flagValue")

fun getByType(type: Type, initiator: Boolean): MplexFlag =
when (type) {
OPEN -> NewStream
DATA -> if (initiator) MessageInitiator else MessageReceiver
CLOSE -> if (initiator) CloseInitiator else CloseReceiver
RESET -> if (initiator) ResetInitiator else ResetReceiver
}
}
}
52 changes: 0 additions & 52 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFlags.kt

This file was deleted.

19 changes: 11 additions & 8 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrame.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
*/
package io.libp2p.mux.mplex

import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxFrame
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled

/**
* Contains the fields that comprise an mplex frame.
Expand All @@ -26,11 +24,16 @@ import io.netty.buffer.ByteBuf
* @param data the data segment.
* @see [mplex documentation](https://github.com/libp2p/specs/tree/master/mplex#opening-a-new-stream)
*/
class MplexFrame(channelId: MuxId, val mplexFlag: Int, data: ByteBuf? = null) :
MuxFrame(channelId, MplexFlags.toAbstractFlag(mplexFlag), data) {
data class MplexFrame(val id: MuxId, val flag: MplexFlag, val data: ByteBuf) {

override fun toString(): String {
val init = if (MplexFlags.isInitiator(mplexFlag)) "init" else "resp"
return "MplexFrame(id=$id, flag=$flag ($init), data=${data?.toByteArray()?.toHex()})"
companion object {
fun createDataFrame(id: MuxId, data: ByteBuf) =
MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.DATA, id.initiator), data)
fun createOpenFrame(id: MuxId) =
MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.OPEN, id.initiator), Unpooled.EMPTY_BUFFER)
fun createCloseFrame(id: MuxId) =
MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.CLOSE, id.initiator), Unpooled.EMPTY_BUFFER)
fun createResetFrame(id: MuxId) =
MplexFrame(id, MplexFlag.getByType(MplexFlag.Type.RESET, id.initiator), Unpooled.EMPTY_BUFFER)
}
}
16 changes: 7 additions & 9 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexFrameCodec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import io.libp2p.core.ProtocolViolationException
import io.libp2p.etc.types.readUvarint
import io.libp2p.etc.types.writeUvarint
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxFrame
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.ByteToMessageCodec

Expand All @@ -29,7 +27,7 @@ const val DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH = 1 shl 20
*/
class MplexFrameCodec(
val maxFrameDataLength: Int = DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH
) : ByteToMessageCodec<MuxFrame>() {
) : ByteToMessageCodec<MplexFrame>() {

/**
* Encodes the given mplex frame into bytes and writes them into the output list.
Expand All @@ -38,10 +36,10 @@ class MplexFrameCodec(
* @param msg the mplex frame.
* @param out the list to write the bytes to.
*/
override fun encode(ctx: ChannelHandlerContext, msg: MuxFrame, out: ByteBuf) {
out.writeUvarint(msg.id.id.shl(3).or(MplexFlags.toMplexFlag(msg.flag, msg.id.initiator).toLong()))
out.writeUvarint(msg.data?.readableBytes() ?: 0)
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
override fun encode(ctx: ChannelHandlerContext, msg: MplexFrame, out: ByteBuf) {
out.writeUvarint(msg.id.id.shl(3).or(msg.flag.value.toLong()))
out.writeUvarint(msg.data.readableBytes())
out.writeBytes(msg.data)
}

/**
Expand Down Expand Up @@ -76,8 +74,8 @@ class MplexFrameCodec(
val streamId = header.shr(3)
val data = msg.readSlice(lenData.toInt())
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val initiator = if (streamTag == MplexFlags.NewStream) false else !MplexFlags.isInitiator(streamTag)
val mplexFrame = MplexFrame(MuxId(ctx.channel().id(), streamId, initiator), streamTag, data)
val flag = MplexFlag.getByValue(streamTag)
val mplexFrame = MplexFrame(MuxId(ctx.channel().id(), streamId, !flag.isInitiator), flag, data)
out.add(mplexFrame)
}
}
Expand Down
50 changes: 49 additions & 1 deletion libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,60 @@ package io.libp2p.mux.mplex
import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicLong

open class MplexHandler(
override val multistreamProtocol: MultistreamProtocol,
override val maxFrameDataLength: Int,
ready: CompletableFuture<StreamMuxer.Session>?,
inboundStreamHandler: StreamHandler<*>
) : MuxHandler(ready, inboundStreamHandler)
) : MuxHandler(ready, inboundStreamHandler) {

private val idGenerator = AtomicLong(0xF)

override fun generateNextId() =
MuxId(getChannelHandlerContext().channel().id(), idGenerator.incrementAndGet(), true)

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
msg as MplexFrame
when (msg.flag.type) {
MplexFlag.Type.OPEN -> onRemoteOpen(msg.id)
MplexFlag.Type.CLOSE -> onRemoteDisconnect(msg.id)
MplexFlag.Type.RESET -> onRemoteClose(msg.id)
MplexFlag.Type.DATA -> childRead(msg.id, msg.data)
}
}

override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()
data.sliceMaxSize(maxFrameDataLength)
.map { frameSliceBuf ->
MplexFrame.createDataFrame(child.id, frameSliceBuf)
}.forEach { muxFrame ->
ctx.write(muxFrame)
}
ctx.flush()
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MplexFrame.createOpenFrame(child.id))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MplexFrame.createCloseFrame(child.id))
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(MplexFrame.createResetFrame(child.id))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
}
}
11 changes: 11 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.libp2p.mux.yamux

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
*/
object YamuxFlags {
const val SYN = 1
const val ACK = 2
const val FIN = 4
const val RST = 8
}
Loading

0 comments on commit 8971b31

Please sign in to comment.