From eba5df91cac6195bbec6183dcc62fa7ae43ef4b2 Mon Sep 17 00:00:00 2001 From: Mathis Richter Date: Fri, 4 Feb 2022 22:12:35 +0100 Subject: [PATCH] Added initial run-unittest for flatten() from issue #163 Signed-off-by: Mathis Richter --- .../magma/core/process/ports/test_flatten.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/lava/magma/core/process/ports/test_flatten.py diff --git a/tests/lava/magma/core/process/ports/test_flatten.py b/tests/lava/magma/core/process/ports/test_flatten.py new file mode 100644 index 000000000..daa6f2160 --- /dev/null +++ b/tests/lava/magma/core/process/ports/test_flatten.py @@ -0,0 +1,48 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +from typing import List +import numpy as np + +from lava.magma.core.run_configs import RunConfig +from lava.magma.core.run_conditions import RunSteps +from lava.proc.io.source import RingBuffer as SendProcess +from lava.proc.io.sink import RingBuffer as ReceiveProcess +from lava.magma.core.model.py.model import PyLoihiProcessModel + + +class TestRunConfig(RunConfig): + """Run configuration selects appropriate ProcessModel based on tag + """ + def __init__(self, select_tag: str = 'fixed_pt'): + super(TestRunConfig, self).__init__(custom_sync_domains=None) + self.select_tag = select_tag + + def select( + self, _, proc_models: List[PyLoihiProcessModel] + ) -> PyLoihiProcessModel: + for pm in proc_models: + if self.select_tag in pm.tags: + return pm + raise AssertionError('No legal ProcessModel found.') + + +if __name__ == '__main__': + num_steps = 10 + shape = (64, 32, 16) + input = np.random.randint(256, size=shape + (num_steps,)) + input -= 128 + + source = SendProcess(data=input) + sink = ReceiveProcess(shape=(np.prod(shape), ), buffer=num_steps) + source.out_ports.s_out.flatten().connect(sink.in_ports.a_in) + + run_condition = RunSteps(num_steps=num_steps) + run_config = TestRunConfig(select_tag='floating_pt') + sink.run(condition=run_condition, run_cfg=run_config) + output = sink.data.get() + sink.stop() + + expected = input.reshape([-1, num_steps]) + print(np.all(output == expected)) \ No newline at end of file