Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BLAS2 GER operator #505

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 55 additions & 47 deletions include/interface/blas2_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,34 +176,38 @@ typename sb_handle_t::event_t _symv(
);

/*!
@brief Generalised vector product followed by a sum with a rectangular
non-symmetric matrix.

Generalised vector product followed by a sum with a rectangular non-symmetric
matrix, i.e. computing the mathematical operation:

A = alpha*x*yT + A

See the netlib blas interface documentation for more details of the high level
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html

* @brief Generalised vector product followed by a sum with a rectangular
* non-symmetric matrix.
*
* Generalised vector product followed by a sum with a rectangular non-symmetric
* matrix, i.e. computing the mathematical operation:
*
* A = alpha*x*yT + A
*
* See the netlib blas interface documentation for more details of the high
* level interface:
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
*
* @param sb_handle SB_handle
* @param _M Number of rows in matrix A
* @param _N Number of columns in matrix A
* @param _alpha Scalar alpha
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
* @param _incx Increment for vector X
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
* @param _incy Increment for vector Y
* @param _mA Input/output matrix A(_lda, n)
* @param _lda Leading dimension of A
* @param _dependencies Vector of events
*/
template <typename sb_handle_t, typename index_t, typename element_t,
typename container_0_t, typename increment_t, typename container_1_t,
typename container_2_t>
typename sb_handle_t::event_t _ger(
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
index_t _M, // The rows in matrix A
index_t _N, // The cols of matrix A
element_t _alpha, // Scalar alpha
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
increment_t _incx, // Increment for vector X
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
increment_t _incy, // Increment for vector Y
container_2_t _mA, // (_lda, n) array containing A, the output
index_t _lda, // >max(1, m), Leading dimension of A
const typename sb_handle_t::event_t& _dependencies // Vector of events
);
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
container_2_t _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies);

/*!
@brief Generalised vector squaring followed by a sum with a symmetric matrix.
Expand Down Expand Up @@ -746,35 +750,39 @@ typename sb_handle_t::event_t inline _symv(
}

/*!
@brief Generalised vector product followed by a sum with a rectangular
non-symmetric matrix.

Generalised vector product followed by a sum with a rectangular non-symmetric
matrix, i.e.
computing the mathematical operation:

A = alpha*x*yT + A

See the netlib blas interface documentation for more details of the high level
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html

* @brief Generalised vector product followed by a sum with a rectangular
* non-symmetric matrix.
*
* Generalised vector product followed by a sum with a rectangular non-symmetric
* matrix, i.e.
* computing the mathematical operation:
*
* A = alpha*x*yT + A
*
* See the netlib blas interface documentation for more details of the high
* level interface:
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
*
* @param sb_handle SB_handle
* @param _M Number of rows in matrix A
* @param _N Number of columns in matrix A
* @param _alpha Scalar alpha
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
* @param _incx Increment for vector X
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
* @param _incy Increment for vector Y
* @param _mA Input/output matrix A(_lda, n)
* @param _lda Leading dimension of A
* @param _dependencies Vector of events
*/
template <typename sb_handle_t, typename index_t, typename element_t,
typename container_0_t, typename increment_t, typename container_1_t,
typename container_2_t>
typename sb_handle_t::event_t inline _ger(
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
index_t _M, // The rows in matrix M
index_t _N, // The rows of matrix N
element_t _alpha, // Scalar alpha
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
increment_t _incx, // Increment for vector X
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
increment_t _incy, // Increment for vector Y
container_2_t _mA, // (_lda, n) array containing A, the output
index_t _lda, // >max(1, m), Leading dimension of A
const typename sb_handle_t::event_t& _dependencies = {} // Vector of events
) {
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
container_2_t _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies = {}) {
return internal::_ger(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA,
_lda, _dependencies);
}
Expand Down
58 changes: 58 additions & 0 deletions include/operations/blas2_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) {
subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_,
sync_);
}
/**
* @struct Ger
* @brief Tree node representing the sum of scalar-vector-vector product with a
* matrix, i.e., it computes lhs_ such that
*
* lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_
*
* @param lhs_ input/output matrix
* @param scalar_ value for scaling vector product
* @param rhs_1_ first input vector
* @param rhs_2_ second input vector
hjabird marked this conversation as resolved.
Show resolved Hide resolved
* @param nRowsWG_ rows of the workgroup tile
* @param nColsWG_ cols of the workgroup tile
* @param nWG_row_ number of tiles per global size row
* @param nWG_col_ number of tiles per global size column
*
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
struct Ger {
using value_t = typename rhs_2_t::value_t;
using index_t = typename rhs_2_t::index_t;

lhs_t lhs_;
value_t scalar_;
rhs_1_t rhs_1_;
rhs_2_t rhs_2_;
index_t nRowsWG_;
index_t nColsWG_;
index_t nWG_row_;
index_t nWG_col_;

Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG,
index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col);

index_t get_size() const;
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
value_t eval(index_t i);
value_t eval(cl::sycl::nd_item<1> ndItem);
template <typename sharedT>
value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem);
void bind(cl::sycl::handler &h);
void adjust_access_displacement();
};

/*!
@brief Generator/factory for GER trees.
*/
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
Ger<lhs_t, rhs_1_t, rhs_2_t> make_ger(lhs_t &lhs_,
typename lhs_t::value_t scalar_,
rhs_1_t &rhs_1_, rhs_2_t &rhs_2_,
typename rhs_2_t::index_t nRowsWG_,
typename rhs_2_t::index_t nColsWG_,
typename rhs_2_t::index_t nWG_row_,
typename rhs_2_t::index_t nWG_col_) {
return Ger<lhs_t, rhs_1_t, rhs_2_t>(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_,
nColsWG_, nWG_row_, nWG_col_);
}

/**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/
// template <typename lhs_t,typename rhs_1_t,typename rhs_2_t>
Expand Down
67 changes: 51 additions & 16 deletions src/interface/blas2_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0,
index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
index_t M = _M;
index_t N = _N;
auto mA = make_matrix_view<col_major>(_mA, M, N, _lda);
Expand All @@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl(
typename VectorViewType<container_t1, index_t, increment_t>::type vy =
make_vector_view(_vy, _incy, N);

const index_t localSize =
(_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG);
_localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
_nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG;
_nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG;

const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG);
assert(_localSize % _nRowsWG == 0);
assert((_nRowsWG * _nColsWG) % _localSize == 0);
assert(_nColsWG % (_localSize / _nRowsWG) == 0);

const index_t scratchPadSize =
(_localSize == 0) ? localSize : _scratchPadSize;
if (_useLocalMem) {
assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize));
} else {
std::vector<size_t> subgroup_sizes =
sb_handle.get_queue()
.get_device()
.template get_info<sycl::info::device::sub_group_sizes>();
size_t min_subgroup_size = *subgroup_sizes.begin();
size_t max_subgroup_size = *subgroup_sizes.rbegin();
assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size);
assert(_nRowsWG % max_subgroup_size == 0);
}

const index_t nWGPerCol = (N - 1) / nColsWG + 1;
const index_t nWGPerRow = (M - 1) / nRowsWG + 1;
const index_t globalSize = localSize * nWGPerRow * nWGPerCol;
const index_t nWGPerCol = (N - 1) / _nColsWG + 1;
const index_t nWGPerRow = (M - 1) / _nRowsWG + 1;
const index_t globalSize = _localSize * nWGPerRow * nWGPerCol;

typename sb_handle_t::event_t ret;
auto assignOp =
make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize);
return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize,
_dependencies);
make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol);

return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize,
_nRowsWG + _nColsWG, _dependencies)
: sb_handle.execute(assignOp, _localSize, globalSize,
_dependencies);
}

/*! _SYR.
Expand Down Expand Up @@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger(
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
container_t2 _mA, index_t _lda,
const typename sb_handle_t::event_t& _dependencies) {
// TODO: Here we can use some heuristics to select localn global, local, and
// scratch size per device
index_t localSize = 0;
bool useLocalMem = true;
index_t nRowsWG = 0;
index_t nColsWG = 0;

#if defined(INTEL_GPU)
localSize = 32;
useLocalMem = false;
nRowsWG = 32;
nColsWG = 8;
#elif defined(NVIDIA_GPU)
localSize = 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = 32;
nColsWG = 32;
#elif defined(AMD_GPU)
localSize = (_N < 8192 && _M < 8192) ? 512 : 256;
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128;
nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256;
#endif
hjabird marked this conversation as resolved.
Show resolved Hide resolved

return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda,
_dependencies);
_dependencies, localSize, useLocalMem, nRowsWG, nColsWG);
}

template <typename sb_handle_t, typename index_t, typename element_t,
Expand Down
Loading
Loading