Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YQ-2549 Checkpointing in match_recognize #1860

Merged
merged 23 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 196 additions & 23 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp

Large diffs are not rendered by default.

46 changes: 44 additions & 2 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_list.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#include "mkql_match_recognize_save_load.h"

#include <ydb/library/yql/minikql/defs.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
#include <ydb/library/yql/minikql/comp_nodes/mkql_saveload.h>
#include <ydb/library/yql/public/udf/udf_value.h>
#include <unordered_map>

Expand Down Expand Up @@ -131,15 +135,37 @@ class TSparseList {
}
}

void Save(TOutputSerializer& serializer) const {
serializer(Storage.size());
for (const auto& [key, item]: Storage) {
serializer(key, item.Value, item.LockCount);
}
}

void Load(TInputSerializer& serializer) {
auto size = serializer.Read<TStorage::size_type>();
Storage.reserve(size);
for (size_t i = 0; i < size; ++i) {
kardymonds marked this conversation as resolved.
Show resolved Hide resolved
TStorage::key_type key;
NUdf::TUnboxedValue row;
decltype(TItem::LockCount) lockCount;
serializer(key, row, lockCount);
Storage.emplace(key, TItem{row, lockCount});
}
}

private:
//TODO consider to replace hash table with contiguous chunks
using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
std::unordered_map<

using TStorage = std::unordered_map<
size_t,
TItem,
std::hash<size_t>,
std::equal_to<size_t>,
TAllocator> Storage;
TAllocator>;

TStorage Storage;
};
using TContainerPtr = TContainer::TPtr;

Expand Down Expand Up @@ -242,6 +268,14 @@ class TSparseList {
ToIndex = -1;
}

void Save(TOutputSerializer& serializer) const {
serializer(Container, FromIndex, ToIndex);
}

void Load(TInputSerializer& serializer) {
serializer(Container, FromIndex, ToIndex);
}

private:
TRange(TContainerPtr container, size_t index)
: Container(container)
Expand Down Expand Up @@ -297,6 +331,14 @@ class TSparseList {
return Size() == 0;
}

void Save(TOutputSerializer& serializer) const {
serializer(Container, ListSize);
}

void Load(TInputSerializer& serializer) {
serializer(Container, ListSize);
}

private:
TContainerPtr Container = MakeIntrusive<TContainer>();
size_t ListSize = 0; //impl: max index ever stored + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {

template<class R>
using TMatchedVar = std::vector<R, TMKQLAllocator<R>>;

template<class R>
void Extend(TMatchedVar<R>& var, const R& r) {
if (var.empty()) {
Expand Down Expand Up @@ -110,8 +111,7 @@ class TMatchedVarsValue : public TComputationValue<TMatchedVarsValue<R>> {
: TComputationValue<TMatchedVarsValue>(memInfo)
, HolderFactory(holderFactory)
, Vars(vars)
{
}
{}

NUdf::TUnboxedValue GetElement(ui32 index) const override {
return HolderFactory.Create<TRangeList>(HolderFactory, Vars[index]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TRowForMeasureValue: public TComputationValue<TRowForMeasureValue>
, VarNames(varNames)
, MatchNumber(matchNumber)
{}

NUdf::TUnboxedValue GetElement(ui32 index) const override {
switch(ColumnOrder[index].first) {
case EMeasureInputDataSpecialColumns::Classifier: {
Expand Down
137 changes: 132 additions & 5 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_save_load.h"
#include "../computation/mkql_computation_node_holders.h"
#include "../computation/mkql_computation_node_impl.h"
#include <ydb/library/yql/core/sql_types/match_recognize.h>
Expand All @@ -12,20 +13,38 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
using namespace NYql::NMatchRecognize;

struct TVoidTransition {
friend bool operator==(const TVoidTransition&, const TVoidTransition&) {
return true;
}
};
using TEpsilonTransition = size_t; //to
using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
using TMatchedVarTransition = std::pair<std::pair<ui32, bool>, size_t>; //{{varIndex, saveState}, to}
using TQuantityEnterTransition = size_t; //to
using TQuantityExitTransition = std::pair<std::pair<ui64, ui64>, std::pair<size_t, size_t>>; //{{min, max}, {foFindMore, toMatched}}
using TNfaTransition = std::variant<

template <typename... Ts>
struct TVariantHelper {
using TVariant = std::variant<Ts...>;
using TTuple = std::tuple<Ts...>;

static std::variant<Ts...> getVariantByIndex(size_t i) {
MKQL_ENSURE(i < sizeof...(Ts), "Wrong variant index");
static std::variant<Ts...> table[] = { Ts{ }... };
return table[i];
}
};

using TNfaTransitionHelper = TVariantHelper<
TVoidTransition,
TMatchedVarTransition,
TEpsilonTransitions,
TQuantityEnterTransition,
TQuantityExitTransition
>;

using TNfaTransition = TNfaTransitionHelper::TVariant;

struct TNfaTransitionDestinationVisitor {
std::function<size_t(size_t)> callback;

Expand Down Expand Up @@ -61,11 +80,42 @@ struct TNfaTransitionDestinationVisitor {
};

struct TNfaTransitionGraph {
std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Граф переходов - это статическая когфигурация, которая строится по SQL запросу. Его нет большого смысла сохранять. Нужно ли уметь проверять, что он не изменился при перезапуске

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я не понял лучше что делать:

  • граф переходов не сохранять и строить заново при восстановлении,
  • сохранять/восстанавливать и проверять что он не изменился при перезапуске (видимо сравнить восстановленный и заново построенный?).
    Или что-то другое имелось ввиду?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Пожалуй, на данном этапе будет лучше сохранять и при восстановлении сравнивать, что он не отличается от заново построенного

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Добавил проверку - сохраняю TNfaTransitionGraph один раз (в TStateForInterleavedPartitions), при ресторе проверяю что не изменилось.


TTransitions Transitions;
size_t Input;
size_t Output;

using TPtr = std::shared_ptr<TNfaTransitionGraph>;

template<class>
inline constexpr static bool always_false_v = false;

void Save(TOutputSerializer& serializer) const {
serializer(Transitions.size());
for (ui64 i = 0; i < Transitions.size(); ++i) {
serializer.Write(Transitions[i].index());
std::visit(serializer, Transitions[i]);
}
serializer(Input, Output);
}

void Load(TInputSerializer& serializer) {
ui64 transitionSize = serializer.Read<TTransitions::size_type>();
Transitions.resize(transitionSize);
for (ui64 i = 0; i < transitionSize; ++i) {
size_t index = serializer.Read<std::size_t>();
Transitions[i] = TNfaTransitionHelper::getVariantByIndex(index);
std::visit(serializer, Transitions[i]);
}
serializer(Input, Output);
}

bool operator==(const TNfaTransitionGraph& other) {
return Transitions == other.Transitions
&& Input == other.Input
&& Output == other.Output;
}
};

class TNfaTransitionGraphOptimizer {
Expand All @@ -78,6 +128,7 @@ class TNfaTransitionGraphOptimizer {
EliminateSingleEpsilons();
CollectGarbage();
}

private:
void EliminateEpsilonChains() {
for (size_t node = 0; node != Graph->Transitions.size(); node++) {
Expand Down Expand Up @@ -250,14 +301,69 @@ class TNfaTransitionGraphBuilder {
class TNfa {
using TRange = TSparseList::TRange;
using TMatchedVars = TMatchedVars<TRange>;


struct TState {

TState() {}

TState(size_t index, const TMatchedVars& vars, std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>&& quantifiers)
: Index(index)
, Vars(vars)
, Quantifiers(quantifiers) {}
const size_t Index;
size_t Index;
TMatchedVars Vars;
std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>> Quantifiers; //get rid of this

using TQuantifiersStdStack = std::stack<
ui64,
std::deque<ui64, TMKQLAllocator<ui64>>>; //get rid of this

struct TQuantifiersStack: public TQuantifiersStdStack {
template<typename...TArgs>
TQuantifiersStack(TArgs... args) : TQuantifiersStdStack(args...) {}

auto begin() const { return c.begin(); }
auto end() const { return c.end(); }
auto clear() { return c.clear(); }
};

TQuantifiersStack Quantifiers;

void Save(TOutputSerializer& serializer) const {
serializer.Write(Index);
serializer.Write(Vars.size());
for (const auto& vector : Vars) {
serializer.Write(vector.size());
for (const auto& range : vector) {
range.Save(serializer);
}
}
serializer.Write(Quantifiers.size());
for (ui64 qnt : Quantifiers) {
serializer.Write(qnt);
}
}

void Load(TInputSerializer& serializer) {
serializer.Read(Index);

auto varsSize = serializer.Read<TMatchedVars::size_type>();
Vars.clear();
Vars.resize(varsSize);
for (auto& subvec: Vars) {
ui64 vectorSize = serializer.Read<ui64>();
subvec.resize(vectorSize);
for (auto& item : subvec) {
item.Load(serializer);
}
}
Quantifiers.clear();
auto quantifiersSize = serializer.Read<ui64>();
for (size_t i = 0; i < quantifiersSize; ++i) {
ui64 qnt = serializer.Read<ui64>();
Quantifiers.push(qnt);
}
}

friend inline bool operator<(const TState& lhs, const TState& rhs) {
return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
Expand All @@ -267,13 +373,14 @@ class TNfa {
}
};
public:

TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
: TransitionGraph(transitionGraph)
, MatchedRangesArg(matchedRangesArg)
, Defines(defines) {
}

void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>{});
MakeEpsilonTransitions();
std::set<TState, std::less<TState>, TMKQLAllocator<TState>> newStates;
Expand Down Expand Up @@ -329,6 +436,25 @@ class TNfa {
return ActiveStates.size();
}

void Save(TOutputSerializer& serializer) const {
// TransitionGraph is not saved/loaded, passed in constructor.
serializer.Write(ActiveStates.size());
for (const auto& state : ActiveStates) {
state.Save(serializer);
}
serializer.Write(EpsilonTransitionsLastRow);
}

void Load(TInputSerializer& serializer) {
auto stateSize = serializer.Read<ui64>();
for (size_t i = 0; i < stateSize; ++i) {
TState state;
state.Load(serializer);
ActiveStates.emplace(state);
}
serializer.Read(EpsilonTransitionsLastRow);
}

private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
Expand Down Expand Up @@ -376,6 +502,7 @@ class TNfa {
TStateSet& NewStates;
TStateSet& DeletedStates;
};

bool MakeEpsilonTransitionsImpl() {
TStateSet newStates;
TStateSet deletedStates;
Expand Down
Loading
Loading