diff --git a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh
index 171d2d598d7..27247ca70e5 100644
--- a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh
+++ b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh
@@ -124,7 +124,7 @@ struct sm90_tuning;
};
@@ -135,7 +135,7 @@ struct sm90_tuning;
};
@@ -146,7 +146,7 @@ struct sm90_tuning;
};
@@ -284,7 +284,7 @@ struct sm90_tuning;
};
@@ -296,7 +296,7 @@ struct sm90_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
@@ -307,7 +307,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
@@ -382,12 +382,310 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
};
#endif
+
+template (),
+ input_size InputSize = classify_input_size()>
+struct sm80_tuning
+{
+ static constexpr int threads = 128;
+
+ static constexpr int nominal_4b_items_per_thread = 10;
+
+ static constexpr int items =
+ CUB_MIN(nominal_4b_items_per_thread,
+ CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT))));
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
+};
+
+// select::if
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 992;
+ static constexpr int items = 20;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<395>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 576;
+ static constexpr int items = 14;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<870>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 18;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<1130>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 192;
+ static constexpr int items = 10;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<832, 1165>;
+};
+
+#if CUB_IS_INT128_ENABLED
+template <>
+struct sm80_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 384;
+ static constexpr int items = 4;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<1140>;
+};
+
+template <>
+struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 384;
+ static constexpr int items = 4;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<1140>;
+};
+#endif
+
+// select::flagged
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 224;
+ static constexpr int items = 20;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<735>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 20;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<1155>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 320;
+ static constexpr int items = 10;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<124, 1115>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 384;
+ static constexpr int items = 6;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<1130>;
+};
+
+#if CUB_IS_INT128_ENABLED
+template <>
+struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
+};
+
+template <>
+struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
+};
+#endif
+
+// partition::if
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 512;
+ static constexpr int items = 20;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<510>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 224;
+ static constexpr int items = 18;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<1045>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 192;
+ static constexpr int items = 15;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<1040>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 192;
+ static constexpr int items = 10;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<68, 1160>;
+};
+
+#if CUB_IS_INT128_ENABLED
+template <>
+struct sm80_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
+};
+
+template <>
+struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
+};
+#endif
+
+// partition::flagged
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 512;
+ static constexpr int items = 20;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::no_delay_constructor_t<595>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 224;
+ static constexpr int items = 18;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<1105>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 192;
+ static constexpr int items = 12;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<912, 1025>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 192;
+ static constexpr int items = 12;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<884, 1130>;
+};
+
+#if CUB_IS_INT128_ENABLED
+template <>
+struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
+};
+
+template <>
+struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 5;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
+};
+#endif
+
} // namespace select
template
struct device_select_policy_hub
{
- struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
+ struct DefaultTuning
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10;
@@ -403,7 +701,32 @@ struct device_select_policy_hub
detail::fixed_delay_constructor_t<350, 450>>;
};
- struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
+ struct Policy350
+ : DefaultTuning
+ , ChainedPolicy<350, Policy350, Policy350>
+ {};
+
+ struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
+ {
+ using tuning = detail::select::sm80_tuning(),
+ select::are_rejects_kept(),
+ select::classify_offset_size()>;
+
+ using SelectIfPolicyT = AgentSelectIfPolicy;
+ };
+
+ struct Policy860
+ : DefaultTuning
+ , ChainedPolicy<860, Policy860, Policy800>
+ {};
+
+ struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::select::sm90_tuning(),
diff --git a/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh b/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
index a1b3c43dc9e..8c8fabe79e0 100644
--- a/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
+++ b/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
@@ -134,7 +134,7 @@ struct sm90_tuning
static constexpr int threads = 384;
static constexpr int items = 7;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::fixed_delay_constructor_t<464, 1165>;
};
@@ -145,7 +145,7 @@ struct sm90_tuning
static constexpr int threads = 128;
static constexpr int items = 7;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::no_delay_constructor_t<1040>;
};
@@ -167,7 +167,7 @@ struct sm90_tuning
static constexpr int threads = 640;
static constexpr int items = 24;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::no_delay_constructor_t<245>;
};
@@ -178,7 +178,7 @@ struct sm90_tuning
static constexpr int threads = 256;
static constexpr int items = 23;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::no_delay_constructor_t<910>;
};
@@ -189,7 +189,7 @@ struct sm90_tuning
static constexpr int threads = 256;
static constexpr int items = 18;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::no_delay_constructor_t<1145>;
};
@@ -200,18 +200,77 @@ struct sm90_tuning
static constexpr int threads = 256;
static constexpr int items = 11;
- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = detail::no_delay_constructor_t<1050>;
};
+template (),
+ offset_size OffsetSize = classify_offset_size()>
+struct sm80_tuning
+{
+ static constexpr int threads = 256;
+ static constexpr int items = Nominal4BItemsToItems(9);
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
+
+ using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t;
+ using AccumPackT = typename AccumPackHelperT::pack_t;
+ using delay_constructor = detail::default_delay_constructor_t;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 12;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<910>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 256;
+ static constexpr int items = 11;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::no_delay_constructor_t<1120>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 224;
+ static constexpr int items = 11;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<264, 1080>;
+};
+
+template
+struct sm80_tuning
+{
+ static constexpr int threads = 128;
+ static constexpr int items = 10;
+
+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
+
+ using delay_constructor = detail::fixed_delay_constructor_t<672, 1120>;
+};
+
} // namespace three_way_partition
template
struct device_three_way_partition_policy_hub
{
- /// SM35
- struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
+ struct DefaultTuning
{
constexpr static int ITEMS_PER_THREAD = Nominal4BItemsToItems(9);
@@ -222,8 +281,32 @@ struct device_three_way_partition_policy_hub
cub::BLOCK_SCAN_WARP_SCANS>;
};
+ /// SM35
+ struct Policy350
+ : DefaultTuning
+ , ChainedPolicy<350, Policy350, Policy350>
+ {};
+
+ struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
+ {
+ using tuning = detail::three_way_partition::sm80_tuning;
+
+ using ThreeWayPartitionPolicy =
+ AgentThreeWayPartitionPolicy;
+ };
+
+ struct Policy860
+ : DefaultTuning
+ , ChainedPolicy<860, Policy860, Policy800>
+ {};
+
/// SM90
- struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
+ struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::three_way_partition::sm90_tuning;