Skip to content

Commit

Permalink
update tests to use utils.rand_equation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 9, 2024
1 parent 9b5a577 commit 13550f3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 44 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ requires-python = ">=3.8"
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
]
keywords = ["tensor", "network", "contraction", "graph", "hypergraph", "partition", "einsum"]
dependencies = [
Expand Down Expand Up @@ -99,7 +103,7 @@ max_line_length = 79
[tool.ruff]
line-length = 79
target-version = "py38"
ignore = ["E741"]
lint.ignore = ["E741"]

[tool.black]
line-length = 79
Expand Down
17 changes: 10 additions & 7 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,16 @@ def test_einsum_formats_interleaved():
assert np.allclose(x, y)


@pytest.mark.parametrize("eq,shapes", [
("c...a,b...c->b...a", [(2, 5, 6, 3), (4, 6, 2)]),
("a...a->...", [(3, 3)]),
("a...a->...a", [(3, 4, 5, 3)]),
("...,...ab->ba...", [(), (2, 3, 4, 5)]),
("a,b,ab...c->b...a", [(2,), (3,), (2, 3, 4, 5, 6)]),
])
@pytest.mark.parametrize(
"eq,shapes",
[
("c...a,b...c->b...a", [(2, 5, 6, 3), (4, 6, 2)]),
("a...a->...", [(3, 3)]),
("a...a->...a", [(3, 4, 5, 3)]),
("...,...ab->ba...", [(), (2, 3, 4, 5)]),
("a,b,ab...c->b...a", [(2,), (3,), (2, 3, 4, 5, 6)]),
],
)
def test_einsum_ellipses(eq, shapes):
arrays = [np.random.rand(*shape) for shape in shapes]
x = np.einsum(eq, *arrays)
Expand Down
60 changes: 25 additions & 35 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,10 @@ def test_binaries(contraction_20_5, optimize):
def test_hyper_slicer(parallel):
if parallel:
pytest.importorskip("distributed")
pytest.importorskip("opt_einsum")

import opt_einsum as oe

try:
eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3)
except AttributeError:
eq, shapes = oe.testing.rand_equation(30, reg=5, seed=42, d_max=3)
inputs, output, _, size_dict = ctg.utils.rand_equation(
30, reg=5, seed=42, d_max=3
)

optimizer = ctg.HyperOptimizer(
max_repeats=16,
Expand All @@ -184,23 +180,21 @@ def test_hyper_slicer(parallel):
slicing_opts={"target_slices": 1000},
progbar=True,
)
oe.contract_path(eq, *shapes, shapes=True, optimize=optimizer)
assert optimizer.get_tree().multiplicity >= 1000
tree = ctg.array_contract_tree(
inputs, output, size_dict, optimize=optimizer
)
assert tree.multiplicity >= 1000
assert optimizer.best["flops"] > optimizer.best["original_flops"]


@pytest.mark.parametrize("parallel", [False, True])
def test_hyper_reconf(parallel):
if parallel:
pytest.importorskip("distributed")
pytest.importorskip("opt_einsum")

import opt_einsum as oe

try:
eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3)
except AttributeError:
eq, shapes = oe.testing.rand_equation(30, reg=5, seed=42, d_max=3)
inputs, output, _, size_dict = ctg.utils.rand_equation(
30, reg=5, seed=42, d_max=3
)

optimizer = ctg.HyperOptimizer(
max_repeats=16,
Expand All @@ -209,22 +203,18 @@ def test_hyper_reconf(parallel):
reconf_opts={"subtree_size": 6},
progbar=True,
)
oe.contract_path(eq, *shapes, shapes=True, optimize=optimizer)
ctg.array_contract_tree(inputs, output, size_dict, optimize=optimizer)
assert optimizer.best["flops"] < optimizer.best["original_flops"]


@pytest.mark.parametrize("parallel", [False, True])
def test_hyper_slicer_reconf(parallel):
if parallel:
pytest.importorskip("distributed")
pytest.importorskip("opt_einsum")

import opt_einsum as oe

try:
eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3)
except AttributeError:
eq, shapes = oe.testing.rand_equation(30, reg=5, seed=42, d_max=3)
inputs, output, _, size_dict = ctg.utils.rand_equation(
30, reg=5, seed=42, d_max=3
)

optimizer = ctg.HyperOptimizer(
max_repeats=16,
Expand All @@ -238,8 +228,10 @@ def test_hyper_slicer_reconf(parallel):
},
progbar=True,
)
oe.contract_path(eq, *shapes, shapes=True, optimize=optimizer)
assert optimizer.get_tree().max_size() <= 2**19
tree = ctg.array_contract_tree(
inputs, output, size_dict, optimize=optimizer
)
assert tree.max_size() <= 2**19


@pytest.mark.parametrize("parallel_backend", ("dask", "ray"))
Expand All @@ -248,14 +240,10 @@ def test_insane_nested(parallel_backend):
pytest.importorskip("distributed")
else:
pytest.importorskip(parallel_backend)
pytest.importorskip("opt_einsum")

import opt_einsum as oe

try:
eq, shapes = oe.helpers.rand_equation(30, reg=5, seed=42, d_max=3)
except AttributeError:
eq, shapes = oe.testing.rand_equation(30, reg=5, seed=42, d_max=3)
inputs, output, _, size_dict = ctg.utils.rand_equation(
30, reg=5, seed=42, d_max=3
)

optimizer = ctg.HyperOptimizer(
max_repeats=16,
Expand All @@ -274,8 +262,10 @@ def test_insane_nested(parallel_backend):
},
},
)
oe.contract_path(eq, *shapes, shapes=True, optimize=optimizer)
assert optimizer.get_tree().max_size() <= 2**20
tree = ctg.array_contract_tree(
inputs, output, size_dict, optimize=optimizer
)
assert tree.max_size() <= 2**20


def test_plotting():
Expand Down

0 comments on commit 13550f3

Please sign in to comment.