Skip to content

Commit

Permalink
count identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna committed Nov 11, 2021
1 parent c36d9d6 commit 21428f5
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,9 @@ def validityerror(self, path="layout"):
return self._validityerror(path)

def nbytes(self):
return self._nbytes_part()
largest = {0: 0}
self._nbytes_part(largest)
return sum(largest.values())

def purelist_parameter(self, key):
return self.Form.purelist_parameter(self, key)
Expand Down
6 changes: 6 additions & 0 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,12 @@ def _rpad(self, target, axis, depth, clip):
else:
return self.rpad_axis0(target, clip=True)

def _nbytes_part(self, largest):
it = id(self.ptr)
if it not in largest or largest[it] < self.data.nbytes:
largest[it] = self.data.nbytes
if self.identifier is not None:
self.identifier._nbytes_part(largest)

def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
if self._data.ndim != 1:
Expand Down
5 changes: 5 additions & 0 deletions src/awkward/_v2/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,8 @@ def referentially_equal(self, other):
and self._data.strides == other._data.strides
and self._data.dtype == other._data.dtype
)

def _nbytes_part(self, largest):
it = id(self.ref)
if it not in largest or largest[it] < self.data.nbytes:
largest[it] = self.data.nbytes
18 changes: 17 additions & 1 deletion tests/v2/test_0927-numpy-array-nbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,31 @@
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

from awkward._v2.tmp_for_testing import v1_to_v2, v1_to_v2_index
from awkward._v2.tmp_for_testing import v1_to_v2

pytestmark = pytest.mark.skipif(
ak._util.py27, reason="No Python 2.7 support in Awkward 2.x"
)


def test():
np_data = np.random.random(size=(4, 100 * 1024 * 1024 // 8 // 4))
array = ak.from_numpy(np_data, regulararray=False)
array = v1_to_v2(array.layout)

assert np_data.nbytes == array.nbytes()


def test_NumpyArray():
np_data = np.random.random(size=(4, 100 * 1024 * 1024 // 8 // 4))

identifier = ak._v2.identifier.Identifier.zeros(
123, {1: "one", 2: "two"}, 5, 10, np, np.int64
)

largest = {0: 0}
identifier._nbytes_part(largest)
assert sum(largest.values()) == 8 * 5 * 10

array = ak._v2.contents.numpyarray.NumpyArray(np_data, identifier)
assert array.nbytes() == np_data.nbytes + 8 * 5 * 10

0 comments on commit 21428f5

Please sign in to comment.