Skip to content

Commit

Permalink
Support assignment from TensorMap wrapping a const type to a `Ten…
Browse files Browse the repository at this point in the history
…sor` storing non-`const`

* Add test that copy-assigns to a `Tensor` from a `TensorMap` wrapping an array of `const`-qualified data.
* Missing inline declaration for a `maskstore` overload causing "multiple definition of `Fastor::maskstore`" errors.
* Use destination type (i.e. non-const) to construct `SIMDVector` in `eval` functions, otherwise we get "assignment of read-only location" errors.
* Don't `const`-qualify destination tensor in `assign`, otherwise overload resolution of called `trivial_assign` fails.
  • Loading branch information
feltech committed May 30, 2020
1 parent 87225e3 commit d1274da
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Fastor/simd_vector/simd_vector_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,8 @@ void maskstore(std::complex<double> * FASTOR_RESTRICT a, const int (&maska)[4],
_mm256_maskstore_pd(reinterpret_cast<double*>(a ), (__m256i) mask0, lo);
_mm256_maskstore_pd(reinterpret_cast<double*>(a+2), (__m256i) mask1, hi);
}
template<>
FASTOR_INLINE
void maskstore(std::complex<float> * FASTOR_RESTRICT a, const int (&maska)[8], SIMDVector<std::complex<float>,simd_abi::avx> &v) {
// Split the mask in to a higher and lower part - we need two masks for this
__m256i mask0 = _mm256_set_epi32(maska[4],maska[4],maska[5],maska[5],maska[6],maska[6],maska[7],maska[7]);
Expand Down
12 changes: 6 additions & 6 deletions Fastor/tensor/TensorEvaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// Expression templates evaluators
//----------------------------------------------------------------------------------------------------------//
template<typename U=T>
FASTOR_INLINE SIMDVector<T,simd_abi_type> eval(FASTOR_INDEX i) const {
SIMDVector<T,simd_abi_type> _vec;
FASTOR_INLINE SIMDVector<U,simd_abi_type> eval(FASTOR_INDEX i) const {
SIMDVector<U,simd_abi_type> _vec;
_vec.load(&_data[get_mem_index(i)],false);
return _vec;
}
Expand All @@ -14,8 +14,8 @@ FASTOR_INLINE T eval_s(FASTOR_INDEX i) const {
return _data[get_mem_index(i)];
}
template<typename U=T>
FASTOR_INLINE SIMDVector<T,simd_abi_type> eval(FASTOR_INDEX i, FASTOR_INDEX j) const {
SIMDVector<T,simd_abi_type> _vec;
FASTOR_INLINE SIMDVector<U,simd_abi_type> eval(FASTOR_INDEX i, FASTOR_INDEX j) const {
SIMDVector<U,simd_abi_type> _vec;
_vec.load(&_data[get_flat_index(i,j)],false);
return _vec;
}
Expand All @@ -25,8 +25,8 @@ FASTOR_INLINE T eval_s(FASTOR_INDEX i, FASTOR_INDEX j) const {
}

template<typename U=T>
FASTOR_INLINE SIMDVector<T,simd_abi_type> teval(const std::array<int, dimension_t::value> &as) const {
SIMDVector<T,simd_abi_type> _vec;
FASTOR_INLINE SIMDVector<U,simd_abi_type> teval(const std::array<int, dimension_t::value> &as) const {
SIMDVector<U,simd_abi_type> _vec;
_vec.load(&_data[get_flat_index(as)],false);
return _vec;
}
Expand Down
2 changes: 1 addition & 1 deletion Fastor/tensor/TensorMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ FASTOR_MAKE_OS_STREAM_TENSORn(TensorMap)


template<typename Derived, size_t DIM, typename T, size_t ...Rest>
FASTOR_INLINE void assign(const AbstractTensor<Derived,DIM> &dst, const TensorMap<T,Rest...> &src) {
FASTOR_INLINE void assign(AbstractTensor<Derived,DIM> &dst, const TensorMap<T,Rest...> &src) {
if (dst.self().data()==src.data()) return;
trivial_assign(dst.self(),src);
}
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tensormap/test_tensormap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ void run() {
FASTOR_EXIT_ASSERT(abs(a.sum() - ma.sum()) < Tol);
}

// Map a const array and copy-assign it to a non-const tensor.
{
const T data[] = {1, 2, 3};
TensorMap<const T, 3> mdata{data};
Tensor<T, 3> tdata = mdata;
Tensor<T, 3> check{1, 2, 3};

FASTOR_EXIT_ASSERT(all_of(tdata == check));
}

print(FGRN(BOLD("All tests passed successfully")));
}

Expand Down

0 comments on commit d1274da

Please sign in to comment.