Skip to content

Commit

Permalink
fix: add ScalarType and treat bare strings as char arrays (#2116)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Jan 13, 2023
1 parent 0030145 commit 71ada43
Show file tree
Hide file tree
Showing 16 changed files with 347 additions and 191 deletions.
26 changes: 8 additions & 18 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,6 @@ def type(self):
def typestr(self):
"""
The high-level type of this Array, presented as a string.
Note that the outermost element of an Array's type is always an
#ak.types.ArrayType, which specifies the number of elements in the array.
The type of a #ak.contents.Content (from #ak.Array.layout) is not
wrapped by an #ak.types.ArrayType.
"""
return str(self.type)

Expand Down Expand Up @@ -1704,18 +1698,20 @@ def type(self):
"""
The high-level type of this Record; same as #ak.type.
Note that the outermost element of a Record's type is always a
#ak.types.RecordType.
Note that the outermost element of a Record's type is always an
#ak.types.ScalarType, which .
The type of a #ak.record.Record (from #ak.Array.layout) is not
wrapped by an #ak.types.ScalarType.
"""
return self._layout.array.form.type_from_behavior(self._behavior)
return ak.types.ScalarType(
self._layout.array.form.type_from_behavior(self._behavior)
)

@property
def typestr(self):
"""
The high-level type of this Record, presented as a string.
Note that the outermost element of a Record's type is always a
#ak.types.RecordType.
"""
return str(self.type)

Expand Down Expand Up @@ -2341,12 +2337,6 @@ def type(self):
def typestr(self):
"""
The high-level type of this accumulated array, presented as a string.
Note that the outermost element of an Array's type is always an
#ak.types.ArrayType, which specifies the number of elements in the array.
The type of a #ak.contents.Content (from #ak.Array.layout) is not
wrapped by an #ak.types.ArrayType.
"""
return str(self.type)

Expand Down
27 changes: 10 additions & 17 deletions src/awkward/operations/ak_to_layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

from collections.abc import Iterable

from awkward_cpp.lib import _ext

import awkward as ak
Expand Down Expand Up @@ -67,27 +69,15 @@ def _impl(array, allow_record, allow_other):
return array.snapshot()

elif numpy.is_own_array(array):
return _impl(
ak.operations.from_numpy(
array, regulararray=True, recordarray=True, highlevel=False
),
allow_record,
allow_other,
return ak.operations.from_numpy(
array, regulararray=True, recordarray=True, highlevel=False
)

elif ak._nplikes.Cupy.is_own_array(array):
return _impl(
ak.operations.from_cupy(array, regulararray=True, highlevel=False),
allow_record,
allow_other,
)
return ak.operations.from_cupy(array, regulararray=True, highlevel=False)

elif ak._nplikes.Jax.is_own_array(array):
return _impl(
ak.operations.from_jax(array, regulararray=True, highlevel=False),
allow_record,
allow_other,
)
return ak.operations.from_jax(array, regulararray=True, highlevel=False)

elif ak._typetracer.TypeTracer.is_own_array(array):
backend = ak._backends.TypeTracerBackend.instance()
Expand All @@ -104,7 +94,10 @@ def _impl(array, allow_record, allow_other):

return ak.contents.NumpyArray(array, parameters=None, backend=backend)

elif ak._util.is_non_string_like_iterable(array):
elif isinstance(array, (str, bytes)):
return ak.operations.from_iter([array], highlevel=False)[0]

elif isinstance(array, Iterable):
return _impl(
ak.operations.from_iter(array, highlevel=False),
allow_record,
Expand Down
95 changes: 45 additions & 50 deletions src/awkward/operations/ak_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
np = ak._nplikes.NumpyMetadata.instance()


def type(array):
def type(array, *, behavior=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
behavior (None or dict): Custom #ak.behavior for the output type, if
high-level.
The high-level type of an `array` (many types supported, including all
Awkward Arrays and Records) as #ak.types.Type objects.
Expand Down Expand Up @@ -72,74 +74,67 @@ def type(array):
"""
with ak._errors.OperationErrorContext(
"ak.type",
dict(array=array),
dict(array=array, behavior=behavior),
):
return _impl(array)
return _impl(array, behavior)


def _impl(array):
if array is None:
return ak.types.UnknownType()

elif isinstance(array, np.dtype):
return ak.types.NumpyType(ak.types.numpytype.dtype_to_primitive(array))

elif (
isinstance(array, np.generic)
or isinstance(array, builtins.type)
and issubclass(array, np.generic)
):
primitive = ak.types.numpytype.dtype_to_primitive(np.dtype(array))
return ak.types.NumpyType(primitive)
def _impl(array, behavior):
if isinstance(array, np.dtype):
return ak.types.ScalarType(
ak.types.NumpyType(ak.types.numpytype.dtype_to_primitive(array))
)

elif isinstance(array, bool): # np.bool_ in np.generic (above)
return ak.types.NumpyType("bool")
return ak.types.ScalarType(ak.types.NumpyType("bool"))

elif isinstance(array, numbers.Integral):
return ak.types.NumpyType("int64")
return ak.types.ScalarType(ak.types.NumpyType("int64"))

elif isinstance(array, numbers.Real):
return ak.types.NumpyType("float64")
return ak.types.ScalarType(ak.types.NumpyType("float64"))

elif isinstance(array, numbers.Complex):
return ak.types.NumpyType("complex128")
return ak.types.ScalarType(ak.types.NumpyType("complex128"))

elif isinstance(array, datetime): # np.datetime64 in np.generic (above)
return ak.types.NumpyType("datetime64")
return ak.types.ScalarType(ak.types.NumpyType("datetime64"))

elif isinstance(array, timedelta): # np.timedelta64 in np.generic (above)
return ak.types.NumpyType("timedelta")

elif isinstance(
array,
(
ak.highlevel.Array,
ak.highlevel.Record,
ak.highlevel.ArrayBuilder,
),
):
return array.type
return ak.types.ScalarType(ak.types.NumpyType("timedelta"))

elif isinstance(array, np.ndarray):
if len(array.shape) == 0:
return _impl(array.reshape((1,))[0])
else:
primitive = ak.types.numpytype.dtype_to_primitive(array.dtype)
out = ak.types.NumpyType(primitive)
for x in array.shape[-1:0:-1]:
out = ak.types.RegularType(out, x)
return ak.types.ArrayType(out, array.shape[0])
elif (
isinstance(array, np.generic)
or isinstance(array, builtins.type)
and issubclass(array, np.generic)
):
primitive = ak.types.numpytype.dtype_to_primitive(np.dtype(array))
return ak.types.ScalarType(ak.types.NumpyType(primitive))

elif isinstance(array, _ext.ArrayBuilder):
form = ak.forms.from_json(array.form())
return ak.types.ArrayType(form.type_from_behavior(None), len(array))
return ak.types.ArrayType(form.type_from_behavior(behavior), len(array))

elif isinstance(array, ak.record.Record):
return array.array.form.type

elif isinstance(array, ak.contents.Content):
return array.form.type
elif isinstance(array, ak.ArrayBuilder):
behavior = ak._util.behavior_of(array, behavior=behavior)
form = ak.forms.from_json(array._layout.form())
return ak.types.ArrayType(form.type_from_behavior(behavior), len(array._layout))

else:
layout = ak.to_layout(array, allow_other=False)
return _impl(ak._util.wrap(layout))
behavior = ak._util.behavior_of(array, behavior=behavior)
layout = ak.to_layout(array, allow_other=True, allow_record=True)
if layout is None:
return ak.types.ScalarType(ak.types.UnknownType())

elif isinstance(layout, ak.record.Record):
return ak.types.ScalarType(layout.array.form.type_from_behavior(behavior))

elif isinstance(layout, ak.contents.Content):
return ak.types.ArrayType(
layout.form.type_from_behavior(behavior), layout.length
)

else:
raise ak._errors.wrap_error(
TypeError(f"unrecognized array type: {array!r}")
)
1 change: 1 addition & 0 deletions src/awkward/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from awkward.types.optiontype import OptionType # noqa: F401
from awkward.types.recordtype import RecordType # noqa: F401
from awkward.types.regulartype import RegularType # noqa: F401
from awkward.types.scalartype import ScalarType # noqa: F401
from awkward.types.type import Type, from_datashape # noqa: F401
from awkward.types.uniontype import UnionType # noqa: F401
from awkward.types.unknowntype import UnknownType # noqa: F401
40 changes: 40 additions & 0 deletions src/awkward/types/scalartype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import sys

import awkward as ak


class ScalarType:
def __init__(self, content):
if not isinstance(content, ak.types.Type):
raise ak._errors.wrap_error(
TypeError(
"{} all 'contents' must be Type subclasses, not {}".format(
type(self).__name__, repr(content)
)
)
)
self._content = content

@property
def content(self):
return self._content

def __str__(self):
return "".join(self._str("", True))

def show(self, stream=sys.stdout):
stream.write("".join(self._str("", False) + ["\n"]))

def _str(self, indent, compact):
return self._content._str(indent, compact)

def __repr__(self):
return f"{type(self).__name__}({self._content!r})"

def __eq__(self, other):
if isinstance(other, ScalarType):
return self._content == other._content
else:
return False
14 changes: 8 additions & 6 deletions tests/test_0021_emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ def test_unknown():
e = ak.contents.EmptyArray()
a = ak.contents.ListOffsetArray(i, e)
assert to_list(a) == [[], [], []]
assert str(ak.operations.type(a)) == "var * unknown"
assert ak.operations.type(a) == ak.types.ListType(ak.types.UnknownType())
assert not ak.operations.type(a) == ak.types.NumpyType("float64")
assert str(ak.operations.type(a)) == "3 * var * unknown"
assert ak.operations.type(a) == ak.types.ArrayType(
ak.types.ListType(ak.types.UnknownType()), 3
)
assert ak.operations.type(a) != ak.types.ArrayType(ak.types.NumpyType("float64"), 3)

i = ak.index.Index64(np.array([0, 0, 0, 0, 0, 0], dtype=np.int64))
ii = ak.index.Index64(np.array([0, 0, 2, 5], dtype=np.int64))
a = ak.contents.ListOffsetArray(i, e)
a = ak.contents.ListOffsetArray(ii, a)

assert to_list(a) == [[], [[], []], [[], [], []]]
assert str(ak.operations.type(a)) == "var * var * unknown"
assert ak.operations.type(a) == ak.types.ListType(
ak.types.ListType(ak.types.UnknownType())
assert str(ak.operations.type(a)) == "3 * var * var * unknown"
assert ak.operations.type(a) == ak.types.ArrayType(
ak.types.ListType(ak.types.ListType(ak.types.UnknownType())), 3
)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_0023_regular_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


def test_simple_type():
assert str(ak.operations.type(content)) == "float64"
assert str(ak.operations.type(content)) == "10 * float64"


def test_type():
assert str(ak.operations.type(regulararray)) == "2 * var * float64"
assert str(ak.operations.type(regulararray)) == "3 * 2 * var * float64"


def test_iteration():
Expand Down
Loading

0 comments on commit 71ada43

Please sign in to comment.