diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index eefc37607ad98..c599d1f982561 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -293,6 +293,19 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs return Status::OK(); } +bool IsSortNullsFirst(const substrait::SortField::SortDirection& direction) { + return direction % 2 == 1; +} + +compute::SortOrder SortOrderFromDirection( + const substrait::SortField::SortDirection& direction) { + if (direction < 3) { + return compute::SortOrder::Ascending; + } else { + return compute::SortOrder::Descending; + } +} + Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; @@ -683,6 +696,95 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(std::move(join), std::move(join_declaration), std::move(join_schema)); } + case substrait::Rel::RelTypeCase::kFetch: { + const auto& fetch = rel.fetch(); + RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options)); + + if (!fetch.has_input()) { + return Status::Invalid("substrait::FetchRel with no input relation"); + } + + ARROW_ASSIGN_OR_RAISE(auto input, + FromProto(fetch.input(), ext_set, conversion_options)); + + int64_t offset = fetch.offset(); + int64_t count = fetch.count(); + + compute::Declaration fetch_dec{ + "fetch", {input.declaration}, compute::FetchNodeOptions(offset, count)}; + + DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema}; + return ProcessEmit(fetch, std::move(fetch_declaration), + fetch_declaration.output_schema); + } + case substrait::Rel::RelTypeCase::kSort: { + const auto& sort = rel.sort(); + RETURN_NOT_OK(CheckRelCommon(sort, conversion_options)); + + if (!sort.has_input()) { + return Status::Invalid("substrait::SortRel with no input relation"); + } + + ARROW_ASSIGN_OR_RAISE(auto input, + FromProto(sort.input(), ext_set, conversion_options)); + + if (sort.sorts_size() == 0) { + return Status::Invalid("substrait::SortRel with no sorts"); + } + + std::vector sort_keys; + compute::NullPlacement null_placement; + bool first = true; + for (const auto& sort : sort.sorts()) { + if (sort.direction() == substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED) { + return Status::Invalid( + "substrait::SortRel with sort that had unspecified direction"); + } + if (sort.direction() == substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_CLUSTERED) { + return Status::NotImplemented( + "substrait::SortRel with sort with clustered sort direction"); + } + // Substrait allows null placement to differ for each field. Acero expects it to + // be consistent across all fields. So we grab the null placement from the first + // key and verify all other keys have the same null placement + if (first) { + null_placement = IsSortNullsFirst(sort.direction()) + ? compute::NullPlacement::AtStart + : compute::NullPlacement::AtEnd; + } else { + if ((null_placement == compute::NullPlacement::AtStart && + !IsSortNullsFirst(sort.direction())) || + (null_placement == compute::NullPlacement::AtEnd && + IsSortNullsFirst(sort.direction()))) { + return Status::NotImplemented( + "substrait::SortRel with ordering with mixed null placement"); + } + } + if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) { + return Status::NotImplemented("substrait::SortRel with custom sort function"); + } + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(sort.expr(), ext_set, conversion_options)); + const FieldRef* field_ref = expr.field_ref(); + if (field_ref) { + sort_keys.push_back( + compute::SortKey(*field_ref, SortOrderFromDirection(sort.direction()))); + } else { + return Status::Invalid("Sort key expressions must be a direct reference."); + } + } + + compute::Declaration sort_dec{"order_by", + {input.declaration}, + compute::OrderByNodeOptions(compute::Ordering( + std::move(sort_keys), null_placement))}; + + DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema}; + return ProcessEmit(sort, std::move(sort_declaration), + sort_declaration.output_schema); + } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 06c3cc209659a..c094fde24be38 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5136,6 +5136,128 @@ TEST(Substrait, CompoundEmitWithFilter) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +TEST(Substrait, SortAndFetch) { + // Sort by A, ascending, take items [2, 5), then sort by B descending + std::string substrait_json = R"({ + "version": { + "major_number": 9999, + "minor_number": 9999, + "patch_number": 9999 + }, + "relations": [ + { + "rel": { + "sort": { + "input": { + "fetch": { + "input": { + "sort": { + "input": { + "read": { + "base_schema": { + "names": [ + "A", + "B" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "table" + ] + } + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ] + } + }, + "offset": 2, + "count": 3 + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_LAST" + } + ] + } + } + } + ], + "extension_uris": [], + "extensions": [] +})"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto test_schema = schema({field("A", int32()), field("B", int32())}); + + auto input_table = TableFromJSON(test_schema, {R"([ + [null, null], + [5, 8], + [null, null], + [null, null], + [3, 4], + [9, 6], + [4, 5] + ])"}); + + // First sort by A, ascending, nulls first to yield rows: + // 0, 2, 3, 4, 6, 1, 5 + // Apply fetch to grab rows 3, 4, 6 + // Then sort by B, descending, to yield rows 6, 4, 3 + + auto output_table = TableFromJSON(test_schema, {R"([ + [4, 5], + [3, 4], + [null, null] + ])"}); + + NamedTableProvider table_provider = [&](const std::vector& names, + const Schema&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_table), buf, {}, conversion_options); +} + TEST(Substrait, PlanWithExtension) { // This demos an extension relation std::string substrait_json = R"({