Skip to content

Commit

Permalink
correct extract column methods (#6133)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanmorozov333 authored Jul 2, 2024
1 parent 48a2519 commit 6e81fbe
Show file tree
Hide file tree
Showing 21 changed files with 246 additions and 221 deletions.
141 changes: 3 additions & 138 deletions ydb/core/formats/arrow/arrow_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,141 +152,6 @@ std::shared_ptr<arrow::RecordBatch> MakeEmptyBatch(const std::shared_ptr<arrow::
return arrow::RecordBatch::Make(schema, rowsCount, columns);
}

namespace {

template <class TStringType, class TDataContainer>
std::shared_ptr<TDataContainer> ExtractColumnsImpl(const std::shared_ptr<TDataContainer>& srcBatch,
const std::vector<TStringType>& columnNames) {
std::vector<std::shared_ptr<arrow::Field>> fields;
fields.reserve(columnNames.size());
std::vector<std::shared_ptr<typename NAdapter::TDataBuilderPolicy<TDataContainer>::TColumn>> columns;
columns.reserve(columnNames.size());

auto srcSchema = srcBatch->schema();
for (auto& name : columnNames) {
int pos = srcSchema->GetFieldIndex(name);
if (pos < 0) {
return {};
}
fields.push_back(srcSchema->field(pos));
columns.push_back(srcBatch->column(pos));
}

return NAdapter::TDataBuilderPolicy<TDataContainer>::Build(std::move(fields), std::move(columns), srcBatch->num_rows());
}
}

std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsImpl(srcBatch, columnNames);
}

std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<std::string>& columnNames) {
return ExtractColumnsImpl(srcBatch, columnNames);
}

std::shared_ptr<arrow::Table> ExtractColumns(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsImpl(srcBatch, columnNames);
}

std::shared_ptr<arrow::Table> ExtractColumns(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<std::string>& columnNames) {
return ExtractColumnsImpl(srcBatch, columnNames);
}

namespace {
template <class TDataContainer, class TStringImpl>
std::shared_ptr<TDataContainer> ExtractColumnsValidateImpl(const std::shared_ptr<TDataContainer>& srcBatch,
const std::vector<TStringImpl>& columnNames, const bool necessaryColumns) {
if (!srcBatch) {
return srcBatch;
}
if (columnNames.empty()) {
return nullptr;
}
std::vector<std::shared_ptr<arrow::Field>> fields;
fields.reserve(columnNames.size());
std::vector<std::shared_ptr<typename NAdapter::TDataBuilderPolicy<TDataContainer>::TColumn>> columns;
columns.reserve(columnNames.size());

auto srcSchema = srcBatch->schema();
for (auto& name : columnNames) {
const int pos = srcSchema->GetFieldIndex(name);
if (necessaryColumns) {
AFL_VERIFY(pos >= 0)("field_name", name)("names", JoinSeq(",", columnNames))("fields", JoinSeq(",", srcBatch->schema()->field_names()));
} else if (pos == -1) {
continue;
}
fields.push_back(srcSchema->field(pos));
columns.push_back(srcBatch->column(pos));
}

return NAdapter::TDataBuilderPolicy<TDataContainer>::Build(std::move(fields), std::move(columns), srcBatch->num_rows());
}
}

std::shared_ptr<arrow::RecordBatch> ExtractColumnsValidate(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, true);
}

std::shared_ptr<arrow::Table> ExtractColumnsValidate(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, true);
}

std::shared_ptr<arrow::RecordBatch> ExtractColumnsOptional(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, false);
}

std::shared_ptr<arrow::Table> ExtractColumnsOptional(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, false);
}

std::shared_ptr<arrow::RecordBatch> ExtractColumnsOptional(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<std::string>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, false);
}

std::shared_ptr<arrow::Table> ExtractColumnsOptional(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<std::string>& columnNames) {
return ExtractColumnsValidateImpl(srcBatch, columnNames, false);
}

std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::shared_ptr<arrow::Schema>& dstSchema) {
Y_ABORT_UNLESS(srcBatch);
Y_ABORT_UNLESS(dstSchema);
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(dstSchema->num_fields());

for (auto& field : dstSchema->fields()) {
const int index = srcBatch->schema()->GetFieldIndex(field->name());
if (index == -1) {
AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "not_found_column")("column", field->name())
("column_type", field->type()->ToString())("columns", JoinSeq(",", srcBatch->schema()->field_names()));
return nullptr;
} else {
columns.push_back(srcBatch->column(index));
auto srcField = srcBatch->schema()->field(index);
if (!field->Equals(srcField)) {
AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "cannot_use_incoming_batch")("reason", "invalid_column_type")("column", field->name())
("column_type", field->ToString(true))("incoming_type", srcField->ToString(true));
return nullptr;
}
}

AFL_VERIFY(columns.back()->type()->Equals(field->type()))("event", "cannot_use_incoming_batch")("reason", "invalid_column_type")("column", field->name())
("column_type", field->type()->ToString())("incoming_type", columns.back()->type()->ToString());
}

return arrow::RecordBatch::Make(dstSchema, srcBatch->num_rows(), columns);
}

std::shared_ptr<arrow::RecordBatch> CombineBatches(const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches) {
if (batches.empty()) {
return nullptr;
Expand Down Expand Up @@ -427,7 +292,7 @@ void DedupSortedBatch(const std::shared_ptr<arrow::RecordBatch>& batch,

Y_DEBUG_ABORT_UNLESS(NArrow::IsSorted(batch, sortingKey));

auto keyBatch = ExtractColumns(batch, sortingKey);
auto keyBatch = TColumnOperator().Adapt(batch, sortingKey).DetachResult();
auto& keyColumns = keyBatch->columns();

bool same = false;
Expand Down Expand Up @@ -487,7 +352,7 @@ static bool IsSelfSorted(const std::shared_ptr<arrow::RecordBatch>& batch) {

bool IsSorted(const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& sortingKey, bool desc) {
auto keyBatch = ExtractColumns(batch, sortingKey);
auto keyBatch = TColumnOperator().Adapt(batch, sortingKey).DetachResult();
if (desc) {
return IsSelfSorted<true, false>(keyBatch);
} else {
Expand All @@ -497,7 +362,7 @@ bool IsSorted(const std::shared_ptr<arrow::RecordBatch>& batch,

bool IsSortedAndUnique(const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& sortingKey, bool desc) {
auto keyBatch = ExtractColumns(batch, sortingKey);
auto keyBatch = TColumnOperator().Adapt(batch, sortingKey).DetachResult();
if (desc) {
return IsSelfSorted<true, true>(keyBatch);
} else {
Expand Down
28 changes: 1 addition & 27 deletions ydb/core/formats/arrow/arrow_helpers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "switch_type.h"
#include "process_columns.h"
#include <ydb/core/formats/factory.h>
#include <ydb/core/scheme/scheme_tablecell.h>
#include <library/cpp/json/writer/json_value.h>
Expand Down Expand Up @@ -56,33 +57,6 @@ std::shared_ptr<arrow::RecordBatch> DeserializeBatch(const TString& blob,
std::shared_ptr<arrow::RecordBatch> MakeEmptyBatch(const std::shared_ptr<arrow::Schema>& schema, const ui32 rowsCount = 0);
std::shared_ptr<arrow::Table> ToTable(const std::shared_ptr<arrow::RecordBatch>& batch);

std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames);
std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<std::string>& columnNames);
std::shared_ptr<arrow::Table> ExtractColumns(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames);
std::shared_ptr<arrow::Table> ExtractColumns(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<std::string>& columnNames);
std::shared_ptr<arrow::Table> ExtractColumnsValidate(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames);
std::shared_ptr<arrow::RecordBatch> ExtractColumnsValidate(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames);

std::vector<TString> ConvertStrings(const std::vector<std::string>& input);
std::vector<std::string> ConvertStrings(const std::vector<TString>& input);

std::shared_ptr<arrow::Table> ExtractColumnsOptional(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<TString>& columnNames);
std::shared_ptr<arrow::Table> ExtractColumnsOptional(const std::shared_ptr<arrow::Table>& srcBatch,
const std::vector<std::string>& columnNames);
std::shared_ptr<arrow::RecordBatch> ExtractColumnsOptional(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<TString>& columnNames);
std::shared_ptr<arrow::RecordBatch> ExtractColumnsOptional(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::vector<std::string>& columnNames);
std::shared_ptr<arrow::RecordBatch> ExtractColumns(const std::shared_ptr<arrow::RecordBatch>& srcBatch,
const std::shared_ptr<arrow::Schema>& dstSchema);

std::shared_ptr<arrow::RecordBatch> ToBatch(const std::shared_ptr<arrow::Table>& combinedTable, const bool combine);
std::shared_ptr<arrow::RecordBatch> CombineBatches(const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches);
std::shared_ptr<arrow::RecordBatch> MergeColumns(const std::vector<std::shared_ptr<arrow::RecordBatch>>& rb);
Expand Down
6 changes: 6 additions & 0 deletions ydb/core/formats/arrow/common/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class TDataBuilderPolicy<arrow::RecordBatch> {
[[nodiscard]] static std::shared_ptr<arrow::RecordBatch> Build(std::vector<std::shared_ptr<arrow::Field>>&& fields, std::vector<std::shared_ptr<TColumn>>&& columns, const ui32 count) {
return arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(std::move(fields)), count, std::move(columns));
}
[[nodiscard]] static std::shared_ptr<arrow::RecordBatch> Build(const std::shared_ptr<arrow::Schema>& schema, std::vector<std::shared_ptr<TColumn>>&& columns, const ui32 count) {
return arrow::RecordBatch::Make(schema, count, std::move(columns));
}
[[nodiscard]] static std::shared_ptr<arrow::RecordBatch> ApplyArrowFilter(const std::shared_ptr<arrow::RecordBatch>& batch, const std::shared_ptr<arrow::BooleanArray>& filter) {
auto res = arrow::compute::Filter(batch, filter);
Y_VERIFY_S(res.ok(), res.status().message());
Expand All @@ -54,6 +57,9 @@ class TDataBuilderPolicy<arrow::Table> {
[[nodiscard]] static std::shared_ptr<arrow::Table> Build(std::vector<std::shared_ptr<arrow::Field>>&& fields, std::vector<std::shared_ptr<TColumn>>&& columns, const ui32 count) {
return arrow::Table::Make(std::make_shared<arrow::Schema>(std::move(fields)), std::move(columns), count);
}
[[nodiscard]] static std::shared_ptr<arrow::Table> Build(const std::shared_ptr<arrow::Schema>& schema, std::vector<std::shared_ptr<TColumn>>&& columns, const ui32 count) {
return arrow::Table::Make(schema, std::move(columns), count);
}
[[nodiscard]] static std::shared_ptr<arrow::Table> AddColumn(const std::shared_ptr<arrow::Table>& batch, const std::shared_ptr<arrow::Field>& field, const std::shared_ptr<arrow::Array>& extCol) {
return TStatusValidator::GetValid(batch->AddColumn(batch->num_columns(), field, std::make_shared<arrow::ChunkedArray>(extCol)));
}
Expand Down
5 changes: 1 addition & 4 deletions ydb/core/formats/arrow/permutations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ std::shared_ptr<arrow::UInt64Array> MakePermutation(const int size, const bool r
}

std::shared_ptr<arrow::UInt64Array> MakeSortPermutation(const std::shared_ptr<arrow::RecordBatch>& batch, const std::shared_ptr<arrow::Schema>& sortingKey, const bool andUnique) {
auto keyBatch = ExtractColumns(batch, sortingKey);
AFL_VERIFY(batch);
AFL_VERIFY(sortingKey);
AFL_VERIFY(!!keyBatch)("problem", "cannot_find_columns")("schema", batch->schema()->ToString())("columns", sortingKey->ToString());
auto keyBatch = TColumnOperator().VerifyIfAbsent().Adapt(batch, sortingKey).DetachResult();
auto keyColumns = std::make_shared<TArrayVec>(keyBatch->columns());
std::vector<TRawReplaceKey> points;
points.reserve(keyBatch->num_rows());
Expand Down
141 changes: 141 additions & 0 deletions ydb/core/formats/arrow/process_columns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "process_columns.h"
#include "common/adapter.h"

#include <util/string/join.h>

namespace NKikimr::NArrow {

namespace {
template <class TDataContainer, class TStringImpl>
std::shared_ptr<TDataContainer> ExtractColumnsValidateImpl(const std::shared_ptr<TDataContainer>& srcBatch,
const std::vector<TStringImpl>& columnNames) {
std::vector<std::shared_ptr<arrow::Field>> fields;
fields.reserve(columnNames.size());
std::vector<std::shared_ptr<typename NAdapter::TDataBuilderPolicy<TDataContainer>::TColumn>> columns;
columns.reserve(columnNames.size());

auto srcSchema = srcBatch->schema();
for (auto& name : columnNames) {
const int pos = srcSchema->GetFieldIndex(name);
if (Y_LIKELY(pos > -1)) {
fields.push_back(srcSchema->field(pos));
columns.push_back(srcBatch->column(pos));
}
}

return NAdapter::TDataBuilderPolicy<TDataContainer>::Build(std::move(fields), std::move(columns), srcBatch->num_rows());
}

template <class TDataContainer>
TConclusion<std::shared_ptr<TDataContainer>> AdaptColumnsImpl(const std::shared_ptr<TDataContainer>& srcBatch,
const std::shared_ptr<arrow::Schema>& dstSchema) {
AFL_VERIFY(srcBatch);
AFL_VERIFY(dstSchema);
std::vector<std::shared_ptr<typename NAdapter::TDataBuilderPolicy<TDataContainer>::TColumn>> columns;
columns.reserve(dstSchema->num_fields());

for (auto& field : dstSchema->fields()) {
const int index = srcBatch->schema()->GetFieldIndex(field->name());
if (index > -1) {
columns.push_back(srcBatch->column(index));
auto srcField = srcBatch->schema()->field(index);
if (field->Equals(srcField)) {
AFL_VERIFY(columns.back()->type()->Equals(field->type()))("event", "cannot_use_incoming_batch")("reason", "invalid_column_type")("column", field->name())
("column_type", field->type()->ToString())("incoming_type", columns.back()->type()->ToString());
} else {
AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "cannot_use_incoming_batch")("reason", "invalid_column_type")("column", field->name())
("column_type", field->ToString(true))("incoming_type", srcField->ToString(true));
return TConclusionStatus::Fail("incompatible column types");
}
} else {
AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "not_found_column")("column", field->name())
("column_type", field->type()->ToString())("columns", JoinSeq(",", srcBatch->schema()->field_names()));
return TConclusionStatus::Fail("not found column '" + field->name() + "'");
}
}

return NAdapter::TDataBuilderPolicy<TDataContainer>::Build(dstSchema, std::move(columns), srcBatch->num_rows());
}

template <class TDataContainer, class TStringType>
std::shared_ptr<TDataContainer> ExtractImpl(const TColumnOperator::EExtractProblemsPolicy& policy,
const std::shared_ptr<TDataContainer>& incoming, const std::vector<TStringType>& columnNames) {
AFL_VERIFY(incoming);
AFL_VERIFY(columnNames.size());
auto result = ExtractColumnsValidateImpl(incoming, columnNames);
switch (policy) {
case TColumnOperator::EExtractProblemsPolicy::Verify:
AFL_VERIFY((ui32)result->num_columns() == columnNames.size())("schema", incoming->schema()->ToString())("required", JoinSeq(",", columnNames));
break;
case TColumnOperator::EExtractProblemsPolicy::Null:
if ((ui32)result->num_columns() != columnNames.size()) {
return nullptr;
}
break;
case TColumnOperator::EExtractProblemsPolicy::Skip:
break;
}
return result;
}

template <class TDataContainer, class TStringType>
TConclusion<std::shared_ptr<TDataContainer>> ReorderImpl(const std::shared_ptr<TDataContainer>& incoming, const std::vector<TStringType>& columnNames) {
AFL_VERIFY(!!incoming);
AFL_VERIFY(columnNames.size());
if ((ui32)incoming->num_columns() < columnNames.size()) {
return TConclusionStatus::Fail("not enough columns for exact reordering");
}
if ((ui32)incoming->num_columns() > columnNames.size()) {
return TConclusionStatus::Fail("need extraction before reorder call");
}
auto result = ExtractColumnsValidateImpl(incoming, columnNames);
AFL_VERIFY(result);
if ((ui32)result->num_columns() != columnNames.size()) {
return TConclusionStatus::Fail("not enough fields for exact reordering");
}
return result;
}

}

std::shared_ptr<arrow::RecordBatch> TColumnOperator::Extract(const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<std::string>& columnNames) {
return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
}

std::shared_ptr<arrow::Table> TColumnOperator::Extract(const std::shared_ptr<arrow::Table>& incoming, const std::vector<std::string>& columnNames) {
return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
}

std::shared_ptr<arrow::RecordBatch> TColumnOperator::Extract(const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<TString>& columnNames) {
return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
}

std::shared_ptr<arrow::Table> TColumnOperator::Extract(const std::shared_ptr<arrow::Table>& incoming, const std::vector<TString>& columnNames) {
return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
}

NKikimr::TConclusion<std::shared_ptr<arrow::RecordBatch>> TColumnOperator::Adapt(const std::shared_ptr<arrow::RecordBatch>& incoming, const std::shared_ptr<arrow::Schema>& dstSchema) {
return AdaptColumnsImpl(incoming, dstSchema);
}

NKikimr::TConclusion<std::shared_ptr<arrow::Table>> TColumnOperator::Adapt(const std::shared_ptr<arrow::Table>& incoming, const std::shared_ptr<arrow::Schema>& dstSchema) {
return AdaptColumnsImpl(incoming, dstSchema);
}

NKikimr::TConclusion<std::shared_ptr<arrow::RecordBatch>> TColumnOperator::Reorder(const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<std::string>& columnNames) {
return ReorderImpl(incoming, columnNames);
}

NKikimr::TConclusion<std::shared_ptr<arrow::Table>> TColumnOperator::Reorder(const std::shared_ptr<arrow::Table>& incoming, const std::vector<std::string>& columnNames) {
return ReorderImpl(incoming, columnNames);
}

NKikimr::TConclusion<std::shared_ptr<arrow::RecordBatch>> TColumnOperator::Reorder(const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<TString>& columnNames) {
return ReorderImpl(incoming, columnNames);
}

NKikimr::TConclusion<std::shared_ptr<arrow::Table>> TColumnOperator::Reorder(const std::shared_ptr<arrow::Table>& incoming, const std::vector<TString>& columnNames) {
return ReorderImpl(incoming, columnNames);
}

}
Loading

0 comments on commit 6e81fbe

Please sign in to comment.