Skip to content

Commit

Permalink
Update BLAS2 GER operator (#505)
Browse files Browse the repository at this point in the history
This patch introduces a new implementation for the BLAS2 GER operator.
  • Loading branch information
pgorlani authored Apr 15, 2024
1 parent e5f9738 commit f067e58
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 66 deletions.
102 changes: 55 additions & 47 deletions include/interface/blas2_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,34 +176,38 @@ typename sb_handle_t::event_t _symv(
);

/*!
@brief Generalised vector product followed by a sum with a rectangular
non-symmetric matrix.
Generalised vector product followed by a sum with a rectangular non-symmetric
matrix, i.e. computing the mathematical operation:
A = alpha*x*yT + A
See the netlib blas interface documentation for more details of the high level
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
* @brief Generalised vector product followed by a sum with a rectangular
* non-symmetric matrix.
*
* Generalised vector product followed by a sum with a rectangular non-symmetric
* matrix, i.e. computing the mathematical operation:
*
* A = alpha*x*yT + A
*
* See the netlib blas interface documentation for more details of the high
* level interface:
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
*
* @param sb_handle SB_handle
* @param _M Number of rows in matrix A
* @param _N Number of columns in matrix A
* @param _alpha Scalar alpha
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
* @param _incx Increment for vector X
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
* @param _incy Increment for vector Y
* @param _mA Input/output matrix A(_lda, n)
* @param _lda Leading dimension of A
* @param _dependencies Vector of events
*/
template <typename sb_handle_t, typename index_t, typename element_t,
typename container_0_t, typename increment_t, typename container_1_t,
typename container_2_t>
typename sb_handle_t::event_t _ger(
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
index_t _M, // The rows in matrix A
index_t _N, // The cols of matrix A
element_t _alpha, // Scalar alpha
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
increment_t _incx, // Increment for vector X
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
increment_t _incy, // Increment for vector Y
container_2_t _mA, // (_lda, n) array containing A, the output
index_t _lda, // >max(1, m), Leading dimension of A
const typename sb_handle_t::event_t& _dependencies // Vector of events
);
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
container_2_t _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies);

/*!
@brief Generalised vector squaring followed by a sum with a symmetric matrix.
Expand Down Expand Up @@ -746,35 +750,39 @@ typename sb_handle_t::event_t inline _symv(
}

/*!
@brief Generalised vector product followed by a sum with a rectangular
non-symmetric matrix.
Generalised vector product followed by a sum with a rectangular non-symmetric
matrix, i.e.
computing the mathematical operation:
A = alpha*x*yT + A
See the netlib blas interface documentation for more details of the high level
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
* @brief Generalised vector product followed by a sum with a rectangular
* non-symmetric matrix.
*
* Generalised vector product followed by a sum with a rectangular non-symmetric
* matrix, i.e.
* computing the mathematical operation:
*
* A = alpha*x*yT + A
*
* See the netlib blas interface documentation for more details of the high
* level interface:
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
*
* @param sb_handle SB_handle
* @param _M Number of rows in matrix A
* @param _N Number of columns in matrix A
* @param _alpha Scalar alpha
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
* @param _incx Increment for vector X
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
* @param _incy Increment for vector Y
* @param _mA Input/output matrix A(_lda, n)
* @param _lda Leading dimension of A
* @param _dependencies Vector of events
*/
template <typename sb_handle_t, typename index_t, typename element_t,
typename container_0_t, typename increment_t, typename container_1_t,
typename container_2_t>
typename sb_handle_t::event_t inline _ger(
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
index_t _M, // The rows in matrix M
index_t _N, // The rows of matrix N
element_t _alpha, // Scalar alpha
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
increment_t _incx, // Increment for vector X
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
increment_t _incy, // Increment for vector Y
container_2_t _mA, // (_lda, n) array containing A, the output
index_t _lda, // >max(1, m), Leading dimension of A
const typename sb_handle_t::event_t& _dependencies = {} // Vector of events
) {
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
container_2_t _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies = {}) {
return internal::_ger(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA,
_lda, _dependencies);
}
Expand Down
58 changes: 58 additions & 0 deletions include/operations/blas2_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) {
subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_,
sync_);
}
/**
* @struct Ger
* @brief Tree node representing the sum of scalar-vector-vector product with a
* matrix, i.e., it computes lhs_ such that
*
* lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_
*
* @param lhs_ input/output matrix
* @param scalar_ value for scaling vector product
* @param rhs_1_ first input vector
* @param rhs_2_ second input vector
* @param nRowsWG_ rows of the workgroup tile
* @param nColsWG_ cols of the workgroup tile
* @param nWG_row_ number of tiles per global size row
* @param nWG_col_ number of tiles per global size column
*
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
struct Ger {
using value_t = typename rhs_2_t::value_t;
using index_t = typename rhs_2_t::index_t;

lhs_t lhs_;
value_t scalar_;
rhs_1_t rhs_1_;
rhs_2_t rhs_2_;
index_t nRowsWG_;
index_t nColsWG_;
index_t nWG_row_;
index_t nWG_col_;

Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG,
index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col);

index_t get_size() const;
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
value_t eval(index_t i);
value_t eval(cl::sycl::nd_item<1> ndItem);
template <typename sharedT>
value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem);
void bind(cl::sycl::handler &h);
void adjust_access_displacement();
};

/*!
@brief Generator/factory for GER trees.
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
Ger<lhs_t, rhs_1_t, rhs_2_t> make_ger(lhs_t &lhs_,
typename lhs_t::value_t scalar_,
rhs_1_t &rhs_1_, rhs_2_t &rhs_2_,
typename rhs_2_t::index_t nRowsWG_,
typename rhs_2_t::index_t nColsWG_,
typename rhs_2_t::index_t nWG_row_,
typename rhs_2_t::index_t nWG_col_) {
return Ger<lhs_t, rhs_1_t, rhs_2_t>(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_,
nColsWG_, nWG_row_, nWG_col_);
}

/**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/
// template <typename lhs_t,typename rhs_1_t,typename rhs_2_t>
Expand Down
67 changes: 51 additions & 16 deletions src/interface/blas2_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0,
index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
index_t M = _M;
index_t N = _N;
auto mA = make_matrix_view<col_major>(_mA, M, N, _lda);
Expand All @@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl(
typename VectorViewType<container_t1, index_t, increment_t>::type vy =
make_vector_view(_vy, _incy, N);

const index_t localSize =
(_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG);
_localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
_nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG;
_nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG;

const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG);
assert(_localSize % _nRowsWG == 0);
assert((_nRowsWG * _nColsWG) % _localSize == 0);
assert(_nColsWG % (_localSize / _nRowsWG) == 0);

const index_t scratchPadSize =
(_localSize == 0) ? localSize : _scratchPadSize;
if (_useLocalMem) {
assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize));
} else {
std::vector<size_t> subgroup_sizes =
sb_handle.get_queue()
.get_device()
.template get_info<cl::sycl::info::device::sub_group_sizes>();
size_t min_subgroup_size = *subgroup_sizes.begin();
size_t max_subgroup_size = *subgroup_sizes.rbegin();
assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size);
assert(_nRowsWG % max_subgroup_size == 0);
}

const index_t nWGPerCol = (N - 1) / nColsWG + 1;
const index_t nWGPerRow = (M - 1) / nRowsWG + 1;
const index_t globalSize = localSize * nWGPerRow * nWGPerCol;
const index_t nWGPerCol = (N - 1) / _nColsWG + 1;
const index_t nWGPerRow = (M - 1) / _nRowsWG + 1;
const index_t globalSize = _localSize * nWGPerRow * nWGPerCol;

typename sb_handle_t::event_t ret;
auto assignOp =
make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize);
return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize,
_dependencies);
make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol);

return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize,
_nRowsWG + _nColsWG, _dependencies)
: sb_handle.execute(assignOp, _localSize, globalSize,
_dependencies);
}

/*! _SYR.
Expand Down Expand Up @@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies) {
// TODO: Here we can use some heuristics to select localn global, local, and
// scratch size per device
index_t localSize = 0;
bool useLocalMem = true;
index_t nRowsWG = 0;
index_t nColsWG = 0;

#if defined(INTEL_GPU)
localSize = 32;
useLocalMem = false;
nRowsWG = 32;
nColsWG = 8;
#elif defined(NVIDIA_GPU)
localSize = 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = 32;
nColsWG = 32;
#elif defined(AMD_GPU)
localSize = (_N < 8192 && _M < 8192) ? 512 : 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128;
nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256;
#endif

return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda,
_dependencies);
_dependencies, localSize, useLocalMem, nRowsWG, nColsWG);
}

template <typename sb_handle_t, typename index_t, typename element_t,
Expand Down
Loading

0 comments on commit f067e58

Please sign in to comment.