Skip to content

Commit

Permalink
[xla:cpu] Add support for 17 sort inputs.
Browse files Browse the repository at this point in the history
Fixes jax-ml/jax#23727
This is a temporary fix. We will add a fallback sort kernel soon.

PiperOrigin-RevId: 676420937
  • Loading branch information
penpornk authored and Google-ML-Automation committed Sep 19, 2024
1 parent b8fffbb commit a7bacdc
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/sort_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ static absl::Status SortInplace(absl::Span<se::DeviceMemoryBase> data,
case 16:
sort(std::integral_constant<size_t, 16>{});
break;
case 17:
sort(std::integral_constant<size_t, 17>{});
break;
case 25:
sort(std::integral_constant<size_t, 25>{});
break;
Expand Down
2 changes: 2 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,8 @@ xla_test(
":test_macros_header",
":xla_internal_test_main",
"//xla:error_spec",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
52 changes: 52 additions & 0 deletions xla/tests/sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <string>
#include <string_view>
#include <vector>

#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "xla/error_spec.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_macros.h"
Expand Down Expand Up @@ -85,5 +91,51 @@ XLA_TEST_F(SortTest, SortTwiceWithSameComparator) {
EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0}));
}

// TODO(penporn): Parameterize `num_inputs` and test several numbers when we
// have a generic fallback sort kernel.
XLA_TEST_F(SortTest, SortManyInputs) {
constexpr int num_inputs = 17;
std::string_view hlo_text_module_template = R"(
HloModule sort
compare {
${COMPARE_DECLARATIONS}
ROOT lt = pred[] compare(p0, p1), direction=LT
}
ENTRY e {
${SORT_DECLARATIONS}
ROOT sort = (${SORT_SHAPE}) sort(${SORT_PARAMS}), dimensions={0},
to_apply=compare
}
)";

// Prepare values for template substitutions.
std::string sort_decls = "";
std::vector<std::string> param_names;
param_names.reserve(num_inputs * 2);
for (int i = 0; i < num_inputs; ++i) {
sort_decls += absl::StrFormat("p%d = f32[32,64] parameter(%d)\n", i, i);
param_names.emplace_back(absl::StrCat("p", i));
}
std::string sort_params = absl::StrJoin(param_names, ", ");
std::string sort_shape =
absl::StrJoin(std::vector<std::string>(num_inputs, "f32[32,64]"), ",");
std::string compare_decls = "";
for (int i = 0; i < num_inputs * 2; ++i) {
compare_decls += absl::StrFormat("p%d = f32[] parameter(%d)\n", i, i);
}
std::string compare_params = absl::StrJoin(param_names, ", ");

// Finalize HLO text.
std::string hlo_text_module = absl::StrReplaceAll(
hlo_text_module_template, {{"${SORT_DECLARATIONS}", sort_decls},
{"${SORT_SHAPE}", sort_shape},
{"${SORT_PARAMS}", sort_params},
{"${COMPARE_DECLARATIONS}", compare_decls}});

EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0}));
}

} // namespace
} // namespace xla

0 comments on commit a7bacdc

Please sign in to comment.