Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune scan on A100 #302

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 146 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,123 @@ template <> struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::
#endif
// clang-format on

template <class AccumT,
primitive_op PrimitiveOp,
primitive_accum PrimitiveAccumulator = is_primitive_accum<AccumT>(),
accum_size AccumSize = classify_accum_size<AccumT>()>
struct sm80_tuning
{
static constexpr int threads = 128;
static constexpr int items = 15;

using delay_constructor = detail::default_delay_constructor_t<AccumT>;

static constexpr bool LargeValues = sizeof(AccumT) > 128;

static constexpr BlockLoadAlgorithm load_algorithm = //
LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED : BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = //
LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED : BLOCK_STORE_WARP_TRANSPOSE;
};

template <class T>
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_1>
{
static constexpr int threads = 320;
static constexpr int items = 14;

using delay_constructor = detail::fixed_delay_constructor_t<368, 725>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

template <class T>
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_2>
{
static constexpr int threads = 352;
static constexpr int items = 16;

using delay_constructor = detail::fixed_delay_constructor_t<488, 1040>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

template <class T>
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_4>
{
static constexpr int threads = 320;
static constexpr int items = 12;

using delay_constructor = detail::fixed_delay_constructor_t<268, 1180>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

template <class T>
struct sm80_tuning<T, primitive_op::yes, primitive_accum::yes, accum_size::_8>
{
static constexpr int threads = 288;
static constexpr int items = 22;

using delay_constructor = detail::fixed_delay_constructor_t<716, 785>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

template <>
struct sm80_tuning<float, primitive_op::yes, primitive_accum::yes, accum_size::_4>
{
static constexpr int threads = 288;
static constexpr int items = 8;

using delay_constructor = detail::fixed_delay_constructor_t<724, 1050>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

template <>
struct sm80_tuning<double, primitive_op::yes, primitive_accum::yes, accum_size::_8>
{
static constexpr int threads = 384;
static constexpr int items = 12;

using delay_constructor = detail::fixed_delay_constructor_t<388, 1100>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_WARP_TRANSPOSE;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
{
static constexpr int threads = 640;
static constexpr int items = 24;

using delay_constructor = detail::no_delay_constructor_t<1200>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT;
};

template <>
struct sm80_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_size::_16>
{
static constexpr int threads = 640;
static constexpr int items = 24;

using delay_constructor = detail::no_delay_constructor_t<1200>;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
static constexpr BlockStoreAlgorithm store_algorithm = BLOCK_STORE_DIRECT;
};
#endif

} // namespace scan
} // namespace detail

Expand Down Expand Up @@ -171,7 +288,7 @@ struct DeviceScanPolicy
};

/// SM600
struct Policy600 : ChainedPolicy<600, Policy600, Policy520>
struct DefaultTuning
{
using ScanPolicyT = policy_t<128,
15, ///< Threads per block, items per thread
Expand All @@ -183,8 +300,35 @@ struct DeviceScanPolicy
detail::default_delay_constructor_t<AccumT>>;
};

/// SM600
struct Policy600
: DefaultTuning
, ChainedPolicy<600, Policy600, Policy520>
{};

/// SM800
struct Policy800 : ChainedPolicy<800, Policy800, Policy600>
{
using tuning = detail::scan::sm80_tuning<AccumT, detail::scan::is_primitive_op<ScanOpT>()>;

using ScanPolicyT = policy_t<tuning::threads,
tuning::items,
AccumT,
tuning::load_algorithm,
LOAD_DEFAULT,
tuning::store_algorithm,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

/// SM860
struct Policy860
: DefaultTuning
, ChainedPolicy<860, Policy860, Policy800>
{};

/// SM900
struct Policy900 : ChainedPolicy<900, Policy900, Policy600>
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::scan::sm90_tuning<AccumT, detail::scan::is_primitive_op<ScanOpT>()>;

Expand Down
Loading