Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add numpy einsum backend #314

Merged
merged 23 commits into from
Apr 21, 2020
Merged

add numpy einsum backend #314

merged 23 commits into from
Apr 21, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Feb 7, 2020

This PR ports pyro einsum MAP backend to funsor. Here are changes in this PR:

  • Make funsor.einsum.numpy_log, funsor.einsum.numpy_map backend-agnostic.
  • Address comments by @fritzo in the previous PR. However, I still keep set_backend(get_backend()) in funsor.__init__ file due to a circular import issue (I'll revisit this issue later).

There are some issues with numpy backend:

  • We can't set attributes ._pyro_dims (and ._backward stuff) to ndarray/DeviceArray.
  • memoize tests fail with numpy.generic data. As a workaround, I cast numpy.generic to numpy.ndarray at Tensor construction. Because this issue does not happen with JAX backend, I didn't attempt to resolve it.

@fehiepsi fehiepsi added Blocked Blocked by other issues WIP labels Feb 7, 2020
@fehiepsi fehiepsi mentioned this pull request Feb 18, 2020
13 tasks
@fritzo
Copy link
Member

fritzo commented Feb 18, 2020

This should be unblocked now, right?

funsor/einsum/numpy_map.py Outdated Show resolved Hide resolved
@fehiepsi fehiepsi removed the Blocked Blocked by other issues label Feb 18, 2020
x = ops.permute(x, tuple(old_dims.index(dim) for dim in dims if dim in old_dims))
x = x.reshape(tuple(sizes[dim] if dim in old_dims else 1 for dim in dims))
x = ops.expand(x, shape)
# workaround: ndarray does not allow setting attribute "_pyro_dims"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an important point: we can't set new attributes to ndarray, DeviceArray directly.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@fritzo
Copy link
Member

fritzo commented Feb 20, 2020

@eb8680 has additional comments, please don't merge yet

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, seems like you could actually simplify further by just removing all the _pyro_dims-related logic entirely

funsor/einsum/numpy_map.py Outdated Show resolved Hide resolved
This assumes all operands have a ``._pyro_dims`` attribute set.
"""
# TODO: the implementation skips "backward" logic
equation = rename_equation(equation, *operands)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually necessary now? What happens if you remove this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed your comments. I'll address those soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is necessary if we keep _pyro_dims attribute. One of the test_einsum fail: test_einsum[funsor.einsum.numpy_map-ab,bc,cd->da].


contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output)))
dims = output + contract_dims
result = reduce(operator.add, broadcast_all(*operands, dims=dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only place where _pyro_dims are actually used and not just passed along?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and rename_equation, broadcast_all utility require pyro_dims. I am trying to remove pyro_dims. If things pass, then I think we can simplify much code here.

@fehiepsi
Copy link
Member Author

fehiepsi commented Apr 7, 2020

@eb8680 To remove pyro_dims, I intend to remove rename_equation and reimplement broadcast_all so that it takes args as values, names of dimensions of each value, and the output dims, then permute, reshape, and possibly expand as in current implementation. Does it sound reasonable to you?

@eb8680
Copy link
Member

eb8680 commented Apr 7, 2020

@fehiepsi sure, that sounds easy enough

@fehiepsi
Copy link
Member Author

@eb8680 It seems that it is actually easy. Thanks for guiding!

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit a29a5d6 into pyro-ppl:master Apr 21, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants