-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add numpy einsum backend #314
Conversation
This should be unblocked now, right? |
funsor/einsum/util.py
Outdated
x = ops.permute(x, tuple(old_dims.index(dim) for dim in dims if dim in old_dims)) | ||
x = x.reshape(tuple(sizes[dim] if dim in old_dims else 1 for dim in dims)) | ||
x = ops.expand(x, shape) | ||
# workaround: ndarray does not allow setting attribute "_pyro_dims" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an important point: we can't set new attributes to ndarray, DeviceArray directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@eb8680 has additional comments, please don't merge yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, seems like you could actually simplify further by just removing all the _pyro_dims
-related logic entirely
funsor/einsum/numpy_map.py
Outdated
This assumes all operands have a ``._pyro_dims`` attribute set. | ||
""" | ||
# TODO: the implementation skips "backward" logic | ||
equation = rename_equation(equation, *operands) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually necessary now? What happens if you remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed your comments. I'll address those soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is necessary if we keep _pyro_dims
attribute. One of the test_einsum
fail: test_einsum[funsor.einsum.numpy_map-ab,bc,cd->da]
.
funsor/einsum/numpy_map.py
Outdated
|
||
contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) | ||
dims = output + contract_dims | ||
result = reduce(operator.add, broadcast_all(*operands, dims=dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only place where _pyro_dims
are actually used and not just passed along?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and rename_equation
, broadcast_all
utility require pyro_dims
. I am trying to remove pyro_dims
. If things pass, then I think we can simplify much code here.
@eb8680 To remove |
@fehiepsi sure, that sounds easy enough |
@eb8680 It seems that it is actually easy. Thanks for guiding! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR ports pyro einsum MAP backend to funsor. Here are changes in this PR:
funsor.einsum.numpy_log
,funsor.einsum.numpy_map
backend-agnostic.set_backend(get_backend())
infunsor.__init__
file due to a circular import issue (I'll revisit this issue later).There are some issues with numpy backend:
._pyro_dims
(and._backward
stuff) to ndarray/DeviceArray.memoize
tests fail withnumpy.generic
data. As a workaround, I castnumpy.generic
tonumpy.ndarray
at Tensor construction. Because this issue does not happen with JAX backend, I didn't attempt to resolve it.