Skip to content

Commit

Permalink
GH-32763: [C++] Add FromProto for fetch & sort (#34651)
Browse files Browse the repository at this point in the history
This does not support the clustered sort direction, custom sort functions, or complex (non-reference) expressions as sort keys.
* Closes: #32763

Authored-by: Weston Pace <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
westonpace authored May 22, 2023
1 parent 2a6848c commit fbe0d5f
Show file tree
Hide file tree
Showing 2 changed files with 341 additions and 0 deletions.
122 changes: 122 additions & 0 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/string.h"
#include "arrow/util/uri.h"

Expand Down Expand Up @@ -329,6 +330,50 @@ ARROW_ENGINE_EXPORT Result<DeclarationInfo> MakeAggregateDeclaration(

} // namespace internal

namespace {

struct SortBehavior {
compute::NullPlacement null_placement;
compute::SortOrder sort_order;

static Result<SortBehavior> Make(substrait::SortField::SortDirection dir) {
SortBehavior sort_behavior;
switch (dir) {
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED:
return Status::Invalid("The substrait plan does not specify a sort direction");
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
sort_behavior.null_placement = compute::NullPlacement::AtStart;
sort_behavior.sort_order = compute::SortOrder::Ascending;
break;
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
sort_behavior.null_placement = compute::NullPlacement::AtEnd;
sort_behavior.sort_order = compute::SortOrder::Ascending;
break;
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
sort_behavior.null_placement = compute::NullPlacement::AtStart;
sort_behavior.sort_order = compute::SortOrder::Descending;
break;
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
sort_behavior.null_placement = compute::NullPlacement::AtEnd;
sort_behavior.sort_order = compute::SortOrder::Descending;
break;
case substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_CLUSTERED:
default:
return Status::NotImplemented(
"Acero does not support the specified sort direction: ", dir);
}
return sort_behavior;
}
};

} // namespace

Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
static bool dataset_init = false;
Expand Down Expand Up @@ -714,6 +759,83 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&

return ProcessEmit(join, join_declaration, join_schema);
}
case substrait::Rel::RelTypeCase::kFetch: {
const auto& fetch = rel.fetch();
RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options));

if (!fetch.has_input()) {
return Status::Invalid("substrait::FetchRel with no input relation");
}

ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(fetch.input(), ext_set, conversion_options));

int64_t offset = fetch.offset();
int64_t count = fetch.count();

acero::Declaration fetch_dec{
"fetch", {input.declaration}, acero::FetchNodeOptions(offset, count)};

DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema};
return ProcessEmit(fetch, std::move(fetch_declaration),
fetch_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kSort: {
const auto& sort = rel.sort();
RETURN_NOT_OK(CheckRelCommon(sort, conversion_options));

if (!sort.has_input()) {
return Status::Invalid("substrait::SortRel with no input relation");
}

ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(sort.input(), ext_set, conversion_options));

if (sort.sorts_size() == 0) {
return Status::Invalid("substrait::SortRel with no sorts");
}

std::vector<compute::SortKey> sort_keys;
sort_keys.reserve(sort.sorts_size());
// Substrait allows null placement to differ for each field. Acero expects it to
// be consistent across all fields. So we grab the null placement from the first
// key and verify all other keys have the same null placement
std::optional<SortBehavior> sample_sort_behavior;
for (const auto& sort : sort.sorts()) {
ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior,
SortBehavior::Make(sort.direction()));
if (sample_sort_behavior) {
if (sample_sort_behavior->null_placement != sort_behavior.null_placement) {
return Status::NotImplemented(
"substrait::SortRel with ordering with mixed null placement");
}
} else {
sample_sort_behavior = sort_behavior;
}
if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) {
return Status::NotImplemented("substrait::SortRel with custom sort function");
}
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(sort.expr(), ext_set, conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order));
} else {
return Status::Invalid("Sort key expressions must be a direct reference.");
}
}

DCHECK(sample_sort_behavior.has_value());
acero::Declaration sort_dec{
"order_by",
{input.declaration},
acero::OrderByNodeOptions(compute::Ordering(
std::move(sort_keys), sample_sort_behavior->null_placement))};

DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema};
return ProcessEmit(sort, std::move(sort_declaration),
sort_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kAggregate: {
const auto& aggregate = rel.aggregate();
RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options));
Expand Down
219 changes: 219 additions & 0 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5257,6 +5257,225 @@ TEST(Substrait, CompoundEmitWithFilter) {
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}

TEST(Substrait, SortAndFetch) {
// Sort by A, ascending, take items [2, 5), then sort by B descending
std::string substrait_json = R"({
"version": {
"major_number": 9999,
"minor_number": 9999,
"patch_number": 9999
},
"relations": [
{
"rel": {
"sort": {
"input": {
"fetch": {
"input": {
"sort": {
"input": {
"read": {
"base_schema": {
"names": [
"A",
"B"
],
"struct": {
"types": [
{
"i32": {}
},
{
"i32": {}
}
]
}
},
"namedTable": {
"names": [
"table"
]
}
}
},
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
]
}
},
"offset": 2,
"count": 3
}
},
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_DESC_NULLS_LAST"
}
]
}
}
}
],
"extension_uris": [],
"extensions": []
})";

ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto test_schema = schema({field("A", int32()), field("B", int32())});

auto input_table = TableFromJSON(test_schema, {R"([
[null, null],
[5, 8],
[null, null],
[null, null],
[3, 4],
[9, 6],
[4, 5]
])"});

// First sort by A, ascending, nulls first to yield rows:
// 0, 2, 3, 4, 6, 1, 5
// Apply fetch to grab rows 3, 4, 6
// Then sort by B, descending, to yield rows 6, 4, 3

auto output_table = TableFromJSON(test_schema, {R"([
[4, 5],
[3, 4],
[null, null]
])"});

ConversionOptions conversion_options;
conversion_options.named_table_provider =
AlwaysProvideSameTable(std::move(input_table));

CheckRoundTripResult(std::move(output_table), buf, {}, conversion_options);
}

TEST(Substrait, MixedSort) {
// Substrait allows two sort keys with differing direction but Acero
// does not. We should detect this and reject it.
std::string substrait_json = R"({
"version": {
"major_number": 9999,
"minor_number": 9999,
"patch_number": 9999
},
"relations": [
{
"rel": {
"sort": {
"input": {
"read": {
"base_schema": {
"names": [
"A",
"B"
],
"struct": {
"types": [
{
"i32": {}
},
{
"i32": {}
}
]
}
},
"namedTable": {
"names": [
"table"
]
}
}
},
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
},
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_LAST"
}
]
}
}
}
],
"extension_uris": [],
"extensions": []
})";

ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto test_schema = schema({field("A", int32()), field("B", int32())});

auto input_table = TableFromJSON(test_schema, {R"([
[null, null],
[5, 8],
[null, null],
[null, null],
[3, 4],
[9, 6],
[4, 5]
])"});

NamedTableProvider table_provider = [&](const std::vector<std::string>& names,
const Schema&) {
std::shared_ptr<acero::ExecNodeOptions> options =
std::make_shared<acero::TableSourceNodeOptions>(input_table);
return acero::Declaration("table_source", {}, options, "mock_source");
};

ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);

ASSERT_THAT(
DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr,
conversion_options),
Raises(StatusCode::NotImplemented, testing::HasSubstr("mixed null placement")));
}

TEST(Substrait, PlanWithExtension) {
// This demos an extension relation
std::string substrait_json = R"({
Expand Down

0 comments on commit fbe0d5f

Please sign in to comment.