Skip to content

Commit

Permalink
Example/c pingpong (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shouyin committed Mar 21, 2023
1 parent a548323 commit 57624c0
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,5 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
else()
message("not debug mode and disable cpp unit test")
endif()

add_subdirectory(examples/c_pingpong)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
include_directories(../../csrc)

add_executable(cprocess
"cprocess.cc"
)

target_link_libraries(cprocess PRIVATE
message_infrastructure
)


set_target_properties(cprocess
PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# C pingpong

## Run instructions
One needs to run the `p.py` first, start the `./cprocess` binary once seeing the prompt and hit the enter as indicated.

Two args can be given as the socket file names.

```bash
# p.py
$ python3 p.py c2py py2c

# cprocess, in another terminal window
$ ./cprocess c2py py2c
```

## Notes on current TempChennel:
TempChannel uses socket file.

The Recv port will bind the socket file in initialization, listen in `start()` and accept in `recv()`. After established a connection, the port closes it immediately after reading from the socket.

The Send port will connect to the recv port in initialization and write to the socket in send().

Therefore,
1. The send port can only be initialized after corresponding Recv port called `start()`
2. In each round, the send port is used one-off. One needs to create a new `TempChannel()` and get the send port from it each time. (The send port will be initialized when accessing the .dst_port property at the first time)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
// See: https://spdx.org/licenses/

#include <core/abstract_channel.h>
#include <core/channel_factory.h>
#include <core/abstract_port.h>
#include <core/utils.h>
#include <channel/socket/socket_port.h>

using namespace message_infrastructure; // NOLINT

int main(int argc, char *argv[]) {
char *c2py = (argc >= 2) ? argv[1] : const_cast<char *>("./c2py");
char *py2c = (argc >= 3) ? argv[2] : const_cast<char *>("./py2c");

std::cout << "socket files: " << c2py << " " << py2c << "\n";

ChannelFactory &channel_factory = GetChannelFactory();

AbstractChannelPtr ch = channel_factory.GetTempChannel(py2c);
AbstractRecvPortPtr rc = ch->GetRecvPort();

// order matters
rc->Start();

for (uint _ = 0; _ < 10; ++_) {
std::cout << "receiving\n";
MetaDataPtr recvd = rc->Recv();
std::cout << "received from py, total size: "
<< recvd->total_size
<< "\n";

AbstractChannelPtr ch2 = channel_factory.GetTempChannel(c2py);
AbstractSendPortPtr sd = ch2->GetSendPort();
sd->Start();
sd->Send(recvd);
sd->Join();
}

rc->Join();

return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys
import numpy as np

from lava.magma.runtime.message_infrastructure. \
MessageInfrastructurePywrapper import (
TempChannel
)


# float equal
def f_eq(a, b):
return abs(a - b) < 0.001


def soc_names_from_args():
# default file names
C2PY = "./c2py"
PY2C = "./py2c"

socket_file_names = [C2PY, PY2C]
filename_args = sys.argv[1:3]
if len(filename_args) == 1:
socket_file_names[0] = filename_args[0]
if len(filename_args) == 2:
socket_file_names = filename_args

return socket_file_names


def main():
c2py, py2c = soc_names_from_args()

if os.path.exists(c2py):
os.remove(c2py)
if os.path.exists(py2c):
os.remove(py2c)

# order matters
ch2 = TempChannel(c2py)
rc = ch2.dst_port
rc.start()

input("Start the c process, hit enter when you see *receiving*")

for i in range(10):
# send port is one-off
ch = TempChannel(py2c)
sd = ch.src_port
sd.start()

print("round ", i)
rands = np.array([np.random.random() * 100 for __ in range(10)]) # noqa
print("Sending array to C: ", rands)
sd.send(rands)

rands2 = rc.recv()
print("Got array from C: ", rands2)

print("Correctness: ", all([f_eq(x, y) for x, y in zip(rands, rands2)])) # noqa
print("========================================")
sd.join()
rc.join()


if __name__ == "__main__":
main()

0 comments on commit 57624c0

Please sign in to comment.