Skip to content

Commit

Permalink
NumpyArray::numbers_to_type must use flattened_length, not length.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 6, 2021
1 parent 7897564 commit 6c8db14
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/libawkward/array/NumpyArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5053,9 +5053,11 @@ namespace awkward {
ssize_t itemsize = util::dtype_to_itemsize(dtype);
std::vector<ssize_t> shape = contiguous_self.shape();
std::vector<ssize_t> strides;
int64_t flattened_length = 1;
for (int64_t j = (int64_t)shape.size(); j > 0; j--) {
strides.insert(strides.begin(), itemsize);
itemsize *= shape[(size_t)(j - 1)];
flattened_length *= shape[(size_t)(j - 1)];
}

IdentitiesPtr identities = contiguous_self.identities();
Expand All @@ -5066,47 +5068,47 @@ namespace awkward {
switch (dtype_) {
case util::dtype::boolean:
ptr = as_type<bool>(reinterpret_cast<bool*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::int8:
ptr = as_type<int8_t>(reinterpret_cast<int8_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::int16:
ptr = as_type<int16_t>(reinterpret_cast<int16_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::int32:
ptr = as_type<int32_t>(reinterpret_cast<int32_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::int64:
ptr = as_type<int64_t>(reinterpret_cast<int64_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::uint8:
ptr = as_type<uint8_t>(reinterpret_cast<uint8_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::uint16:
ptr = as_type<uint16_t>(reinterpret_cast<uint16_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::uint32:
ptr = as_type<uint32_t>(reinterpret_cast<uint32_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::uint64:
ptr = as_type<uint64_t>(reinterpret_cast<uint64_t*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::float16:
Expand All @@ -5116,12 +5118,12 @@ namespace awkward {
break;
case util::dtype::float32:
ptr = as_type<float>(reinterpret_cast<float*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::float64:
ptr = as_type<double>(reinterpret_cast<double*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::float128:
Expand All @@ -5131,12 +5133,12 @@ namespace awkward {
break;
case util::dtype::complex64:
ptr = as_type<std::complex<float>>(reinterpret_cast<std::complex<float>*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::complex128:
ptr = as_type<std::complex<double>>(reinterpret_cast<std::complex<double>*>(contiguous_self.data()),
contiguous_self.length(),
flattened_length,
dtype);
break;
case util::dtype::complex256:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_1173-numbers_to_type-length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

from __future__ import absolute_import

import pytest # noqa: F401
import numpy as np # noqa: F401
import awkward as ak # noqa: F401


def test():
assert ak.to_list(
ak.layout.NumpyArray(np.array([[1, 2], [3, 4]], np.int64)).numbers_to_type(
"int16"
)
) == [[1, 2], [3, 4]]

0 comments on commit 6c8db14

Please sign in to comment.