Skip to content

Commit

Permalink
Merge pull request #2746 from spectre-ns/reshape
Browse files Browse the repository at this point in the history
[WIP] Make reshape_view accept -1 as a wildcard dimension
  • Loading branch information
JohanMabille authored Nov 16, 2023
2 parents 5f49f64 + aaa819e commit 69eaac5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 15 deletions.
65 changes: 60 additions & 5 deletions include/xtensor/xstrided_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,53 @@ namespace xt
);
}

namespace detail
{
template <typename S>
struct rebind_shape;

template <std::size_t... X>
struct rebind_shape<xt::fixed_shape<X...>>
{
using type = xt::fixed_shape<X...>;
};

template <class S>
struct rebind_shape
{
using type = rebind_container_t<size_t, S>;
};

template <
class S,
std::enable_if_t<std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, bool> = true>
inline void recalculate_shape_impl(S& shape, size_t size)
{
using value_type = get_value_type_t<typename std::decay_t<S>>;
XTENSOR_ASSERT(std::count(shape.cbegin(), shape.cend(), -1) <= 1);
auto iter = std::find(shape.begin(), shape.end(), -1);
if (iter != std::end(shape))
{
const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies<int>{});
const auto missing_dimension = size / total;
(*iter) = static_cast<value_type>(missing_dimension);
}
}

template <
class S,
std::enable_if_t<!std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, bool> = true>
inline void recalculate_shape_impl(S&, size_t)
{
}

template <class S>
inline auto recalculate_shape(S&& shape, size_t size)
{
return recalculate_shape_impl(shape, size);
}
}

template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class S>
inline auto reshape_view(E&& e, S&& shape)
{
Expand All @@ -815,18 +862,26 @@ namespace xt
"traversal has to be row or column major"
);

using shape_type = std::decay_t<S>;
get_strides_t<shape_type> strides;
using shape_type = std::decay_t<decltype(shape)>;
using unsigned_shape_type = typename detail::rebind_shape<shape_type>::type;
get_strides_t<unsigned_shape_type> strides;

detail::recalculate_shape(shape, e.size());
xt::resize_container(strides, shape.size());
compute_strides(shape, L, strides);
constexpr auto computed_layout = std::decay_t<E>::static_layout == L ? L : layout_type::dynamic;
using view_type = xstrided_view<
xclosure_t<E>,
shape_type,
unsigned_shape_type,
computed_layout,
detail::flat_adaptor_getter<xclosure_t<E>, L>>;
return view_type(std::forward<E>(e), std::forward<S>(shape), std::move(strides), 0, e.layout());
return view_type(
std::forward<E>(e),
xtl::forward_sequence<unsigned_shape_type, S>(shape),
std::move(strides),
0,
e.layout()
);
}

/**
Expand Down Expand Up @@ -858,7 +913,7 @@ namespace xt
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class I, std::size_t N>
inline auto reshape_view(E&& e, const I (&shape)[N])
{
using shape_type = std::array<std::size_t, N>;
using shape_type = std::array<I, N>;
return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(shape)>(shape));
}
}
Expand Down
27 changes: 17 additions & 10 deletions test/test_xstrided_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,24 +696,31 @@ namespace xt
EXPECT_EQ(av, e);
EXPECT_EQ(av, a);

bool truthy;
truthy = std::is_same<
typename decltype(xv)::temporary_type,
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>();
EXPECT_TRUE(truthy);

truthy = std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>(
static_assert(
std::is_same<
typename decltype(xv)::temporary_type,
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>::value,
"Container types do not match"
);
static_assert(
std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>::value,
"Container types do not match"
);
static_assert(
std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value,
"Shape types do not match"
);
EXPECT_TRUE(truthy);
truthy = std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value;
EXPECT_TRUE(truthy);

xarray<int> xa = {{1, 2, 3}, {4, 5, 6}};
std::vector<std::size_t> new_shape = {3, 2};
auto xrv = reshape_view(xa, new_shape);

xarray<int> xres = {{1, 2}, {3, 4}, {5, 6}};
EXPECT_EQ(xrv, xres);

auto nv = xt::reshape_view<XTENSOR_DEFAULT_LAYOUT>(a, {-1, 3});
std::vector<size_t> expected_shape({3, 3});
EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin()));
}

TEST(xstrided_view, reshape_view_assign)
Expand Down

0 comments on commit 69eaac5

Please sign in to comment.