Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the performance issue #705

Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ void ShmemRecvPort::QueueRecv() {
});
}
if (!ret) {
// sleep
helper::Sleep();
}
}
Expand All @@ -124,7 +123,8 @@ MetaDataPtr ShmemRecvPort::Recv() {
void ShmemRecvPort::Join() {
if (!done_) {
done_ = true;
recv_queue_thread_.join();
if (recv_queue_thread_.joinable())
recv_queue_thread_.join();
recv_queue_->Stop();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ PYBIND11_MODULE(MessageInfrastructurePywrapper, m) {
.def_property_readonly("shape", &RecvPortProxy::Shape)
.def_property_readonly("d_type", &RecvPortProxy::DType)
.def_property_readonly("size", &RecvPortProxy::Size);
py::class_<Selector, std::shared_ptr<Selector>> (m, "CPPSelector")
.def(py::init<>())
.def("select", &Selector::Select);
}

} // namespace message_infrastructure
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,5 @@ py::object RecvPortProxy::MDataToObject_(MetaDataPtr metadata) {
return py::reinterpret_steal<py::object>(array);
}


} // namespace message_infrastructure
17 changes: 17 additions & 0 deletions src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include <utility>

namespace message_infrastructure {

Expand Down Expand Up @@ -87,6 +89,21 @@ using RecvPortProxyPtr = std::shared_ptr<RecvPortProxy>;
using SendPortProxyList = std::vector<SendPortProxyPtr>;
using RecvPortProxyList = std::vector<RecvPortProxyPtr>;


class Selector {
public:
pybind11::object Select(std::vector<std::tuple<RecvPortProxyPtr,
py::function>> *args) {
while (true) {
for (auto it = args->begin(); it != args->end(); ++it) {
if (std::get<0>(*it)->Probe()) {
return std::get<1>(*it)();
}
}
}
}
};

} // namespace message_infrastructure

#endif // PORT_PROXY_H_
3 changes: 2 additions & 1 deletion src/lava/magma/runtime/message_infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def load_library():
AbstractTransferPort, # noqa # nosec
support_grpc_channel,
support_fastdds_channel,
support_cyclonedds_channel)
support_cyclonedds_channel,
CPPSelector)

ChannelQueueSize = 1
SyncChannelBytes = 128
Expand Down
15 changes: 5 additions & 10 deletions src/lava/magma/runtime/message_infrastructure/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

from lava.magma.runtime.message_infrastructure import ChannelQueueSize
from lava.magma.runtime.message_infrastructure \
import ChannelQueueSize, CPPSelector
from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \
import Channel as CppChannel
from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \
Expand All @@ -20,15 +21,9 @@
import warnings


class Selector:
def select(
self,
*args: ty.Tuple[RecvPort, ty.Callable[[], ty.Any]],
):
for recv_port, action in args:
if recv_port.probe():
return action()
return None
class Selector(CPPSelector):
szc321 marked this conversation as resolved.
Show resolved Hide resolved
def select(self, *args: ty.Tuple[RecvPort, ty.Callable[[], ty.Any]]):
return super().select(args)


class SendPort(AbstractTransferPort):
Expand Down
124 changes: 124 additions & 0 deletions tests/lava/magma/runtime/message_infrastructure/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
import unittest
from functools import partial
from lava.magma.runtime.message_infrastructure \
import Selector
from lava.magma.runtime.message_infrastructure import (
PURE_PYTHON_VERSION,
Channel)


class Builder:
def build(self, i):
pass


def prepare_data():
arr1 = np.array([1] * 9990)
arr2 = np.array([1, 2, 3, 4, 5,
6, 7, 8, 9, 0])
return np.concatenate((arr2, arr1))


def bound_target_a1(loop, actor_to_mp_0, actor_to_mp_1,
actor_to_mp_2, builder):
to_mp_0 = actor_to_mp_0.src_port
to_mp_1 = actor_to_mp_1.src_port
to_mp_2 = actor_to_mp_2.src_port
to_mp_0.start()
to_mp_1.start()
to_mp_2.start()
predata = prepare_data()
while loop > 0:
loop = loop - 1
to_mp_0.send(predata)
to_mp_1.send(predata)
to_mp_2.send(predata)
to_mp_0.join()
to_mp_1.join()
to_mp_2.join()


class TestSelector(unittest.TestCase):

def __init__(self, methodName: str = ...) -> None:
super().__init__(methodName)
self.loop_ = 1000

@unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test")
def test_selector(self):
from lava.magma.runtime.message_infrastructure \
.MessageInfrastructurePywrapper import ChannelType
from lava.magma.runtime.message_infrastructure \
.multiprocessing \
import MultiProcessing

loop = self.loop_ * 3
mp = MultiProcessing()
mp.start()
predata = prepare_data()
queue_size = 1
nbytes = np.prod(predata.shape) * predata.dtype.itemsize
selector = Selector()
actor_to_mp_0 = Channel(
ChannelType.SHMEMCHANNEL,
queue_size,
nbytes,
"actor_to_mp_0",
"actor_to_mp_0",
(2, 2),
np.int32)
actor_to_mp_1 = Channel(
ChannelType.SHMEMCHANNEL,
queue_size,
nbytes,
"actor_to_mp_1",
"actor_to_mp_1",
(2, 2),
np.int32)
actor_to_mp_2 = Channel(
ChannelType.SHMEMCHANNEL,
queue_size,
nbytes,
"actor_to_mp_2",
"actor_to_mp_2",
(2, 2),
np.int32)

target_a1 = partial(bound_target_a1, self.loop_, actor_to_mp_0,
actor_to_mp_1, actor_to_mp_2)

builder = Builder()

mp.build_actor(target_a1, builder) # actor1

from_a0 = actor_to_mp_0.dst_port
from_a1 = actor_to_mp_1.dst_port
from_a2 = actor_to_mp_2.dst_port

from_a0.start()
from_a1.start()
from_a2.start()
expect_result = predata * 3 * self.loop_
recv_port_list = [from_a0, from_a1, from_a2]
channel_actions = [(recv_port, (lambda y: (lambda: y))(
recv_port)) for recv_port in recv_port_list]
real_result = np.array(0)
while loop > 0:
loop = loop - 1
recv_port = selector.select(*channel_actions)
data = recv_port.recv()
real_result = real_result + data
if not np.array_equal(expect_result, real_result):
print("expect: ", expect_result)
print("result: ", real_result)
raise AssertionError()
from_a0.join()
from_a1.join()
from_a2.join()
mp.stop()
mp.cleanup(True)


if __name__ == '__main__':
unittest.main()