Skip to content

Commit

Permalink
Fix path evaluation order in optimizer (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Mar 6, 2019
1 parent 4537e90 commit c47a17f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
5 changes: 3 additions & 2 deletions funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def optimize_reduction(op, arg, reduced_vars):
reduce_op, finitary_op = op, arg.op
operands = list(arg.operands)
for (a, b) in path:
ta = operands[a]
b, a = tuple(sorted((a, b), reverse=True))
tb = operands.pop(b)
ta = operands.pop(a)

# don't reduce a dimension too early - keep a collections.Counter
# and only reduce when the dimension is removed from all lhs terms in path
Expand All @@ -200,7 +201,7 @@ def optimize_reduction(op, arg, reduced_vars):
if path_end_reduced_vars:
path_end = Reduce(reduce_op, path_end, path_end_reduced_vars)

operands[a] = path_end
operands.append(path_end)

# reduce any remaining dims, if necessary
final_reduced_vars = frozenset(d for (d, count) in reduce_dim_counter.items()
Expand Down
16 changes: 9 additions & 7 deletions test/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def test_einsum(equation, backend):
actual_optimized = reinterpret(optimized_ast) # eager by default
actual = naive_einsum(equation, *funsor_operands, backend=backend)

assert_close(actual, actual_optimized, atol=1e-4)

assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
if len(outputs[0]) > 0:
actual = actual.align(tuple(outputs[0]))
actual_optimized = actual_optimized.align(tuple(outputs[0]))

assert_close(actual, actual_optimized, atol=1e-4)
assert expected.shape == actual.data.shape
assert torch.allclose(expected, actual.data)
for output in outputs:
Expand Down Expand Up @@ -84,10 +84,11 @@ def test_einsum_categorical(equation):
actual_optimized = reinterpret(optimized_ast) # eager by default
actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

assert_close(actual, actual_optimized, atol=1e-4)

if len(outputs[0]) > 0:
actual = actual.align(tuple(outputs[0]))
actual_optimized = actual_optimized.align(tuple(outputs[0]))

assert_close(actual, actual_optimized, atol=1e-4)

assert expected.shape == actual.data.shape
assert torch.allclose(expected, actual.data)
Expand All @@ -104,7 +105,7 @@ def test_einsum_categorical(equation):
(',ai,abij->', 'ij'),
('a,ai,bij->', 'ij'),
('ai,abi,bci,cdi->', 'i'),
('aij,abij,bcij,cdij->', 'ij'),
('aij,abij,bcij->', 'ij'),
('a,abi,bcij,cdij->', 'ij'),
]

Expand All @@ -120,10 +121,11 @@ def test_plated_einsum(equation, plates, backend):
actual_optimized = reinterpret(optimized_ast) # eager by default
actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend)

assert_close(actual, actual_optimized, atol=1e-4)

if len(outputs[0]) > 0:
actual = actual.align(tuple(outputs[0]))
actual_optimized = actual_optimized.align(tuple(outputs[0]))

assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4)

assert expected.shape == actual.data.shape
assert torch.allclose(expected, actual.data)
Expand Down
6 changes: 3 additions & 3 deletions test/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2):
def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0):

assert num_obs_plates >= num_hidden_plates
t0 = num_obs_plates
t0 = num_obs_plates + 1

obs_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_obs_plates))
hidden_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_hidden_plates))
Expand All @@ -123,8 +123,8 @@ def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0):

PLATED_EINSUM_EXAMPLES = [
make_plated_hmm_einsum(num_steps, num_obs_plates=b, num_hidden_plates=a)
for num_steps in range(2, 6)
for (a, b) in [(0, 1), (0, 2), (0, 0), (1, 1), (1, 2), (1, 2)]
for num_steps in range(3, 50, 6)
for (a, b) in [(0, 1), (0, 2), (0, 0), (1, 1), (1, 2)]
]


Expand Down

0 comments on commit c47a17f

Please sign in to comment.