Skip to content

Commit

Permalink
Add substrait support for sort & fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Mar 21, 2023
1 parent 38bf313 commit d77339c
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 0 deletions.
102 changes: 102 additions & 0 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ Status DiscoverFilesFromDir(const std::shared_ptr<fs::LocalFileSystem>& local_fs
return Status::OK();
}

bool IsSortNullsFirst(const substrait::SortField::SortDirection& direction) {
return direction % 2 == 1;
}

compute::SortOrder SortOrderFromDirection(
const substrait::SortField::SortDirection& direction) {
if (direction < 3) {
return compute::SortOrder::Ascending;
} else {
return compute::SortOrder::Descending;
}
}

Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
static bool dataset_init = false;
Expand Down Expand Up @@ -683,6 +696,95 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
return ProcessEmit(std::move(join), std::move(join_declaration),
std::move(join_schema));
}
case substrait::Rel::RelTypeCase::kFetch: {
const auto& fetch = rel.fetch();
RETURN_NOT_OK(CheckRelCommon(fetch, conversion_options));

if (!fetch.has_input()) {
return Status::Invalid("substrait::FetchRel with no input relation");
}

ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(fetch.input(), ext_set, conversion_options));

int64_t offset = fetch.offset();
int64_t count = fetch.count();

compute::Declaration fetch_dec{
"fetch", {input.declaration}, compute::FetchNodeOptions(offset, count)};

DeclarationInfo fetch_declaration{std::move(fetch_dec), input.output_schema};
return ProcessEmit(fetch, std::move(fetch_declaration),
fetch_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kSort: {
const auto& sort = rel.sort();
RETURN_NOT_OK(CheckRelCommon(sort, conversion_options));

if (!sort.has_input()) {
return Status::Invalid("substrait::SortRel with no input relation");
}

ARROW_ASSIGN_OR_RAISE(auto input,
FromProto(sort.input(), ext_set, conversion_options));

if (sort.sorts_size() == 0) {
return Status::Invalid("substrait::SortRel with no sorts");
}

std::vector<compute::SortKey> sort_keys;
compute::NullPlacement null_placement;
bool first = true;
for (const auto& sort : sort.sorts()) {
if (sort.direction() == substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_UNSPECIFIED) {
return Status::Invalid(
"substrait::SortRel with sort that had unspecified direction");
}
if (sort.direction() == substrait::SortField::SortDirection::
SortField_SortDirection_SORT_DIRECTION_CLUSTERED) {
return Status::NotImplemented(
"substrait::SortRel with sort with clustered sort direction");
}
// Substrait allows null placement to differ for each field. Acero expects it to
// be consistent across all fields. So we grab the null placement from the first
// key and verify all other keys have the same null placement
if (first) {
null_placement = IsSortNullsFirst(sort.direction())
? compute::NullPlacement::AtStart
: compute::NullPlacement::AtEnd;
} else {
if ((null_placement == compute::NullPlacement::AtStart &&
!IsSortNullsFirst(sort.direction())) ||
(null_placement == compute::NullPlacement::AtEnd &&
IsSortNullsFirst(sort.direction()))) {
return Status::NotImplemented(
"substrait::SortRel with ordering with mixed null placement");
}
}
if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) {
return Status::NotImplemented("substrait::SortRel with custom sort function");
}
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(sort.expr(), ext_set, conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
sort_keys.push_back(
compute::SortKey(*field_ref, SortOrderFromDirection(sort.direction())));
} else {
return Status::Invalid("Sort key expressions must be a direct reference.");
}
}

compute::Declaration sort_dec{"order_by",
{input.declaration},
compute::OrderByNodeOptions(compute::Ordering(
std::move(sort_keys), null_placement))};

DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema};
return ProcessEmit(sort, std::move(sort_declaration),
sort_declaration.output_schema);
}
case substrait::Rel::RelTypeCase::kAggregate: {
const auto& aggregate = rel.aggregate();
RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options));
Expand Down
122 changes: 122 additions & 0 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5136,6 +5136,128 @@ TEST(Substrait, CompoundEmitWithFilter) {
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}

TEST(Substrait, SortAndFetch) {
// Sort by A, ascending, take items [2, 5), then sort by B descending
std::string substrait_json = R"({
"version": {
"major_number": 9999,
"minor_number": 9999,
"patch_number": 9999
},
"relations": [
{
"rel": {
"sort": {
"input": {
"fetch": {
"input": {
"sort": {
"input": {
"read": {
"base_schema": {
"names": [
"A",
"B"
],
"struct": {
"types": [
{
"i32": {}
},
{
"i32": {}
}
]
}
},
"namedTable": {
"names": [
"table"
]
}
}
},
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 0
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
]
}
},
"offset": 2,
"count": 3
}
},
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_DESC_NULLS_LAST"
}
]
}
}
}
],
"extension_uris": [],
"extensions": []
})";

ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
auto test_schema = schema({field("A", int32()), field("B", int32())});

auto input_table = TableFromJSON(test_schema, {R"([
[null, null],
[5, 8],
[null, null],
[null, null],
[3, 4],
[9, 6],
[4, 5]
])"});

// First sort by A, ascending, nulls first to yield rows:
// 0, 2, 3, 4, 6, 1, 5
// Apply fetch to grab rows 3, 4, 6
// Then sort by B, descending, to yield rows 6, 4, 3

auto output_table = TableFromJSON(test_schema, {R"([
[4, 5],
[3, 4],
[null, null]
])"});

NamedTableProvider table_provider = [&](const std::vector<std::string>& names,
const Schema&) {
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::TableSourceNodeOptions>(input_table);
return compute::Declaration("table_source", {}, options, "mock_source");
};

ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);

CheckRoundTripResult(std::move(output_table), buf, {}, conversion_options);
}

TEST(Substrait, PlanWithExtension) {
// This demos an extension relation
std::string substrait_json = R"({
Expand Down

0 comments on commit d77339c

Please sign in to comment.