diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c19ba18a38335..73f55c27ee834 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -58,6 +58,7 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" #include "arrow/util/string.h" #include "arrow/util/uri.h" @@ -329,6 +330,50 @@ ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( } // namespace internal +namespace { + +struct SortBehavior { + compute::NullPlacement null_placement; + compute::SortOrder sort_order; + + static Result Make(substrait::SortField::SortDirection dir) { + SortBehavior sort_behavior; + switch (dir) { + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED: + return Status::Invalid("The substrait plan does not specify a sort direction"); + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + sort_behavior.null_placement = compute::NullPlacement::AtStart; + sort_behavior.sort_order = compute::SortOrder::Ascending; + break; + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + sort_behavior.null_placement = compute::NullPlacement::AtEnd; + sort_behavior.sort_order = compute::SortOrder::Ascending; + break; + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + sort_behavior.null_placement = compute::NullPlacement::AtStart; + sort_behavior.sort_order = compute::SortOrder::Descending; + break; + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + sort_behavior.null_placement = compute::NullPlacement::AtEnd; + sort_behavior.sort_order = compute::SortOrder::Descending; + break; + case substrait::SortField::SortDirection:: + SortField_SortDirection_SORT_DIRECTION_CLUSTERED: + default: + return Status::NotImplemented( + "Acero does not support the specified sort direction: ", dir); + } + return sort_behavior; + } +}; + +} // namespace + Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; @@ -714,6 +759,83 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(join, join_declaration, join_schema); } + case substrait::Rel::RelTypeCase::kFetch: { + const auto& fetch = rel.fetch(); + RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options)); + + if (!fetch.has_input()) { + return Status::Invalid("substrait::FetchRel with no input relation"); + } + + ARROW_ASSIGN_OR_RAISE(auto input, + FromProto(fetch.input(), ext_set, conversion_options)); + + int64_t offset = fetch.offset(); + int64_t count = fetch.count(); + + acero::Declaration fetch_dec{ + "fetch", {input.declaration}, acero::FetchNodeOptions(offset, count)}; + + DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema}; + return ProcessEmit(fetch, std::move(fetch_declaration), + fetch_declaration.output_schema); + } + case substrait::Rel::RelTypeCase::kSort: { + const auto& sort = rel.sort(); + RETURN_NOT_OK(CheckRelCommon(sort, conversion_options)); + + if (!sort.has_input()) { + return Status::Invalid("substrait::SortRel with no input relation"); + } + + ARROW_ASSIGN_OR_RAISE(auto input, + FromProto(sort.input(), ext_set, conversion_options)); + + if (sort.sorts_size() == 0) { + return Status::Invalid("substrait::SortRel with no sorts"); + } + + std::vector sort_keys; + sort_keys.reserve(sort.sorts_size()); + // Substrait allows null placement to differ for each field. Acero expects it to + // be consistent across all fields. So we grab the null placement from the first + // key and verify all other keys have the same null placement + std::optional sample_sort_behavior; + for (const auto& sort : sort.sorts()) { + ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior, + SortBehavior::Make(sort.direction())); + if (sample_sort_behavior) { + if (sample_sort_behavior->null_placement != sort_behavior.null_placement) { + return Status::NotImplemented( + "substrait::SortRel with ordering with mixed null placement"); + } + } else { + sample_sort_behavior = sort_behavior; + } + if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) { + return Status::NotImplemented("substrait::SortRel with custom sort function"); + } + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(sort.expr(), ext_set, conversion_options)); + const FieldRef* field_ref = expr.field_ref(); + if (field_ref) { + sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order)); + } else { + return Status::Invalid("Sort key expressions must be a direct reference."); + } + } + + DCHECK(sample_sort_behavior.has_value()); + acero::Declaration sort_dec{ + "order_by", + {input.declaration}, + acero::OrderByNodeOptions(compute::Ordering( + std::move(sort_keys), sample_sort_behavior->null_placement))}; + + DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema}; + return ProcessEmit(sort, std::move(sort_declaration), + sort_declaration.output_schema); + } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index fb211707f05f8..6342388744f39 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5257,6 +5257,225 @@ TEST(Substrait, CompoundEmitWithFilter) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +TEST(Substrait, SortAndFetch) { + // Sort by A, ascending, take items [2, 5), then sort by B descending + std::string substrait_json = R"({ + "version": { + "major_number": 9999, + "minor_number": 9999, + "patch_number": 9999 + }, + "relations": [ + { + "rel": { + "sort": { + "input": { + "fetch": { + "input": { + "sort": { + "input": { + "read": { + "base_schema": { + "names": [ + "A", + "B" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "table" + ] + } + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ] + } + }, + "offset": 2, + "count": 3 + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_LAST" + } + ] + } + } + } + ], + "extension_uris": [], + "extensions": [] +})"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto test_schema = schema({field("A", int32()), field("B", int32())}); + + auto input_table = TableFromJSON(test_schema, {R"([ + [null, null], + [5, 8], + [null, null], + [null, null], + [3, 4], + [9, 6], + [4, 5] + ])"}); + + // First sort by A, ascending, nulls first to yield rows: + // 0, 2, 3, 4, 6, 1, 5 + // Apply fetch to grab rows 3, 4, 6 + // Then sort by B, descending, to yield rows 6, 4, 3 + + auto output_table = TableFromJSON(test_schema, {R"([ + [4, 5], + [3, 4], + [null, null] + ])"}); + + ConversionOptions conversion_options; + conversion_options.named_table_provider = + AlwaysProvideSameTable(std::move(input_table)); + + CheckRoundTripResult(std::move(output_table), buf, {}, conversion_options); +} + +TEST(Substrait, MixedSort) { + // Substrait allows two sort keys with differing direction but Acero + // does not. We should detect this and reject it. + std::string substrait_json = R"({ + "version": { + "major_number": 9999, + "minor_number": 9999, + "patch_number": 9999 + }, + "relations": [ + { + "rel": { + "sort": { + "input": { + "read": { + "base_schema": { + "names": [ + "A", + "B" + ], + "struct": { + "types": [ + { + "i32": {} + }, + { + "i32": {} + } + ] + } + }, + "namedTable": { + "names": [ + "table" + ] + } + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + } + } + ], + "extension_uris": [], + "extensions": [] +})"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto test_schema = schema({field("A", int32()), field("B", int32())}); + + auto input_table = TableFromJSON(test_schema, {R"([ + [null, null], + [5, 8], + [null, null], + [null, null], + [3, 4], + [9, 6], + [4, 5] + ])"}); + + NamedTableProvider table_provider = [&](const std::vector& names, + const Schema&) { + std::shared_ptr options = + std::make_shared(input_table); + return acero::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + ASSERT_THAT( + DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, + conversion_options), + Raises(StatusCode::NotImplemented, testing::HasSubstr("mixed null placement"))); +} + TEST(Substrait, PlanWithExtension) { // This demos an extension relation std::string substrait_json = R"({