Skip to content

Commit

Permalink
presets: return and propagate trees better
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 2, 2024
1 parent 3188142 commit ffa22ad
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 49 deletions.
64 changes: 58 additions & 6 deletions cotengra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .pathfinders import (
path_basic,
path_compressed_greedy,
path_greedy,
path_igraph,
path_kahypar,
path_labels,
Expand Down Expand Up @@ -174,6 +175,7 @@
"optimize_quickbb",
"path_basic",
"path_compressed_greedy",
"path_greedy",
"path_igraph",
"path_kahypar",
"path_labels",
Expand Down Expand Up @@ -206,9 +208,38 @@
# add some presets


def hyper_optimize(inputs, output, size_dict, memory_limit=None, **opts):
def hyper_optimize(
inputs,
output,
size_dict,
memory_limit=None,
get="path",
**opts,
):
optimizer = HyperOptimizer(**opts)
return optimizer(inputs, output, size_dict, memory_limit)
if get == "path":
return optimizer(inputs, output, size_dict, memory_limit)
elif get == "tree":
return optimizer.search(inputs, output, size_dict, memory_limit)
else:
raise ValueError(f"Unknown get option {get}")


def hyper_compressed_optimize(
inputs,
output,
size_dict,
get="path",
**opts,
):
optimizer = HyperCompressedOptimizer(**opts)

if get == "path":
return optimizer(inputs, output, size_dict)
elif get == "tree":
return optimizer.search(inputs, output, size_dict)
else:
raise ValueError(f"Unknown get option {get}")


def random_greedy_optimize(
Expand All @@ -222,35 +253,54 @@ def random_greedy_optimize(
register_preset(
"hyper",
hyper_optimize,
optimizer_tree=functools.partial(hyper_optimize, get="tree"),
)
register_preset(
"hyper-256",
functools.partial(hyper_optimize, max_repeats=256),
optimizer_tree=functools.partial(
hyper_optimize, max_repeats=256, get="tree"
),
)
register_preset(
"hyper-greedy",
functools.partial(hyper_optimize, methods=["greedy"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["greedy"], get="tree"
),
)
register_preset(
"hyper-labels",
functools.partial(hyper_optimize, methods=["labels"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["labels"], get="tree"
),
)
register_preset(
"hyper-kahypar",
functools.partial(hyper_optimize, methods=["kahypar"]),
optimizer_tree=functools.partial(
hyper_optimize, methods=["kahypar"], get="tree"
),
)
register_preset(
"hyper-balanced",
functools.partial(
hyper_optimize, methods=["kahypar-balanced"], max_repeats=16
),
optimizer_tree=functools.partial(
hyper_optimize,
methods=["kahypar-balanced"],
max_repeats=16,
get="tree",
),
)
register_preset(
"hyper-compressed",
functools.partial(
hyper_optimize,
minimize="peak-compressed",
methods=("greedy-span", "greedy-compressed", "kahypar-agglom"),
hyper_compressed_optimize,
optimizer_tree=functools.partial(
hyper_compressed_optimize,
get="tree",
),
compressed=True,
)
Expand Down Expand Up @@ -297,11 +347,13 @@ def random_greedy_optimize(
register_preset(
"greedy-compressed",
path_compressed_greedy.greedy_compressed,
path_compressed_greedy.trial_greedy_compressed,
compressed=True,
)
register_preset(
"greedy-span",
path_compressed_greedy.greedy_span,
path_compressed_greedy.trial_greedy_span,
compressed=True,
)
except KeyError:
Expand Down
63 changes: 45 additions & 18 deletions cotengra/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,47 @@
register_path_fn,
)

_PRESETS = {}
_PRESETS_PATH = {}
_PRESETS_TREE = {}
_COMPRESSED_PRESETS = set()


def register_preset(
preset, optimizer, register_opt_einsum="auto", compressed=False
preset,
optimizer,
optimizer_tree=None,
register_opt_einsum="auto",
compressed=False,
):
"""Register a preset optimizer."""
_PRESETS[preset] = optimizer
"""Register a preset optimizer.
if register_opt_einsum == "auto":
register_opt_einsum = opt_einsum_installed
Parameters
----------
preset : str
The name of the preset.
optimizer : callable
The optimizer function that returns a path.
optimizer_tree : callable, optional
The optimizer function that returns a tree.
register_opt_einsum : bool or str, optional
If ``True`` or ``'auto'``, register the preset with opt_einsum.
compressed : bool, optional
If ``True``, the preset presents a compressed contraction optimizer.
"""
if optimizer is not None:
_PRESETS_PATH[preset] = optimizer

if register_opt_einsum:
try:
register_path_fn(preset, optimizer)
except KeyError:
pass
if register_opt_einsum == "auto":
register_opt_einsum = opt_einsum_installed

if register_opt_einsum:
try:
register_path_fn(preset, optimizer)
except KeyError:
pass

if optimizer_tree is not None:
_PRESETS_TREE[preset] = optimizer_tree

if compressed:
_COMPRESSED_PRESETS.add(preset)
Expand All @@ -45,7 +68,7 @@ def register_preset(
def preset_to_optimizer(preset):
""" """
try:
return _PRESETS[preset]
return _PRESETS_PATH[preset]
except KeyError:
if not opt_einsum_installed:
raise KeyError(
Expand Down Expand Up @@ -277,12 +300,16 @@ def _find_tree_optimizer_basic(inputs, output, size_dict, optimize, **kwargs):


def _find_tree_preset(inputs, output, size_dict, optimize, **kwargs):
compressed = optimize in _COMPRESSED_PRESETS
optimize = preset_to_optimizer(optimize)
tree = find_tree(inputs, output, size_dict, optimize, **kwargs)
if compressed:
tree.__class__ = ContractionTreeCompressed
return tree
try:
# preset method can directly return a tree
return _PRESETS_TREE[optimize](inputs, output, size_dict, **kwargs)
except KeyError:
# preset method returns a path, which we convert to a tree
optimize = preset_to_optimizer(optimize)
tree = find_tree(inputs, output, size_dict, optimize, **kwargs)
if optimize in _COMPRESSED_PRESETS:
tree.__class__ = ContractionTreeCompressed
return tree


def _find_tree_tree(inputs, output, size_dict, optimize, **kwargs):
Expand Down
26 changes: 10 additions & 16 deletions cotengra/pathfinders/path_compressed_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math

from ..core import (
ContractionTree,
ContractionTreeCompressed,
get_hypergraph,
)
Expand Down Expand Up @@ -214,18 +213,19 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):


def greedy_compressed(inputs, output, size_dict, memory_limit=None, **kwargs):
chi = max(size_dict.values()) ** 2
try:
chi = kwargs.pop("chi")
except KeyError:
chi = max(size_dict.values()) ** 2
return GreedyCompressed(chi, **kwargs)(inputs, output, size_dict)


def trial_greedy_compressed(inputs, output, size_dict, **kwargs):
opt = GreedyCompressed(**kwargs)
ssa_path = opt.get_ssa_path(inputs, output, size_dict)
tree = ContractionTree.from_path(
inputs, output, size_dict, ssa_path=ssa_path
)
tree.set_surface_order_from_path(ssa_path)
return tree
try:
chi = kwargs.pop("chi")
except KeyError:
chi = max(size_dict.values()) ** 2
return GreedyCompressed(chi, **kwargs).search(inputs, output, size_dict)


register_hyper_function(
Expand Down Expand Up @@ -444,13 +444,7 @@ def greedy_span(inputs, output, size_dict, memory_limit=None, **kwargs):


def trial_greedy_span(inputs, output, size_dict, **kwargs):
opt = GreedySpan(**kwargs)
ssa_path = opt.get_ssa_path(inputs, output, size_dict)
tree = ContractionTree.from_path(
inputs, output, size_dict, ssa_path=ssa_path
)
tree.set_surface_order_from_path(ssa_path)
return tree
return GreedySpan(**kwargs).search(inputs, output, size_dict)


_allowed_perms = tuple(
Expand Down
2 changes: 2 additions & 0 deletions cotengra/pathfinders/path_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def trial_greedy(
)

# greedy but less exploratative -> better for a small number of runs
# note this hyper driver is slightly different from overall preset
# "random-greedy" which doesn't use the hyper framework
register_hyper_function(
name="random-greedy",
ssa_func=trial_greedy,
Expand Down
54 changes: 45 additions & 9 deletions cotengra/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,48 @@ def __init__(self, **kwargs):


# these names overlap with opt_einsum, but won't override presets there
register_preset("auto", auto_optimize)
register_preset("auto-hq", auto_hq_optimize)
register_preset("greedy", greedy_optimize)
register_preset("eager", greedy_optimize)
register_preset("opportunistic", greedy_optimize)
register_preset("optimal", optimal_optimize)
register_preset("dp", optimal_optimize)
register_preset("dynamic-programming", optimal_optimize)
register_preset("optimal-outer", optimal_outer_optimize)
register_preset(
"auto",
auto_optimize,
auto_optimize.search,
)
register_preset(
"auto-hq",
auto_hq_optimize,
auto_optimize.search,
)
register_preset(
"greedy",
greedy_optimize,
greedy_optimize.search,
)
register_preset(
"eager",
greedy_optimize,
greedy_optimize.search,
)
register_preset(
"opportunistic",
greedy_optimize,
greedy_optimize.search,
)
register_preset(
"optimal",
optimal_optimize,
optimal_optimize.search
)
register_preset(
"dp",
optimal_optimize,
optimal_optimize.search
)
register_preset(
"dynamic-programming",
optimal_optimize,
optimal_optimize.search
)
register_preset(
"optimal-outer",
optimal_outer_optimize,
optimal_outer_optimize.search
)

0 comments on commit ffa22ad

Please sign in to comment.