Skip to content

Commit

Permalink
Enhancement in permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
romeric committed May 29, 2020
1 parent 039dc2f commit 97179bb
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 96 deletions.
86 changes: 83 additions & 3 deletions Fastor/meta/einsum_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define EINSUM_META_H

#include "Fastor/tensor/Tensor.h"
#include <array>

namespace Fastor {

Expand Down Expand Up @@ -875,7 +876,8 @@ IndexTensors<Index<Idx...>,Derived0<T,Rest...>,Index<Idx_t...>,Derived1<T,Rest_t



//---------------------------------------------------------------------------------
namespace internal {
//------------------------------------------------------------------------------------------------------------//
template<class arg>
struct meta_argmin_wrapper;
template<size_t ...rest>
Expand Down Expand Up @@ -928,9 +930,87 @@ struct tmp_argsort<Index<value>,Index<ss>> {
using reduced_argseq = Index<ss>;
using new_argseq = Index<ss>;
};
//---------------------------------------------------------------------------------
//------------------------------------------------------------------------------------------------------------//


// Permutation functions
//------------------------------------------------------------------------------------------------------------//
template<size_t N>
constexpr size_t count_less(const size_t (&seq)[N], size_t i, size_t cur = 0) {
return cur == N ? 0 : (count_less(seq, i, cur + 1) + (seq[cur] < i ? 1 : 0));
}

/* Check if a compile time array is sequential */
template<size_t N>
constexpr bool is_sequential(const size_t (&seq)[N], size_t i=0) {
return i+1 == N ? true : ( seq[i] + 1 == seq[i+1] ? is_sequential(seq, i+1) : false ) ;
}
template<size_t N>
constexpr bool is_sequential(const std::array<size_t,N> &seq, size_t i=0) {
return i+1 == N ? true : ( seq[i] + 1 == seq[i+1] ? is_sequential(seq, i+1) : false ) ;
}

// permutation helper class
template<class Idx, class Tens, class Seq>
struct permute_impl;

template<typename T, size_t ... ls, size_t ... fs, size_t... ss>
struct permute_impl<Index<ls...>, Tensor<T, fs...>, std_ext::index_sequence<ss...>> {
constexpr static size_t lst[sizeof...(ls)] = { ls... };
constexpr static size_t fvals[sizeof...(ls)] = {fs...};
using resulting_tensor = Tensor<T,fvals[count_less(lst, lst[ss])]...>;
using resulting_index = typename tmp_argsort<Index<ls...>,Index<ss...>>::new_argseq;
using maxes_out_type = Index<fvals[tmp_argsort<Index<ls...>,Index<ss...>>::new_argseq::values[ss]]...>;
static constexpr bool requires_permutation = !(is_same_v_<resulting_tensor,Tensor<T, fs...>> &&
is_sequential(resulting_index::values));
};

// permute helper class
template<class Idx, class Tens, class Seq>
struct new_permute_impl;

template<typename T, size_t ... ls, size_t ... fs, size_t... ss>
struct new_permute_impl<Index<ls...>, Tensor<T, fs...>, std_ext::index_sequence<ss...>> {
constexpr static size_t lst[sizeof...(ls)] = { ls... };
constexpr static size_t fvals[sizeof...(ls)] = {fs...};
using resulting_tensor = Tensor<T,fvals[count_less(lst, lst[ss])]...>;
constexpr static size_t aranger[sizeof...(ss)] = { ss... };
using resulting_index = Index<aranger[count_less(lst, lst[ss])]...>;
static constexpr bool requires_permutation = !(is_same_v_<resulting_tensor,Tensor<T, fs...>> &&
is_sequential(resulting_index::values));
};

} // internal


template<class Idx, class Tens>
struct requires_permutation;
template<typename T, size_t ... Idx, size_t ... Rest>
struct requires_permutation<Index<Idx...>, Tensor<T, Rest...>> {
using _permute_impl = internal::permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>;
static constexpr bool value = _permute_impl::requires_permutation;
};

// helper
template<class Idx, class Tens>
constexpr bool requires_permutation_v = requires_permutation<Idx,Tens>::value;


template<class Idx, class Tens>
struct requires_permute;
template<typename T, size_t ... Idx, size_t ... Rest>
struct requires_permute<Index<Idx...>, Tensor<T, Rest...>> {
using _permute_impl = internal::new_permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>;
static constexpr bool value = _permute_impl::requires_permutation;
};

// helper
template<class Idx, class Tens>
constexpr bool requires_permute_v = requires_permute<Idx,Tens>::value;
//------------------------------------------------------------------------------------------------------------//

//}

}

Expand Down
105 changes: 39 additions & 66 deletions Fastor/tensor_algebra/permutation.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,6 @@ namespace Fastor {

namespace internal {

template<size_t N>
constexpr size_t count_less(const size_t (&seq)[N], size_t i, size_t cur = 0) {
return cur == N ? 0 : (count_less(seq, i, cur + 1) + (seq[cur] < i ? 1 : 0));
}

template<typename T, class List, class Tensor, class Seq>
struct permute_impl;

template<typename T, size_t ... ls, size_t ... fs, size_t... ss>
struct permute_impl<T,Index<ls...>, Tensor<T, fs...>, std_ext::index_sequence<ss...>>{
constexpr static size_t lst[sizeof...(ls)] = { ls... };
constexpr static size_t fvals[sizeof...(ls)] = {fs...};
using type = Tensor<T,fvals[count_less(lst, lst[ss])]...>;
using index_type = typename tmp_argsort<Index<ls...>,Index<ss...>>::new_argseq;
using maxes_out_type = Index<fvals[tmp_argsort<Index<ls...>,Index<ss...>>::new_argseq::values[ss]]...>;
};




template<class Idx, class Tens, size_t ... Args>
struct RecursiveCartesianPerm;

Expand All @@ -51,15 +31,15 @@ struct RecursiveCartesianPerm<Index<Idx...>, Tensor<T,Rest...>, First, Lasts...>
template<typename T, size_t Last, size_t ...Idx, size_t ...Rest>
struct RecursiveCartesianPerm<Index<Idx...>, Tensor<T,Rest...>,Last>
{
using OutTensor = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type;
using maxes_out_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
using OutTensor = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::resulting_tensor;
using maxes_out_type = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type;
using index_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type;
static constexpr std::array<size_t,sizeof...(Rest)> maxes_idx = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
using index_type = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::resulting_index;
static constexpr std::array<size_t,sizeof...(Rest)> maxes_idx = permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type::values;
static constexpr std::array<size_t,sizeof...(Rest)> maxes_out = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
static constexpr std::array<size_t,sizeof...(Rest)> maxes_out = permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type::values;

static constexpr int a_dim = sizeof...(Rest);
Expand Down Expand Up @@ -115,15 +95,15 @@ constexpr std::array<size_t,sizeof...(Rest)> RecursiveCartesianPerm<Index<Idx...
// template<typename T, size_t ...Idx, size_t ...Rest>
// struct RecursiveCartesianPerm<Index<Idx...>, Tensor<T,Rest...>>
// {
// using OutTensor = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
// using OutTensor = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type;
// using maxes_out_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
// using maxes_out_type = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type;
// using index_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
// using index_type = typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type;
// static constexpr auto maxes_idx = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
// static constexpr auto maxes_idx = permute_impl<Index<Idx...>, Tensor<T,Rest...>,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type::values;
// static constexpr auto maxes_out = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
// static constexpr auto maxes_out = permute_impl<Index<Idx...>, Tensor<T,Rest...>,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type::values;

// static constexpr int a_dim = sizeof...(Rest);
Expand Down Expand Up @@ -190,22 +170,22 @@ struct extractor_perm<Index<Idx...> > {
template<typename T, size_t ... Rest>
static
FASTOR_INLINE
typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type
permutation_impl(const Tensor<T,Rest...> &a) {

using _permute_impl = permute_impl<Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>;
using resulting_index = typename _permute_impl::resulting_index;
using resulting_tensor = typename _permute_impl::resulting_tensor;
constexpr bool requires_permutation = _permute_impl::requires_permutation;
FASTOR_IF_CONSTEXPR(!requires_permutation) return a;

#if CONTRACT_OPT==-1

using OutTensor = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type;
using maxes_out_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type;
using index_type = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type;
constexpr auto& maxes_idx = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type::values;
constexpr auto& maxes_out = permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type::values;
using maxes_out_type = typename _permute_impl::maxes_out_type;
constexpr auto& maxes_idx = resulting_index::values;
constexpr auto& maxes_out = maxes_out_type::values;

constexpr int a_dim = sizeof...(Rest);
constexpr int out_dim = a_dim;
Expand All @@ -216,7 +196,7 @@ struct extractor_perm<Index<Idx...> > {
constexpr auto& products_out = nprods<maxes_out_type,
typename std_ext::make_index_sequence<a_dim>::type>::values;

OutTensor out;
resulting_tensor out;
out.zeros();

T *a_data = a.data();
Expand Down Expand Up @@ -253,10 +233,7 @@ struct extractor_perm<Index<Idx...> > {

#else

using OutTensor = typename permute_impl<T,Index<Idx...>, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type;

OutTensor out;
resulting_tensor out;
out.zeros();

T *a_data = a.data();
Expand Down Expand Up @@ -287,24 +264,20 @@ struct extractor_perm<Index<Idx...> > {
enable_if_t_<!requires_evaluation_v<Derived>,bool> = false>
static
FASTOR_INLINE
typename permute_impl<typename scalar_type_finder<Derived>::type,
typename permute_impl<
Index<Idx...>, typename tensor_type_finder<Derived>::type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::resulting_tensor
permutation_impl(const AbstractTensor<Derived,DIMS> &a) {

using T = typename scalar_type_finder<Derived>::type;
using tensor_type = typename tensor_type_finder<Derived>::type;

using OutTensor = typename permute_impl<T,Index<Idx...>, tensor_type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::type;
using maxes_out_type = typename permute_impl<T,Index<Idx...>, tensor_type,
using resulting_tensor = typename permute_impl<Index<Idx...>, tensor_type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::resulting_tensor;
using maxes_out_type = typename permute_impl<Index<Idx...>, tensor_type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type;
// using index_type = typename permute_impl<T,Index<Idx...>, tensor_type,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type;
constexpr auto& maxes_idx = permute_impl<T,Index<Idx...>, tensor_type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::index_type::values;
// constexpr auto& maxes_out = permute_impl<T,Index<Idx...>, tensor_type,
// typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::maxes_out_type::values;
constexpr auto& maxes_idx = permute_impl<Index<Idx...>, tensor_type,
typename std_ext::make_index_sequence<sizeof...(Idx)>::type>::resulting_index::values;

constexpr int a_dim = DIMS;
constexpr int out_dim = a_dim;
Expand All @@ -315,7 +288,7 @@ struct extractor_perm<Index<Idx...> > {
constexpr auto& products_out = nprods<maxes_out_type,
typename std_ext::make_index_sequence<a_dim>::type>::values;

OutTensor out;
resulting_tensor out;
out.zeros();
T *out_data = out.data();
const Derived & a_src = a.self();
Expand Down Expand Up @@ -357,28 +330,28 @@ struct extractor_perm<Index<Idx...> > {

template<class Index_I, typename T, size_t ... Rest>
FASTOR_INLINE
typename internal::permute_impl<T,Index_I, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Rest)>::type>::type
typename internal::permute_impl<Index_I, Tensor<T,Rest...>,
typename std_ext::make_index_sequence<sizeof...(Rest)>::type>::resulting_tensor
permutation(const Tensor<T, Rest...> &a) {
return internal::extractor_perm<Index_I>::permutation_impl(a);
}

template<class Index_I, typename Derived, size_t DIMS,
enable_if_t_<!requires_evaluation_v<Derived>,bool> = false>
FASTOR_INLINE
typename internal::permute_impl<typename scalar_type_finder<Derived>::type,Index_I,
typename internal::permute_impl<Index_I,
typename Derived::result_type,
typename std_ext::make_index_sequence<DIMS>::type>::type
typename std_ext::make_index_sequence<DIMS>::type>::resulting_tensor
permutation(const AbstractTensor<Derived, DIMS> &a) {
return internal::extractor_perm<Index_I>::permutation_impl(a);
}

template<class Index_I, typename Derived, size_t DIMS,
enable_if_t_<requires_evaluation_v<Derived>,bool> = false>
FASTOR_INLINE
typename internal::permute_impl<typename scalar_type_finder<Derived>::type,Index_I,
typename internal::permute_impl<Index_I,
typename Derived::result_type,
typename std_ext::make_index_sequence<DIMS>::type>::type
typename std_ext::make_index_sequence<DIMS>::type>::resulting_tensor
permutation(const AbstractTensor<Derived, DIMS> &a) {
using result_type = typename Derived::result_type;
const result_type tmp(a);
Expand Down
Loading

0 comments on commit 97179bb

Please sign in to comment.