diff --git a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh index 171d2d598d7..27247ca70e5 100644 --- a/cub/cub/device/dispatch/tuning/tuning_select_if.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_select_if.cuh @@ -124,7 +124,7 @@ struct sm90_tuning; }; @@ -135,7 +135,7 @@ struct sm90_tuning; }; @@ -146,7 +146,7 @@ struct sm90_tuning; }; @@ -284,7 +284,7 @@ struct sm90_tuning; }; @@ -296,7 +296,7 @@ struct sm90_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, static constexpr int threads = 192; static constexpr int items = 5; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; }; @@ -307,7 +307,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, static constexpr int threads = 192; static constexpr int items = 5; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; }; @@ -382,12 +382,310 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4 }; #endif + +template (), + input_size InputSize = classify_input_size()> +struct sm80_tuning +{ + static constexpr int threads = 128; + + static constexpr int nominal_4b_items_per_thread = 10; + + static constexpr int items = + CUB_MIN(nominal_4b_items_per_thread, + CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT)))); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; +}; + +// select::if +template +struct sm80_tuning +{ + static constexpr int threads = 992; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<395>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 576; + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<870>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 18; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<1130>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 192; + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<832, 1165>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm80_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 384; + static constexpr int items = 4; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<1140>; +}; + +template <> +struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 384; + static constexpr int items = 4; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<1140>; +}; +#endif + +// select::flagged +template +struct sm80_tuning +{ + static constexpr int threads = 224; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<735>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<1155>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 320; + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<124, 1115>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 6; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<1130>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>; +}; + +template <> +struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>; +}; +#endif + +// partition::if +template +struct sm80_tuning +{ + static constexpr int threads = 512; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<510>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 224; + static constexpr int items = 18; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<1045>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 192; + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<1040>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 192; + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<68, 1160>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm80_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>; +}; + +template <> +struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>; +}; +#endif + +// partition::flagged +template +struct sm80_tuning +{ + static constexpr int threads = 512; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<595>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 224; + static constexpr int items = 18; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<1105>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 192; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<912, 1025>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 192; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<884, 1130>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>; +}; + +template <> +struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 256; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>; +}; +#endif + } // namespace select template struct device_select_policy_hub { - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + struct DefaultTuning { static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10; @@ -403,7 +701,32 @@ struct device_select_policy_hub detail::fixed_delay_constructor_t<350, 450>>; }; - struct Policy900 : ChainedPolicy<900, Policy900, Policy350> + struct Policy350 + : DefaultTuning + , ChainedPolicy<350, Policy350, Policy350> + {}; + + struct Policy800 : ChainedPolicy<800, Policy800, Policy350> + { + using tuning = detail::select::sm80_tuning(), + select::are_rejects_kept(), + select::classify_offset_size()>; + + using SelectIfPolicyT = AgentSelectIfPolicy; + }; + + struct Policy860 + : DefaultTuning + , ChainedPolicy<860, Policy860, Policy800> + {}; + + struct Policy900 : ChainedPolicy<900, Policy900, Policy860> { using tuning = detail::select::sm90_tuning(), diff --git a/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh b/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh index a1b3c43dc9e..8c8fabe79e0 100644 --- a/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh @@ -134,7 +134,7 @@ struct sm90_tuning static constexpr int threads = 384; static constexpr int items = 7; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::fixed_delay_constructor_t<464, 1165>; }; @@ -145,7 +145,7 @@ struct sm90_tuning static constexpr int threads = 128; static constexpr int items = 7; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::no_delay_constructor_t<1040>; }; @@ -167,7 +167,7 @@ struct sm90_tuning static constexpr int threads = 640; static constexpr int items = 24; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::no_delay_constructor_t<245>; }; @@ -178,7 +178,7 @@ struct sm90_tuning static constexpr int threads = 256; static constexpr int items = 23; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::no_delay_constructor_t<910>; }; @@ -189,7 +189,7 @@ struct sm90_tuning static constexpr int threads = 256; static constexpr int items = 18; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::no_delay_constructor_t<1145>; }; @@ -200,18 +200,77 @@ struct sm90_tuning static constexpr int threads = 256; static constexpr int items = 11; - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; using delay_constructor = detail::no_delay_constructor_t<1050>; }; +template (), + offset_size OffsetSize = classify_offset_size()> +struct sm80_tuning +{ + static constexpr int threads = 256; + static constexpr int items = Nominal4BItemsToItems(9); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t; + using AccumPackT = typename AccumPackHelperT::pack_t; + using delay_constructor = detail::default_delay_constructor_t; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<910>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::no_delay_constructor_t<1120>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 224; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<264, 1080>; +}; + +template +struct sm80_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 10; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<672, 1120>; +}; + } // namespace three_way_partition template struct device_three_way_partition_policy_hub { - /// SM35 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + struct DefaultTuning { constexpr static int ITEMS_PER_THREAD = Nominal4BItemsToItems(9); @@ -222,8 +281,32 @@ struct device_three_way_partition_policy_hub cub::BLOCK_SCAN_WARP_SCANS>; }; + /// SM35 + struct Policy350 + : DefaultTuning + , ChainedPolicy<350, Policy350, Policy350> + {}; + + struct Policy800 : ChainedPolicy<800, Policy800, Policy350> + { + using tuning = detail::three_way_partition::sm80_tuning; + + using ThreeWayPartitionPolicy = + AgentThreeWayPartitionPolicy; + }; + + struct Policy860 + : DefaultTuning + , ChainedPolicy<860, Policy860, Policy800> + {}; + /// SM90 - struct Policy900 : ChainedPolicy<900, Policy900, Policy350> + struct Policy900 : ChainedPolicy<900, Policy900, Policy860> { using tuning = detail::three_way_partition::sm90_tuning;