Skip to content

Commit

Permalink
Defer order/casting einsum parameters to NumPy implementation (dask#4914
Browse files Browse the repository at this point in the history
)
  • Loading branch information
pentschev authored and mrocklin committed Jun 17, 2019
1 parent 76f55fd commit abe9e28
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
9 changes: 2 additions & 7 deletions dask/array/einsumfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,9 @@ def parse_einsum_input(operands):

@derived_from(np)
def einsum(*operands, **kwargs):
casting = kwargs.pop('casting', 'safe')
dtype = kwargs.pop('dtype', None)
optimize = kwargs.pop('optimize', False)
order = kwargs.pop('order', 'K')
split_every = kwargs.pop('split_every', None)
if kwargs:
raise TypeError("einsum() got unexpected keyword "
"argument(s) %s" % ",".join(kwargs))

einsum_dtype = dtype

Expand Down Expand Up @@ -237,8 +232,8 @@ def einsum(*operands, **kwargs):
adjust_chunks={ind: 1 for ind in contract_inds}, dtype=dtype,
# np.einsum parameters
subscripts=subscripts, kernel_dtype=einsum_dtype,
ncontract_inds=ncontract_inds, order=order,
casting=casting, optimize=optimize)
ncontract_inds=ncontract_inds,
optimize=optimize, **kwargs)

# Now reduce over any extra contraction dimensions
if ncontract_inds > 0:
Expand Down
2 changes: 1 addition & 1 deletion dask/array/tests/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def test_einsum_split_every(split_every):
def test_einsum_invalid_args():
_, da_inputs = _numpy_and_dask_inputs('a')
with pytest.raises(TypeError):
da.einsum('a', *da_inputs, foo=1, bar=2)
da.einsum('a', *da_inputs, foo=1, bar=2).compute()


def test_einsum_broadcasting_contraction():
Expand Down

0 comments on commit abe9e28

Please sign in to comment.