Skip to content

Commit

Permalink
Merge pull request #289 from senior-zero/enh-main/github/sm80_select
Browse files Browse the repository at this point in the history
Tune Select and Partition on A100
  • Loading branch information
gevtushenko authored Aug 1, 2023
2 parents 2da0c1f + 1df02a4 commit 5ace4bb
Show file tree
Hide file tree
Showing 2 changed files with 423 additions and 17 deletions.
339 changes: 331 additions & 8 deletions cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<320, 605>;
};
Expand All @@ -135,7 +135,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 384;
static constexpr int items = 17;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>;
};
Expand All @@ -146,7 +146,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>;
};
Expand Down Expand Up @@ -284,7 +284,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primi
static constexpr int threads = 128;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>;
};
Expand All @@ -296,7 +296,7 @@ struct sm90_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
Expand All @@ -307,7 +307,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
Expand Down Expand Up @@ -382,12 +382,310 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
};
#endif


template <class InputT,
flagged,
keep_rejects,
offset_size OffsetSize,
primitive = is_primitive<InputT>(),
input_size InputSize = classify_input_size<InputT>()>
struct sm80_tuning
{
static constexpr int threads = 128;

static constexpr int nominal_4b_items_per_thread = 10;

static constexpr int items =
CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT))));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
};

// select::if
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 992;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<395>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 576;
static constexpr int items = 14;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<870>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1130>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<832, 1165>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 384;
static constexpr int items = 4;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1140>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 384;
static constexpr int items = 4;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1140>;
};
#endif

// select::flagged
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 224;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<735>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1155>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 320;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<124, 1115>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 384;
static constexpr int items = 6;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1130>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
};
#endif

// partition::if
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<510>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 224;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1045>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 192;
static constexpr int items = 15;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1040>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<68, 1160>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};
#endif

// partition::flagged
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<595>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 224;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1105>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 192;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<912, 1025>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<884, 1130>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};
#endif

} // namespace select

template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
struct DefaultTuning
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10;

Expand All @@ -403,7 +701,32 @@ struct device_select_policy_hub
detail::fixed_delay_constructor_t<350, 450>>;
};

struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
struct Policy350
: DefaultTuning
, ChainedPolicy<350, Policy350, Policy350>
{};

struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::select::sm80_tuning<InputT,
select::is_flagged<FlagT>(),
select::are_rejects_kept<KeepRejects>(),
select::classify_offset_size<OffsetT>()>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

struct Policy860
: DefaultTuning
, ChainedPolicy<860, Policy860, Policy800>
{};

struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::select::sm90_tuning<InputT,
select::is_flagged<FlagT>(),
Expand Down
Loading

0 comments on commit 5ace4bb

Please sign in to comment.