From 6fccef3633160288adeaaa27690633677594a6c8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 10:09:06 -0700 Subject: [PATCH 01/30] Implement overrides of NumPy's public API on JAX arrays `__array_ufunc__` allows for writing NumPy's ufuncs, e.g., `onp.sin()`. `__array_function__` is a new, experimental override for most other functions in NumPy public API, e.g., `onp.concatenate()`. It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variable `NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1` before importing NumPy. Together, these should allow users to stick with `import numpy as np` for use with JAX, instead of requiring `import jax.numpy as np`. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays. Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g., `pip install -U git+https://github.com/numpy/numpy.git`). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy. --- .travis.yml | 8 ++++-- jax/numpy/lax_numpy.py | 64 +++++++++++++++++++++++++++++++++++++++++ tests/lax_numpy_test.py | 63 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index bb1dab926542..6c56df8b02b1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,8 +7,12 @@ python: - "2.7" - "3.6" env: - - JAX_ENABLE_X64=0 JAX_NUM_GENERATED_CASES=100 - - JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=100 + global: + - JAX_NUM_GENERATED_CASES=100 + - NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 + matrix: + - JAX_ENABLE_X64=0 + - JAX_ENABLE_X64=1 before_install: - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 593602556c60..ae81c7254c30 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -18,6 +18,7 @@ import collections import itertools +import numbers import re import string import warnings @@ -2289,6 +2290,69 @@ def _swap_args(f): setattr(DeviceArray, "astype", lax.convert_element_type) +# Override NumPy's public API. + +_HANDLED_TYPES = (DeviceArray, onp.ndarray, numbers.Number) + + +def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + """Override NumPy ufuncs, per NEP-13.""" + if method != '__call__' or kwargs or not _all( + isinstance(x, _HANDLED_TYPES) for x in inputs + ): + return NotImplemented + + lax_func = globals().get(ufunc.__name__) + if lax_func is None: + lax_func = _not_implemented(func) + + return lax_func(*inputs) + + +def __array_function__(self, func, types, args, kwargs): + """Override other NumPy functions, per NEP-18.""" + from .. import numpy + + if not _all(issubclass(t, _HANDLED_TYPES) for t in types): + return NotImplemented + + module = numpy + for submodule in func.__module__.split('.')[1:]: + try: + module = getattr(module, submodule) + except AttributeError: + lax_func = _not_implemented(func) + break + else: + lax_func = getattr(module, func.__name__, None) + if lax_func is None: + lax_func = _not_implemented(func) + elif lax_func is func: + # TODO(shoyer): remove these special cases once we've settled on a + # protocol for using original NumPy implementations upstream and it's + # made it into a NumPy release (see + # https://github.com/numpy/numpy/pull/13305). These could be replaced by + # something like: + # lax_func = func.__numpy_implementation__ + if func is onp.iscomplexobj: + # This matchs NumPy's original implementation + return issubclass(args[0].dtype.type, onp.complexfloating) + elif func is onp.result_type: + return onp.result_type(*[getattr(x, 'dtype', 'x') for x in args]) + elif func is onp.shape: + return args[0].shape + elif func is onp.ndim: + return args[0].ndim + elif func is onp.size: + return args[0].size + + return lax_func(*args, **kwargs) + + +setattr(DeviceArray, '__array_ufunc__', __array_ufunc__) +setattr(DeviceArray, '__array_function__', __array_function__) + + # Extra methods that are handy setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) setattr(ShapedArray, "split", core.aval_method(split)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 777a56e36a95..9d5c0f07790a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -243,6 +243,21 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None, ] +# __array_function__ overrides are enabled if using NumPy 1.17+ (not yet +# released) or if the NUMPY_EXPERIMENTAL_ARRAY_FUNCTION environment is set: +# https://www.numpy.org/neps/nep-0018-array-function-protocol.html +class _Dummy(object): + def __array_function__(self, *args, **kwargs): + return True + def __array__(self, *args, **kwargs): + return TypeError + +try: + array_function_overrides_enabled = onp.atleast_1d(_Dummy()) +except TypeError: + array_function_overrides_enabled = False + + CombosWithReplacement = itertools.combinations_with_replacement @@ -1400,6 +1415,54 @@ def testArange(self): self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77))) self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77))) + def testArrayUfuncUnary(self): + lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + onp_array = onp.sqrt(onp.array([1, 2])) # make a DeviceArray object + lnp_on_lnp = lnp.sin(lnp_array) + onp_on_lnp = onp.sin(lnp_array) + onp_on_onp = onp.sin(onp_array) + self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) + self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) + + def testArrayUfuncErrors(self): + x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + with self.assertRaises(TypeError): + onp.sin(x, out=x) + with self.assertRaises(NotImplementedError): + onp.isnat(x) + + def testArrayUfuncBinary(self): + x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + lnp_expected = lnp.array(x) + 3 + onp_expected = onp.array(x) + 3 + + result = onp.add(lnp.array(x), 3) + self.assertEqual(type(result), type(lnp_expected)) + self.assertAllClose(result, onp_expected, check_dtypes=True) + + result = onp.add(lnp.array(x), onp.array(3)) + self.assertEqual(type(result), type(lnp_expected)) + self.assertAllClose(result, onp_expected, check_dtypes=True) + + result = onp.add(onp.array(3), lnp.array(x)) + self.assertEqual(type(result), type(lnp_expected)) + self.assertAllClose(result, onp_expected, check_dtypes=True) + + def testArrayFunction(self): + if not array_function_overrides_enabled: + self.skipTest('__array_function__ overrides not enabled') + lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + onp_array = onp.sqrt(onp.array([1, 2])) # make a DeviceArray object + lnp_on_lnp = lnp.concatenate([lnp_array] * 2) + onp_on_lnp = onp.concatenate([lnp_array] * 2) + onp_on_onp = onp.concatenate([onp_array] * 2) + self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) + self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) + + onp_on_mixed = onp.concatenate([onp_array, lnp_array]) + self.assertEqual(type(onp_on_mixed), type(lnp_on_lnp)) + self.assertAllClose(onp_on_mixed, onp_on_onp, check_dtypes=True) + if __name__ == "__main__": absltest.main() From 7511ef02549dc7ced073a5510c69eb0f9cfa2371 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 12:27:32 -0700 Subject: [PATCH 02/30] Fix misc. test failures --- jax/numpy/lax_numpy.py | 13 ++++++++----- jax/test_util.py | 8 +++++++- tests/batching_test.py | 2 +- tests/lapax_test.py | 6 +++--- tests/random_test.py | 1 + 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index ae81c7254c30..816d28a02111 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -733,7 +733,7 @@ def broadcast_to(arr, shape): # lax.broadcast and lax.transpose lax.broadcast_shapes(shape, _shape(arr)) # error checking nlead = len(shape) - len(_shape(arr)) - diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr))) + diff, = onp.nonzero(onp.not_equal(shape[nlead:], _shape(arr))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(onp.delete(onp.arange(len(shape)), new_dims)) @@ -2011,8 +2011,9 @@ def _rewriting_take(arr, idx, axis=0): msg = "Boolean index shape did not match indexed array shape prefix." raise IndexError(msg) else: + idx = onp.asarray(idx) reshaped_arr = arr.reshape((-1,) + arr.shape[idx.ndim:]) - int_idx, = onp.where(idx.ravel()) + int_idx, = onp.nonzero(idx.ravel()) return lax.index_take(reshaped_arr, (int_idx,), (0,)) # Handle non-advanced tuple indices by recursing once @@ -2194,8 +2195,8 @@ def lcm(x1, x2): def _not_implemented(fun): @_wraps(fun) def wrapped(*args, **kwargs): - msg = "Numpy function {} not yet implemented" - raise NotImplementedError(msg.format(fun)) + msg = "'{}.{}' is not yet implemented by JAX" + raise NotImplementedError(msg.format(fun.__module__, fun.__name__)) return wrapped # Build a set of all unimplemented NumPy functions. @@ -2338,13 +2339,15 @@ def __array_function__(self, func, types, args, kwargs): # This matchs NumPy's original implementation return issubclass(args[0].dtype.type, onp.complexfloating) elif func is onp.result_type: - return onp.result_type(*[getattr(x, 'dtype', 'x') for x in args]) + return onp.result_type(*[getattr(x, 'dtype', x) for x in args]) elif func is onp.shape: return args[0].shape elif func is onp.ndim: return args[0].ndim elif func is onp.size: return args[0].size + else: + raise AssertionError('{} needs an override'.format(func)) return lax_func(*args, **kwargs) diff --git a/jax/test_util.py b/jax/test_util.py index f411b4ec3bc6..d44833021a09 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -61,6 +61,8 @@ def numpy_eq(x, y): testing_tpu = FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu") testing_x32 = not FLAGS.jax_enable_x64 + x = onp.asarray(x) + y = onp.asarray(y) if testing_tpu or testing_x32: return onp.allclose(x, y, 1e-3, 1e-3, equal_nan=testing_tpu) else: @@ -74,6 +76,8 @@ def numpy_close(a, b, atol=ATOL, rtol=RTOL, equal_nan=False): atol = max(atol, 1e-1) rtol = max(rtol, 1e-1) assert a.shape == b.shape + a = onp.asarray(a) + b = onp.asarray(b) return onp.allclose(a, b, atol=atol * a.size, rtol=rtol * b.size, equal_nan=equal_nan or testing_tpu) @@ -99,7 +103,7 @@ def inner_prod(xs, ys): def scalar_mul(xs, a): - return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs) + return tree_map(lambda x: onp.multiply(onp.array(x), a, dtype=_dtype(x)), xs) def rand_like(rng, x): @@ -413,6 +417,8 @@ def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): atol = max(atol, 0.5) rtol = max(rtol, 1e-1) + x = onp.asarray(x) + y = onp.asarray(y) if not onp.allclose(x, y, atol=atol, rtol=rtol, equal_nan=True): msg = ('Arguments x and y not equal to tolerance atol={}, rtol={}:\n' 'x:\n{}\n' diff --git a/tests/batching_test.py b/tests/batching_test.py index b6eab209de15..e63425835f33 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -373,7 +373,7 @@ def testRandom(self): expected = onp.stack([random.normal(random.PRNGKey(seed), (3, 2)) for seed in onp.arange(10)]) self.assertAllClose(ans, expected, check_dtypes=False) - assert len(onp.unique(ans)) == 10 * 3 * 2 + assert len(onp.unique(onp.asarray(ans))) == 10 * 3 * 2 def testSortKeyVal(self): k = onp.arange(12)[::-1].reshape(3, 4) diff --git a/tests/lapax_test.py b/tests/lapax_test.py index d26015c1b2ac..fa6de2021ef2 100644 --- a/tests/lapax_test.py +++ b/tests/lapax_test.py @@ -44,7 +44,7 @@ def testSolveLowerTriangularVec(self): rhs2 = npr.randn(3, 1) def check(fun, lhs, rhs): - a1 = onp.linalg.solve(lhs, rhs) + a1 = onp.linalg.solve(onp.asarray(lhs), onp.asarray(rhs)) a2 = fun(lhs, rhs) a3 = fun(lhs, rhs) self.assertArraysAllClose(a1, a2, check_dtypes=True) @@ -65,7 +65,7 @@ def testSolveLowerTriangularMat(self): rhs2 = npr.randn(4, 3) def check(fun, lhs, rhs): - a1 = onp.linalg.solve(lhs, rhs) + a1 = onp.linalg.solve(onp.asarray(lhs), onp.asarray(rhs)) a2 = fun(lhs, rhs) a3 = fun(lhs, rhs) self.assertArraysAllClose(a1, a2, check_dtypes=True) @@ -86,7 +86,7 @@ def testSolveLowerTriangularBroadcasting(self): rhs2 = npr.randn(3, 3, 2) def check(fun, lhs, rhs): - a1 = onp.linalg.solve(lhs, rhs) + a1 = onp.linalg.solve(onp.asarray(lhs), onp.asarray(rhs)) a2 = fun(lhs, rhs) a3 = fun(lhs, rhs) self.assertArraysAllClose(a1, a2, check_dtypes=True) diff --git a/tests/random_test.py b/tests/random_test.py index a2705d9b1db2..f549e7b283de 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -39,6 +39,7 @@ class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev + samples = onp.asarray(samples) nitems = len(samples) nbins = 2 ** nbits nexpected = nbins * (1 - ((nbins - 1) / nbins) ** nitems) From 6329907d0a8a90efc5436f4304b25edc80f78fe5 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 12:43:10 -0700 Subject: [PATCH 03/30] More test coverage for __array_function__ --- jax/numpy/lax_numpy.py | 7 +++++-- tests/lax_numpy_test.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 816d28a02111..c3849175833d 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2195,8 +2195,11 @@ def lcm(x1, x2): def _not_implemented(fun): @_wraps(fun) def wrapped(*args, **kwargs): - msg = "'{}.{}' is not yet implemented by JAX" - raise NotImplementedError(msg.format(fun.__module__, fun.__name__)) + try: + name = "'{}.{}'".format(fun.__module__, fun.__name__) + except AttributeError: + name = "NumPy function {}".format(fun) + raise NotImplementedError("{} is not yet implemented by JAX".format(name)) return wrapped # Build a set of all unimplemented NumPy functions. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9d5c0f07790a..848f09d3ef01 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -28,6 +28,7 @@ import six import numpy as onp +import numpy.lib.recfunctions from jax import api from jax import lax @@ -1429,6 +1430,7 @@ def testArrayUfuncErrors(self): with self.assertRaises(TypeError): onp.sin(x, out=x) with self.assertRaises(NotImplementedError): + # JAX is unlikely to ever implement datetime64 ufuncs onp.isnat(x) def testArrayUfuncBinary(self): @@ -1453,9 +1455,10 @@ def testArrayFunction(self): self.skipTest('__array_function__ overrides not enabled') lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object onp_array = onp.sqrt(onp.array([1, 2])) # make a DeviceArray object - lnp_on_lnp = lnp.concatenate([lnp_array] * 2) - onp_on_lnp = onp.concatenate([lnp_array] * 2) + onp_on_onp = onp.concatenate([onp_array] * 2) + onp_on_lnp = onp.concatenate([lnp_array] * 2) + lnp_on_lnp = lnp.concatenate([lnp_array] * 2) self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) @@ -1463,6 +1466,22 @@ def testArrayFunction(self): self.assertEqual(type(onp_on_mixed), type(lnp_on_lnp)) self.assertAllClose(onp_on_mixed, onp_on_onp, check_dtypes=True) + def testArrayFunctionSubModule(self): + if not array_function_overrides_enabled: + self.skipTest('__array_function__ overrides not enabled') + onp_array = onp.sqrt(onp.array([[1, 2], [2, 1]])) + lnp_array = lnp.sqrt(onp.array([[1, 2], [2, 1]])) + + onp_on_onp = onp.linalg.inv(onp_array) + onp_on_lnp = onp.linalg.inv(lnp_array) + lnp_on_lnp = lnp.linalg.inv(lnp_array) + self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) + self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) + + with self.assertRaisesRegexp(NotImplementedError, 'repack_fields'): + # An arbitrary choice from the long tail of NumPy's API + onp.lib.recfunctions.repack_fields(lnp_array) + if __name__ == "__main__": absltest.main() From 42a51ea3dd307749a5d43646da2493e369105f67 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 13:14:05 -0700 Subject: [PATCH 04/30] Less nesting in __array_function__ --- jax/numpy/lax_numpy.py | 47 +++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index c3849175833d..5fcf0625e93a 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2325,32 +2325,33 @@ def __array_function__(self, func, types, args, kwargs): try: module = getattr(module, submodule) except AttributeError: - lax_func = _not_implemented(func) + lax_func = None break else: lax_func = getattr(module, func.__name__, None) - if lax_func is None: - lax_func = _not_implemented(func) - elif lax_func is func: - # TODO(shoyer): remove these special cases once we've settled on a - # protocol for using original NumPy implementations upstream and it's - # made it into a NumPy release (see - # https://github.com/numpy/numpy/pull/13305). These could be replaced by - # something like: - # lax_func = func.__numpy_implementation__ - if func is onp.iscomplexobj: - # This matchs NumPy's original implementation - return issubclass(args[0].dtype.type, onp.complexfloating) - elif func is onp.result_type: - return onp.result_type(*[getattr(x, 'dtype', x) for x in args]) - elif func is onp.shape: - return args[0].shape - elif func is onp.ndim: - return args[0].ndim - elif func is onp.size: - return args[0].size - else: - raise AssertionError('{} needs an override'.format(func)) + + if lax_func is None: + lax_func = _not_implemented(func) + elif lax_func is func: + # TODO(shoyer): remove these special cases once we've settled on a + # protocol for using original NumPy implementations upstream and it's + # made it into a NumPy release (see + # https://github.com/numpy/numpy/pull/13305). These could be replaced by + # something like: + # lax_func = func.__numpy_implementation__ + if func is onp.iscomplexobj: + # This matchs NumPy's original implementation + return issubclass(args[0].dtype.type, onp.complexfloating) + elif func is onp.result_type: + return onp.result_type(*[getattr(x, 'dtype', x) for x in args]) + elif func is onp.shape: + return args[0].shape + elif func is onp.ndim: + return args[0].ndim + elif func is onp.size: + return args[0].size + else: + raise AssertionError('{} needs an override'.format(func)) return lax_func(*args, **kwargs) From 565001119cf128786941d8acb5a97ef2992de839 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 14:07:03 -0700 Subject: [PATCH 05/30] Add __array_ufunc__ and __array_function__ to Tracer --- jax/core.py | 10 ++++++++++ jax/numpy/lax_numpy.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/core.py b/jax/core.py index dd36918667ab..f5c4c0c1a6d6 100644 --- a/jax/core.py +++ b/jax/core.py @@ -198,9 +198,19 @@ class Tracer(object): __slots__ = ['trace'] def __array__(self): + # TODO(shoyer): update this error message, once __array_function__ support + # can be taken for granted. raise Exception("Tracer can't be used with raw numpy functions. " "You might have\n import numpy as np\ninstead of\n import jax.numpy as np") + def __array_ufunc__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_ufunc__ + return __array_ufunc__(self, *args, **kwargs) + + def __array_function__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_function__ + return __array_function__(self, *args, **kwargs) + def __init__(self, trace): self.trace = trace diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 5fcf0625e93a..b5c19b2a76a0 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2296,7 +2296,7 @@ def _swap_args(f): # Override NumPy's public API. -_HANDLED_TYPES = (DeviceArray, onp.ndarray, numbers.Number) +_HANDLED_TYPES = (DeviceArray, core.Tracer, onp.ndarray, numbers.Number) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): From 0c9e50c3d21e575af2c81709e59b0cc554cada69 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 14:11:40 -0700 Subject: [PATCH 06/30] Move __array_ufunc__ and __array_function__ to UnshapedArray --- jax/abstract_arrays.py | 8 ++++++++ jax/core.py | 8 -------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 0233ee483986..13913be289d5 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -76,6 +76,14 @@ def join(self, other): def str_short(self): return onp.dtype(self.dtype).name + def __array_ufunc__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_ufunc__ + return __array_ufunc__(self, *args, **kwargs) + + def __array_function__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_function__ + return __array_function__(self, *args, **kwargs) + class ShapedArray(UnshapedArray): __slots__ = ['shape'] diff --git a/jax/core.py b/jax/core.py index f5c4c0c1a6d6..8f710599c47c 100644 --- a/jax/core.py +++ b/jax/core.py @@ -203,14 +203,6 @@ def __array__(self): raise Exception("Tracer can't be used with raw numpy functions. " "You might have\n import numpy as np\ninstead of\n import jax.numpy as np") - def __array_ufunc__(self, *args, **kwargs): - from .numpy.lax_numpy import __array_ufunc__ - return __array_ufunc__(self, *args, **kwargs) - - def __array_function__(self, *args, **kwargs): - from .numpy.lax_numpy import __array_function__ - return __array_function__(self, *args, **kwargs) - def __init__(self, trace): self.trace = trace From f8387214f6fe3bd43b6b113f1517295eee230a80 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 13 Apr 2019 14:47:50 -0700 Subject: [PATCH 07/30] Fixes to support the neural net example from the readme --- jax/abstract_arrays.py | 8 -------- jax/core.py | 11 +++++++++++ jax/numpy/lax_numpy.py | 27 +++++++++++++++++++++++---- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 13913be289d5..0233ee483986 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -76,14 +76,6 @@ def join(self, other): def str_short(self): return onp.dtype(self.dtype).name - def __array_ufunc__(self, *args, **kwargs): - from .numpy.lax_numpy import __array_ufunc__ - return __array_ufunc__(self, *args, **kwargs) - - def __array_function__(self, *args, **kwargs): - from .numpy.lax_numpy import __array_function__ - return __array_function__(self, *args, **kwargs) - class ShapedArray(UnshapedArray): __slots__ = ['shape'] diff --git a/jax/core.py b/jax/core.py index 8f710599c47c..619c391cfb98 100644 --- a/jax/core.py +++ b/jax/core.py @@ -263,6 +263,17 @@ def __complex__(self): return self.aval._complex(self) def __hex__(self): return self.aval._hex(self) def __oct__(self): return self.aval._oct(self) + # Like Python's checks for arithmetic special methods, NumPy doesn't check + # instance attributes when looking up __array_ufunc__ and __array_function__, + # so these need to be defined here rather than on UnshapedArray. + def __array_ufunc__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_ufunc__ + return __array_ufunc__(self, *args, **kwargs) + + def __array_function__(self, *args, **kwargs): + from .numpy.lax_numpy import __array_function__ + return __array_function__(self, *args, **kwargs) + def __setitem__(self, idx, val): raise TypeError("JAX 'Tracer' objects do not support item assignment") diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index b5c19b2a76a0..b451103d0c27 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -688,7 +688,8 @@ def isclose(a, b, rtol=1e-05, atol=1e-08): else: return lax.eq(a, b) -numpy_version = tuple(map(int, onp.version.version.split('.'))) +# ignore suffixes in NumPy's development version, e.g., 1.17.0.dev0+cc8b978 +numpy_version = tuple(map(int, onp.__version__.split('+')[0].split('.')[:3])) if numpy_version < (1, 14): # see discussion at https://github.com/numpy/numpy/pull/9720 def _maybe_numpy_1_13_isclose_behavior(a, out): @@ -2301,11 +2302,29 @@ def _swap_args(f): def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Override NumPy ufuncs, per NEP-13.""" - if method != '__call__' or kwargs or not _all( - isinstance(x, _HANDLED_TYPES) for x in inputs - ): + + if not _all(isinstance(x, _HANDLED_TYPES) for x in inputs): return NotImplemented + # We special case the 'out' argument so we can support in-place arithmetic + # that assigns to NumPy arrays, e.g., np_array += jax_device_array + out = kwargs.pop('out', None) + if out is not None: + if isinstance(out, (DeviceArray, core.Tracer)): + raise TypeError("JAX arrays cannot be modified inplace.") + inputs = tuple(map(onp.asarray, inputs)) + return getattr(ufunc, method)(*inputs, out=out, **kwargs) + + if method != '__call__': + raise NotImplementedError( + 'JAX does not yet support the {!r} method of NumPy universal functions' + .format(method)) + + if kwargs: + raise NotImplementedError( + 'JAX does not yet support keyword arguments on NumPy universal ' + 'functions: {}'.format(list(kwargs))) + lax_func = globals().get(ufunc.__name__) if lax_func is None: lax_func = _not_implemented(func) From 3bfe51b6057bee5aff4c7d3680f91a2246c1cca3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 15 Apr 2019 21:39:56 -0700 Subject: [PATCH 08/30] Fix API test --- tests/api_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api_test.py b/tests/api_test.py index d2fee437f5c8..cb3666dfc0c7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -166,7 +166,7 @@ def test_grad_nonscalar_output(self): def test_unwrapped_numpy(self): def f(x): - return onp.exp(x) + return onp.array(x) jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception, "Tracer can't be used with raw numpy functions. " From 10067a90ab0bfb74458ab54c45696b3a2426db75 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 25 Apr 2019 22:42:23 -0700 Subject: [PATCH 09/30] Fix failing test with inplace arithmetic --- jax/numpy/lax_numpy.py | 2 +- tests/lax_numpy_test.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index b451103d0c27..5c76084485f5 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2310,7 +2310,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # that assigns to NumPy arrays, e.g., np_array += jax_device_array out = kwargs.pop('out', None) if out is not None: - if isinstance(out, (DeviceArray, core.Tracer)): + if _any(isinstance(o, (DeviceArray, core.Tracer)) for o in out): raise TypeError("JAX arrays cannot be modified inplace.") inputs = tuple(map(onp.asarray, inputs)) return getattr(ufunc, method)(*inputs, out=out, **kwargs) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 848f09d3ef01..0a2df68d1158 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -34,6 +34,7 @@ from jax import lax from jax import numpy as lnp from jax import test_util as jtu +from jax.numpy.lax_numpy import numpy_version from jax.config import config config.parse_flags_with_absl() @@ -227,7 +228,6 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None, # TODO(mattjj): lshift, rshift ] -numpy_version = tuple(map(int, onp.version.version.split('.'))) if numpy_version >= (1, 15): JAX_COMPOUND_OP_RECORDS += [ op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default(), []), @@ -1184,6 +1184,16 @@ def testRot90(self, shape, dtype, k, axes, rng): # TODO(mattjj): test infix operator overrides + @parameterized.named_parameters( + ('_numpy_on_lax', onp, lnp), + ('_lax_on_numpy', lnp, onp), + ('_lax_on_lax', lnp, lnp), + ) + def testInplaceArithmetic(self, inplace_mod, other_mod): + x = inplace_mod.zeros(3) + x += other_mod.arange(3) + self.assertAllClose(x, onp.arange(3), check_dtypes=True) + def testRavel(self): rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] From f7c2c27e940057eef8410e4dfdb9c4111e2581d3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 26 Apr 2019 11:10:12 -0700 Subject: [PATCH 10/30] Fix for CheckChiSquared --- tests/random_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/random_test.py b/tests/random_test.py index faadba2b48c2..57a9b097804e 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -54,7 +54,7 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf): def _CheckChiSquared(self, samples, pmf): alpha = 0.01 # significance level, threshold for p-value - values, actual_freq = onp.unique(samples, return_counts=True) + values, actual_freq = onp.unique(onp.asarray(samples), return_counts=True) expected_freq = pmf(values) * len(values) _, p_value = scipy.stats.chisquare(actual_freq, expected_freq) self.assertLess(p_value, alpha) From d8d76298f4795ae5330c95ced4fbdb6a5974b9c4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 29 Jul 2019 21:15:39 -0700 Subject: [PATCH 11/30] Ensure __array_function__ is only used by numpy --- jax/numpy/lax_numpy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 6feede8d830c..41f00d25e709 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2862,8 +2862,12 @@ def __array_function__(self, func, types, args, kwargs): if not _all(issubclass(t, _HANDLED_TYPES) for t in types): return NotImplemented + modules = func.__module__.split('.') + if modules[0] != 'numpy': + return NotImplemented + module = numpy - for submodule in func.__module__.split('.')[1:]: + for submodule in modules[1:]: try: module = getattr(module, submodule) except AttributeError: From 903c50cdf9e9752685e93998f7028aa38fe97864 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 30 Jul 2019 22:58:26 -0700 Subject: [PATCH 12/30] Test fixes --- jax/numpy/lax_numpy.py | 9 ++------- jax/test_util.py | 5 ++++- tests/lax_numpy_test.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 41f00d25e709..bf9fc3c9f359 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2879,17 +2879,12 @@ def __array_function__(self, func, types, args, kwargs): if lax_func is None: lax_func = _not_implemented(func) elif lax_func is func: - # TODO(shoyer): remove these special cases once we've settled on a - # protocol for using original NumPy implementations upstream and it's - # made it into a NumPy release (see - # https://github.com/numpy/numpy/pull/13305). These could be replaced by - # something like: - # lax_func = func.__numpy_implementation__ if func is onp.iscomplexobj: # This matchs NumPy's original implementation return issubclass(args[0].dtype.type, onp.complexfloating) elif func is onp.result_type: - return onp.result_type(*[getattr(x, 'dtype', x) for x in args]) + return onp.result_type( + *[x.dtype if isinstance(x, ndarray) else x for x in args]) elif func is onp.shape: return args[0].shape elif func is onp.ndim: diff --git a/jax/test_util.py b/jax/test_util.py index da90f076b246..a8a03b52db9b 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -31,6 +31,7 @@ from six.moves import xrange from . import api +from . import numpy as np from .config import flags from .util import partial from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce @@ -558,7 +559,9 @@ def wrapped_fun(*args): def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, check_dtypes=False, tol=1e-5): args = args_maker() - numpy_ans = numpy_reference_op(*args) + np_args = tuple(onp.array(arg) if isinstance(arg, np.ndarray) else arg + for arg in args) + numpy_ans = numpy_reference_op(*np_args) lax_ans = lax_op(*args) self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=tol, rtol=tol) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b17cc60afe03..2c850b25e288 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -272,7 +272,7 @@ def __array__(self, *args, **kwargs): try: array_function_overrides_enabled = onp.atleast_1d(_Dummy()) -except TypeError: +except Exception: array_function_overrides_enabled = False From c5215b828de0bb3a39dc27ae7d341e2c51102e25 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 1 Aug 2019 22:03:05 -0700 Subject: [PATCH 13/30] tweak --- jax/core.py | 10 ++++++---- jax/numpy/lax_numpy.py | 8 +++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/jax/core.py b/jax/core.py index 7dfb28385eea..73204a687a55 100644 --- a/jax/core.py +++ b/jax/core.py @@ -284,10 +284,12 @@ class Tracer(object): __slots__ = ['trace'] def __array__(self): - # TODO(shoyer): update this error message, once __array_function__ support - # can be taken for granted. - raise Exception("Tracer can't be used with raw numpy functions. " - "You might have\n import numpy as np\ninstead of\n import jax.numpy as np") + from .numpy.lax_numpy import numpy_version + msg = ("Tracer can't be used with functions that convert their arguments " + "into raw NumPy arrays.") + if numpy_version < (1, 17): + msg += " You might have\n import numpy as np\ninstead of\n import jax.numpy as np" + raise Exception(msg) def __init__(self, trace): self.trace = trace diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index b9ccd95596a6..4103f3291ca5 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2854,8 +2854,8 @@ def _unimplemented_setitem(self, i, x): # Override NumPy's public API. - -_HANDLED_TYPES = (DeviceArray, core.Tracer, onp.ndarray, numbers.Number) +_ARRAY_TYPES = DeviceArray, core.Tracer +_HANDLED_TYPES = _ARRAY_TYPES + (onp.ndarray, numbers.Number) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @@ -2914,12 +2914,14 @@ def __array_function__(self, func, types, args, kwargs): if lax_func is None: lax_func = _not_implemented(func) elif lax_func is func: + # Implementations of NumPy functions that work if at least one array + # argument is a JAX array. if func is onp.iscomplexobj: # This matchs NumPy's original implementation return issubclass(args[0].dtype.type, onp.complexfloating) elif func is onp.result_type: return onp.result_type( - *[x.dtype if isinstance(x, ndarray) else x for x in args]) + *[x.dtype if isinstance(x, _ARRAY_TYPES) else x for x in args]) elif func is onp.shape: return args[0].shape elif func is onp.ndim: From 45a6723b7491e9eba9e74f55fa9ebf246bd3e8f1 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 3 Aug 2019 13:27:56 -0700 Subject: [PATCH 14/30] Explicit wrappers for aliases to numpy functions --- jax/numpy/lax_numpy.py | 97 ++++++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 27 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 4103f3291ca5..5cf84d361f7d 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -28,6 +28,7 @@ from __future__ import print_function import collections +import functools import itertools import numbers import re @@ -100,13 +101,7 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, " Use jax.numpy.array, or jax.numpy.zeros instead.") # pylint: enable=invalid-name - isscalar = onp.isscalar -iscomplexobj = onp.iscomplexobj -result_type = onp.result_type -shape = _shape = onp.shape -ndim = _ndim = onp.ndim -size = onp.size _dtype = lax.dtype bool_ = onp.bool_ @@ -143,13 +138,74 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, ComplexWarning = onp.ComplexWarning -array_str = onp.array_str -array_repr = onp.array_repr - -save = onp.save -savez = onp.savez load = onp.load + +# These wrappers are necessary to avoid infinite recursion inside JAX's +# __array_function__ method. The implementations below exactly match NumPy. + +@functools.wraps(onp.iscomplexobj) +def iscomplexobj(x): + if isinstance(x, _ARRAY_TYPES): + return issubclass(x.dtype.type, onp.complexfloating) + return onp.iscomplexobj(x) + +@functools.wraps(onp.result_type) +def result_type(*arrays_and_dtypes): + arrays_and_dtypes = tuple( + x.dtype if isinstance(x, _ARRAY_TYPES) else x + for x in arrays_and_dtypes) + return onp.result_type(*arrays_and_dtypes) + +@functools.wraps(onp.shape) +def shape(a): + if isinstance(a, _ARRAY_TYPES): + return a.shape + return onp.shape(a) + +@functools.wraps(onp.ndim) +def ndim(a): + if isinstance(a, _ARRAY_TYPES): + return a.ndim + return onp.ndim(a) + +_shape = shape +_ndim = ndim + +@functools.wraps(onp.size) +def size(a, axis=None): + if isinstance(a, _ARRAY_TYPES): + if axis is None: + return a.size + else: + return a.shape[axis] + return onp.size(a, axis) + +@functools.wraps(onp.array_str) +def array_str(a, max_line_width=None, precision=None, suppress_small=None): + return onp.array_str(onp.asarray(a), max_line_width=max_line_width, + precision=precision, suppress_small=suppress_small) + +@functools.wraps(onp.array_repr) +def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): + return onp.array_repr(onp.asarray(arr), max_line_width=max_line_width, + precision=precision, suppress_small=suppress_small) + +@functools.wraps(onp.save) +def save(file, arr, allow_pickle=True, fix_imports=True): + return onp.save(file, onp.asarray(arr), allow_pickle=allow_pickle, + fix_imports=fix_imports) + +def _cast_if_needed(x): + return onp.asarray(x) if isinstance(x, _ARRAY_TYPES) else x + +@functools.wraps(onp.savez) +def savez(file, *args, **kwds): + args = tuple(map(_cast_if_needed, args)) + kwds = dict(zip(kwds.keys(), map(_cast_if_needed, kwds.values()))) + return onp.savez(file, *args, **kwds) + + ### utility functions @@ -2914,22 +2970,9 @@ def __array_function__(self, func, types, args, kwargs): if lax_func is None: lax_func = _not_implemented(func) elif lax_func is func: - # Implementations of NumPy functions that work if at least one array - # argument is a JAX array. - if func is onp.iscomplexobj: - # This matchs NumPy's original implementation - return issubclass(args[0].dtype.type, onp.complexfloating) - elif func is onp.result_type: - return onp.result_type( - *[x.dtype if isinstance(x, _ARRAY_TYPES) else x for x in args]) - elif func is onp.shape: - return args[0].shape - elif func is onp.ndim: - return args[0].ndim - elif func is onp.size: - return args[0].size - else: - raise AssertionError('{} needs an override'.format(func)) + raise AssertionError('{} needs to be defined with a wrapper that does not ' + 'call the corresponding NumPy function directly on ' + 'JAX arrays'.format(func)) return lax_func(*args, **kwargs) From 951bd5b6ca349d9aa48fa7d7cba7dd70c487907d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 3 Aug 2019 16:35:43 -0700 Subject: [PATCH 15/30] Warn instead of erroring for unimplemented NumPy functions --- jax/numpy/fft.py | 4 +-- jax/numpy/lax_numpy.py | 71 ++++++++++++++++++++++++----------------- jax/numpy/linalg.py | 4 +-- tests/lax_numpy_test.py | 51 ++++++++++++++++++++++++----- 4 files changed, 89 insertions(+), 41 deletions(-) diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 8eb9392a8f01..718fc6f0a21e 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -21,7 +21,7 @@ from .. import lax from ..lib.xla_bridge import xla_client, canonicalize_dtype from ..util import get_module_functions -from .lax_numpy import _not_implemented +from .lax_numpy import _NotImplementedByJAX from .lax_numpy import _wraps from . import lax_numpy as np @@ -71,4 +71,4 @@ def fftn(a, s=None, axes=None, norm=None): for func in get_module_functions(onp.fft): if func.__name__ not in globals(): - globals()[func.__name__] = _not_implemented(func) + globals()[func.__name__] = _NotImplementedByJAX(func) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 5cf84d361f7d..9703118efc49 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -41,7 +41,7 @@ import six from six.moves import builtins, xrange -from jax import jit, device_put +from jax import jit, device_get, device_put from .. import core from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray from ..interpreters.xla import DeviceArray @@ -2796,20 +2796,24 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): ### track unimplemented functions -def _not_implemented(fun): - @_wraps(fun) - def wrapped(*args, **kwargs): - try: - name = "'{}.{}'".format(fun.__module__, fun.__name__) - except AttributeError: - name = "NumPy function {}".format(fun) +def _numpy_function_name(fun): + try: + name = "'{}.{}'".format(fun.__module__, fun.__name__) + except AttributeError: + name = "NumPy function {}".format(fun) + return name + +class _NotImplementedByJAX(object): + def __init__(self, fun): + self._fun = fun + def __call__(self, *args, **kwargs): + name = _numpy_function_name(fun) raise NotImplementedError("{} is not yet implemented by JAX".format(name)) - return wrapped # Build a set of all unimplemented NumPy functions. for func in get_module_functions(onp): if func.__name__ not in globals(): - globals()[func.__name__] = _not_implemented(func) + globals()[func.__name__] = _NotImplementedByJAX(func) ### add method and operator overloads to arraylike classes @@ -2910,9 +2914,17 @@ def _unimplemented_setitem(self, i, x): # Override NumPy's public API. -_ARRAY_TYPES = DeviceArray, core.Tracer +_ARRAY_TYPES = (DeviceArray, core.Tracer) _HANDLED_TYPES = _ARRAY_TYPES + (onp.ndarray, numbers.Number) +def _implement_with_warning(func, args, kwargs, name, stacklevel=2): + warnings.warn("{} is not yet implemented by JAX; coercing arguments to " + "NumPy arrays. Coerce arrays with onp.asarray() or " + "jax.device_get() to explicitly silence this warning." + .format(name), stacklevel=stacklevel) + args, kwargs = device_get((args, kwargs)) + return func(*args, **kwargs) + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Override NumPy ufuncs, per NEP-13.""" @@ -2920,29 +2932,29 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if not _all(isinstance(x, _HANDLED_TYPES) for x in inputs): return NotImplemented + ufunc_method = getattr(ufunc, method) + + lax_func = globals().get(ufunc.__name__) + if lax_func is None or isinstance(lax_func, _NotImplementedByJAX): + return _implement_with_warning(ufunc_method, inputs, kwargs, ufunc) + + if method != '__call__': + name = '{!r} method of NumPy universal functions'.format(method) + return _implement_with_warning(ufunc_method, inputs, kwargs, name) + # We special case the 'out' argument so we can support in-place arithmetic # that assigns to NumPy arrays, e.g., np_array += jax_device_array out = kwargs.pop('out', None) if out is not None: - if _any(isinstance(o, (DeviceArray, core.Tracer)) for o in out): + if _any(isinstance(o, _ARRAY_TYPES) for o in out): raise TypeError("JAX arrays cannot be modified inplace.") - inputs = tuple(map(onp.asarray, inputs)) - return getattr(ufunc, method)(*inputs, out=out, **kwargs) - - if method != '__call__': - raise NotImplementedError( - 'JAX does not yet support the {!r} method of NumPy universal functions' - .format(method)) + inputs = device_get(inputs) + return ufunc_method(*inputs, out=out, **kwargs) if kwargs: - raise NotImplementedError( - 'JAX does not yet support keyword arguments on NumPy universal ' - 'functions: {}'.format(list(kwargs))) - - lax_func = globals().get(ufunc.__name__) - if lax_func is None: - lax_func = _not_implemented(func) - + kwargs_str = ', '.join(map(str, kwargs)) + warnings.warn("{} keyword argument(s) are ignored when called on JAX " + "arrays: {}".format(ufunc, kwargs_str), stacklevel=2) return lax_func(*inputs) @@ -2967,8 +2979,9 @@ def __array_function__(self, func, types, args, kwargs): else: lax_func = getattr(module, func.__name__, None) - if lax_func is None: - lax_func = _not_implemented(func) + if lax_func is None or isinstance(lax_func, _NotImplementedByJAX): + return _implement_with_warning( + func, args, kwargs, name=_numpy_function_name(func), stacklevel=4) elif lax_func is func: raise AssertionError('{} needs to be defined with a wrapper that does not ' 'call the corresponding NumPy function directly on ' diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index f7b052121241..1ce7b8c20bbb 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -22,7 +22,7 @@ from .. import lax from .. import lax_linalg -from .lax_numpy import _not_implemented +from .lax_numpy import _NotImplementedByJAX from .lax_numpy import _wraps from . import lax_numpy as np from ..util import get_module_functions @@ -255,4 +255,4 @@ def solve(a, b): for func in get_module_functions(onp.linalg): if func.__name__ not in globals(): - globals()[func.__name__] = _not_implemented(func) + globals()[func.__name__] = _NotImplementedByJAX(func) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 42386bbe0adb..eefd09599875 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1701,13 +1701,26 @@ def testArrayUfuncUnary(self): self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) - def testArrayUfuncErrors(self): + def testArrayUfuncWarningsAndErrors(self): x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object - with self.assertRaises(TypeError): + + with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): + y = onp.add.reduce(x) + self.assertTrue(onp.isscalar(y)) + self.assertAllClose(y, onp.add.reduce(onp.asarray(x)), check_dtypes=True) + + with self.assertRaisesRegex(TypeError, "cannot be modified inplace"): onp.sin(x, out=x) - with self.assertRaises(NotImplementedError): + + with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): # JAX is unlikely to ever implement datetime64 ufuncs - onp.isnat(x) + with self.assertRaises(TypeError): + onp.isnat(x) # x has the wrong dtpe + + with self.assertWarnsRegex( + UserWarning, r'keyword argument\(s\) are ignored'): + y = onp.add(x, x, dtype=onp.float32) + self.assertAllClose(y, x + x, check_dtypes=True) def testArrayUfuncBinary(self): x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object @@ -1726,11 +1739,20 @@ def testArrayUfuncBinary(self): self.assertEqual(type(result), type(lnp_expected)) self.assertAllClose(result, onp_expected, check_dtypes=True) + def testArrayUfuncUnhandledType(self): + class Other(object): + def __array_ufunc__(self, *args, **kwargs): + return 'success' + + x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + result = onp.add(x, Other()) + self.assertEqual(result, 'success') + def testArrayFunction(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') + onp_array = onp.sqrt(onp.array([1, 2])) lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object - onp_array = onp.sqrt(onp.array([1, 2])) # make a DeviceArray object onp_on_onp = onp.concatenate([onp_array] * 2) onp_on_lnp = onp.concatenate([lnp_array] * 2) @@ -1754,9 +1776,22 @@ def testArrayFunctionSubModule(self): self.assertEqual(type(onp_on_lnp), type(lnp_on_lnp)) self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) - with self.assertRaisesRegexp(NotImplementedError, 'repack_fields'): - # An arbitrary choice from the long tail of NumPy's API - onp.lib.recfunctions.repack_fields(lnp_array) + def testArrayFunctionUnimplemented(self): + if not array_function_overrides_enabled: + self.skipTest('__array_function__ overrides not enabled') + lnp_array = lnp.sqrt(onp.arange(4)) + with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): + actual = onp.unique(lnp_array) + self.assertAllClose(actual, lnp_array, check_dtypes=False) + + def testArrayFunctionUnhandledType(self): + class Other(object): + def __array_function__(self, *args, **kwargs): + return 'success' + + x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + result = onp.concatenate([x, Other()]) + self.assertEqual(result, 'success') def testIssue830(self): a = lnp.arange(4, dtype=lnp.complex64) From feff162621f2daa75ed890bb5d9cf28df1dfbdd9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 15 Sep 2019 17:18:51 -0700 Subject: [PATCH 16/30] cleanup --- jax/numpy/lax_numpy.py | 9 ++++++--- tests/api_test.py | 5 ++--- tests/lax_numpy_test.py | 20 ++++++++++---------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index ab3ca5344987..92479023a03c 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -43,7 +43,7 @@ import six from six.moves import builtins, xrange -from jax import jit, device_get, device_put +from jax import jit, device_get, device_put, custom_transforms, defjvp from .. import core from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray from ..config import flags @@ -3130,13 +3130,16 @@ def _unimplemented_setitem(self, i, x): _ARRAY_TYPES = (DeviceArray, core.Tracer) _HANDLED_TYPES = _ARRAY_TYPES + (onp.ndarray, numbers.Number) +def _implement_via_coercion(func, args, kwargs): + args, kwargs = device_get((args, kwargs)) + return func(*args, **kwargs) + def _implement_with_warning(func, args, kwargs, name, stacklevel=2): warnings.warn("{} is not yet implemented by JAX; coercing arguments to " "NumPy arrays. Coerce arrays with onp.asarray() or " "jax.device_get() to explicitly silence this warning." .format(name), stacklevel=stacklevel) - args, kwargs = device_get((args, kwargs)) - return func(*args, **kwargs) + return _implement_via_coercion(func, args, kwargs) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): diff --git a/tests/api_test.py b/tests/api_test.py index 2a267b6ed495..bc77977fb782 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -180,9 +180,8 @@ def f(x): return onp.array(x) jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception, - "Tracer can't be used with raw numpy functions. " - "You might have\n import numpy as np\ninstead of\n" - " import jax.numpy as np") + "Tracer can't be used with functions that convert their " + "arguments into raw NumPy arrays.") def test_binop_mismatch(self): def f(x, y): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d4b547839ea3..3d9b7c7c58bc 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1698,8 +1698,8 @@ def testIssue728(self): self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) def testArrayUfuncUnary(self): - lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object - onp_array = onp.sqrt(onp.array([1, 2])) # make a DeviceArray object + lnp_array = lnp.array([1, 2]) + onp_array = onp.array([1, 2]) lnp_on_lnp = lnp.sin(lnp_array) onp_on_lnp = onp.sin(lnp_array) onp_on_onp = onp.sin(onp_array) @@ -1707,7 +1707,7 @@ def testArrayUfuncUnary(self): self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) def testArrayUfuncWarningsAndErrors(self): - x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + x = lnp.array([1, 2]) with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): y = onp.add.reduce(x) @@ -1728,7 +1728,7 @@ def testArrayUfuncWarningsAndErrors(self): self.assertAllClose(y, x + x, check_dtypes=True) def testArrayUfuncBinary(self): - x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + x = lnp.array([1, 2]) lnp_expected = lnp.array(x) + 3 onp_expected = onp.array(x) + 3 @@ -1749,15 +1749,15 @@ class Other(object): def __array_ufunc__(self, *args, **kwargs): return 'success' - x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + x = lnp.array([1, 2]) result = onp.add(x, Other()) self.assertEqual(result, 'success') def testArrayFunction(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') - onp_array = onp.sqrt(onp.array([1, 2])) - lnp_array = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + onp_array = onp.array([1, 2]) + lnp_array = lnp.array([1, 2]) onp_on_onp = onp.concatenate([onp_array] * 2) onp_on_lnp = onp.concatenate([lnp_array] * 2) @@ -1772,8 +1772,8 @@ def testArrayFunction(self): def testArrayFunctionSubModule(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') - onp_array = onp.sqrt(onp.array([[1, 2], [2, 1]])) - lnp_array = lnp.sqrt(onp.array([[1, 2], [2, 1]])) + onp_array = onp.array([[1, 2], [2, 1]]) + lnp_array = lnp.array([[1, 2], [2, 1]]) onp_on_onp = onp.linalg.inv(onp_array) onp_on_lnp = onp.linalg.inv(lnp_array) @@ -1794,7 +1794,7 @@ class Other(object): def __array_function__(self, *args, **kwargs): return 'success' - x = lnp.sqrt(onp.array([1, 2])) # make a DeviceArray object + x = lnp.array([1, 2]) result = onp.concatenate([x, Other()]) self.assertEqual(result, 'success') From 61d086694692007c60eeff219f20fcffa737fb40 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 16 Sep 2019 09:52:14 -0700 Subject: [PATCH 17/30] flags for disabling numpy overrides --- .travis.yml | 4 +- jax/api.py | 7 ++- jax/lax/lax.py | 124 +++++++++++++++++++++++----------------- jax/numpy/lax_numpy.py | 73 ++++++++++++----------- tests/lax_numpy_test.py | 35 ++++++++++++ 5 files changed, 152 insertions(+), 91 deletions(-) diff --git a/.travis.yml b/.travis.yml index ea747d17b7ad..31eed0beb78a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,8 +7,8 @@ python: - "2.7" - "3.6" env: - - JAX_ENABLE_X64=0 JAX_NUM_GENERATED_CASES=25 NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 - - JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 + - JAX_ENABLE_X64=0 JAX_NUM_GENERATED_CASES=25 + - JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 JAX_ENABLE_NUMPY_OVERRIDES=1 NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 before_install: - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; diff --git a/jax/api.py b/jax/api.py index c039a0312539..b5640e16dc7b 100644 --- a/jax/api.py +++ b/jax/api.py @@ -1172,7 +1172,12 @@ def device_put(x, device_num=0, backend=None): def _device_get(x): if isinstance(x, core.Tracer): return x - return x.copy() + try: + copy = x.copy + except AttributeError: + return x + else: + return copy() def device_get(x): for y in tree_leaves(x): diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 2b6500cf470c..e6c86f30d184 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -94,6 +94,20 @@ def _canonicalize_shape(shape): def _identity(x): return x + +def shape(a): + try: + return a.shape + except AttributeError: + return onp.shape(a) + +def ndim(a): + try: + return a.ndim + except AttributeError: + return onp.ndim(a) + + ### traceables def neg(x): @@ -498,9 +512,9 @@ def dot(lhs, rhs, precision=None): # TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of # products instead. if onp.issubdtype(lhs.dtype, onp.integer): - lhs_shape = onp.shape(lhs) + lhs_shape = shape(lhs) lhs_ndim = len(lhs_shape) - rhs_ndim = onp.ndim(rhs) + rhs_ndim = ndim(rhs) if rhs_ndim > 1: lhs = broadcast_in_dim(lhs, lhs_shape + (1,), tuple(range(len(lhs_shape)))) if lhs_ndim > 1: @@ -537,17 +551,17 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): lhs_contract_dims, rhs_contract_dims = contract_dims lhs_batch_dims, rhs_batch_dims = batch_dims lhs_noncontract_dims = tuple(sorted( - set(range(onp.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) + set(range(ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) rhs_noncontract_dims = tuple(sorted( - set(range(onp.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) + set(range(ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) lhs = transpose(lhs, lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims) rhs = transpose(rhs, rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims) new_lhs_shape = onp.insert( - onp.shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims), - (1,) * len(rhs_noncontract_dims)) - new_rhs_shape = onp.insert(onp.shape(rhs), len(lhs_batch_dims), + shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims), + (1,) * len(rhs_noncontract_dims)) + new_rhs_shape = onp.insert(shape(rhs), len(lhs_batch_dims), (1,) * len(lhs_noncontract_dims)) lhs = reshape(lhs, new_lhs_shape) rhs = reshape(rhs, new_rhs_shape) @@ -595,15 +609,15 @@ def reshape(operand, new_sizes, dimensions=None): """ new_sizes = _canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) - same_shape = onp.shape(operand) == new_sizes - same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand))) - if onp.shape(operand) and same_shape and same_dims: + same_shape = shape(operand) == new_sizes + same_dims = dimensions is None or tuple(dimensions) == tuple(range(ndim(operand))) + if shape(operand) and same_shape and same_dims: return operand else: return reshape_p.bind( operand, new_sizes=new_sizes, dimensions=None if same_dims else tuple(dimensions), - old_sizes=onp.shape(operand)) + old_sizes=shape(operand)) def pad(operand, padding_value, padding_config): """Wraps XLA's `Pad @@ -894,7 +908,7 @@ def _get_min_identity(dtype): def _reduce_sum(operand, axes): return reduce_sum_p.bind(operand, axes=tuple(axes), - input_shape=onp.shape(operand)) + input_shape=shape(operand)) def _reduce_prod(operand, axes): return reduce_prod_p.bind(operand, axes=tuple(axes)) @@ -1022,9 +1036,9 @@ def full(shape, fill_value, dtype=None): "`static_argnums` or applying `jit` to smaller subfunctions instead.") raise TypeError(msg) - if onp.shape(fill_value): + if _shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." - raise TypeError(msg.format(onp.shape(fill_value))) + raise TypeError(msg.format(_shape(fill_value))) dtype = dtype or _dtype(fill_value) dtype = xla_bridge.canonicalize_dtype(dtype) @@ -1248,7 +1262,7 @@ def full_like(x, fill_value, dtype=None, shape=None): An ndarray with the same shape as `x` with its entries set equal to `fill_value`, similar to the output of np.full. """ - shape = onp.shape(x) if shape is None else _canonicalize_shape(shape) + shape = _shape(x) if shape is None else _canonicalize_shape(shape) out = full(shape, fill_value, dtype or _dtype(x)) return tie_in(x, out) @@ -1519,16 +1533,16 @@ def _brcast(x, *others): # Requires shape info during jvp tracing, which isn't strictly necessary. # We don't need full numpy broadcasting, but otherwise the logic is the same # so we reuse the broadcast_shapes function after filtering out scalars. - shapes = tuple(filter(None, map(onp.shape, (x,) + others))) + shapes = tuple(filter(None, map(_shape, (x,) + others))) shape = shapes and broadcast_shapes(*shapes) - if onp.shape(x) != shape: + if _shape(x) != shape: return _brcast_to(x, shape) else: return x def _brcast_to(x, shape): - x_shape = onp.shape(x) + x_shape = _shape(x) assert x_shape != shape if x_shape: assert len(x_shape) == len(shape) @@ -2037,51 +2051,51 @@ def require(shape_cond): return lhs.shape[:-1] + rhs.shape[:-2] + rhs.shape[-1:] def _dot_transpose_lhs(t, rhs, precision): - if onp.ndim(t) == onp.ndim(rhs) == 2: + if ndim(t) == ndim(rhs) == 2: return dot(t, transpose(rhs, (1, 0)), precision=precision) - elif onp.ndim(t) == 1 and onp.ndim(rhs) == 2: + elif ndim(t) == 1 and ndim(rhs) == 2: return dot(rhs, t, precision=precision) - elif onp.ndim(t) == onp.ndim(rhs) == 1: + elif ndim(t) == ndim(rhs) == 1: return _outer(t, rhs) - elif onp.ndim(t) == 0 or onp.ndim(rhs) == 0: + elif ndim(t) == 0 or ndim(rhs) == 0: return mul(t, rhs) else: raise TypeError def _dot_transpose_rhs(t, lhs, precision): - if onp.ndim(lhs) == onp.ndim(t) == 2: + if ndim(lhs) == ndim(t) == 2: return dot(transpose(lhs, (1, 0)), t) - elif onp.ndim(lhs) == 2 and onp.ndim(t) == 1: + elif ndim(lhs) == 2 and ndim(t) == 1: return dot(t, lhs, precision=precision) - elif onp.ndim(t) == onp.ndim(lhs) == 1: + elif ndim(t) == ndim(lhs) == 1: return _outer(lhs, t) - elif onp.ndim(t) == 0 or onp.ndim(lhs) == 0: + elif ndim(t) == 0 or ndim(lhs) == 0: return mul(t, lhs) else: raise TypeError def _outer(x, y): - assert onp.ndim(x) == onp.ndim(y) == 1 + assert ndim(x) == ndim(y) == 1 return mul(reshape(x, (x.shape[0], 1)), reshape(y, (1, y.shape[0]))) def _dot_batch_rule(batched_args, batch_dims, precision=None): lhs, rhs = batched_args lbd, rbd = batch_dims - T = lambda x: transpose(x, onp.arange(onp.ndim(x))[::-1]) + T = lambda x: transpose(x, onp.arange(ndim(x))[::-1]) # in some cases, we can call dot instead of dot_general - if max(onp.ndim(lhs), onp.ndim(rhs)) <= 2: + if max(ndim(lhs), ndim(rhs)) <= 2: if rbd is None: assert lbd in (0, 1) if lbd == 0: return dot(lhs, rhs, precision=precision), 0 else: - return dot(T(rhs), lhs, precision=precision), onp.ndim(rhs) - 1 + return dot(T(rhs), lhs, precision=precision), ndim(rhs) - 1 if lbd is None: assert rbd in (0, 1) - if rbd == onp.ndim(rhs) - 1: - return dot(lhs, rhs, precision=precision), onp.ndim(lhs) - 1 + if rbd == ndim(rhs) - 1: + return dot(lhs, rhs, precision=precision), ndim(lhs) - 1 else: return dot(rhs, T(lhs), precision=precision), 0 @@ -2099,7 +2113,7 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None): else: lhs = batching.moveaxis(lhs, lbd, 0) lhs_batch = (0,) - lhs_contracting = (onp.ndim(lhs) - 1,) + lhs_contracting = (ndim(lhs) - 1,) if rbd is None: assert lbd is not None @@ -2107,7 +2121,7 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None): else: rhs = batching.moveaxis(rhs, rbd, 0) rhs_batch = (0,) - rhs_contracting = (onp.arange(1, onp.ndim(rhs))[-2:][0],) + rhs_contracting = (onp.arange(1, ndim(rhs))[-2:][0],) dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)] return dot_general(lhs, rhs, dim_nums, precision=precision), 0 @@ -2507,14 +2521,14 @@ def _reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs): if not onp.all(onp.greater_equal(new_sizes, 0)): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) - if prod(onp.shape(operand)) != prod(new_sizes): + if prod(shape(operand)) != prod(new_sizes): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' - raise TypeError(msg.format(new_sizes, onp.shape(operand))) + raise TypeError(msg.format(new_sizes, shape(operand))) if dimensions is not None: - if set(dimensions) != set(range(onp.ndim(operand))): + if set(dimensions) != set(range(ndim(operand))): msg = ('reshape dimensions must be a permutation of operand dimensions, ' 'got dimensions {} for shape {}.') - raise TypeError(msg.format(dimensions, onp.shape(operand))) + raise TypeError(msg.format(dimensions, shape(operand))) return tuple(new_sizes) def _reshape_dtype_rule(operand, new_sizes, dimensions, **unused_kwargs): @@ -2635,31 +2649,31 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): # avoid transposes and some broadcasts in special cases if pred_bdim == ot_bdim == of_bdim: - if onp.shape(pred) == onp.shape(on_true): + if shape(pred) == shape(on_true): return select(pred, on_true, on_false), pred_bdim else: # vmapped function had a scalar pred with nonscalar args - assert onp.ndim(pred) == 1 + assert ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim]) return select(pred, on_true, on_false), pred_bdim - elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: + elif ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: if ot_bdim == of_bdim: return select(pred, on_true, on_false), ot_bdim - elif onp.shape(on_true) == onp.shape(on_false): + elif shape(on_true) == shape(on_false): on_false = batching.moveaxis(on_false, of_bdim, ot_bdim) return select(pred, on_true, on_false), ot_bdim - pred = batching.bdim_at_front(pred, pred_bdim, size) if onp.shape(pred) else pred - if not onp.shape(on_true) == onp.shape(on_false) == (): + pred = batching.bdim_at_front(pred, pred_bdim, size) if shape(pred) else pred + if not shape(on_true) == shape(on_false) == (): on_true = batching.bdim_at_front(on_true, ot_bdim, size) on_false = batching.bdim_at_front(on_false, of_bdim, size) - assert onp.shape(on_true) == onp.shape(on_false) - if 0 < onp.ndim(pred) < onp.ndim(on_true): + assert shape(on_true) == shape(on_false) + if 0 < ndim(pred) < ndim(on_true): # vmapped function had a scalar pred with nonscalar args - assert onp.ndim(pred) == 1 + assert ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [0]) - if onp.ndim(pred) > onp.ndim(on_true): - assert onp.ndim(on_true) == 0 + if ndim(pred) > ndim(on_true): + assert ndim(on_true) == 0 on_true = broadcast(on_true, pred.shape) on_false = broadcast(on_false, pred.shape) return select(pred, on_true, on_false), 0 @@ -4137,7 +4151,8 @@ def _stop_gradient_batch_rule(batched_args, batch_dims): ### util -_ndim = onp.ndim +_ndim = ndim +_shape = shape def _dilate_shape(shape, dilation): @@ -4308,7 +4323,14 @@ def _const(example, val): _twos = partial(full_like, fill_value=2) _two = partial(full_like, shape=(), fill_value=2) -_dtype = dtype = onp.result_type + +def dtype(*arrays_and_dtypes): + arrays_and_dtypes = tuple( + x if isinstance(x, type) else getattr(x, 'dtype', x) + for x in arrays_and_dtypes) + return onp.result_type(*arrays_and_dtypes) + +_dtype = dtype _iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 92479023a03c..8c5a17685c90 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -49,6 +49,7 @@ from ..config import flags from ..interpreters.xla import DeviceArray from .. import lax +from ..tree_util import tree_map from ..util import partial, get_module_functions, unzip2, prod as _prod from ..lib import pytree from ..lib import xla_bridge @@ -61,6 +62,14 @@ help= 'Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").') +flags.DEFINE_bool( + 'jax_enable_numpy_overrides', + strtobool(os.getenv('JAX_ENABLE_NUMPY_OVERRIDES', 'False')), + help= + "If true, enable overrides of NumPy functions on JAX arrays, e.g., " + "np.sum(jax_array) should return a JAX array rather than a NumPy array. " + "Also requires NumPy 1.17, or NumPy 1.16 with the environment variable " + "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1.") if six.PY3: def removechars(s, chars): @@ -114,7 +123,9 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, # pylint: enable=invalid-name isscalar = onp.isscalar -_dtype = lax.dtype +result_type = _dtype = lax.dtype +_shape = shape = lax.shape +_ndim = ndim = lax.ndim bool_ = onp.bool_ uint8 = onp.uint8 @@ -162,28 +173,6 @@ def iscomplexobj(x): return issubclass(x.dtype.type, onp.complexfloating) return onp.iscomplexobj(x) -@functools.wraps(onp.result_type) -def result_type(*arrays_and_dtypes): - arrays_and_dtypes = tuple( - x.dtype if isinstance(x, _ARRAY_TYPES) else x - for x in arrays_and_dtypes) - return onp.result_type(*arrays_and_dtypes) - -@functools.wraps(onp.shape) -def shape(a): - if isinstance(a, _ARRAY_TYPES): - return a.shape - return onp.shape(a) - -@functools.wraps(onp.ndim) -def ndim(a): - if isinstance(a, _ARRAY_TYPES): - return a.ndim - return onp.ndim(a) - -_shape = shape -_ndim = ndim - @functools.wraps(onp.size) def size(a, axis=None): if isinstance(a, _ARRAY_TYPES): @@ -257,7 +246,7 @@ def _promote_dtypes(*args): if len(args) < 2: return args else: - from_dtypes = map(_dtype, args) + from_dtypes = tuple(map(_dtype, args)) to_dtype = xla_bridge.canonicalize_dtype(result_type(*from_dtypes)) return [lax.convert_element_type(x, to_dtype) if _dtype(x) != to_dtype else x for x in args] @@ -910,7 +899,7 @@ def where(condition, x=None, y=None): if not onp.issubdtype(_dtype(condition), onp.bool_): condition = lax.ne(condition, zeros_like(condition)) condition, x, y = broadcast_arrays(condition, x, y) - if not onp.size(x): + if not size(x): empty, _ = _promote_dtypes(x, y) return empty else: @@ -967,7 +956,7 @@ def broadcast_to(arr, shape): def split(ary, indices_or_sections, axis=0): dummy_val = onp.broadcast_to(0, ary.shape) # zero strides subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes - split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays]) + split_indices = onp.cumsum([0] + [shape(sub)[axis] for sub in subarrays]) starts, ends = [0] * ndim(ary), shape(ary) _subval = lambda x, i, v: lax.subvals(x, [(i, v)]) return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) @@ -2717,7 +2706,7 @@ def _index_to_gather(x_shape, idx): def _should_unpack_list_index(x): """Helper for _eliminate_deprecated_list_indexing.""" - return (isinstance(x, ndarray) and onp.ndim(x) != 0 + return (isinstance(x, ndarray) and ndim(x) != 0 or isinstance(x, collections.Sequence) or isinstance(x, slice) or x is Ellipsis or x is None) @@ -2771,7 +2760,7 @@ def _is_advanced_int_indexer(idx): """Returns True if idx should trigger int array indexing, False otherwise.""" # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing assert isinstance(idx, tuple) - if _all(onp.ndim(elt) == 0 for elt in idx): + if _all(ndim(elt) == 0 for elt in idx): return False return _all(e is None or e is Ellipsis or isinstance(e, slice) or _is_int_arraylike(e) for e in idx) @@ -2872,15 +2861,15 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, w = None if fweights is not None: - if onp.ndim(fweights) > 1: + if ndim(fweights) > 1: raise RuntimeError("cannot handle multidimensional fweights") - if onp.shape(fweights)[0] != X.shape[1]: + if shape(fweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and fweights") w = asarray(fweights) if aweights is not None: - if onp.ndim(aweights) > 1: + if ndim(aweights) > 1: raise RuntimeError("cannot handle multidimensional aweights") - if onp.shape(aweights)[0] != X.shape[1]: + if shape(aweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and aweights") w = aweights if w is None else w * aweights @@ -3131,14 +3120,17 @@ def _unimplemented_setitem(self, i, x): _HANDLED_TYPES = _ARRAY_TYPES + (onp.ndarray, numbers.Number) def _implement_via_coercion(func, args, kwargs): - args, kwargs = device_get((args, kwargs)) + args, kwargs = tree_map(onp.asarray, device_get((args, kwargs))) return func(*args, **kwargs) -def _implement_with_warning(func, args, kwargs, name, stacklevel=2): +def _coercion_warning(name, stacklevel=3): warnings.warn("{} is not yet implemented by JAX; coercing arguments to " "NumPy arrays. Coerce arrays with onp.asarray() or " "jax.device_get() to explicitly silence this warning." .format(name), stacklevel=stacklevel) + +def _implement_with_warning(func, args, kwargs, name, stacklevel=3): + _coercion_warning(name, stacklevel) return _implement_via_coercion(func, args, kwargs) @@ -3152,11 +3144,15 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): lax_func = globals().get(ufunc.__name__) if lax_func is None or isinstance(lax_func, _NotImplementedByJAX): - return _implement_with_warning(ufunc_method, inputs, kwargs, ufunc) + if FLAGS.jax_enable_numpy_overrides: + _coercion_warning(ufunc) + return _implement_via_coercion(ufunc_method, inputs, kwargs) if method != '__call__': - name = '{!r} method of NumPy universal functions'.format(method) - return _implement_with_warning(ufunc_method, inputs, kwargs, name) + if FLAGS.jax_enable_numpy_overrides: + name = '{!r} method of NumPy universal functions'.format(method) + _coercion_warning(name) + return _implement_via_coercion(ufunc_method, inputs, kwargs) # We special case the 'out' argument so we can support in-place arithmetic # that assigns to NumPy arrays, e.g., np_array += jax_device_array @@ -3178,6 +3174,9 @@ def __array_function__(self, func, types, args, kwargs): """Override other NumPy functions, per NEP-18.""" from .. import numpy + if not FLAGS.jax_enable_numpy_overrides: + return _implement_via_coercion(func, args, kwargs) + if not _all(issubclass(t, _HANDLED_TYPES) for t in types): return NotImplemented diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3d9b7c7c58bc..a5ec41dc678f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -277,6 +277,7 @@ def __array__(self, *args, **kwargs): array_function_overrides_enabled = onp.atleast_1d(_Dummy()) except Exception: array_function_overrides_enabled = False +array_function_overrides_enabled &= FLAGS.jax_enable_numpy_overrides CombosWithReplacement = itertools.combinations_with_replacement @@ -1698,6 +1699,9 @@ def testIssue728(self): self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) def testArrayUfuncUnary(self): + if not FLAGS.jax_enable_numpy_overrides: + self.skipTest('requires numpy overrides') + lnp_array = lnp.array([1, 2]) onp_array = onp.array([1, 2]) lnp_on_lnp = lnp.sin(lnp_array) @@ -1707,6 +1711,9 @@ def testArrayUfuncUnary(self): self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) def testArrayUfuncWarningsAndErrors(self): + if not FLAGS.jax_enable_numpy_overrides: + self.skipTest('requires numpy overrides') + x = lnp.array([1, 2]) with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): @@ -1728,6 +1735,9 @@ def testArrayUfuncWarningsAndErrors(self): self.assertAllClose(y, x + x, check_dtypes=True) def testArrayUfuncBinary(self): + if not FLAGS.jax_enable_numpy_overrides: + self.skipTest('requires numpy overrides') + x = lnp.array([1, 2]) lnp_expected = lnp.array(x) + 3 onp_expected = onp.array(x) + 3 @@ -1745,6 +1755,9 @@ def testArrayUfuncBinary(self): self.assertAllClose(result, onp_expected, check_dtypes=True) def testArrayUfuncUnhandledType(self): + if not FLAGS.jax_enable_numpy_overrides: + self.skipTest('requires numpy overrides') + class Other(object): def __array_ufunc__(self, *args, **kwargs): return 'success' @@ -1756,6 +1769,7 @@ def __array_ufunc__(self, *args, **kwargs): def testArrayFunction(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') + onp_array = onp.array([1, 2]) lnp_array = lnp.array([1, 2]) @@ -1769,9 +1783,26 @@ def testArrayFunction(self): self.assertEqual(type(onp_on_mixed), type(lnp_on_lnp)) self.assertAllClose(onp_on_mixed, onp_on_onp, check_dtypes=True) + def testArrayFunctionNotEnabled(self): + if array_function_overrides_enabled: + self.skipTest('__array_function__ overrides enabled') + + onp_array = onp.array([1, 2]) + lnp_array = lnp.array([1, 2]) + + onp_on_onp = onp.concatenate([onp_array] * 2) + onp_on_lnp = onp.concatenate([lnp_array] * 2) + self.assertEqual(type(onp_on_lnp), type(onp_on_onp)) + self.assertAllClose(onp_on_lnp, onp_on_onp, check_dtypes=True) + + onp_on_mixed = onp.concatenate([onp_array, lnp_array]) + self.assertEqual(type(onp_on_mixed), type(onp_on_onp)) + self.assertAllClose(onp_on_mixed, onp_on_onp, check_dtypes=True) + def testArrayFunctionSubModule(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') + onp_array = onp.array([[1, 2], [2, 1]]) lnp_array = lnp.array([[1, 2], [2, 1]]) @@ -1784,12 +1815,16 @@ def testArrayFunctionSubModule(self): def testArrayFunctionUnimplemented(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') + lnp_array = lnp.sqrt(onp.arange(4)) with self.assertWarnsRegex(UserWarning, 'not yet implemented by JAX'): actual = onp.unique(lnp_array) self.assertAllClose(actual, lnp_array, check_dtypes=False) def testArrayFunctionUnhandledType(self): + if not array_function_overrides_enabled: + self.skipTest('__array_function__ overrides not enabled') + class Other(object): def __array_function__(self, *args, **kwargs): return 'success' From 5d7525735fd047a2628b5d42e97a33ed4b206e48 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 18 Sep 2019 18:19:08 -0700 Subject: [PATCH 18/30] Cleanup --- jax/numpy/lax_numpy.py | 2 +- jax/test_util.py | 20 ++++++++++++++++++++ tests/lax_numpy_test.py | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 8c5a17685c90..887c5088258a 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2745,7 +2745,7 @@ def _expand_bool_indices(idx): "argument to a jit or vmap function).") raise IndexError(msg) else: - out.extend(onp.where(i)) + out.extend(onp.where(onp.asarrray(i))) else: out.append(i) return tuple(out) diff --git a/jax/test_util.py b/jax/test_util.py index 8399ac448013..a052cbb45611 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -16,10 +16,12 @@ from __future__ import division from __future__ import print_function +import contextlib import functools import re import itertools as it import os +import sys from unittest import SkipTest from absl.testing import absltest @@ -567,3 +569,21 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, lax_ans = lax_op(*args) self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=tol, rtol=tol) + + if sys.version_info.major < 3: + @contextlib.contextmanager + def assertWarnsRegex(self, expected_warning, expected_regex): + """Minimal backport of Python 3's assertWarnsRegex. + + Only works as a context manager. + """ + with warnings.catch_warnings(record=True) as warnings_list: + warnings.simplefilter("always") + yield + if not any( + issubclass(warning.category, expected_warning) and + re.match(expected_regex, str(warning.message)) + for warning in warnings_list + ): + self.fail("{} with message {!r} not found in triggered warnings: {}" + .format(expected_warning, expected_regex, warnings)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a5ec41dc678f..99994e2e3902 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1331,7 +1331,7 @@ def testRot90(self, shape, dtype, k, axes, rng): ('_lax_on_lax', lnp, lnp), ) def testInplaceArithmetic(self, inplace_mod, other_mod): - x = inplace_mod.zeros(3) + x = inplace_mod.zeros(3, dtype=onp.intp) x += other_mod.arange(3) self.assertAllClose(x, onp.arange(3), check_dtypes=True) From 02ce11d6797ec99e540c7341f7b0960999b67ea6 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 19 Sep 2019 07:46:31 -0700 Subject: [PATCH 19/30] spelling --- jax/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 887c5088258a..684dbf0c51d6 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2745,7 +2745,7 @@ def _expand_bool_indices(idx): "argument to a jit or vmap function).") raise IndexError(msg) else: - out.extend(onp.where(onp.asarrray(i))) + out.extend(onp.where(onp.asarray(i))) else: out.append(i) return tuple(out) From cb57ab64df5cf72478eb1a3f5fa8197211b73d6f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 19 Sep 2019 09:30:21 -0700 Subject: [PATCH 20/30] add missing import --- jax/test_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/test_util.py b/jax/test_util.py index a052cbb45611..89db85f32212 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -23,6 +23,7 @@ import os import sys from unittest import SkipTest +import warnings from absl.testing import absltest from absl.testing import parameterized From 71bc3fe5154f2111959bd9620283f934d9f39cec Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 24 Sep 2019 19:54:05 -0700 Subject: [PATCH 21/30] coercion blacklist --- jax/core.py | 2 +- jax/numpy/lax_numpy.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jax/core.py b/jax/core.py index 447711a9008d..fbe0582821b4 100644 --- a/jax/core.py +++ b/jax/core.py @@ -264,7 +264,7 @@ def __array__(self): from .numpy.lax_numpy import numpy_version msg = ("Tracer can't be used with functions that convert their arguments " "into raw NumPy arrays.") - if numpy_version < (1, 17): + if numpy_version < (1, 17) or not FLAGS.jax_enable_numpy_overrides: msg += " You might have\n import numpy as np\ninstead of\n import jax.numpy as np" raise Exception(msg) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 684dbf0c51d6..39c25e267dc1 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -3170,11 +3170,24 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return lax_func(*inputs) +# These NumPy functions don't need coercion on JAX arrays. We can always use +# our JAX implementations. +_COERCION_BLACKLIST = frozenset([ + onp.result_type, + onp.ndim, + onp.shape, + onp.iscomplexobj, + onp.size, + onp.transpose, + onp.moveaxis, +]) + + def __array_function__(self, func, types, args, kwargs): """Override other NumPy functions, per NEP-18.""" from .. import numpy - if not FLAGS.jax_enable_numpy_overrides: + if not FLAGS.jax_enable_numpy_overrides and func not in _COERCION_BLACKLIST: return _implement_via_coercion(func, args, kwargs) if not _all(issubclass(t, _HANDLED_TYPES) for t in types): From 7d94dfa88b59847f46936ad2956c9645c5fbe884 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 24 Sep 2019 21:06:26 -0700 Subject: [PATCH 22/30] Fix flags import --- jax/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/core.py b/jax/core.py index fbe0582821b4..4699bc193e37 100644 --- a/jax/core.py +++ b/jax/core.py @@ -26,6 +26,7 @@ import six from . import linear_util as lu +from .config import flags from .util import safe_zip, safe_map, partial, curry from .pprint_util import pp, vcat, hcat, pp_kv_pairs @@ -264,7 +265,7 @@ def __array__(self): from .numpy.lax_numpy import numpy_version msg = ("Tracer can't be used with functions that convert their arguments " "into raw NumPy arrays.") - if numpy_version < (1, 17) or not FLAGS.jax_enable_numpy_overrides: + if numpy_version < (1, 17) or not flags.FLAGS.jax_enable_numpy_overrides: msg += " You might have\n import numpy as np\ninstead of\n import jax.numpy as np" raise Exception(msg) From ee57b15bf2b4e612975cb25c365ac5511e0684da Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 25 Sep 2019 17:37:42 -0700 Subject: [PATCH 23/30] fixup --- jax/numpy/lax_numpy.py | 1 + jax/test_util.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 39c25e267dc1..d415ec167b67 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -3177,6 +3177,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): onp.ndim, onp.shape, onp.iscomplexobj, + onp.reshape, onp.size, onp.transpose, onp.moveaxis, diff --git a/jax/test_util.py b/jax/test_util.py index 89db85f32212..cb71356b9ca0 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -587,4 +587,4 @@ def assertWarnsRegex(self, expected_warning, expected_regex): for warning in warnings_list ): self.fail("{} with message {!r} not found in triggered warnings: {}" - .format(expected_warning, expected_regex, warnings)) + .format(expected_warning, expected_regex, warnings_list)) From 130306800f897607f0216ae925b555c419ee2fc7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 19:33:48 -0700 Subject: [PATCH 24/30] fix warnings on Python 2.7 --- jax/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/test_util.py b/jax/test_util.py index cb71356b9ca0..ede74ae9f830 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -583,7 +583,7 @@ def assertWarnsRegex(self, expected_warning, expected_regex): yield if not any( issubclass(warning.category, expected_warning) and - re.match(expected_regex, str(warning.message)) + re.search(expected_regex, str(warning.message.message)) for warning in warnings_list ): self.fail("{} with message {!r} not found in triggered warnings: {}" From dc5989f4b05b749850005bb3dd23f66d1d2428f3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 19:43:49 -0700 Subject: [PATCH 25/30] revert lax_numpy changes --- jax/numpy/lax_numpy.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 0e04931641ea..403cb9ba5acf 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -123,9 +123,7 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, # pylint: enable=invalid-name isscalar = onp.isscalar -result_type = _dtype = lax.dtype -_shape = shape = lax.shape -_ndim = ndim = lax.ndim +_dtype = lax.dtype bool_ = onp.bool_ uint8 = onp.uint8 @@ -167,6 +165,24 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, # These wrappers are necessary to avoid infinite recursion inside JAX's # __array_function__ method. The implementations below exactly match NumPy. +@functools.wraps(onp.ndim) +def ndim(x): + if isinstance(x, _ARRAY_TYPES): + return x.ndim + return onp.ndim(x) + +@functools.wraps(onp.shape) +def shape(x): + if isinstance(x, _ARRAY_TYPES): + return x.shape + return onp.shape(x) + +@functools.wraps(onp.result_type) +def result_type(*arrays_and_dtypes): + return onp.result_type( + x.dtype is isinstance(x, _ARRAY_TYPES) else x for x in arrays_and_dtypes + ) + @functools.wraps(onp.iscomplexobj) def iscomplexobj(x): if isinstance(x, _ARRAY_TYPES): @@ -246,7 +262,7 @@ def _promote_dtypes(*args): if len(args) < 2: return args else: - from_dtypes = tuple(map(_dtype, args)) + from_dtypes = map(_dtype, args) to_dtype = xla_bridge.canonicalize_dtype(result_type(*from_dtypes)) return [lax.convert_element_type(x, to_dtype) if _dtype(x) != to_dtype else x for x in args] @@ -899,7 +915,7 @@ def where(condition, x=None, y=None): if not onp.issubdtype(_dtype(condition), onp.bool_): condition = lax.ne(condition, zeros_like(condition)) condition, x, y = broadcast_arrays(condition, x, y) - if not size(x): + if not onp.size(x): empty, _ = _promote_dtypes(x, y) return empty else: @@ -956,7 +972,7 @@ def broadcast_to(arr, shape): def split(ary, indices_or_sections, axis=0): dummy_val = onp.broadcast_to(0, ary.shape) # zero strides subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes - split_indices = onp.cumsum([0] + [shape(sub)[axis] for sub in subarrays]) + split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays]) starts, ends = [0] * ndim(ary), shape(ary) _subval = lambda x, i, v: lax.subvals(x, [(i, v)]) return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) @@ -2772,7 +2788,7 @@ def _index_to_gather(x_shape, idx): def _should_unpack_list_index(x): """Helper for _eliminate_deprecated_list_indexing.""" - return (isinstance(x, ndarray) and ndim(x) != 0 + return (isinstance(x, ndarray) and onp.ndim(x) != 0 or isinstance(x, collections.Sequence) or isinstance(x, slice) or x is Ellipsis or x is None) @@ -2826,7 +2842,7 @@ def _is_advanced_int_indexer(idx): """Returns True if idx should trigger int array indexing, False otherwise.""" # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing assert isinstance(idx, tuple) - if _all(ndim(elt) == 0 for elt in idx): + if _all(onp.ndim(elt) == 0 for elt in idx): return False return _all(e is None or e is Ellipsis or isinstance(e, slice) or _is_int_arraylike(e) for e in idx) @@ -2928,15 +2944,15 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, w = None if fweights is not None: - if ndim(fweights) > 1: + if onp.ndim(fweights) > 1: raise RuntimeError("cannot handle multidimensional fweights") - if shape(fweights)[0] != X.shape[1]: + if onp.shape(fweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and fweights") w = asarray(fweights) if aweights is not None: - if ndim(aweights) > 1: + if onp.ndim(aweights) > 1: raise RuntimeError("cannot handle multidimensional aweights") - if shape(aweights)[0] != X.shape[1]: + if onp.shape(aweights)[0] != X.shape[1]: raise RuntimeError("incompatible numbers of samples and aweights") w = aweights if w is None else w * aweights From 5aa2dcfe67044c3664a577b8c56c52126b1bc2b8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 19:46:12 -0700 Subject: [PATCH 26/30] Revert lax/lax.py changes --- jax/lax/lax.py | 124 ++++++++++++++++++++----------------------------- 1 file changed, 51 insertions(+), 73 deletions(-) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 8f3e0fef680a..785281658238 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -95,20 +95,6 @@ def _canonicalize_shape(shape): def _identity(x): return x - -def shape(a): - try: - return a.shape - except AttributeError: - return onp.shape(a) - -def ndim(a): - try: - return a.ndim - except AttributeError: - return onp.ndim(a) - - ### traceables def neg(x): @@ -513,9 +499,9 @@ def dot(lhs, rhs, precision=None): # TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of # products instead. if onp.issubdtype(lhs.dtype, onp.integer): - lhs_shape = shape(lhs) + lhs_shape = onp.shape(lhs) lhs_ndim = len(lhs_shape) - rhs_ndim = ndim(rhs) + rhs_ndim = onp.ndim(rhs) if rhs_ndim > 1: lhs = broadcast_in_dim(lhs, lhs_shape + (1,), tuple(range(len(lhs_shape)))) if lhs_ndim > 1: @@ -552,17 +538,17 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): lhs_contract_dims, rhs_contract_dims = contract_dims lhs_batch_dims, rhs_batch_dims = batch_dims lhs_noncontract_dims = tuple(sorted( - set(range(ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) + set(range(onp.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims))) rhs_noncontract_dims = tuple(sorted( - set(range(ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) + set(range(onp.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims))) lhs = transpose(lhs, lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims) rhs = transpose(rhs, rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims) new_lhs_shape = onp.insert( - shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims), - (1,) * len(rhs_noncontract_dims)) - new_rhs_shape = onp.insert(shape(rhs), len(lhs_batch_dims), + onp.shape(lhs), len(lhs_batch_dims) + len(lhs_noncontract_dims), + (1,) * len(rhs_noncontract_dims)) + new_rhs_shape = onp.insert(onp.shape(rhs), len(lhs_batch_dims), (1,) * len(lhs_noncontract_dims)) lhs = reshape(lhs, new_lhs_shape) rhs = reshape(rhs, new_rhs_shape) @@ -610,15 +596,15 @@ def reshape(operand, new_sizes, dimensions=None): """ new_sizes = _canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) - same_shape = shape(operand) == new_sizes - same_dims = dimensions is None or tuple(dimensions) == tuple(range(ndim(operand))) - if shape(operand) and same_shape and same_dims: + same_shape = onp.shape(operand) == new_sizes + same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand))) + if onp.shape(operand) and same_shape and same_dims: return operand else: return reshape_p.bind( operand, new_sizes=new_sizes, dimensions=None if same_dims else tuple(dimensions), - old_sizes=shape(operand)) + old_sizes=onp.shape(operand)) def pad(operand, padding_value, padding_config): """Wraps XLA's `Pad @@ -913,7 +899,7 @@ def _get_min_identity(dtype): def _reduce_sum(operand, axes): return reduce_sum_p.bind(operand, axes=tuple(axes), - input_shape=shape(operand)) + input_shape=onp.shape(operand)) def _reduce_prod(operand, axes): return reduce_prod_p.bind(operand, axes=tuple(axes)) @@ -1041,9 +1027,9 @@ def full(shape, fill_value, dtype=None): "`static_argnums` or applying `jit` to smaller subfunctions instead.") raise TypeError(msg) - if _shape(fill_value): + if onp.shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." - raise TypeError(msg.format(_shape(fill_value))) + raise TypeError(msg.format(onp.shape(fill_value))) dtype = dtype or _dtype(fill_value) dtype = xla_bridge.canonicalize_dtype(dtype) @@ -1267,7 +1253,7 @@ def full_like(x, fill_value, dtype=None, shape=None): An ndarray with the same shape as `x` with its entries set equal to `fill_value`, similar to the output of np.full. """ - shape = _shape(x) if shape is None else _canonicalize_shape(shape) + shape = onp.shape(x) if shape is None else _canonicalize_shape(shape) out = full(shape, fill_value, dtype or _dtype(x)) return tie_in(x, out) @@ -1538,16 +1524,16 @@ def _brcast(x, *others): # Requires shape info during jvp tracing, which isn't strictly necessary. # We don't need full numpy broadcasting, but otherwise the logic is the same # so we reuse the broadcast_shapes function after filtering out scalars. - shapes = tuple(filter(None, map(_shape, (x,) + others))) + shapes = tuple(filter(None, map(onp.shape, (x,) + others))) shape = shapes and broadcast_shapes(*shapes) - if _shape(x) != shape: + if onp.shape(x) != shape: return _brcast_to(x, shape) else: return x def _brcast_to(x, shape): - x_shape = _shape(x) + x_shape = onp.shape(x) assert x_shape != shape if x_shape: assert len(x_shape) == len(shape) @@ -2061,51 +2047,51 @@ def require(shape_cond): return lhs.shape[:-1] + rhs.shape[:-2] + rhs.shape[-1:] def _dot_transpose_lhs(t, rhs, precision): - if ndim(t) == ndim(rhs) == 2: + if onp.ndim(t) == onp.ndim(rhs) == 2: return dot(t, transpose(rhs, (1, 0)), precision=precision) - elif ndim(t) == 1 and ndim(rhs) == 2: + elif onp.ndim(t) == 1 and onp.ndim(rhs) == 2: return dot(rhs, t, precision=precision) - elif ndim(t) == ndim(rhs) == 1: + elif onp.ndim(t) == onp.ndim(rhs) == 1: return _outer(t, rhs) - elif ndim(t) == 0 or ndim(rhs) == 0: + elif onp.ndim(t) == 0 or onp.ndim(rhs) == 0: return mul(t, rhs) else: raise TypeError def _dot_transpose_rhs(t, lhs, precision): - if ndim(lhs) == ndim(t) == 2: + if onp.ndim(lhs) == onp.ndim(t) == 2: return dot(transpose(lhs, (1, 0)), t) - elif ndim(lhs) == 2 and ndim(t) == 1: + elif onp.ndim(lhs) == 2 and onp.ndim(t) == 1: return dot(t, lhs, precision=precision) - elif ndim(t) == ndim(lhs) == 1: + elif onp.ndim(t) == onp.ndim(lhs) == 1: return _outer(lhs, t) - elif ndim(t) == 0 or ndim(lhs) == 0: + elif onp.ndim(t) == 0 or onp.ndim(lhs) == 0: return mul(t, lhs) else: raise TypeError def _outer(x, y): - assert ndim(x) == ndim(y) == 1 + assert onp.ndim(x) == onp.ndim(y) == 1 return mul(reshape(x, (x.shape[0], 1)), reshape(y, (1, y.shape[0]))) def _dot_batch_rule(batched_args, batch_dims, precision=None): lhs, rhs = batched_args lbd, rbd = batch_dims - T = lambda x: transpose(x, onp.arange(ndim(x))[::-1]) + T = lambda x: transpose(x, onp.arange(onp.ndim(x))[::-1]) # in some cases, we can call dot instead of dot_general - if max(ndim(lhs), ndim(rhs)) <= 2: + if max(onp.ndim(lhs), onp.ndim(rhs)) <= 2: if rbd is None: assert lbd in (0, 1) if lbd == 0: return dot(lhs, rhs, precision=precision), 0 else: - return dot(T(rhs), lhs, precision=precision), ndim(rhs) - 1 + return dot(T(rhs), lhs, precision=precision), onp.ndim(rhs) - 1 if lbd is None: assert rbd in (0, 1) - if rbd == ndim(rhs) - 1: - return dot(lhs, rhs, precision=precision), ndim(lhs) - 1 + if rbd == onp.ndim(rhs) - 1: + return dot(lhs, rhs, precision=precision), onp.ndim(lhs) - 1 else: return dot(rhs, T(lhs), precision=precision), 0 @@ -2123,7 +2109,7 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None): else: lhs = batching.moveaxis(lhs, lbd, 0) lhs_batch = (0,) - lhs_contracting = (ndim(lhs) - 1,) + lhs_contracting = (onp.ndim(lhs) - 1,) if rbd is None: assert lbd is not None @@ -2131,7 +2117,7 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None): else: rhs = batching.moveaxis(rhs, rbd, 0) rhs_batch = (0,) - rhs_contracting = (onp.arange(1, ndim(rhs))[-2:][0],) + rhs_contracting = (onp.arange(1, onp.ndim(rhs))[-2:][0],) dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)] return dot_general(lhs, rhs, dim_nums, precision=precision), 0 @@ -2531,14 +2517,14 @@ def _reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs): if not onp.all(onp.greater_equal(new_sizes, 0)): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) - if prod(shape(operand)) != prod(new_sizes): + if prod(onp.shape(operand)) != prod(new_sizes): msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.' - raise TypeError(msg.format(new_sizes, shape(operand))) + raise TypeError(msg.format(new_sizes, onp.shape(operand))) if dimensions is not None: - if set(dimensions) != set(range(ndim(operand))): + if set(dimensions) != set(range(onp.ndim(operand))): msg = ('reshape dimensions must be a permutation of operand dimensions, ' 'got dimensions {} for shape {}.') - raise TypeError(msg.format(dimensions, shape(operand))) + raise TypeError(msg.format(dimensions, onp.shape(operand))) return tuple(new_sizes) def _reshape_dtype_rule(operand, new_sizes, dimensions, **unused_kwargs): @@ -2659,31 +2645,31 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): # avoid transposes and some broadcasts in special cases if pred_bdim == ot_bdim == of_bdim: - if shape(pred) == shape(on_true): + if onp.shape(pred) == onp.shape(on_true): return select(pred, on_true, on_false), pred_bdim else: # vmapped function had a scalar pred with nonscalar args - assert ndim(pred) == 1 + assert onp.ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim]) return select(pred, on_true, on_false), pred_bdim - elif ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: + elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None: if ot_bdim == of_bdim: return select(pred, on_true, on_false), ot_bdim - elif shape(on_true) == shape(on_false): + elif onp.shape(on_true) == onp.shape(on_false): on_false = batching.moveaxis(on_false, of_bdim, ot_bdim) return select(pred, on_true, on_false), ot_bdim - pred = batching.bdim_at_front(pred, pred_bdim, size) if shape(pred) else pred - if not shape(on_true) == shape(on_false) == (): + pred = batching.bdim_at_front(pred, pred_bdim, size) if onp.shape(pred) else pred + if not onp.shape(on_true) == onp.shape(on_false) == (): on_true = batching.bdim_at_front(on_true, ot_bdim, size) on_false = batching.bdim_at_front(on_false, of_bdim, size) - assert shape(on_true) == shape(on_false) - if 0 < ndim(pred) < ndim(on_true): + assert onp.shape(on_true) == onp.shape(on_false) + if 0 < onp.ndim(pred) < onp.ndim(on_true): # vmapped function had a scalar pred with nonscalar args - assert ndim(pred) == 1 + assert onp.ndim(pred) == 1 pred = broadcast_in_dim(pred, on_true.shape, [0]) - if ndim(pred) > ndim(on_true): - assert ndim(on_true) == 0 + if onp.ndim(pred) > onp.ndim(on_true): + assert onp.ndim(on_true) == 0 on_true = broadcast(on_true, pred.shape) on_false = broadcast(on_false, pred.shape) return select(pred, on_true, on_false), 0 @@ -4164,8 +4150,7 @@ def _stop_gradient_batch_rule(batched_args, batch_dims): ### util -_ndim = ndim -_shape = shape +_ndim = onp.ndim def _dilate_shape(shape, dilation): @@ -4339,14 +4324,7 @@ def _const(example, val): _twos = partial(full_like, fill_value=2) _two = partial(full_like, shape=(), fill_value=2) - -def dtype(*arrays_and_dtypes): - arrays_and_dtypes = tuple( - x if isinstance(x, type) else getattr(x, 'dtype', x) - for x in arrays_and_dtypes) - return onp.result_type(*arrays_and_dtypes) - -_dtype = dtype +_dtype = dtype = onp.result_type _iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating) From 0f9ddb999fa43a17a1dd6590f9e6ff0faf31795f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 19:53:01 -0700 Subject: [PATCH 27/30] fixups --- jax/numpy/lax_numpy.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 403cb9ba5acf..f86a36c0b5f5 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -171,17 +171,21 @@ def ndim(x): return x.ndim return onp.ndim(x) +_ndim = ndim + @functools.wraps(onp.shape) def shape(x): if isinstance(x, _ARRAY_TYPES): return x.shape return onp.shape(x) +_shape = shape + @functools.wraps(onp.result_type) def result_type(*arrays_and_dtypes): - return onp.result_type( - x.dtype is isinstance(x, _ARRAY_TYPES) else x for x in arrays_and_dtypes - ) + arrays_and_dtypes = [x.dtype if isinstance(x, _ARRAY_TYPES) + else x for x in arrays_and_dtypes] + return onp.result_type(*arrays_and_dtypes) @functools.wraps(onp.iscomplexobj) def iscomplexobj(x): From bd097f2734c95cbca71281432515cbad093b4e55 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 20:34:14 -0700 Subject: [PATCH 28/30] use aval for __array_function__ --- jax/core.py | 4 ++-- jax/numpy/lax_numpy.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/jax/core.py b/jax/core.py index 4699bc193e37..f96beb6d8266 100644 --- a/jax/core.py +++ b/jax/core.py @@ -334,11 +334,11 @@ def __oct__(self): return self.aval._oct(self) # so these need to be defined here rather than on UnshapedArray. def __array_ufunc__(self, *args, **kwargs): from .numpy.lax_numpy import __array_ufunc__ - return __array_ufunc__(self, *args, **kwargs) + return self.aval.__array_ufunc__(*args, **kwargs) def __array_function__(self, *args, **kwargs): from .numpy.lax_numpy import __array_function__ - return __array_function__(self, *args, **kwargs) + return self.aval.__array_function__(*args, **kwargs) def __setitem__(self, idx, val): raise TypeError("JAX 'Tracer' objects do not support item assignment") diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index f86a36c0b5f5..4d2e5e356e52 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -167,7 +167,7 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, @functools.wraps(onp.ndim) def ndim(x): - if isinstance(x, _ARRAY_TYPES): + if isinstance(x, _JAX_ARRAY_TYPES): return x.ndim return onp.ndim(x) @@ -175,7 +175,7 @@ def ndim(x): @functools.wraps(onp.shape) def shape(x): - if isinstance(x, _ARRAY_TYPES): + if isinstance(x, _JAX_ARRAY_TYPES): return x.shape return onp.shape(x) @@ -183,19 +183,19 @@ def shape(x): @functools.wraps(onp.result_type) def result_type(*arrays_and_dtypes): - arrays_and_dtypes = [x.dtype if isinstance(x, _ARRAY_TYPES) + arrays_and_dtypes = [x.dtype if isinstance(x, _JAX_ARRAY_TYPES) else x for x in arrays_and_dtypes] return onp.result_type(*arrays_and_dtypes) @functools.wraps(onp.iscomplexobj) def iscomplexobj(x): - if isinstance(x, _ARRAY_TYPES): + if isinstance(x, _JAX_ARRAY_TYPES): return issubclass(x.dtype.type, onp.complexfloating) return onp.iscomplexobj(x) @functools.wraps(onp.size) def size(a, axis=None): - if isinstance(a, _ARRAY_TYPES): + if isinstance(a, _JAX_ARRAY_TYPES): if axis is None: return a.size else: @@ -218,7 +218,7 @@ def save(file, arr, allow_pickle=True, fix_imports=True): fix_imports=fix_imports) def _cast_if_needed(x): - return onp.asarray(x) if isinstance(x, _ARRAY_TYPES) else x + return onp.asarray(x) if isinstance(x, _JAX_ARRAY_TYPES) else x @functools.wraps(onp.savez) def savez(file, *args, **kwds): @@ -3203,8 +3203,8 @@ def _unimplemented_setitem(self, i, x): # Override NumPy's public API. -_ARRAY_TYPES = (DeviceArray, core.Tracer) -_HANDLED_TYPES = _ARRAY_TYPES + (onp.ndarray, numbers.Number) +_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer) +_HANDLED_TYPES = _JAX_ARRAY_TYPES + (onp.ndarray, numbers.Number) def _implement_via_coercion(func, args, kwargs): args, kwargs = tree_map(onp.asarray, device_get((args, kwargs))) @@ -3245,7 +3245,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # that assigns to NumPy arrays, e.g., np_array += jax_device_array out = kwargs.pop('out', None) if out is not None: - if _any(isinstance(o, _ARRAY_TYPES) for o in out): + if _any(isinstance(o, _JAX_ARRAY_TYPES) for o in out): raise TypeError("JAX arrays cannot be modified inplace.") inputs = device_get(inputs) return ufunc_method(*inputs, out=out, **kwargs) @@ -3309,6 +3309,9 @@ def __array_function__(self, func, types, args, kwargs): setattr(DeviceArray, '__array_ufunc__', __array_ufunc__) setattr(DeviceArray, '__array_function__', __array_function__) +setattr(ShapedArray, '__array_ufunc__', __array_ufunc__) +setattr(ShapedArray, '__array_function__', __array_function__) + # Extra methods that are handy setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) From 0ec6dadb56f172f7e32336612f23a17ec4145e11 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 20:41:11 -0700 Subject: [PATCH 29/30] add a jit override test --- jax/numpy/lax_numpy.py | 2 +- tests/lax_numpy_test.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 4d2e5e356e52..e0dd4abbfd61 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -3203,7 +3203,7 @@ def _unimplemented_setitem(self, i, x): # Override NumPy's public API. -_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer) +_JAX_ARRAY_TYPES = (DeviceArray, ShapedArray, core.Tracer) _HANDLED_TYPES = _JAX_ARRAY_TYPES + (onp.ndarray, numbers.Number) def _implement_via_coercion(func, args, kwargs): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 8237905e48f2..c948ecfb116b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1799,6 +1799,14 @@ def __array_ufunc__(self, *args, **kwargs): result = onp.add(x, Other()) self.assertEqual(result, 'success') + def testArrayUfuncTracer(self): + + if not FLAGS.jax_enable_numpy_overrides: + self.skipTest('requires numpy overrides') + + actual = jax.jit(onp.sin)(1.0) + self.assertAllClose(actual, onp.sin(1.0), check_dtypes=False) + def testArrayFunction(self): if not array_function_overrides_enabled: self.skipTest('__array_function__ overrides not enabled') From 81754eca296b859a840146fb9927802f1d44c3f9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Sep 2019 20:42:02 -0700 Subject: [PATCH 30/30] Remove duplicate test --- tests/lax_numpy_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c948ecfb116b..21ba773817fe 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1727,10 +1727,6 @@ def testArange(self): self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == type(lax.iota(onp.int32, 77))) - def testIssue728(self): - assert lnp.allclose(lnp.eye(5000), onp.eye(5000)) - self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) - def testArrayUfuncUnary(self): if not FLAGS.jax_enable_numpy_overrides: self.skipTest('requires numpy overrides')