diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 4a45b5fe5..e447b488c 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -247,6 +247,9 @@ def _compile_proc_groups( f"Cache {cache_dir}\n") return proc_builders, channel_map + # Get manual partitioning, if available + partitioning = self._compile_config.get("partitioning", None) + # Create the global ChannelMap that is passed between # SubCompilers to communicate about Channels between Processes. @@ -266,7 +269,8 @@ def _compile_proc_groups( subcompilers.append(pg_subcompilers) # Compile this ProcGroup. - self._compile_proc_group(pg_subcompilers, channel_map) + self._compile_proc_group(pg_subcompilers, channel_map, + partitioning) # Flatten the list of all SubCompilers. subcompilers = list(itertools.chain.from_iterable(subcompilers)) @@ -403,7 +407,8 @@ def _create_subcompilers( @staticmethod def _compile_proc_group( - subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap + subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap, + partitioning: ty.Dict[str, ty.Dict] ) -> None: """For a given list of SubCompilers that have been initialized with the Processes of a single ProcGroup, iterate through the compilation @@ -419,6 +424,8 @@ def _compile_proc_group( channel_map : ChannelMap The global ChannelMap that contains information about Channels between Processes. + partitioning: ty.Dict + Optional manual mapping dictionary used by ncproc compiler. """ channel_map_prev = None @@ -431,7 +438,7 @@ def _compile_proc_group( for subcompiler in subcompilers: # Compile the Processes registered with each SubCompiler and # update the ChannelMap. - channel_map = subcompiler.compile(channel_map) + channel_map = subcompiler.compile(channel_map, partitioning) @staticmethod def _extract_proc_builders( diff --git a/src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py b/src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py index 9c74c92d8..c5948a399 100644 --- a/src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py +++ b/src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py @@ -89,7 +89,8 @@ def __init__( super().__init__(proc_group, compile_config) self._spike_io_counter_offset: Offset = Offset() - def compile(self, channel_map: ChannelMap) -> ChannelMap: + def compile(self, channel_map: ChannelMap, + partitioning: ty.Dict = None) -> ChannelMap: return self._update_channel_map(channel_map) def __del__(self): diff --git a/tests/lava/magma/compiler/test_compiler.py b/tests/lava/magma/compiler/test_compiler.py index 0d4e6d07f..876def248 100644 --- a/tests/lava/magma/compiler/test_compiler.py +++ b/tests/lava/magma/compiler/test_compiler.py @@ -217,7 +217,8 @@ def create_patches( and the compile() method returns the given ChannelMap unchanged. .""" - def compile_return(channel_map: ChannelMap) -> ChannelMap: + def compile_return(channel_map: ChannelMap, + partitioning=None) -> ChannelMap: return channel_map py_patch = patch( @@ -391,13 +392,13 @@ def test_compile_proc_group_single_loop(self) -> None: subcompilers = [py_proc_compiler] # Call the method to be tested. - self.compiler._compile_proc_group(subcompilers, channel_map) + self.compiler._compile_proc_group(subcompilers, channel_map, None) # Check that it called compile() on every SubCompiler instance # exactly once. After that, the while loop should exit because the # ChannelMap instance has not changed. for sc in subcompilers: - sc.compile.assert_called_once_with({}) + sc.compile.assert_called_once_with({}, None) def test_compile_proc_group_multiple_loops(self) -> None: """Test whether the correct methods are called on all objects when @@ -424,13 +425,15 @@ def test_compile_proc_group_multiple_loops(self) -> None: subcompilers = [py_proc_compiler] # Call the method to be tested. - self.compiler._compile_proc_group(subcompilers, channel_map) + self.compiler._compile_proc_group(subcompilers, channel_map, + None) # Check that it called compile() on every SubCompiler instance # exactly once. After that, the while loop should exit because the # ChannelMap instance has not changed. for sc in subcompilers: - sc.compile.assert_called_with({**channel_map1, **channel_map2}) + sc.compile.assert_called_with({**channel_map1, **channel_map2}, + None) self.assertEqual(sc.compile.call_count, 3) def test_extract_proc_builders(self) -> None: