Skip to content

Commit

Permalink
[C++][Acero] Fix key hashers for tables with differently ordered sche…
Browse files Browse the repository at this point in the history
…mas than the output
  • Loading branch information
JerAguilon committed Jan 25, 2024
1 parent 9c0a761 commit eaffb69
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ class AsofJoinNode : public ExecNode {
auto inputs = this->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(),
output_schema()));
inputs[i]->output_schema()));
ARROW_ASSIGN_OR_RAISE(
auto input_state,
InputState::Make(i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(),
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/acero/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include <gmock/gmock-matchers.h>
#include <iostream> // nocommit

#include <chrono>
#include <memory>
Expand Down Expand Up @@ -1582,6 +1583,66 @@ TEST(AsofJoinTest, BatchSequencing) {
return TestSequencing(MakeIntegerBatches, /*num_batches=*/32, /*batch_size=*/1);
}

template <typename BatchesMaker>
void TestSchemaResolution(BatchesMaker maker, int num_batches, int batch_size) {
auto l_schema =
schema({field("time", int32()), field("key", int32()), field("l_value", int32())});
auto r_schema =
schema({field("time", int32()), field("key", int32()), field("r0_value", int32())});

auto make_shift = [&maker, num_batches, batch_size](
const std::shared_ptr<Schema>& schema, int shift) {
return maker({[](int row) -> int64_t { return row; },
[num_batches](int row) -> int64_t { return row / num_batches; },
[shift](int row) -> int64_t { return row * 10 + shift; }},
schema, num_batches, batch_size);
};
ASSERT_OK_AND_ASSIGN(auto l_batches, make_shift(l_schema, 0));
ASSERT_OK_AND_ASSIGN(auto r_batches, make_shift(r_schema, 1));

Declaration l_src = {"source",
SourceNodeOptions(l_schema, l_batches.gen(false, false))};
Declaration r_src = {"source",
SourceNodeOptions(r_schema, r_batches.gen(false, false))};
Declaration l_project = {
"project",
{std::move(l_src)},
ProjectNodeOptions({compute::field_ref("time"),
compute::call("cast", {compute::field_ref("key")},
compute::CastOptions::Safe(utf8())),
compute::field_ref("l_value")},
{"time", "key", "l_value"})};
Declaration r_project = {
"project",
{std::move(r_src)},
ProjectNodeOptions({compute::call("cast", {compute::field_ref("key")},
compute::CastOptions::Safe(utf8())),
compute::field_ref("r0_value"), compute::field_ref("time")},
{"key", "r0_value", "time"})};

Declaration asofjoin = {
"asofjoin", {l_project, r_project}, GetRepeatedOptions(2, "time", {"key"}, 1000)};

QueryOptions query_options;
query_options.use_threads = false;
ASSERT_OK_AND_ASSIGN(auto table, DeclarationToTable(asofjoin, query_options));

Int32Builder expected_r0_b;
for (int i = 1; i <= 91; i += 10) {
ASSERT_OK(expected_r0_b.Append(i));
}
ASSERT_OK_AND_ASSIGN(auto expected_r0, expected_r0_b.Finish());

auto actual_r0 = table->GetColumnByName("r0_value");
std::vector<std::shared_ptr<arrow::Array>> chunks = {expected_r0};
auto expected_r0_chunked = std::make_shared<arrow::ChunkedArray>(chunks);
ASSERT_TRUE(actual_r0->Equals(expected_r0_chunked));
}

TEST(AsofJoinTest, OutputSchemaResolution) {
return TestSchemaResolution(MakeIntegerBatches, /*num_batches=*/1, /*batch_size=*/10);
}

namespace {

Result<AsyncGenerator<std::optional<ExecBatch>>> MakeIntegerBatchGenForTest(
Expand Down

0 comments on commit eaffb69

Please sign in to comment.