Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement overrides of NumPy's public API on JAX arrays #611

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6fccef3
Implement overrides of NumPy's public API on JAX arrays
shoyer Apr 13, 2019
7511ef0
Fix misc. test failures
shoyer Apr 13, 2019
6329907
More test coverage for __array_function__
shoyer Apr 13, 2019
42a51ea
Less nesting in __array_function__
shoyer Apr 13, 2019
5650011
Add __array_ufunc__ and __array_function__ to Tracer
shoyer Apr 13, 2019
0c9e50c
Move __array_ufunc__ and __array_function__ to UnshapedArray
shoyer Apr 13, 2019
f838721
Fixes to support the neural net example from the readme
shoyer Apr 13, 2019
3bfe51b
Fix API test
shoyer Apr 16, 2019
10067a9
Fix failing test with inplace arithmetic
shoyer Apr 26, 2019
af564e1
Merge branch 'master' into numpy-api-overrides
shoyer Apr 26, 2019
f7c2c27
Fix for CheckChiSquared
shoyer Apr 26, 2019
6247335
Merge branch 'master' into numpy-api-overrides
shoyer May 21, 2019
a6e0308
Merge branch 'master' into numpy-api-overrides
shoyer Jul 30, 2019
d8d7629
Ensure __array_function__ is only used by numpy
shoyer Jul 30, 2019
903c50c
Test fixes
shoyer Jul 31, 2019
15f0164
Merge branch 'numpy-api-overrides' of github.com:shoyer/jax into nump…
shoyer Jul 31, 2019
eb6e28d
Merge branch 'master' into numpy-api-overrides
shoyer Aug 2, 2019
c5215b8
tweak
shoyer Aug 2, 2019
45a6723
Explicit wrappers for aliases to numpy functions
shoyer Aug 3, 2019
951bd5b
Warn instead of erroring for unimplemented NumPy functions
shoyer Aug 3, 2019
195d176
Merge branch 'master' into numpy-api-overrides
shoyer Sep 15, 2019
feff162
cleanup
shoyer Sep 16, 2019
61d0866
flags for disabling numpy overrides
shoyer Sep 16, 2019
5d75257
Cleanup
shoyer Sep 19, 2019
02ce11d
spelling
shoyer Sep 19, 2019
cb57ab6
add missing import
shoyer Sep 19, 2019
71bc3fe
coercion blacklist
shoyer Sep 25, 2019
7d94dfa
Fix flags import
shoyer Sep 25, 2019
ee57b15
fixup
shoyer Sep 26, 2019
1303068
fix warnings on Python 2.7
shoyer Sep 30, 2019
7f1b497
Merge branch 'master' into numpy-api-overrides
shoyer Sep 30, 2019
dc5989f
revert lax_numpy changes
shoyer Sep 30, 2019
5aa2dcf
Revert lax/lax.py changes
shoyer Sep 30, 2019
0f9ddb9
fixups
shoyer Sep 30, 2019
bd097f2
use aval for __array_function__
shoyer Sep 30, 2019
0ec6dad
add a jit override test
shoyer Sep 30, 2019
81754ec
Remove duplicate test
shoyer Sep 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python:
- "3.6"
env:
- JAX_ENABLE_X64=0 JAX_NUM_GENERATED_CASES=25
- JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25
- JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 JAX_ENABLE_NUMPY_OVERRIDES=1 NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1
before_install:
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;
Expand Down
7 changes: 6 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,12 @@ def device_put(x, device_num=0, backend=None):
def _device_get(x):
if isinstance(x, core.Tracer):
return x
return x.copy()
try:
copy = x.copy
except AttributeError:
return x
else:
return copy()

def device_get(x):
for y in tree_leaves(x):
Expand Down
20 changes: 18 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import six

from . import linear_util as lu
from .config import flags
from .util import safe_zip, safe_map, partial, curry
from .pprint_util import pp, vcat, hcat, pp_kv_pairs

Expand Down Expand Up @@ -261,8 +262,12 @@ class Tracer(object):
__slots__ = ['trace', '__weakref__']

def __array__(self):
raise Exception("Tracer can't be used with raw numpy functions. "
"You might have\n import numpy as np\ninstead of\n import jax.numpy as np")
from .numpy.lax_numpy import numpy_version
msg = ("Tracer can't be used with functions that convert their arguments "
"into raw NumPy arrays.")
if numpy_version < (1, 17) or not flags.FLAGS.jax_enable_numpy_overrides:
msg += " You might have\n import numpy as np\ninstead of\n import jax.numpy as np"
raise Exception(msg)

def __init__(self, trace):
self.trace = trace
Expand Down Expand Up @@ -324,6 +329,17 @@ def __complex__(self): return self.aval._complex(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)

# Like Python's checks for arithmetic special methods, NumPy doesn't check
# instance attributes when looking up __array_ufunc__ and __array_function__,
# so these need to be defined here rather than on UnshapedArray.
shoyer marked this conversation as resolved.
Show resolved Hide resolved
def __array_ufunc__(self, *args, **kwargs):
from .numpy.lax_numpy import __array_ufunc__
return self.aval.__array_ufunc__(*args, **kwargs)

def __array_function__(self, *args, **kwargs):
from .numpy.lax_numpy import __array_function__
return self.aval.__array_function__(*args, **kwargs)

def __setitem__(self, idx, val):
raise TypeError("JAX 'Tracer' objects do not support item assignment")

Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .. import lax
from ..lib.xla_bridge import xla_client, canonicalize_dtype
from ..util import get_module_functions
from .lax_numpy import _not_implemented
from .lax_numpy import _NotImplementedByJAX
from .lax_numpy import _wraps
from . import lax_numpy as np

Expand Down Expand Up @@ -71,4 +71,4 @@ def fftn(a, s=None, axes=None, norm=None):

for func in get_module_functions(onp.fft):
if func.__name__ not in globals():
globals()[func.__name__] = _not_implemented(func)
globals()[func.__name__] = _NotImplementedByJAX(func)
228 changes: 206 additions & 22 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

from distutils.util import strtobool
import collections
import functools
import itertools
import numbers
import os
import re
import string
Expand All @@ -41,12 +43,13 @@
import six
from six.moves import builtins, xrange

from jax import jit, device_put, custom_transforms, defjvp
from jax import jit, device_get, device_put, custom_transforms, defjvp
from .. import core
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..config import flags
from ..interpreters.xla import DeviceArray
from .. import lax
from ..tree_util import tree_map
from ..util import partial, get_module_functions, unzip2, prod as _prod
from ..lib import pytree
from ..lib import xla_bridge
Expand All @@ -59,6 +62,14 @@
help=
'Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").')
flags.DEFINE_bool(
'jax_enable_numpy_overrides',
strtobool(os.getenv('JAX_ENABLE_NUMPY_OVERRIDES', 'False')),
help=
"If true, enable overrides of NumPy functions on JAX arrays, e.g., "
"np.sum(jax_array) should return a JAX array rather than a NumPy array. "
"Also requires NumPy 1.17, or NumPy 1.16 with the environment variable "
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1.")

if six.PY3:
def removechars(s, chars):
Expand Down Expand Up @@ -111,13 +122,7 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
" Use jax.numpy.array, or jax.numpy.zeros instead.")
# pylint: enable=invalid-name


isscalar = onp.isscalar
iscomplexobj = onp.iscomplexobj
result_type = onp.result_type
shape = _shape = onp.shape
ndim = _ndim = onp.ndim
size = onp.size
_dtype = lax.dtype

bool_ = onp.bool_
Expand Down Expand Up @@ -154,14 +159,74 @@ def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,

ComplexWarning = onp.ComplexWarning

array_str = onp.array_str
array_repr = onp.array_repr

save = onp.save
savez = onp.savez
load = onp.load


# These wrappers are necessary to avoid infinite recursion inside JAX's
# __array_function__ method. The implementations below exactly match NumPy.

@functools.wraps(onp.ndim)
def ndim(x):
if isinstance(x, _JAX_ARRAY_TYPES):
return x.ndim
return onp.ndim(x)

_ndim = ndim

@functools.wraps(onp.shape)
def shape(x):
if isinstance(x, _JAX_ARRAY_TYPES):
return x.shape
return onp.shape(x)

_shape = shape

@functools.wraps(onp.result_type)
def result_type(*arrays_and_dtypes):
arrays_and_dtypes = [x.dtype if isinstance(x, _JAX_ARRAY_TYPES)
else x for x in arrays_and_dtypes]
return onp.result_type(*arrays_and_dtypes)

@functools.wraps(onp.iscomplexobj)
def iscomplexobj(x):
if isinstance(x, _JAX_ARRAY_TYPES):
return issubclass(x.dtype.type, onp.complexfloating)
return onp.iscomplexobj(x)

@functools.wraps(onp.size)
def size(a, axis=None):
if isinstance(a, _JAX_ARRAY_TYPES):
if axis is None:
return a.size
else:
return a.shape[axis]
return onp.size(a, axis)

@functools.wraps(onp.array_str)
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
return onp.array_str(onp.asarray(a), max_line_width=max_line_width,
precision=precision, suppress_small=suppress_small)

@functools.wraps(onp.array_repr)
def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
return onp.array_repr(onp.asarray(arr), max_line_width=max_line_width,
precision=precision, suppress_small=suppress_small)

@functools.wraps(onp.save)
def save(file, arr, allow_pickle=True, fix_imports=True):
return onp.save(file, onp.asarray(arr), allow_pickle=allow_pickle,
fix_imports=fix_imports)

def _cast_if_needed(x):
return onp.asarray(x) if isinstance(x, _JAX_ARRAY_TYPES) else x

@functools.wraps(onp.savez)
def savez(file, *args, **kwds):
args = tuple(map(_cast_if_needed, args))
kwds = dict(zip(kwds.keys(), map(_cast_if_needed, kwds.values())))
return onp.savez(file, *args, **kwds)


### utility functions

def _promote_shapes(fun_name, *args):
Expand Down Expand Up @@ -829,7 +894,8 @@ def isclose(a, b, rtol=1e-05, atol=1e-08):
else:
return lax.eq(a, b)

numpy_version = tuple(map(int, onp.version.version.split('.')))
# ignore suffixes in NumPy's development version, e.g., 1.17.0.dev0+cc8b978
numpy_version = tuple(map(int, onp.__version__.split('+')[0].split('.')[:3]))
if numpy_version < (1, 14):
# see discussion at https://github.com/numpy/numpy/pull/9720
def _maybe_numpy_1_13_isclose_behavior(a, out):
Expand Down Expand Up @@ -893,7 +959,7 @@ def broadcast_to(arr, shape):
# lax.broadcast and lax.transpose
lax.broadcast_shapes(shape, _shape(arr)) # error checking
nlead = len(shape) - len(_shape(arr))
diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr)))
diff, = onp.nonzero(onp.not_equal(shape[nlead:], _shape(arr)))

new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(onp.delete(onp.arange(len(shape)), new_dims))
Expand Down Expand Up @@ -2765,7 +2831,7 @@ def _expand_bool_indices(idx):
"argument to a jit or vmap function).")
raise IndexError(msg)
else:
out.extend(onp.where(i))
out.extend(onp.where(onp.asarray(i)))
else:
out.append(i)
return tuple(out)
Expand Down Expand Up @@ -3019,17 +3085,24 @@ def _astype(arr, dtype):

### track unimplemented functions

def _not_implemented(fun):
@_wraps(fun)
def wrapped(*args, **kwargs):
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
return wrapped
def _numpy_function_name(fun):
try:
name = "'{}.{}'".format(fun.__module__, fun.__name__)
except AttributeError:
name = "NumPy function {}".format(fun)
return name

class _NotImplementedByJAX(object):
def __init__(self, fun):
self._fun = fun
def __call__(self, *args, **kwargs):
name = _numpy_function_name(fun)
raise NotImplementedError("{} is not yet implemented by JAX".format(name))

# Build a set of all unimplemented NumPy functions.
for func in get_module_functions(onp):
if func.__name__ not in globals():
globals()[func.__name__] = _not_implemented(func)
globals()[func.__name__] = _NotImplementedByJAX(func)


### add method and operator overloads to arraylike classes
Expand Down Expand Up @@ -3129,6 +3202,117 @@ def _unimplemented_setitem(self, i, x):
setattr(DeviceArray, "astype", _astype)


# Override NumPy's public API.
_JAX_ARRAY_TYPES = (DeviceArray, ShapedArray, core.Tracer)
_HANDLED_TYPES = _JAX_ARRAY_TYPES + (onp.ndarray, numbers.Number)

def _implement_via_coercion(func, args, kwargs):
args, kwargs = tree_map(onp.asarray, device_get((args, kwargs)))
return func(*args, **kwargs)

def _coercion_warning(name, stacklevel=3):
warnings.warn("{} is not yet implemented by JAX; coercing arguments to "
"NumPy arrays. Coerce arrays with onp.asarray() or "
"jax.device_get() to explicitly silence this warning."
.format(name), stacklevel=stacklevel)

def _implement_with_warning(func, args, kwargs, name, stacklevel=3):
_coercion_warning(name, stacklevel)
return _implement_via_coercion(func, args, kwargs)


def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Override NumPy ufuncs, per NEP-13."""

if not _all(isinstance(x, _HANDLED_TYPES) for x in inputs):
return NotImplemented

ufunc_method = getattr(ufunc, method)

lax_func = globals().get(ufunc.__name__)
if lax_func is None or isinstance(lax_func, _NotImplementedByJAX):
if FLAGS.jax_enable_numpy_overrides:
_coercion_warning(ufunc)
return _implement_via_coercion(ufunc_method, inputs, kwargs)

if method != '__call__':
if FLAGS.jax_enable_numpy_overrides:
name = '{!r} method of NumPy universal functions'.format(method)
_coercion_warning(name)
return _implement_via_coercion(ufunc_method, inputs, kwargs)

# We special case the 'out' argument so we can support in-place arithmetic
# that assigns to NumPy arrays, e.g., np_array += jax_device_array
out = kwargs.pop('out', None)
if out is not None:
if _any(isinstance(o, _JAX_ARRAY_TYPES) for o in out):
raise TypeError("JAX arrays cannot be modified inplace.")
inputs = device_get(inputs)
return ufunc_method(*inputs, out=out, **kwargs)

if kwargs:
kwargs_str = ', '.join(map(str, kwargs))
warnings.warn("{} keyword argument(s) are ignored when called on JAX "
"arrays: {}".format(ufunc, kwargs_str), stacklevel=2)
return lax_func(*inputs)


# These NumPy functions don't need coercion on JAX arrays. We can always use
# our JAX implementations.
_COERCION_BLACKLIST = frozenset([
onp.result_type,
onp.ndim,
onp.shape,
onp.iscomplexobj,
onp.reshape,
onp.size,
onp.transpose,
onp.moveaxis,
])


def __array_function__(self, func, types, args, kwargs):
"""Override other NumPy functions, per NEP-18."""
from .. import numpy

if not FLAGS.jax_enable_numpy_overrides and func not in _COERCION_BLACKLIST:
return _implement_via_coercion(func, args, kwargs)

if not _all(issubclass(t, _HANDLED_TYPES) for t in types):
return NotImplemented

modules = func.__module__.split('.')
if modules[0] != 'numpy':
return NotImplemented

module = numpy
for submodule in modules[1:]:
try:
module = getattr(module, submodule)
except AttributeError:
lax_func = None
break
else:
lax_func = getattr(module, func.__name__, None)

if lax_func is None or isinstance(lax_func, _NotImplementedByJAX):
return _implement_with_warning(
func, args, kwargs, name=_numpy_function_name(func), stacklevel=4)
elif lax_func is func:
raise AssertionError('{} needs to be defined with a wrapper that does not '
'call the corresponding NumPy function directly on '
'JAX arrays'.format(func))

return lax_func(*args, **kwargs)


setattr(DeviceArray, '__array_ufunc__', __array_ufunc__)
setattr(DeviceArray, '__array_function__', __array_function__)

setattr(ShapedArray, '__array_ufunc__', __array_ufunc__)
setattr(ShapedArray, '__array_function__', __array_function__)


# Extra methods that are handy
setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
Expand Down
Loading