Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yamux implementation #281

Merged
merged 26 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2b642cd
First attempt at yamux implementation
ianopolous Feb 10, 2023
5c76004
Fix deadlock in yamux window update
ianopolous Feb 10, 2023
db965af
Fix condition to send yamux window updates
ianopolous Feb 10, 2023
100e163
Remove debug
ianopolous Feb 10, 2023
395e5a6
Add yamux test. Fix new stream bug in yamux
ianopolous Feb 12, 2023
d9c078b
Fix yamux bug opening reverse stream on existing connection!
ianopolous Apr 18, 2023
8071148
move yamux files to new structure
ianopolous May 18, 2023
e389c6d
linting
ianopolous May 18, 2023
5c94d4d
add deprecated annotation in yamux test
ianopolous May 18, 2023
1df8489
Make sure there are enough bytes to read yamux header
ianopolous May 19, 2023
bb29ac7
Read unsigned ints in yamux decoder for length and stream id
ianopolous May 19, 2023
4426d71
Track yamux windows per stream
ianopolous May 19, 2023
af74486
Flush yamux acks
ianopolous May 19, 2023
d75ac36
Use Libp2pExeption in yamux for missing stream
ianopolous May 19, 2023
586b1d8
Allow decoding yamux frames in buffer after a non data frame
ianopolous May 19, 2023
ace7a18
Implement per stream write buffers and
ianopolous May 23, 2023
2f3691e
Make MuxHandler abstract. Move Mplex specific members to MplexHandler…
Nashatyrev May 23, 2023
070bcf1
Refactor MuxHandler tests
Nashatyrev May 23, 2023
0c9cf42
Refactor MplexFrame, remove obsolete MuxFrame
Nashatyrev May 23, 2023
6cf1f09
Flush buffered writes in yamux on local disconnect
ianopolous May 23, 2023
c927bc2
Fix tests after refactor
Nashatyrev May 23, 2023
8479b8e
Formatting
Nashatyrev May 23, 2023
2614d4f
Fix regression
Nashatyrev May 23, 2023
754f52d
Merge remote-tracking branch 'Peergos/upstream-yamux' into refactor-mux
Nashatyrev May 23, 2023
87cd661
Merge pull request #7 from Nashatyrev/refactor-mux
ianopolous May 23, 2023
8663c94
linting
ianopolous May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.libp2p.core.mux
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.multistream.ProtocolBinding
import io.libp2p.mux.mplex.MplexStreamMuxer
import io.libp2p.mux.yamux.YamuxStreamMuxer

fun interface StreamMuxerProtocol {

Expand All @@ -18,5 +19,15 @@ fun interface StreamMuxerProtocol {
multistreamProtocol
)
}

@JvmStatic
val Yamux = StreamMuxerProtocol { multistreamProtocol, protocols ->
YamuxStreamMuxer(
multistreamProtocol.createMultistream(
protocols
).toStreamHandler(),
multistreamProtocol
)
}
}
}
11 changes: 11 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.libp2p.mux.yamux

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
*/
object YamuxFlags {
const val SYN = 1
const val ACK = 2
const val FIN = 4
const val RST = 8
}
23 changes: 23 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.libp2p.mux.yamux

import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.util.netty.mux.MuxId
import io.netty.buffer.ByteBuf
import io.netty.buffer.DefaultByteBufHolder
import io.netty.buffer.Unpooled

/**
* Contains the fields that comprise a yamux frame.
* @param streamId the ID of the stream.
* @param flag the flag value for this frame.
* @param data the data segment.
*/
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val lenData: Int, val data: ByteBuf? = null) :
DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) {

override fun toString(): String {
if (data == null)
return "YamuxFrame(id=$id, type=$type, flag=$flags)"
return "YamuxFrame(id=$id, type=$type, flag=$flags, data=${String(data.toByteArray())})"
}
}
85 changes: 85 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.libp2p.mux.yamux

import io.libp2p.core.ProtocolViolationException
import io.libp2p.etc.util.netty.mux.MuxId
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.ByteToMessageCodec

const val DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH = 1 shl 20

/**
* A Netty codec implementation that converts [YamuxFrame] instances to [ByteBuf] and vice-versa.
*/
class YamuxFrameCodec(
val isInitiator: Boolean,
val maxFrameDataLength: Int = DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH
) : ByteToMessageCodec<YamuxFrame>() {

/**
* Encodes the given yamux frame into bytes and writes them into the output list.
* @see [https://github.com/hashicorp/yamux/blob/master/spec.md]
* @param ctx the context.
* @param msg the yamux frame.
* @param out the list to write the bytes to.
*/
override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) {
out.writeByte(0) // version
out.writeByte(msg.type)
out.writeShort(msg.flags)
out.writeInt(msg.id.id.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.lenData)
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
}

/**
* Decodes the bytes in the given byte buffer and constructs a [YamuxFrame] that is written into
* the output list.
* @param ctx the context.
* @param msg the byte buffer.
* @param out the list to write the extracted frame to.
*/
override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList<Any>) {
while (msg.isReadable) {
val readerIndex = msg.readerIndex()
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
msg.readByte(); // version always 0
val type = msg.readUnsignedByte()
val flags = msg.readUnsignedShort()
val streamId = msg.readInt()
val lenData = msg.readInt()
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
if (type.toInt() != YamuxType.DATA) {
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId.toLong(), isInitiator.xor(streamId % 2 == 1).not()), type.toInt(), flags, lenData)
out.add(yamuxFrame)
return
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
}
if (lenData < 0) {
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
// not enough data to read the frame length
// will wait for more ...
msg.readerIndex(readerIndex)
return
}
if (lenData > maxFrameDataLength) {
msg.skipBytes(msg.readableBytes())
throw ProtocolViolationException("Yamux frame is too large: $lenData")
}
if (msg.readableBytes() < lenData) {
// not enough data to read the frame content
// will wait for more ...
msg.readerIndex(readerIndex)
return
}
val data = msg.readSlice(lenData)
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId.toLong(), isInitiator.xor(streamId % 2 == 1).not()), type.toInt(), flags, lenData, data)
out.add(yamuxFrame)
}
}

override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
// notify higher level handlers on the error
ctx.fireExceptionCaught(cause)
// exceptions in [decode] are very likely unrecoverable so just close the connection
ctx.close()
}
}
151 changes: 151 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package io.libp2p.mux.yamux

import io.libp2p.core.Stream
import io.libp2p.core.StreamHandler
import io.libp2p.core.StreamPromise
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.multistream.ProtocolBinding
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.CONNECTION
import io.libp2p.etc.STREAM
import io.libp2p.etc.types.forward
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.util.netty.mux.AbstractMuxHandler
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxChannelInitializer
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.transport.implementation.StreamOverNetty
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger

const val INITIAL_WINDOW_SIZE = 256 * 1024

open class YamuxHandler(
protected val multistreamProtocol: MultistreamProtocol,
protected val maxFrameDataLength: Int,
private val ready: CompletableFuture<StreamMuxer.Session>?,
inboundStreamHandler: StreamHandler<*>,
initiator: Boolean
) : AbstractMuxHandler<ByteBuf>(), StreamMuxer.Session {
private val idGenerator = AtomicInteger(if (initiator) 1 else 2) // 0 is reserved
private val receiveWindow = AtomicInteger(INITIAL_WINDOW_SIZE)
private val sendWindow = AtomicInteger(INITIAL_WINDOW_SIZE)
private val lock = Semaphore(1)

override val inboundInitializer: MuxChannelInitializer<ByteBuf> = {
inboundStreamHandler.handleStream(createStream(it))
}

override fun handlerAdded(ctx: ChannelHandlerContext) {
super.handlerAdded(ctx)
ready?.complete(this)
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
msg as YamuxFrame
when (msg.type) {
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
YamuxType.PING -> handlePing(msg)
YamuxType.GO_AWAY -> onRemoteClose(msg.id)
}
}

fun handlePing(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> ctx.write(YamuxFrame(MuxId(msg.id.parentId, 0, msg.id.initiator), YamuxType.PING, YamuxFlags.ACK, msg.lenData))
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
YamuxFlags.ACK -> {}
}
}

fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
if (msg.flags == YamuxFlags.SYN) {
// ACK the new stream
onRemoteOpen(msg.id)
ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
}
if (msg.flags == YamuxFlags.FIN)
onRemoteDisconnect(msg.id)
}

fun handleDataRead(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
val size = msg.lenData
handleFlags(msg)
if (size == 0)
return
val newWindow = receiveWindow.addAndGet(-size)
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE / 2
receiveWindow.addAndGet(delta)
ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta))
ctx.flush()
}
childRead(msg.id, msg.data!!)
}

fun handleWindowUpdate(msg: YamuxFrame) {
handleFlags(msg)
val size = msg.lenData
sendWindow.addAndGet(size)
ianopolous marked this conversation as resolved.
Show resolved Hide resolved
lock.release()
}

override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()
while (sendWindow.get() <= 0) {
// wait until the window is increased
lock.acquire()
Nashatyrev marked this conversation as resolved.
Show resolved Hide resolved
}
data.sliceMaxSize(minOf(maxFrameDataLength, sendWindow.get()))
.map { frameSliceBuf ->
sendWindow.addAndGet(-frameSliceBuf.readableBytes())
YamuxFrame(child.id, YamuxType.DATA, 0, frameSliceBuf.readableBytes(), frameSliceBuf)
}.forEach { muxFrame ->
ctx.write(muxFrame)
}
ctx.flush()
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
}

override fun generateNextId() =
MuxId(getChannelHandlerContext().channel().id(), idGenerator.addAndGet(2).toLong(), true)

private fun createStream(channel: MuxChannel<ByteBuf>): Stream {
val connection = ctx!!.channel().attr(CONNECTION).get()
val stream = StreamOverNetty(channel, connection, channel.initiator)
channel.attr(STREAM).set(stream)
return stream
}

override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
return createStream(multistreamProtocol.createMultistream(protocols).toStreamHandler())
}

fun <T> createStream(streamHandler: StreamHandler<T>): StreamPromise<T> {
val controller = CompletableFuture<T>()
val stream = newStream {
streamHandler.handleStream(createStream(it)).forward(controller)
}.thenApply { it.attr(STREAM).get() }
return StreamPromise(stream, controller)
}
}
39 changes: 39 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.libp2p.mux.yamux

import io.libp2p.core.ChannelVisitor
import io.libp2p.core.Connection
import io.libp2p.core.P2PChannel
import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.multistream.ProtocolDescriptor
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.core.mux.StreamMuxerDebug
import java.util.concurrent.CompletableFuture

class YamuxStreamMuxer(
val inboundStreamHandler: StreamHandler<*>,
private val multistreamProtocol: MultistreamProtocol
) : StreamMuxer, StreamMuxerDebug {

override val protocolDescriptor = ProtocolDescriptor("/yamux/1.0.0")
override var muxFramesDebugHandler: ChannelVisitor<Connection>? = null

override fun initChannel(ch: P2PChannel, selectedProtocol: String): CompletableFuture<out StreamMuxer.Session> {
val muxSessionReady = CompletableFuture<StreamMuxer.Session>()

val yamuxFrameCodec = YamuxFrameCodec(ch.isInitiator)
ch.pushHandler(yamuxFrameCodec)
muxFramesDebugHandler?.also { it.visit(ch as Connection) }
ch.pushHandler(
YamuxHandler(
multistreamProtocol,
yamuxFrameCodec.maxFrameDataLength,
muxSessionReady,
inboundStreamHandler,
ch.isInitiator
)
)

return muxSessionReady
}
}
11 changes: 11 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.libp2p.mux.yamux

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
*/
object YamuxType {
const val DATA = 0
const val WINDOW_UPDATE = 1
const val PING = 2
const val GO_AWAY = 3
}
Loading