Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable manual partitioning #876

Merged
merged 5 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def _compile_proc_groups(
f"Cache {cache_dir}\n")
return proc_builders, channel_map

# Get manual partitioning, if available
partitioning = self._compile_config.get("partitioning", None)

# Create the global ChannelMap that is passed between
# SubCompilers to communicate about Channels between Processes.

Expand All @@ -266,7 +269,8 @@ def _compile_proc_groups(
subcompilers.append(pg_subcompilers)

# Compile this ProcGroup.
self._compile_proc_group(pg_subcompilers, channel_map)
self._compile_proc_group(pg_subcompilers, channel_map,
partitioning)

# Flatten the list of all SubCompilers.
subcompilers = list(itertools.chain.from_iterable(subcompilers))
Expand Down Expand Up @@ -403,7 +407,8 @@ def _create_subcompilers(

@staticmethod
def _compile_proc_group(
subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap
subcompilers: ty.List[AbstractSubCompiler], channel_map: ChannelMap,
partitioning: ty.Dict[str, ty.Dict]
) -> None:
"""For a given list of SubCompilers that have been initialized with
the Processes of a single ProcGroup, iterate through the compilation
Expand All @@ -419,6 +424,8 @@ def _compile_proc_group(
channel_map : ChannelMap
The global ChannelMap that contains information about Channels
between Processes.
partitioning: ty.Dict
Optional manual mapping dictionary used by ncproc compiler.
"""
channel_map_prev = None

Expand All @@ -431,7 +438,7 @@ def _compile_proc_group(
for subcompiler in subcompilers:
# Compile the Processes registered with each SubCompiler and
# update the ChannelMap.
channel_map = subcompiler.compile(channel_map)
channel_map = subcompiler.compile(channel_map, partitioning)

@staticmethod
def _extract_proc_builders(
Expand Down
3 changes: 2 additions & 1 deletion src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
super().__init__(proc_group, compile_config)
self._spike_io_counter_offset: Offset = Offset()

def compile(self, channel_map: ChannelMap) -> ChannelMap:
def compile(self, channel_map: ChannelMap,
partitioning: ty.Dict = None) -> ChannelMap:
return self._update_channel_map(channel_map)

def __del__(self):
Expand Down
13 changes: 8 additions & 5 deletions tests/lava/magma/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def create_patches(
and the compile() method returns the given ChannelMap unchanged.
."""

def compile_return(channel_map: ChannelMap) -> ChannelMap:
def compile_return(channel_map: ChannelMap,
partitioning=None) -> ChannelMap:
return channel_map

py_patch = patch(
Expand Down Expand Up @@ -391,13 +392,13 @@ def test_compile_proc_group_single_loop(self) -> None:
subcompilers = [py_proc_compiler]

# Call the method to be tested.
self.compiler._compile_proc_group(subcompilers, channel_map)
self.compiler._compile_proc_group(subcompilers, channel_map, None)

# Check that it called compile() on every SubCompiler instance
# exactly once. After that, the while loop should exit because the
# ChannelMap instance has not changed.
for sc in subcompilers:
sc.compile.assert_called_once_with({})
sc.compile.assert_called_once_with({}, None)

def test_compile_proc_group_multiple_loops(self) -> None:
"""Test whether the correct methods are called on all objects when
Expand All @@ -424,13 +425,15 @@ def test_compile_proc_group_multiple_loops(self) -> None:
subcompilers = [py_proc_compiler]

# Call the method to be tested.
self.compiler._compile_proc_group(subcompilers, channel_map)
self.compiler._compile_proc_group(subcompilers, channel_map,
None)

# Check that it called compile() on every SubCompiler instance
# exactly once. After that, the while loop should exit because the
# ChannelMap instance has not changed.
for sc in subcompilers:
sc.compile.assert_called_with({**channel_map1, **channel_map2})
sc.compile.assert_called_with({**channel_map1, **channel_map2},
None)
self.assertEqual(sc.compile.call_count, 3)

def test_extract_proc_builders(self) -> None:
Expand Down
Loading