diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index a65f9467a..8352d5f5c 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -53,6 +53,33 @@ namespace xt using xstrided_view_base_t = typename xstrided_view_base::type; } + namespace detail + { + template > + struct get_linear_iterator : std::false_type + { + using iterator = typename C::iterator; + }; + + template + struct get_linear_iterator().linear_begin())>> : std::true_type + { + using iterator = typename C::linear_iterator; + }; + + template > + struct get_const_linear_iterator : std::false_type + { + using iterator = typename C::const_iterator; + }; + + template + struct get_const_linear_iterator().linear_cbegin())>> : std::true_type + { + using iterator = typename C::const_linear_iterator; + }; + } + template struct select_iterable_base { @@ -153,10 +180,13 @@ namespace xt using inner_storage_type = typename base_type::inner_storage_type; using storage_type = typename base_type::storage_type; - using linear_iterator = typename storage_type::iterator; - using const_linear_iterator = typename storage_type::const_iterator; - using reverse_linear_iterator = std::reverse_iterator; - using const_reverse_linear_iterator = std::reverse_iterator; + + using linear_iterator = typename detail::get_linear_iterator::iterator; + using const_linear_iterator = typename detail::get_const_linear_iterator::iterator; + using reverse_linear_iterator = std::reverse_iterator< + typename detail::get_linear_iterator::iterator>; + using const_reverse_linear_iterator = std::reverse_iterator< + typename detail::get_const_linear_iterator::iterator>; using iterable_base = select_iterable_base_t; using inner_shape_type = typename base_type::inner_shape_type; @@ -223,7 +253,6 @@ namespace xt const_linear_iterator linear_end() const; const_linear_iterator linear_cbegin() const; const_linear_iterator linear_cend() const; - reverse_linear_iterator linear_rbegin(); reverse_linear_iterator linear_rend(); const_reverse_linear_iterator linear_rbegin() const; @@ -487,13 +516,32 @@ namespace xt template inline auto xstrided_view::linear_begin() -> linear_iterator { - return this->storage().begin() + static_cast(data_offset()); + return xtl::mpl::static_if::value>( + [&](auto self) + { + return self(this->storage()).linear_begin() + static_cast(data_offset()); + }, + [&](auto self) + { + return self(this->storage()).begin() + static_cast(data_offset()); + } + ); } template inline auto xstrided_view::linear_end() -> linear_iterator { - return this->storage().begin() + static_cast(data_offset() + size()); + return xtl::mpl::static_if::value>( + [&](auto self) + { + return self(this->storage()).linear_begin() + + static_cast(data_offset() + size()); + }, + [&](auto self) + { + return self(this->storage()).begin() + static_cast(data_offset() + size()); + } + ); } template @@ -511,13 +559,31 @@ namespace xt template inline auto xstrided_view::linear_cbegin() const -> const_linear_iterator { - return this->storage().cbegin() + static_cast(data_offset()); + return xtl::mpl::static_if::value>( + [&](auto self) + { + return self(this->storage()).linear_cbegin() + static_cast(data_offset()); + }, + [&](auto self) + { + return self(this->storage()).cbegin() + static_cast(data_offset()); + } + ); } template inline auto xstrided_view::linear_cend() const -> const_linear_iterator { - return this->storage().cbegin() + static_cast(data_offset() + size()); + return xtl::mpl::static_if::value>( + [&](auto self) + { + return self(this->storage()).linear_cend() + static_cast(data_offset()); + }, + [&](auto self) + { + return self(this->storage()).cend() + static_cast(data_offset()); + } + ); } template diff --git a/include/xtensor/xstrided_view_base.hpp b/include/xtensor/xstrided_view_base.hpp index 9f3f8ef4c..533e6d83e 100644 --- a/include/xtensor/xstrided_view_base.hpp +++ b/include/xtensor/xstrided_view_base.hpp @@ -75,6 +75,52 @@ namespace xt size_type m_size; }; + template + class linear_flat_expression_adaptor : public flat_expression_adaptor + { + public: + + using xexpression_type = std::decay_t; + using shape_type = typename xexpression_type::shape_type; + using inner_strides_type = get_strides_t; + using index_type = inner_strides_type; + using size_type = typename xexpression_type::size_type; + using value_type = typename xexpression_type::value_type; + using const_reference = typename xexpression_type::const_reference; + using reference = std::conditional_t< + std::is_const>::value, + typename xexpression_type::const_reference, + typename xexpression_type::reference>; + + + using linear_iterator = decltype(std::declval>().linear_begin()); + using const_linear_iterator = decltype(std::declval>().linear_cbegin()); + using reverse_linear_iterator = decltype(std::declval>().linear_rbegin() + ); + using const_reverse_linear_iterator = decltype(std::declval>().linear_crbegin()); + + + explicit linear_flat_expression_adaptor(CT* e); + + template + linear_flat_expression_adaptor(CT* e, FST&& strides); + + linear_iterator linear_begin(); + linear_iterator linear_end(); + const_linear_iterator linear_begin() const; + const_linear_iterator linear_end() const; + const_linear_iterator linear_cbegin() const; + const_linear_iterator linear_cend() const; + + private: + + static index_type& get_index(); + + mutable CT* m_e; + inner_strides_type m_strides; + size_type m_size; + }; + template struct is_flat_expression_adaptor : std::false_type { @@ -85,9 +131,21 @@ namespace xt { }; + template + struct is_linear_flat_expression_adaptor : std::false_type + { + }; + + template + struct is_linear_flat_expression_adaptor> : std::true_type + { + }; + template - struct provides_data_interface - : xtl::conjunction>, xtl::negation>> + struct provides_data_interface : xtl::conjunction< + has_data_interface>, + xtl::negation>, + xtl::negation>> { }; } @@ -246,7 +304,11 @@ namespace xt template struct flat_adaptor_getter { - using type = flat_expression_adaptor, L>; + using type = std::conditional_t< + detail::has_linear_iterator>::value + && (std::remove_reference_t::static_layout == L), + linear_flat_expression_adaptor, L>, + flat_expression_adaptor, L>>; using reference = std::add_lvalue_reference_t; template @@ -318,9 +380,7 @@ namespace xt layout_type layout ) noexcept : m_e(std::forward(e)) - , - // m_storage(detail::get_flat_storage(m_e)), - m_storage(storage_getter::get_flat_storage(m_e)) + , m_storage(storage_getter::get_flat_storage(m_e)) , m_shape(std::forward(shape)) , m_strides(std::move(strides)) , m_offset(offset) @@ -345,6 +405,14 @@ namespace xt new_storage.update_pointer(std::addressof(expr)); return new_storage; } + + template + auto copy_move_storage(T& expr, const detail::linear_flat_expression_adaptor& storage) + { + detail::linear_flat_expression_adaptor new_storage = storage; // copy storage + new_storage.update_pointer(std::addressof(expr)); + return new_storage; + } } template @@ -652,7 +720,7 @@ namespace xt template inline bool xstrided_view_base::has_linear_assign(const O& str) const noexcept { - return has_data_interface::value && str.size() == strides().size() + return detail::has_linear_iterator::value && str.size() == strides().size() && std::equal(str.cbegin(), str.cend(), strides().begin()); } @@ -783,6 +851,58 @@ namespace xt thread_local static index_type index; return index; } + + template + inline linear_flat_expression_adaptor::linear_flat_expression_adaptor(CT* e) + : flat_expression_adaptor(e) + , m_e(e) + { + } + + template + template + inline linear_flat_expression_adaptor::linear_flat_expression_adaptor(CT* e, FST&& strides) + : flat_expression_adaptor(e, strides) + , m_e(e) + , m_strides(xtl::forward_sequence(strides)) + { + } + + template + inline auto linear_flat_expression_adaptor::linear_begin() -> linear_iterator + { + return m_e->linear_begin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_end() -> linear_iterator + { + return m_e->linear_end(); + } + + template + inline auto linear_flat_expression_adaptor::linear_begin() const -> const_linear_iterator + { + return m_e->linear_cbegin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_end() const -> const_linear_iterator + { + return m_e->linear_cend(); + } + + template + inline auto linear_flat_expression_adaptor::linear_cbegin() const -> const_linear_iterator + { + return m_e->linear_cbegin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_cend() const -> const_linear_iterator + { + return m_e->linear_cend(); + } } /**********************************