Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect message list limits when creating messages to send #218

Merged
merged 5 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 69 additions & 38 deletions src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.libp2p.etc.types.completedExceptionally
import io.libp2p.etc.types.copy
import io.libp2p.etc.types.forward
import io.libp2p.etc.types.lazyVarInit
import io.libp2p.etc.types.thenApplyAll
import io.libp2p.etc.types.toWBytes
import io.libp2p.etc.util.P2PServiceSemiDuplex
import io.libp2p.etc.util.netty.protobuf.LimitedProtobufVarint32FrameDecoder
Expand Down Expand Up @@ -53,11 +54,11 @@ abstract class AbstractRouter(
LRUSeenCache(SimpleSeenCache(), maxSeenMessagesLimit)
}

private val peerTopics = MultiSet<PeerHandler, String>()
private val peerTopics = MultiSet<PeerHandler, Topic>()
private var msgHandler: (PubsubMessage) -> CompletableFuture<ValidationResult> = { RESULT_VALID }
override var messageValidator = NOP_ROUTER_VALIDATOR

val subscribedTopics = linkedSetOf<String>()
val subscribedTopics = linkedSetOf<Topic>()
val pendingRpcParts = linkedMapOf<PeerHandler, MutableList<Rpc.RPC>>()
private var debugHandler: ChannelHandler? = null
private val pendingMessagePromises = MultiSet<PeerHandler, CompletableFuture<Unit>>()
Expand Down Expand Up @@ -89,16 +90,58 @@ abstract class AbstractRouter(
pendingRpcParts.getOrPut(toPeer, { mutableListOf() }) += msgPart
}

private fun addPendingSubscription(toPeer: PeerHandler, topic: Topic, subscriptionStatus: SubscriptionStatus) {
addPendingRpcPart(
toPeer,
Rpc.RPC.newBuilder()
.addSubscriptions(
Rpc.RPC.SubOpts.newBuilder()
.setSubscribe(subscriptionStatus == SubscriptionStatus.Subscribed)
.setTopicid(topic)
)
.build()
)
}

/**
* Drains all partial messages for [toPeer] and returns merged message
*/
protected fun collectPeerMessage(toPeer: PeerHandler): Rpc.RPC? {
val msgs = pendingRpcParts.remove(toPeer) ?: emptyList<Rpc.RPC>()
if (msgs.isEmpty()) return null
protected fun collectPeerMessages(toPeer: PeerHandler): List<Rpc.RPC> =
mergeMessageParts(pendingRpcParts.remove(toPeer) ?: emptyList())

protected fun mergeMessageParts(parts: List<Rpc.RPC>): List<Rpc.RPC> {
if (parts.isEmpty()) return emptyList()

val optimisticBuilder = parts.fold(Rpc.RPC.newBuilder()) { builder, part ->
builder.mergeFrom(part)
}
return if (validateMessageListLimits(optimisticBuilder)) {
// optimistic case
listOf(optimisticBuilder.build())
} else {
// need to split to multiple messages
splitToLimitedMessages(parts)
}
}

val bld = Rpc.RPC.newBuilder()
msgs.forEach { bld.mergeFrom(it) }
return bld.build()
private fun splitToLimitedMessages(parts: List<Rpc.RPC>): List<Rpc.RPC> {
val ret = mutableListOf<Rpc.RPC>()
val partQueue = ArrayDeque(parts)
while (partQueue.isNotEmpty()) {
var builder = Rpc.RPC.newBuilder()
while (partQueue.isNotEmpty()) {
val validationBuilder = builder.clone()
val part = partQueue.first()
validationBuilder.mergeFrom(part)
if (!validateMessageListLimits(validationBuilder)) {
break
}
partQueue.removeFirst()
builder = validationBuilder
}
ret += builder.build()
}
return ret
}

/**
Expand All @@ -110,11 +153,10 @@ abstract class AbstractRouter(
}

protected fun flushPending(peer: PeerHandler) {
collectPeerMessage(peer)?.also {
val future = send(peer, it)
pendingMessagePromises.removeAll(peer)?.forEach {
future.forward(it)
}
val peerMessages = collectPeerMessages(peer)
val allSendPromise = peerMessages.map { send(peer, it) }.thenApplyAll { }
pendingMessagePromises.removeAll(peer)?.forEach {
allSendPromise.forward(it)
}
}

Expand Down Expand Up @@ -162,13 +204,10 @@ abstract class AbstractRouter(
protected abstract fun processControl(ctrl: Rpc.ControlMessage, receivedFrom: PeerHandler)

override fun onPeerActive(peer: PeerHandler) {
val helloPubsubMsg = Rpc.RPC.newBuilder().addAllSubscriptions(
subscribedTopics.map {
Rpc.RPC.SubOpts.newBuilder().setSubscribe(true).setTopicid(it).build()
}
).build()

peer.writeAndFlush(helloPubsubMsg)
subscribedTopics.forEach {
addPendingSubscription(peer, it, SubscriptionStatus.Subscribed)
}
flushPending(peer)
}

protected open fun notifyMalformedMessage(peer: PeerHandler) {}
Expand Down Expand Up @@ -283,7 +322,7 @@ abstract class AbstractRouter(
}
}

internal open fun validateMessageListLimits(msg: Rpc.RPC): Boolean {
internal open fun validateMessageListLimits(msg: Rpc.RPCOrBuilder): Boolean {
return true
}

Expand Down Expand Up @@ -332,36 +371,26 @@ abstract class AbstractRouter(
}
}

protected open fun subscribe(topic: String) {
activePeers.forEach {
addPendingRpcPart(
it,
Rpc.RPC.newBuilder().addSubscriptions(Rpc.RPC.SubOpts.newBuilder().setSubscribe(true).setTopicid(topic)).build()
)
}
protected open fun subscribe(topic: Topic) {
activePeers.forEach { addPendingSubscription(it, topic, SubscriptionStatus.Subscribed) }
subscribedTopics += topic
}

override fun unsubscribe(vararg topics: String) {
override fun unsubscribe(vararg topics: Topic) {
runOnEventThread {
topics.forEach(::unsubscribe)
flushAllPending()
}
}

protected open fun unsubscribe(topic: String) {
activePeers.forEach {
addPendingRpcPart(
it,
Rpc.RPC.newBuilder().addSubscriptions(Rpc.RPC.SubOpts.newBuilder().setSubscribe(false).setTopicid(topic)).build()
)
}
protected open fun unsubscribe(topic: Topic) {
activePeers.forEach { addPendingSubscription(it, topic, SubscriptionStatus.Unsubscribed) }
subscribedTopics -= topic
}

override fun getPeerTopics(): CompletableFuture<Map<PeerId, Set<String>>> {
override fun getPeerTopics(): CompletableFuture<Map<PeerId, Set<Topic>>> {
return submitOnEventThread {
val topicsByPeerId = hashMapOf<PeerId, Set<String>>()
val topicsByPeerId = hashMapOf<PeerId, Set<Topic>>()
peerTopics.forEach { entry ->
topicsByPeerId[entry.key.peerId] = HashSet(entry.value)
}
Expand All @@ -376,4 +405,6 @@ abstract class AbstractRouter(
override fun initHandler(handler: (PubsubMessage) -> CompletableFuture<ValidationResult>) {
msgHandler = handler
}

protected enum class SubscriptionStatus { Subscribed, Unsubscribed }
}
8 changes: 4 additions & 4 deletions src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ open class GossipRouter @JvmOverloads constructor(
mesh.values.forEach { it.remove(peer) }
fanout.values.forEach { it.remove(peer) }
acceptRequestsWhitelist -= peer
collectPeerMessage(peer) // discard them
collectPeerMessages(peer) // discard them
super.onPeerDisconnected(peer)
}

Expand Down Expand Up @@ -208,7 +208,7 @@ open class GossipRouter @JvmOverloads constructor(
return peerScore >= score.params.graylistThreshold
}

override fun validateMessageListLimits(msg: Rpc.RPC): Boolean {
override fun validateMessageListLimits(msg: Rpc.RPCOrBuilder): Boolean {
return params.maxPublishedMessages?.let { msg.publishCount <= it } ?: true &&
params.maxTopicsPerPublishedMessage?.let { msg.publishList.none { m -> m.topicIDsCount > it } } ?: true &&
params.maxSubscriptions?.let { msg.subscriptionsCount <= it } ?: true &&
Expand All @@ -219,11 +219,11 @@ open class GossipRouter @JvmOverloads constructor(
params.maxPeersPerPruneMessage?.let { msg.control?.pruneList?.none { p -> p.peersCount > it } } ?: true
}

private fun countIWantMessageIds(msg: Rpc.RPC): Int {
private fun countIWantMessageIds(msg: Rpc.RPCOrBuilder): Int {
return msg.control?.iwantList?.map { w -> w.messageIDsCount }?.sum() ?: 0
}

private fun countIHaveMessageIds(msg: Rpc.RPC): Int {
private fun countIHaveMessageIds(msg: Rpc.RPCOrBuilder): Int {
return msg.control?.ihaveList?.map { w -> w.messageIDsCount }?.sum() ?: 0
}

Expand Down
87 changes: 87 additions & 0 deletions src/test/kotlin/io/libp2p/pubsub/AbstractRouterTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package io.libp2p.pubsub

import io.libp2p.etc.types.toProtobuf
import io.libp2p.pubsub.TopicSubscriptionFilter.AllowAllTopicSubscriptionFilter
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import pubsub.pb.Rpc
import java.util.concurrent.CompletableFuture

class AbstractRouterTest {

private class TestRouter(val msgValidator: (Rpc.RPCOrBuilder) -> Boolean) :
AbstractRouter(AllowAllTopicSubscriptionFilter()) {
override val protocol = PubsubProtocol.Floodsub
override fun broadcastOutbound(msg: PubsubMessage): CompletableFuture<Unit> =
CompletableFuture.completedFuture(null)

override fun broadcastInbound(msgs: List<PubsubMessage>, receivedFrom: PeerHandler) {}
override fun processControl(ctrl: Rpc.ControlMessage, receivedFrom: PeerHandler) {}

override fun validateMessageListLimits(msg: Rpc.RPCOrBuilder) = msgValidator(msg)
fun testMerge(parts: List<Rpc.RPC>): List<Rpc.RPC> = mergeMessageParts(parts)
}

private fun Collection<Rpc.RPC>.merge(): Rpc.RPC =
this.fold(Rpc.RPC.newBuilder()) { bld, part -> bld.mergeFrom(part) }.build()

@Test
fun `test many subscriptions split to several messages`() {
val router = TestRouter { it.subscriptionsCount <= 5 }
val parts = (0 until 14).map {
Rpc.RPC.newBuilder().addSubscriptions(
Rpc.RPC.SubOpts.newBuilder()
.setTopicid("topic-$it")
.setSubscribe(true)
.build()
).build()
}
val msgs = router.testMerge(parts)

assertThat(msgs)
.hasSize(3)
.allMatch { it.subscriptionsCount <= 5 }

assertThat(msgs.merge()).isEqualTo(parts.merge())
}

@Test
fun `test few subscriptions don't split to several messages`() {
val router = TestRouter { it.subscriptionsCount <= 5 }
val parts = (0 until 5).map {
Rpc.RPC.newBuilder().addSubscriptions(
Rpc.RPC.SubOpts.newBuilder()
.setTopicid("topic-$it")
.setSubscribe(true)
.build()
).build()
}
val msgs = router.testMerge(parts)

assertThat(msgs).hasSize(1)
assertThat(msgs.merge()).isEqualTo(parts.merge())
}

@Test
fun `test that split doesn't result in topic publish before subscribe`() {
val router = TestRouter { it.subscriptionsCount <= 5 }
val parts = (0 until 6).map {
Rpc.RPC.newBuilder().addSubscriptions(
Rpc.RPC.SubOpts.newBuilder()
.setTopicid("topic-$it")
.setSubscribe(true)
.build()
).build()
} + Rpc.RPC.newBuilder().addPublish(
Rpc.Message.newBuilder()
.addTopicIDs("topic-5")
.setData(byteArrayOf(11).toProtobuf())
).build()
val msgs = router.testMerge(parts)

assertThat(msgs).hasSize(2)
assertThat(msgs[0].publishCount).isZero()
assertThat(msgs[1].publishCount).isEqualTo(1)
assertThat(msgs.merge()).isEqualTo(parts.merge())
}
}
81 changes: 81 additions & 0 deletions src/test/kotlin/io/libp2p/pubsub/gossip/SubscriptionsLimitTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.libp2p.pubsub.gossip

import io.libp2p.core.pubsub.MessageApi
import io.libp2p.core.pubsub.Subscriber
import io.libp2p.core.pubsub.Topic
import io.libp2p.etc.types.toByteArray
import io.libp2p.etc.types.toByteBuf
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertDoesNotThrow

class SubscriptionsLimitTest : TwoGossipHostTestBase() {
override val params = GossipParams(maxSubscriptions = 5, floodPublish = true)

@Test
fun `new peer subscribed to many topics`() {
val topics = (0..13).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber {}, *topics)
val messages2 = mutableListOf<MessageApi>()
gossip2.subscribe(Subscriber { messages2 += it }, *topics)

connect()
waitForSubscribed(router1, "topic-13")
waitForSubscribed(router2, "topic-13")

val topics1 = router1.getPeerTopics().join().values.first()
assertThat(topics1).containsExactlyInAnyOrderElementsOf(topics.map { it.topic })
val topics2 = router2.getPeerTopics().join().values.first()
assertThat(topics2).containsExactlyInAnyOrderElementsOf(topics.map { it.topic })

val msg1Promise =
gossip1.createPublisher(null).publish(byteArrayOf(11).toByteBuf(), Topic("topic-13"))

assertDoesNotThrow { msg1Promise.join() }
waitFor { messages2.isNotEmpty() }
assertThat(messages2)
.hasSize(1)
.allMatch {
it.topics == listOf(Topic("topic-13")) &&
it.data.toByteArray().contentEquals(byteArrayOf(11))
}
}

@Test
fun `new peer subscribed to few topics`() {
val topics = (0..4).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber { }, *topics)
gossip2.subscribe(Subscriber { }, *topics)

connect()
waitForSubscribed(router1, "topic-4")
waitForSubscribed(router2, "topic-4")

val topics1 = router1.getPeerTopics().join().values.first()
assertThat(topics1).containsExactlyInAnyOrderElementsOf(topics.map { it.topic })
val topics2 = router2.getPeerTopics().join().values.first()
assertThat(topics2).containsExactlyInAnyOrderElementsOf(topics.map { it.topic })
}

@Test
fun `existing peer subscribed to many topics`() {
gossip1.subscribe(Subscriber { }, Topic("test-topic"))
gossip2.subscribe(Subscriber { }, Topic("test-topic"))

connect()
waitForSubscribed(router1, "test-topic")
waitForSubscribed(router2, "test-topic")

val topics = (0..13).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber { }, *topics)
gossip2.subscribe(Subscriber { }, *topics)

waitForSubscribed(router1, "topic-13")
waitForSubscribed(router2, "topic-13")

val topics1 = router1.getPeerTopics().join().values.first()
assertThat(topics1).containsExactlyInAnyOrderElementsOf(topics.map { it.topic } + "test-topic")
val topics2 = router2.getPeerTopics().join().values.first()
assertThat(topics2).containsExactlyInAnyOrderElementsOf(topics.map { it.topic } + "test-topic")
}
}