diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index 4734288dcdac..fca569454432 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -101,7 +101,7 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { TMaybeNode RewriteAggregate(TExprBase node, TExprContext& ctx) { TMaybeNode output; auto aggregate = node.Cast(); - auto hopSetting = GetSetting(aggregate.Settings().Ref(), "hopping"); + auto hopSetting = GetSetting(aggregate.Settings().Ref(), "hopping"); if (hopSetting) { auto input = aggregate.Input().Maybe(); if (!input) { @@ -116,7 +116,9 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { true, // defaultWatermarksMode true); // syncActor } else { - output = DqRewriteAggregate(node, ctx, TypesCtx, false, KqpCtx.Config->HasOptEnableOlapPushdown() || KqpCtx.Config->HasOptUseFinalizeByKey(), KqpCtx.Config->HasOptUseFinalizeByKey()); + output = DqRewriteAggregate(node, ctx, TypesCtx, false, + KqpCtx.Config->HasOptEnableOlapPushdown() || KqpCtx.Config->HasOptUseFinalizeByKey(), + KqpCtx.Config->HasOptUseFinalizeByKey(), false /* allowSpilling */); } if (output) { DumpAppliedRule("RewriteAggregate", node.Ptr(), output.Cast().Ptr(), ctx); diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp index a4cfee447263..ea22fbc53523 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp @@ -54,6 +54,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { AddHandler(0, &TCoCombineByKey::Match, HNDL(PushCombineToStage)); AddHandler(0, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage)); AddHandler(0, &TCoFinalizeByKey::Match, HNDL(BuildFinalizeByKeyStage)); + AddHandler(0, &TCoFinalizeByKeyWithSpilling::Match, HNDL(BuildFinalizeByKeyWithSpillingStage)); AddHandler(0, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage)); AddHandler(0, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage)); AddHandler(0, &TCoTop::Match, HNDL(BuildTopStage)); @@ -101,6 +102,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { AddHandler(1, &TCoCombineByKey::Match, HNDL(PushCombineToStage)); AddHandler(1, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage)); AddHandler(1, &TCoFinalizeByKey::Match, HNDL(BuildFinalizeByKeyStage)); + AddHandler(1, &TCoFinalizeByKeyWithSpilling::Match, HNDL(BuildFinalizeByKeyWithSpillingStage)); AddHandler(1, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage)); AddHandler(1, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage)); AddHandler(1, &TCoTop::Match, HNDL(BuildTopStage)); @@ -321,6 +323,13 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { return output; } + template + TMaybeNode BuildFinalizeByKeyWithSpillingStage(TExprBase node, TExprContext& ctx, const TGetParents& getParents) { + TExprBase output = DqBuildFinalizeByKeyWithSpillingStage(node, ctx, *getParents(), IsGlobal); + DumpAppliedRule("BuildFinalizeByKeyWithSpillingStage", node.Ptr(), output.Ptr(), ctx); + return output; + } + template TMaybeNode BuildPartitionsStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TGetParents& getParents) diff --git a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json index 2b9527a6dde9..19ffb5d24065 100644 --- a/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json +++ b/ydb/library/yql/core/expr_nodes/yql_expr_nodes.json @@ -438,6 +438,19 @@ {"Index": 5, "Name": "FinishHandlerLambda", "Type": "TCoLambda"} ] }, + { + "Name": "TCoCombineByKeyWithSpilling", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "CombineByKeyWithSpilling"}, + "Children": [ + {"Index": 1, "Name": "PreMapLambda", "Type": "TCoLambda"}, + {"Index": 2, "Name": "KeySelectorLambda", "Type": "TCoLambda"}, + {"Index": 3, "Name": "InitHandlerLambda", "Type": "TCoLambda"}, + {"Index": 4, "Name": "UpdateHandlerLambda", "Type": "TCoLambda"}, + {"Index": 5, "Name": "FinishHandlerLambda", "Type": "TCoLambda"}, + {"Index": 6, "Name": "LoadHandlerLambda", "Type": "TCoLambda"} + ] + }, { "Name": "TCoFinalizeByKey", "Base": "TCoInputBase", @@ -450,6 +463,19 @@ {"Index": 5, "Name": "FinishHandlerLambda", "Type": "TCoLambda"} ] }, + { + "Name": "TCoFinalizeByKeyWithSpilling", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "FinalizeByKeyWithSpilling"}, + "Children": [ + {"Index": 1, "Name": "PreMapLambda", "Type": "TCoLambda"}, + {"Index": 2, "Name": "KeySelectorLambda", "Type": "TCoLambda"}, + {"Index": 3, "Name": "InitHandlerLambda", "Type": "TCoLambda"}, + {"Index": 4, "Name": "UpdateHandlerLambda", "Type": "TCoLambda"}, + {"Index": 5, "Name": "FinishHandlerLambda", "Type": "TCoLambda"}, + {"Index": 6, "Name": "SaveHandlerLambda", "Type": "TCoLambda"} + ] + }, { "Name": "TCoAggrCountInit", "Base": "TCallable", @@ -1625,7 +1651,7 @@ {"Index": 5, "Name": "LeftRenames", "Type": "TCoAtomList"}, {"Index": 6, "Name": "RightRenames", "Type": "TCoAtomList"}, {"Index": 7, "Name": "LeftKeysColumnNames", "Type": "TCoAtomList"}, - {"Index": 8, "Name": "RightKeysColumnNames", "Type": "TCoAtomList"}, + {"Index": 8, "Name": "RightKeysColumnNames", "Type": "TCoAtomList"}, {"Index": 9, "Name": "Flags", "Type": "TCoAtomList"} ] }, @@ -1641,7 +1667,7 @@ {"Index": 4, "Name": "LeftRenames", "Type": "TCoAtomList"}, {"Index": 5, "Name": "RightRenames", "Type": "TCoAtomList"}, {"Index": 6, "Name": "LeftKeysColumnNames", "Type": "TCoAtomList"}, - {"Index": 7, "Name": "RightKeysColumnNames", "Type": "TCoAtomList"}, + {"Index": 7, "Name": "RightKeysColumnNames", "Type": "TCoAtomList"}, {"Index": 8, "Name": "Flags", "Type": "TCoAtomList"} ] }, @@ -1977,6 +2003,19 @@ {"Index": 5, "Name": "MemLimit", "Type": "TCoAtom", "Optional": true} ] }, + { + "Name": "TCoCombineCoreWithSpilling", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "CombineCoreWithSpilling"}, + "Children": [ + {"Index": 1, "Name": "KeyExtractor", "Type": "TCoLambda"}, + {"Index": 2, "Name": "InitHandler", "Type": "TCoLambda"}, + {"Index": 3, "Name": "UpdateHandler", "Type": "TCoLambda"}, + {"Index": 4, "Name": "FinishHandler", "Type": "TCoLambda"}, + {"Index": 5, "Name": "LoadHandler", "Type": "TCoLambda"}, + {"Index": 6, "Name": "MemLimit", "Type": "TCoAtom", "Optional": true} + ] + }, { "Name": "TCoWideCombiner", "Base": "TCoInputBase", @@ -1989,6 +2028,20 @@ {"Index": 5, "Name": "FinishHandler", "Type": "TCoLambda"} ] }, + { + "Name": "TCoWideCombinerWithSpilling", + "Base": "TCoInputBase", + "Match": {"Type": "Callable", "Name": "WideCombinerWithSpilling"}, + "Children": [ + {"Index": 1, "Name": "MemLimit", "Type": "TCoAtom"}, + {"Index": 2, "Name": "KeyExtractor", "Type": "TCoLambda"}, + {"Index": 3, "Name": "InitHandler", "Type": "TCoLambda"}, + {"Index": 4, "Name": "UpdateHandler", "Type": "TCoLambda"}, + {"Index": 5, "Name": "FinishHandler", "Type": "TCoLambda"}, + {"Index": 6, "Name": "SerializeHandler", "Type": "TCoLambda"}, + {"Index": 7, "Name": "DeserializeHandler", "Type": "TCoLambda"} + ] + }, { "Name": "TCoGroupingCore", "Base": "TCoInputBase", diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index 611816bcee25..fc165cca6b18 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -3335,6 +3335,41 @@ TExprNode::TPtr ExpandCombineByKey(const TExprNode::TPtr& node, TExprContext& ct .Ptr(); } +TExprNode::TPtr ExpandCombineByKeyWithSpilling(const TExprNode::TPtr& node, TExprContext& ctx) { + const bool isStreamOrFlow = node->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream || + node->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Flow; + + if (!isStreamOrFlow) { + return node; + } + + YQL_CLOG(DEBUG, CorePeepHole) << "Expand " << node->Content() << " over stream or flow"; + + TCoCombineByKeyWithSpilling combine(node); + + return Build(ctx, node->Pos()) + .Input() + .Input(combine.Input()) + .Lambda() + .Args({"arg"}) + .Body() + .Apply(combine.PreMapLambda()) + .With(0, "arg") + .Build() + .Build() + .Build() + .KeyExtractor(combine.KeySelectorLambda()) + .InitHandler(combine.InitHandlerLambda()) + .UpdateHandler(combine.UpdateHandlerLambda()) + .FinishHandler(combine.FinishHandlerLambda()) + .LoadHandler(combine.LoadHandlerLambda()) + .MemLimit() + .Value("0") + .Build() + .Done() + .Ptr(); +} + template TExprNode::TPtr MakeWideMapJoinCore(const TExprNode& mapjoin, TExprNode::TPtr&& input, TExprContext& ctx) { const auto inStructType = GetSeqItemType(mapjoin.Head().GetTypeAnn())->Cast(); @@ -3453,6 +3488,60 @@ ui32 CollectStateNodes(const TExprNode& initLambda, const TExprNode& updateLambd return size; } +ui32 CollectStateNodesForSpilling(const TExprNode& initLambda, const TExprNode& updateLambda, const TExprNode& serdeLambda, TExprNode::TListType& fields, TExprNode::TListType& init, TExprNode::TListType& update, TExprNode::TListType& serdeState, TExprContext& ctx) { + YQL_ENSURE(IsSameAnnotation(*initLambda.Tail().GetTypeAnn(), *updateLambda.Tail().GetTypeAnn()), "Must be same type."); + + if (ETypeAnnotationKind::Struct != initLambda.Tail().GetTypeAnn()->GetKind()) { + fields.clear(); + init = TExprNode::TListType(1U, initLambda.TailPtr()); + update = TExprNode::TListType(1U, updateLambda.TailPtr()); + serdeState = TExprNode::TListType(1U, serdeLambda.TailPtr()); + return 1U; + } + + const auto structType = initLambda.Tail().GetTypeAnn()->Cast(); + const auto size = structType->GetSize(); + fields.reserve(size); + init.reserve(size); + update.reserve(size); + serdeState.reserve(size); + fields.clear(); + init.clear(); + update.clear(); + serdeState.clear(); + + if (initLambda.Tail().IsCallable("AsStruct")) + initLambda.Tail().ForEachChild([&](const TExprNode& child) { fields.emplace_back(child.HeadPtr()); }); + else if (updateLambda.Tail().IsCallable("AsStruct")) + updateLambda.Tail().ForEachChild([&](const TExprNode& child) { fields.emplace_back(child.HeadPtr()); }); + else + for (const auto& item : structType->GetItems()) + fields.emplace_back(ctx.NewAtom(initLambda.Tail().Pos(), item->GetName())); + + if (initLambda.Tail().IsCallable("AsStruct")) + initLambda.Tail().ForEachChild([&](const TExprNode& child) { init.emplace_back(child.TailPtr()); }); + else + std::transform(fields.cbegin(), fields.cend(), std::back_inserter(init), [&](const TExprNode::TPtr& name) { + return ctx.NewCallable(name->Pos(), "Member", {initLambda.TailPtr(), name}); + }); + + if (updateLambda.Tail().IsCallable("AsStruct")) + updateLambda.Tail().ForEachChild([&](const TExprNode& child) { update.emplace_back(child.TailPtr()); }); + else + std::transform(fields.cbegin(), fields.cend(), std::back_inserter(update), [&](const TExprNode::TPtr& name) { + return ctx.NewCallable(name->Pos(), "Member", {updateLambda.TailPtr(), name}); + }); + + if (serdeLambda.Tail().IsCallable("AsStruct")) + serdeLambda.Tail().ForEachChild([&](const TExprNode& child) { serdeState.emplace_back(child.TailPtr()); }); + else + std::transform(fields.cbegin(), fields.cend(), std::back_inserter(serdeState), [&](const TExprNode::TPtr& name) { + return ctx.NewCallable(name->Pos(), "Member", {serdeLambda.TailPtr(), name}); + }); + + return size; +} + TExprNode::TPtr ExpandFinalizeByKey(const TExprNode::TPtr& node, TExprContext& ctx) { if (node->GetTypeAnn()->GetKind() == ETypeAnnotationKind::List) { return ctx.NewCallable(node->Pos(), "Collect", @@ -3631,6 +3720,220 @@ TExprNode::TPtr ExpandFinalizeByKey(const TExprNode::TPtr& node, TExprContext& c .Seal().Build(); } +TExprNode::TPtr ExpandFinalizeByKeyWithSpilling(const TExprNode::TPtr& node, TExprContext& ctx) { + if (node->GetTypeAnn()->GetKind() == ETypeAnnotationKind::List) { + return ctx.NewCallable(node->Pos(), "Collect", + { ctx.ChangeChild(*node, 0, ctx.NewCallable(node->Pos(), "ToFlow", { node->HeadPtr() })) }); + } + if (node->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) { + return ctx.NewCallable(node->Pos(), "FromFlow", + { ctx.ChangeChild(*node, 0, ctx.NewCallable(node->Pos(), "ToFlow", { node->HeadPtr() })) }); + } + + YQL_CLOG(DEBUG, CorePeepHole) << "Expand " << node->Content() << " over stream or flow"; + + TCoFinalizeByKeyWithSpilling combine(node); + + const auto inputStructType = GetSeqItemType(combine.PreMapLambda().Body().Ref().GetTypeAnn())->Cast(); + const auto inputWidth = inputStructType->GetSize(); + + TExprNode::TListType inputFields; + inputFields.reserve(inputWidth); + for (const auto& item : inputStructType->GetItems()) { + inputFields.emplace_back(ctx.NewAtom(combine.PreMapLambda().Pos(), item->GetName())); + } + + TExprNode::TListType stateFields, init, update, outputFields, save; + const auto stateWidth = CollectStateNodesForSpilling(combine.InitHandlerLambda().Ref(), combine.UpdateHandlerLambda().Ref(), combine.SaveHandlerLambda().Ref(), stateFields, init, update, save, ctx); + + auto output = combine.FinishHandlerLambda().Body().Ptr(); + const auto outputStructType = GetSeqItemType(node->GetTypeAnn())->Cast(); + const ui32 outputWidth = outputStructType ? outputStructType->GetSize() : 1; + TExprNode::TListType finish; + finish.reserve(outputWidth); + if (output->IsCallable("AsStruct")) { + output->ForEachChild([&](const TExprNode& child) { + outputFields.emplace_back(child.HeadPtr()); + finish.emplace_back(child.TailPtr()); + }); + } else if (outputStructType) { + for (const auto& item : outputStructType->GetItems()) { + outputFields.emplace_back(ctx.NewAtom(output->Pos(), item->GetName())); + finish.emplace_back(ctx.NewCallable(output->Pos(), "Member", { output, outputFields.back() })); + } + } else { + finish.emplace_back(output); + } + + const auto uniteToStructure = [&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 0U; i < inputWidth; ++i) { + parent + .List(i) + .Add(0, inputFields[i]) + .Arg(1, "items", i) + .Seal(); + } + return parent; + }; + + return ctx.Builder(node->Pos()) + .Callable("NarrowMap") + .Callable(0, "WideCombinerWithSpilling") + .Callable(0, "ExpandMap") + .Callable(0, "FlatMap") + .Add(0, combine.Input().Ptr()) + .Add(1, combine.PreMapLambda().Ptr()) + .Seal() + .Lambda(1) + .Param("item") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 0U; i < inputWidth; ++i) { + parent.Callable(i, "Member") + .Arg(0, "item") + .Add(1, inputFields[i]) + .Seal(); + } + return parent; + }) + .Seal() + .Seal() + .Atom(1, "") + .Lambda(2) + .Params("items", inputWidth) + .Apply(combine.KeySelectorLambda().Ref()) + .With(0) + .Callable("AsStruct") + .Do(uniteToStructure) + .Seal() + .Done() + .Seal() + .Seal() + .Lambda(3) + .Param("key") + .Params("items", inputWidth) + .ApplyPartial(combine.InitHandlerLambda().Args().Ptr(), init) + .With(0, "key") + .With(1) + .Callable("AsStruct") + .Do(uniteToStructure) + .Seal() + .Done() + .Seal() + .Seal() + .Lambda(4) + .Param("key") + .Params("items", inputWidth) + .Params("state", stateWidth) + .ApplyPartial(combine.UpdateHandlerLambda().Args().Ptr(), std::move(update)) + .With(0) + .Arg("key") + .Done() + .With(1) + .Callable("AsStruct") + .Do(uniteToStructure) + .Seal() + .Done() + .With(2) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(5) + .Param("key") + .Params("state", stateWidth) + .ApplyPartial(combine.FinishHandlerLambda().Args().Ptr(), std::move(finish)) + .With(0, "key") + .With(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(6) + .Param("key") + .Params("state", stateWidth) + .ApplyPartial(combine.SaveHandlerLambda().Args().Ptr(), std::move(save)) + .With(0, "key") + .With(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(7) + .Param("key") + .Params("items", inputWidth) + .ApplyPartial(combine.InitHandlerLambda().Args().Ptr(), std::move(init)) + .With(0, "key") + .With(1) + .Callable("AsStruct") + .Do(uniteToStructure) + .Seal() + .Done() + .Seal() + .Seal() + .Seal() + .Lambda(1) + .Params("items", outputWidth) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (outputFields.empty()) + parent.Arg("items", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < outputWidth; ++i) { + str.List(i) + .Add(0, std::move(outputFields[i])) + .Arg(1, "items", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Seal() + .Seal().Build(); +} + // TODO: move in context using TLiteralStructIndexMap = std::unordered_map; using TLieralStructsCacheMap = std::unordered_map; @@ -4504,6 +4807,191 @@ TExprNode::TPtr OptimizeCombineCore(const TExprNode::TPtr& node, TExprContext& c return node; } +TExprNode::TPtr OptimizeCombineCoreWithSpilling(const TExprNode::TPtr& node, TExprContext& ctx) { + if (node->Head().IsCallable("NarrowMap") && node->Child(4U)->Tail().IsCallable("Just")) { + YQL_CLOG(DEBUG, CorePeepHole) << "Swap " << node->Content() << " with " << node->Head().Content(); + + const auto& output = node->Child(4U)->Tail().Head(); + const auto inputWidth = node->Head().Tail().Head().ChildrenSize(); + + const auto structType = ETypeAnnotationKind::Struct == output.GetTypeAnn()->GetKind() ? output.GetTypeAnn()->Cast() : nullptr; + const auto outputWidth = structType ? structType->GetSize() : 1U; + + TExprNode::TListType stateFields, outputFields, init, update, finish, load; + outputFields.reserve(outputWidth); + finish.reserve(outputWidth); + + const auto stateWidth = CollectStateNodesForSpilling(*node->Child(2U), *node->Child(3U), *node->Child(5U), stateFields, init, update, load, ctx); + + if (output.IsCallable("AsStruct")) { + node->Child(4U)->Tail().Head().ForEachChild([&](const TExprNode& child) { + outputFields.emplace_back(child.HeadPtr()); + finish.emplace_back(child.TailPtr()); + }); + } else if (structType) { + for (const auto& item : structType->GetItems()) { + outputFields.emplace_back(ctx.NewAtom(output.Pos(), item->GetName())); + finish.emplace_back(ctx.NewCallable(output.Pos(), "Member", {node->Child(4U)->Tail().HeadPtr(), outputFields.back()})); + } + } else { + finish.emplace_back(node->Child(4U)->Tail().HeadPtr()); + } + + auto limit = node->ChildrenSize() > TCoCombineCoreWithSpilling::idx_MemLimit ? node->TailPtr() : ctx.NewAtom(node->Pos(), ""); + return ctx.Builder(node->Pos()) + .Callable("NarrowMap") + .Callable(0, "WideCombinerWithSpilling") + .Add(0, node->Head().HeadPtr()) + .Add(1, std::move(limit)) + .Lambda(2) + .Params("items", inputWidth) + .Apply(*node->Child(1U)) + .With(0) + .Apply(node->Head().Tail()) + .With("items") + .Seal() + .Done() + .Seal() + .Seal() + .Lambda(3) + .Param("key") + .Params("items", inputWidth) + .ApplyPartial(node->Child(2U)->HeadPtr(), std::move(init)) + .With(0, "key") + .With(1) + .Apply(node->Head().Tail()) + .With("items") + .Seal() + .Done() + .Seal() + .Seal() + .Lambda(4) + .Param("key") + .Params("items", inputWidth) + .Params("state", stateWidth) + .ApplyPartial(node->Child(3U)->HeadPtr(), std::move(update)) + .With(0, "key") + .With(1) + .Apply(node->Head().Tail()) + .With("items") + .Seal() + .Done() + .With(2) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(5) + .Param("key") + .Params("state", stateWidth) + .ApplyPartial(node->Child(4U)->HeadPtr(), finish) + .With(0, "key") + .With(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(6) + .Param("key") + .Params("state", stateWidth) + .ApplyPartial(node->Child(4U)->HeadPtr(), std::move(finish)) + .With(0, "key") + .With(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (stateFields.empty()) + parent.Arg("state", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < stateWidth; ++i) { + str.List(i) + .Add(0, stateFields[i]) + .Arg(1, "state", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Lambda(7) + .Param("key") + .Params("items", outputWidth) + .ApplyPartial(node->Child(5U)->HeadPtr(), std::move(load)) + .With(0, "key") + .With(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (outputFields.empty()) + parent.Arg("items", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < outputWidth; ++i) { + str.List(i) + .Add(0, outputFields[i]) + .Arg(1, "items", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Done() + .Seal() + .Seal() + .Seal() + .Lambda(1) + .Params("items", outputWidth) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (outputFields.empty()) + parent.Arg("items", 0); + else { + auto str = parent.Callable("AsStruct"); + for (ui32 i = 0U; i < outputWidth; ++i) { + str.List(i) + .Add(0, std::move(outputFields[i])) + .Arg(1, "items", i) + .Seal(); + } + str.Seal(); + } + return parent; + }) + .Seal() + .Seal().Build(); + } + + return node; +} + bool IsExpression(const TExprNode& root, const TExprNode& arg) { if (&root == &arg) return false; @@ -6976,7 +7464,7 @@ template TExprNode::TPtr AggrComparePg(const TExprNode& node, TExprContext& ctx) { YQL_CLOG(DEBUG, CorePeepHole) << "Expand '" << node.Content() << "' over Pg."; auto op = Asc ? ( Equals ? "<=" : "<") : ( Equals ? ">=" : ">"); - auto finalPart = (Equals == Asc) ? MakeBool(node.Pos(), ctx) : + auto finalPart = (Equals == Asc) ? MakeBool(node.Pos(), ctx) : ctx.NewCallable(node.Pos(), "Exists", { node.TailPtr() }); if (!Asc) { finalPart = ctx.NewCallable(node.Pos(), "Not", { finalPart }); @@ -8324,7 +8812,9 @@ struct TPeepHoleRules { {"And", &OptimizeLogicalDups}, {"Or", &OptimizeLogicalDups}, {"CombineByKey", &ExpandCombineByKey}, + {"CombineByKeyWithSpilling", &ExpandCombineByKeyWithSpilling}, {"FinalizeByKey", &ExpandFinalizeByKey}, + {"FinalizeByKeyWithSpilling", &ExpandFinalizeByKeyWithSpilling}, {"SkipNullMembers", &ExpandSkipNullFields}, {"SkipNullElements", &ExpandSkipNullFields}, {"ConstraintsOf", &ExpandConstraintsOf}, @@ -8421,6 +8911,7 @@ struct TPeepHoleRules { {"Member", &OptimizeMember}, {"Condense1", &OptimizeCondense1}, {"CombineCore", &OptimizeCombineCore}, + {"CombineCoreWithSpilling", &OptimizeCombineCoreWithSpilling}, {"Chopper", &OptimizeChopper}, {"WideCombiner", &OptimizeWideCombiner}, {"WideCondense1", &OptimizeWideCondense1}, @@ -8488,7 +8979,7 @@ struct TPeepHoleRules { const TExtPeepHoleOptimizerMap BlockStageExtFinalRules = { {"BlockExtend", &ExpandBlockExtend}, {"BlockOrderedExtend", &ExpandBlockExtend}, - {"ReplicateScalars", &ExpandReplicateScalars} + {"ReplicateScalars", &ExpandReplicateScalars} }; static const TPeepHoleRules& Instance() { diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index 09af8bc773af..f28bab044022 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -1232,7 +1232,7 @@ namespace NTypeAnnImpl { YQL_ENSURE(IsSameAnnotation(*memberType, *resultType) || IsSameAnnotation(*ctx.Expr.MakeType(memberType), *resultType)); - + output = ctx.Expr.Builder(input->Pos()) .Callable("IfStrict") .Callable(0, "Exists") @@ -2749,8 +2749,8 @@ namespace NTypeAnnImpl { if (!(*dataTypeOne == *dataTypeTwo)) { ctx.Expr.AddError(TIssue( ctx.Expr.GetPosition(input->Pos()), - TStringBuilder() << "Cannot calculate with different decimals: " - << static_cast(*dataType[0]) << " != " + TStringBuilder() << "Cannot calculate with different decimals: " + << static_cast(*dataType[0]) << " != " << static_cast(*dataType[1]) )); @@ -6193,18 +6193,18 @@ template }; auto mergeMembers = [&ctx, &buildJustMember, &input, &left, &right, &mergeLambda](const TStringBuf& name, bool hasLeft, bool hasRight) -> TExprNode::TPtr { - auto leftMaybe = hasLeft ? + auto leftMaybe = hasLeft ? buildJustMember(left, name) : ctx.Expr.NewCallable(input->Pos(), "Nothing", { ExpandType(input->Pos(), *ctx.Expr.MakeType(right->GetTypeAnn()->Cast()->FindItemType(name)), ctx.Expr) }); - - auto rightMaybe = hasRight ? + + auto rightMaybe = hasRight ? buildJustMember(right, name) : ctx.Expr.NewCallable(input->Pos(), "Nothing", { ExpandType(input->Pos(), *ctx.Expr.MakeType(left->GetTypeAnn()->Cast()->FindItemType(name)), ctx.Expr) }); - + return ctx.Expr.Builder(input->Pos()) .List() .Atom(0, name) @@ -6492,7 +6492,7 @@ template .With(1, result) .Seal() .Build(); - } + } } } else if (collection->GetTypeAnn()->GetKind() == ETypeAnnotationKind::Tuple) { for (size_t idx = 0; idx < collection->GetTypeAnn()->Cast()->GetSize(); idx++) { @@ -6523,7 +6523,7 @@ template return IGraphTransformer::TStatus::Error; } - if (result) { + if (result) { output = result; } else { output = ctx.Expr.Builder(input->Pos()).Callable("Null").Seal().Build(); @@ -12311,6 +12311,7 @@ template Functions["GraceJoinCore"] = &GraceJoinCoreWrapper; Functions["GraceSelfJoinCore"] = &GraceSelfJoinCoreWrapper; Functions["CombineCore"] = &CombineCoreWrapper; + Functions["CombineCoreWithSpilling"] = &CombineCoreWithSpillingWrapper; Functions["GroupingCore"] = &GroupingCoreWrapper; Functions["HoppingTraits"] = &HoppingTraitsWrapper; Functions["HoppingCore"] = &HoppingCoreWrapper; @@ -12333,7 +12334,9 @@ template Functions["CallableResultType"] = &TypeArgWrapper; Functions["CallableArgumentType"] = &TypeArgWrapper; Functions["CombineByKey"] = &CombineByKeyWrapper; + Functions["CombineByKeyWithSpilling"] = &CombineByKeyWrapper; Functions["FinalizeByKey"] = &CombineByKeyWrapper; + Functions["FinalizeByKeyWithSpilling"] = &CombineByKeyWrapper; Functions["NewMTRand"] = &NewMTRandWrapper; Functions["NextMTRand"] = &NextMTRandWrapper; Functions["FormatType"] = &FormatTypeWrapper; @@ -12562,6 +12565,7 @@ template Functions["WideSkipWhileInclusive"] = &WideWhileWrapper; Functions["WideCondense1"] = &WideCondense1Wrapper; Functions["WideCombiner"] = &WideCombinerWrapper; + Functions["WideCombinerWithSpilling"] = &WideCombinerWithSpillingWrapper; Functions["WideChopper"] = &WideChopperWrapper; Functions["WideChain1Map"] = &WideChain1MapWrapper; Functions["WideTop"] = &WideTopWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_impl.h b/ydb/library/yql/core/type_ann/type_ann_impl.h index d249e83db33d..b754b1967407 100644 --- a/ydb/library/yql/core/type_ann/type_ann_impl.h +++ b/ydb/library/yql/core/type_ann/type_ann_impl.h @@ -33,6 +33,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus CommonJoinCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus EquiJoinWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus CombineCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus CombineCoreWithSpillingWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus GroupingCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); TMaybe FindOrReportMissingMember(TStringBuf memberName, TPositionHandle pos, const TStructExprType& structType, TExprContext& ctx); diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index 710bef530808..8d1e18396c69 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -1432,10 +1432,10 @@ namespace { } IGraphTransformer::TStatus ListTopSortWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { - if (!EnsureMinMaxArgsCount(*input, 2, 3, ctx.Expr)) { + if (!EnsureMinMaxArgsCount(*input, 2, 3, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - + TStringBuf newName = input->Content(); newName.Skip(4); bool desc = false; @@ -1457,7 +1457,7 @@ namespace { .Seal() .Build(); } - + return OptListWrapperImpl<4U, 4U>(ctx.Expr.Builder(input->Pos()) .Callable(newName) .Add(0, input->ChildPtr(0)) @@ -4166,7 +4166,17 @@ namespace { IGraphTransformer::TStatus CombineByKeyWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); - if (!EnsureArgsCount(*input, 6, ctx.Expr)) { + ui32 expectedArgsCount = 6; + bool isWithSpilling = false; + if (input->Content().find("WithSpilling") != std::string::npos) { + expectedArgsCount = 7; + isWithSpilling = true; + } + bool isFinalize = false; + if (input->Content().find("FinalizeByKey") != std::string::npos) { + isFinalize = true; + } + if (!EnsureArgsCount(*input, expectedArgsCount, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } @@ -4181,6 +4191,9 @@ namespace { status = status.Combine(ConvertToLambda(input->ChildRef(3), ctx.Expr, 2)); status = status.Combine(ConvertToLambda(input->ChildRef(4), ctx.Expr, 3)); status = status.Combine(ConvertToLambda(input->ChildRef(5), ctx.Expr, 2)); + if (isWithSpilling) { + status = status.Combine(ConvertToLambda(input->ChildRef(6), ctx.Expr, 2)); + } if (status.Level != IGraphTransformer::TStatus::Ok) { return status; } @@ -4261,7 +4274,7 @@ namespace { } const TTypeAnnotationNode* retItemType = nullptr; - if (input->Content() == "FinalizeByKey") { + if (isFinalize) { retItemType = lambdaFinishHandler->GetTypeAnn(); } else { if (!EnsureNewSeqType(*lambdaFinishHandler, ctx.Expr, &retItemType)) { @@ -4269,6 +4282,27 @@ namespace { } } + if (isWithSpilling) { + auto& lambdaSerDeHandler = input->ChildRef(6); + if (isFinalize) { + if (!UpdateLambdaAllArgumentsTypes(lambdaSerDeHandler, {lambdaKeySelector->GetTypeAnn(), stateType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } else { + if (!UpdateLambdaAllArgumentsTypes(lambdaSerDeHandler, {lambdaKeySelector->GetTypeAnn(), retItemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + + if (!lambdaSerDeHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (!EnsureComputableType(lambdaSerDeHandler->Pos(), *lambdaSerDeHandler->GetTypeAnn(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + input->SetTypeAnn(MakeSequenceType(inputTypeKind, *retItemType, ctx.Expr)); return IGraphTransformer::TStatus::Ok; } @@ -6503,7 +6537,7 @@ namespace { return IGraphTransformer::TStatus::Repeat; } - const TTypeAnnotationNode* outputType = ctx.Expr.MakeType(input->IsCallable("PercentRank") ? + const TTypeAnnotationNode* outputType = ctx.Expr.MakeType(input->IsCallable("PercentRank") ? EDataSlot::Double : EDataSlot::Uint64); if (!isAnsi && keyType->GetKind() == ETypeAnnotationKind::Optional) { outputType = ctx.Expr.MakeType(outputType); @@ -7007,6 +7041,121 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus CombineCoreWithSpillingWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureMinMaxArgsCount(*input, 6U, 7U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + const TTypeAnnotationNode* itemType = nullptr; + if (!EnsureNewSeqType(input->Head(), ctx.Expr, &itemType)) { + return IGraphTransformer::TStatus::Error; + } + + if (TCoCombineCoreWithSpilling::idx_MemLimit < input->ChildrenSize()) { + if (!EnsureAtom(*input->Child(TCoCombineCoreWithSpilling::idx_MemLimit), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + ui64 memLimit = 0ULL; + if (!TryFromString(input->Child(TCoCombineCoreWithSpilling::idx_MemLimit)->Content(), memLimit)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(TCoCombineCoreWithSpilling::idx_MemLimit)->Pos()), TStringBuilder() << + "Bad memLimit value: " << input->Child(TCoCombineCoreWithSpilling::idx_MemLimit)->Content())); + return IGraphTransformer::TStatus::Error; + } + } + + TExprNode::TPtr& keyExtractor = input->ChildRef(TCoCombineCoreWithSpilling::idx_KeyExtractor); + TExprNode::TPtr& initHandler = input->ChildRef(TCoCombineCoreWithSpilling::idx_InitHandler); + TExprNode::TPtr& updateHandler = input->ChildRef(TCoCombineCoreWithSpilling::idx_UpdateHandler); + TExprNode::TPtr& finishHandler = input->ChildRef(TCoCombineCoreWithSpilling::idx_FinishHandler); + TExprNode::TPtr& loadHandler = input->ChildRef(TCoCombineCoreWithSpilling::idx_LoadHandler); + + auto status = ConvertToLambda(keyExtractor, ctx.Expr, 1); + status = status.Combine(ConvertToLambda(initHandler, ctx.Expr, 2)); + status = status.Combine(ConvertToLambda(updateHandler, ctx.Expr, 3)); + status = status.Combine(ConvertToLambda(finishHandler, ctx.Expr, 2)); + status = status.Combine(ConvertToLambda(loadHandler, ctx.Expr, 2)); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + // keyExtractor + if (!UpdateLambdaAllArgumentsTypes(keyExtractor, {itemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!keyExtractor->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + auto keyType = keyExtractor->GetTypeAnn(); + if (!EnsureHashableKey(keyExtractor->Pos(), keyType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!EnsureEquatableKey(keyExtractor->Pos(), keyType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + // initHandler + if (!UpdateLambdaAllArgumentsTypes(initHandler, {keyType, itemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!initHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + auto stateType = initHandler->GetTypeAnn(); + if (!EnsureComputableType(initHandler->Pos(), *stateType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + // updateHandler + if (!UpdateLambdaAllArgumentsTypes(updateHandler, {keyType, itemType, stateType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!updateHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + if (!IsSameAnnotation(*updateHandler->GetTypeAnn(), *stateType)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(updateHandler->Pos()), "Mismatch of update lambda return type and state type")); + return IGraphTransformer::TStatus::Error; + } + + // finishHandler + if (!UpdateLambdaAllArgumentsTypes(finishHandler, {keyType, stateType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!finishHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + auto finishType = finishHandler->GetTypeAnn(); + if (!EnsureSeqOrOptionalType(finishHandler->Pos(), *finishType, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto retKind = finishType->GetKind(); + const TTypeAnnotationNode* retItemType = nullptr; + if (retKind == ETypeAnnotationKind::List) { + retItemType = finishType->Cast()->GetItemType(); + } else if (retKind == ETypeAnnotationKind::Optional) { + retItemType = finishType->Cast()->GetItemType(); + } else { + retItemType = finishType->Cast()->GetItemType(); + } + + // loadHandler + if (!UpdateLambdaAllArgumentsTypes(loadHandler, {keyType, retItemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!loadHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + if (!EnsureComputableType(loadHandler->Pos(), *loadHandler->GetTypeAnn(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + input->SetTypeAnn(MakeSequenceType(input->Head().GetTypeAnn()->GetKind(), *retItemType, ctx.Expr)); + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus GroupingCoreWrapper(const TExprNode::TPtr& input, TExprNode::TPtr&, TContext& ctx) { if (!EnsureMinMaxArgsCount(*input, 3, 4, ctx.Expr)) { return IGraphTransformer::TStatus::Error; diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.cpp b/ydb/library/yql/core/type_ann/type_ann_wide.cpp index 6e6e71b0d2eb..bc722e843b1f 100644 --- a/ydb/library/yql/core/type_ann/type_ann_wide.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_wide.cpp @@ -418,6 +418,166 @@ IGraphTransformer::TStatus WideCombinerWrapper(const TExprNode::TPtr& input, TEx return IGraphTransformer::TStatus::Ok; } +IGraphTransformer::TStatus WideCombinerWithSpillingWrapper(const TExprNode::TPtr& input, TExprNode::TPtr&, TContext& ctx) { + if (!EnsureArgsCount(*input, 8U, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!EnsureWideFlowType(input->Head(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + const auto multiType = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast(); + + if (!EnsureAtom(*input->Child(1U), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (const auto& limit = input->Child(1U)->Content(); !limit.empty()) { + i64 memLimit = 0LL; + if (!TryFromString(limit, memLimit)) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(1U)->Pos()), TStringBuilder() << + "Bad memLimit value: " << limit)); + return IGraphTransformer::TStatus::Error; + } + } + + auto& keyExtractor = input->ChildRef(2U); + auto& initHandler = input->ChildRef(3U); + auto& updateHandler = input->ChildRef(4U); + auto& finishHandler = input->ChildRef(5U); + auto& serializeHandler = input->ChildRef(6U); + auto& deserializeHandler = input->ChildRef(7U); + + if (const auto status = ConvertToLambda(keyExtractor, ctx.Expr, multiType->GetSize()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(keyExtractor, multiType->GetItems(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!keyExtractor->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + TTypeAnnotationNode::TListType keyTypes; + keyTypes.reserve(keyExtractor->ChildrenSize() - 1U); + for (ui32 i = 1U; i < keyExtractor->ChildrenSize(); ++i) { + const auto child = keyExtractor->Child(i); + keyTypes.emplace_back(child->GetTypeAnn()); + if (!EnsureHashableKey(child->Pos(), keyTypes.back(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + if (!EnsureEquatableKey(child->Pos(), keyTypes.back(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + + auto argTypes = multiType->GetItems(); + argTypes.insert(argTypes.cbegin(), keyTypes.cbegin(), keyTypes.cend()); + + if (const auto status = ConvertToLambda(initHandler, ctx.Expr, argTypes.size()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(initHandler, argTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!initHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + TTypeAnnotationNode::TListType stateTypes; + stateTypes.reserve(initHandler->ChildrenSize() - 1U); + for (ui32 i = 1U; i < initHandler->ChildrenSize(); ++i) { + const auto child = initHandler->Child(i); + stateTypes.emplace_back(child->GetTypeAnn()); + if (!EnsureComputableType(child->Pos(), *stateTypes.back(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + } + + argTypes.insert(argTypes.cend(), stateTypes.cbegin(), stateTypes.cend()); + + if (const auto status = ConvertToLambda(updateHandler, ctx.Expr, argTypes.size()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(updateHandler, argTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!updateHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + for (auto i = updateHandler->ChildrenSize() - 1U; i;) { + const auto child = updateHandler->Child(i); + if (!IsSameAnnotation(*stateTypes[--i], *child->GetTypeAnn())) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder() << "State type changed in update from " + << *stateTypes[i] << " on " << *child->GetTypeAnn())); + return IGraphTransformer::TStatus::Error; + } + } + + argTypes.erase(argTypes.cbegin(), argTypes.cbegin() + multiType->GetSize()); + + if (const auto status = ConvertToLambda(finishHandler, ctx.Expr, argTypes.size()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(finishHandler, argTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!finishHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (const auto status = ConvertToLambda(serializeHandler, ctx.Expr, argTypes.size()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(serializeHandler, argTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!serializeHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + auto serializeRetType = GetWideLambdaOutputType(*serializeHandler, ctx.Expr); + TTypeAnnotationNode::TListType actualRetType; + for (auto type : serializeRetType->GetItems()) { + if (type->GetKind() == ETypeAnnotationKind::Optional) { + type = type->Cast()->GetItemType(); + } + actualRetType.push_back(type); + } + serializeRetType = ctx.Expr.MakeType(actualRetType); + + argTypes.erase(argTypes.cbegin() + keyTypes.size(), argTypes.cend()); + argTypes.insert(argTypes.cbegin() + keyTypes.size(), serializeRetType->GetItems().begin(), serializeRetType->GetItems().end()); + if (const auto status = ConvertToLambda(deserializeHandler, ctx.Expr, argTypes.size()); status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + if (!UpdateLambdaAllArgumentsTypes(deserializeHandler, argTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!deserializeHandler->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + auto retType = GetWideLambdaOutputType(*finishHandler, ctx.Expr); + input->SetTypeAnn(ctx.Expr.MakeType(retType)); + + return IGraphTransformer::TStatus::Ok; +} + IGraphTransformer::TStatus WideChopperWrapper(const TExprNode::TPtr& input, TExprNode::TPtr&, TContext& ctx) { if (!EnsureArgsCount(*input, 4U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.h b/ydb/library/yql/core/type_ann/type_ann_wide.h index fba62091a9dc..83a6fdf48048 100644 --- a/ydb/library/yql/core/type_ann/type_ann_wide.h +++ b/ydb/library/yql/core/type_ann/type_ann_wide.h @@ -14,6 +14,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus WideFilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideWhileWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideCombinerWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus WideCombinerWithSpillingWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideCondense1Wrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideChopperWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideChain1MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/core/yql_aggregate_expander.cpp b/ydb/library/yql/core/yql_aggregate_expander.cpp index e8ef9ca7e4ff..ad039f28cbc5 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.cpp +++ b/ydb/library/yql/core/yql_aggregate_expander.cpp @@ -36,6 +36,9 @@ TExprNode::TPtr TAggregateExpander::ExpandAggregateWithFullOutput() HaveDistinct = AnyOf(AggregatedColumns->ChildrenList(), [](const auto& child) { return child->ChildrenSize() == 3; }); + + AllowSpilling = AllowSpilling && UseFinalizeByKeys && !HaveDistinct; + EffectiveCompact = (HaveDistinct && CompactForDistinct && !TypesCtx.IsBlockEngineEnabled()) || ForceCompact || HasSetting(*settings, "compact"); for (const auto& trait : Traits) { auto mergeLambda = trait->Child(5); @@ -605,7 +608,7 @@ TExprNode::TPtr TAggregateExpander::MakeInputBlocks(const TExprNode::TPtr& strea TVector roots; for (ui32 i = 1; i < argsCount + 1; ++i) { auto root = trait->Child(2)->ChildPtr(i); - allTypes.push_back(root->GetTypeAnn()); + allTypes.push_back(root->GetTypeAnn()); auto status = RemapExpr(root, root, remaps, Ctx, TOptimizeExprSettings(&TypesCtx)); @@ -944,16 +947,188 @@ TExprNode::TPtr TAggregateExpander::GeneratePartialAggregateForNonDistinct(const .Seal() .Build(); - return Ctx.Builder(Node->Pos()) - .Callable("CombineByKey") - .Add(0, AggList) - .Add(1, PreMap) - .Add(2, keyExtractor) - .Add(3, std::move(combineInit)) - .Add(4, std::move(combineUpdate)) - .Add(5, std::move(combineSave)) + ui32 index = 0U; + auto combineLoad = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("item") + .Callable("AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 0; i < KeyColumns->ChildrenSize(); ++i) { + parent + .List(index++) + .Add(0, KeyColumns->ChildPtr(i)) + .Callable(1, "Member") + .Arg(0, "item") + .Add(1, KeyColumns->ChildPtr(i)) + .Seal() + .Seal(); + } + if (SessionWindowParams.Update) { + parent + .List(index++) + .Atom(0, SessionStartMemberName) + .Callable(1, "Member") + .Arg(0, "item") + .Atom(1, SessionStartMemberName) + .Seal() + .Seal(); + } + return parent; + }) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 0; i < columnNames.size(); ++i) { + auto child = AggregatedColumns->Child(i); + auto trait = Traits[i]; + if (!EffectiveCompact) { + auto loadLambda = trait->Child(4); + auto extractorLambda = GetFinalAggStateExtractor(i); + + if (!DistinctFields.empty() || Suffix == "MergeManyFinalize") { + parent.List(index++) + .Add(0, columnNames[i]) + .Callable(1, "Map") + .Apply(0, *extractorLambda) + .With(0, "item") + .Seal() + .Add(1, loadLambda) + .Seal() + .Seal(); + } else { + parent.List(index++) + .Add(0, columnNames[i]) + .Apply(1, *loadLambda) + .With(0) + .Apply(*extractorLambda) + .With(0, "item") + .Seal() + .Done() + .Seal(); + } + } else { + auto initLambda = trait->Child(1); + auto distinctField = (child->ChildrenSize() == 3) ? child->Child(2) : nullptr; + auto initApply = [&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + parent.Apply(1, *initLambda) + .With(0) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (distinctField) { + parent + .Callable("Member") + .Arg(0, "item") + .Add(1, distinctField) + .Seal(); + } else { + parent + .Callable("CastStruct") + .Arg(0, "item") + .Add(1, ExpandType(Node->Pos(), *initLambda->Head().Head().GetTypeAnn(), Ctx)) + .Seal(); + } + + return parent; + }) + .Done() + .Do([&](TExprNodeReplaceBuilder& parent) -> TExprNodeReplaceBuilder& { + if (initLambda->Head().ChildrenSize() == 2) { + parent.With(1) + .Callable("Uint32") + .Atom(0, ToString(i), TNodeFlags::Default) + .Seal() + .Done(); + } + + return parent; + }) + .Seal(); + + return parent; + }; + + if (distinctField) { + const bool isFirst = *Distinct2Columns[distinctField->Content()].begin() == i; + if (isFirst) { + parent.List(index++) + .Add(0, columnNames[i]) + .List(1) + .Callable(0, "NamedApply") + .Add(0, UdfSetCreate[distinctField->Content()]) + .List(1) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + if (!DistinctFieldNeedsPickle[distinctField->Content()]) { + parent.Callable(0, "Member") + .Arg(0, "item") + .Add(1, distinctField) + .Seal(); + } else { + parent.Callable(0, "StablePickle") + .Callable(0, "Member") + .Arg(0, "item") + .Add(1, distinctField) + .Seal() + .Seal(); + } + + return parent; + }) + .Callable(1, "Uint32") + .Atom(0, "0", TNodeFlags::Default) + .Seal() + .Seal() + .Callable(2, "AsStruct").Seal() + .Callable(3, "DependsOn") + .Callable(0, "String") + .Add(0, distinctField) + .Seal() + .Seal() + .Seal() + .Do(initApply) + .Seal() + .Seal(); + } else { + parent.List(index++) + .Add(0, columnNames[i]) + .Do(initApply) + .Seal(); + } + } else { + parent.List(index++) + .Add(0, columnNames[i]) + .Do(initApply) + .Seal(); + } + } + } + return parent; + }) + .Seal() .Seal() .Build(); + + if (AllowSpilling) { + return Ctx.Builder(Node->Pos()) + .Callable("CombineByKeyWithSpilling") + .Add(0, AggList) + .Add(1, PreMap) + .Add(2, keyExtractor) + .Add(3, std::move(combineInit)) + .Add(4, std::move(combineUpdate)) + .Add(5, std::move(combineSave)) + .Add(6, std::move(combineLoad)) + .Seal() + .Build(); + } else { + return Ctx.Builder(Node->Pos()) + .Callable("CombineByKey") + .Add(0, AggList) + .Add(1, PreMap) + .Add(2, keyExtractor) + .Add(3, std::move(combineInit)) + .Add(4, std::move(combineUpdate)) + .Add(5, std::move(combineSave)) + .Seal() + .Build(); + } } void TAggregateExpander::GenerateInitForDistinct(TExprNodeBuilder& parent, ui32& ndx, const TIdxSet& indicies, const TExprNode::TPtr& distinctField) { @@ -1386,45 +1561,84 @@ TExprNode::TPtr TAggregateExpander::ReturnKeyAsIsForCombineInit(const TExprNode: } TExprNode::TPtr TAggregateExpander::BuildFinalizeByKeyLambda(const TExprNode::TPtr& preprocessLambda, const TExprNode::TPtr& keyExtractor) { - return Ctx.Builder(Node->Pos()) - .Lambda() - .Param("stream") - .Callable("FinalizeByKey") - .Arg(0, "stream") - .Lambda(1) - .Param("item") + auto preprocess = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("item") .Callable("Just") .Apply(0, preprocessLambda) .With(0, "item") .Seal() .Seal() .Seal() - .Add(2, keyExtractor) - .Lambda(3) - .Param("key") - .Param("item") + .Build(); + auto initLambda = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("item") .Apply(GeneratePostAggregateInitPhase()) .With(0, "item") .Seal() .Seal() - .Lambda(4) - .Param("key") - .Param("item") - .Param("state") + .Build(); + auto mergeLambda = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("item") + .Param("state") .Apply(GeneratePostAggregateMergePhase()) .With(0, "item") .With(1, "state") .Seal() .Seal() - .Lambda(5) - .Param("key") - .Param("state") - .Apply(GeneratePostAggregateSavePhase()) + .Build(); + auto finishLambda = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("state") + .Apply(GeneratePostAggregateFinishPhase()) .With(0, "state") .Seal() .Seal() - .Seal() - .Seal().Build(); + .Build(); + if (AllowSpilling) { + auto saveLambda = Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("state") + .Apply(GeneratePostAggregateSerializePhase()) + .With(0, "key") + .With(1, "state") + .Seal() + .Seal() + .Build(); + + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("stream") + .Callable("FinalizeByKeyWithSpilling") + .Arg(0, "stream") + .Add(1, preprocess) + .Add(2, keyExtractor) + .Add(3, initLambda) + .Add(4, mergeLambda) + .Add(5, finishLambda) + .Add(6, saveLambda) + .Seal() + .Seal().Build(); + } else { + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("stream") + .Callable("FinalizeByKey") + .Arg(0, "stream") + .Add(1, preprocess) + .Add(2, keyExtractor) + .Add(3, initLambda) + .Add(4, mergeLambda) + .Add(5, finishLambda) + .Seal() + .Seal().Build(); + } } @@ -1772,7 +1986,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregate(const TExprNode::TPtr& .Add(2, condenseSwitch) .Add(3, GeneratePostAggregateMergePhase()) .Seal() - .Add(1, GeneratePostAggregateSavePhase()) + .Add(1, GeneratePostAggregateFinishPhase()) .Seal() .Done() .Seal() @@ -2020,7 +2234,7 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateInitPhase() .Build(); } -TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSavePhase() +TExprNode::TPtr TAggregateExpander::GeneratePostAggregateFinishPhase() { bool aggregateOnly = (Suffix != ""); const auto& columnNames = aggregateOnly ? FinalColumnNames : InitialColumnNames; @@ -2407,6 +2621,105 @@ TExprNode::TPtr TAggregateExpander::GeneratePostAggregateMergePhase() .Build(); } +TExprNode::TPtr TAggregateExpander::GeneratePostAggregateSerializePhase() { + auto keyItemTypes = GetKeyItemTypes(); + TExprNode::TPtr pickleTypeNode = nullptr; + if (IsNeedPickle(keyItemTypes)) { + const TTypeAnnotationNode* pickleType = nullptr; + pickleType = KeyColumns->ChildrenSize() > 1 ? Ctx.MakeType(keyItemTypes) : keyItemTypes[0]; + pickleTypeNode = ExpandType(Node->Pos(), *pickleType, Ctx); + } + auto columnNames = FinalColumnNames; + return Ctx.Builder(Node->Pos()) + .Lambda() + .Param("key") + .Param("state") + .Callable("Just") + .Callable(0, "AsStruct") + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + for (ui32 i = 0; i < columnNames.size(); ++i) { + if (NonDistinctColumns.find(i) == NonDistinctColumns.end()) { + parent.List(i) + .Add(0, columnNames[i]) + .Add(1, NothingStates[i]) + .Seal(); + } else { + auto trait = Traits[i]; + auto saveLambda = trait->Child(3); + if (!DistinctFields.empty()) { + parent.List(i) + .Add(0, columnNames[i]) + .Callable(1, "Just") + .Apply(0, *saveLambda) + .With(0) + .Callable("Member") + .Arg(0, "state") + .Add(1, columnNames[i]) + .Seal() + .Done() + .Seal() + .Seal() + .Seal(); + } else { + parent.List(i) + .Add(0, columnNames[i]) + .Apply(1, *saveLambda) + .With(0) + .Callable("Member") + .Arg(0, "state") + .Add(1, columnNames[i]) + .Seal() + .Done() + .Seal() + .Seal(); + } + } + } + return parent; + }) + .Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& { + ui32 pos = 0; + for (ui32 i = 0; i < KeyColumns->ChildrenSize(); ++i) { + auto listBuilder = parent.List(columnNames.size() + i); + listBuilder.Add(0, KeyColumns->ChildPtr(i)); + if (KeyColumns->ChildrenSize() > 1) { + if (pickleTypeNode) { + listBuilder + .Callable(1, "Nth") + .Callable(0, "Unpickle") + .Add(0, pickleTypeNode) + .Arg(1, "key") + .Seal() + .Atom(1, ToString(pos), TNodeFlags::Default) + .Seal(); + } else { + listBuilder + .Callable(1, "Nth") + .Arg(0, "key") + .Atom(1, ToString(pos), TNodeFlags::Default) + .Seal(); + } + ++pos; + } else { + if (pickleTypeNode) { + listBuilder.Callable(1, "Unpickle") + .Add(0, pickleTypeNode) + .Arg(1, "key") + .Seal(); + } else { + listBuilder.Arg(1, "key"); + } + } + listBuilder.Seal(); + } + return parent; + }) + .Seal() + .Seal() + .Seal() + .Build(); +} + TExprNode::TPtr TAggregateExpander::GenerateJustOverStates(const TExprNode::TPtr& input, const TIdxSet& indicies) { return Ctx.Builder(Node->Pos()) .Callable("Map") diff --git a/ydb/library/yql/core/yql_aggregate_expander.h b/ydb/library/yql/core/yql_aggregate_expander.h index 5333db934229..90cffa962857 100644 --- a/ydb/library/yql/core/yql_aggregate_expander.h +++ b/ydb/library/yql/core/yql_aggregate_expander.h @@ -9,7 +9,7 @@ namespace NYql { class TAggregateExpander { public: TAggregateExpander(bool usePartitionsByKeys, const bool useFinalizeByKeys, const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, - bool forceCompact = false, bool compactForDistinct = false, bool usePhases = false) + bool forceCompact = false, bool compactForDistinct = false, bool usePhases = false, bool allowSpilling = false) : Node(node) , Ctx(ctx) , TypesCtx(typesCtx) @@ -18,6 +18,7 @@ class TAggregateExpander { , ForceCompact(forceCompact) , CompactForDistinct(compactForDistinct) , UsePhases(usePhases) + , AllowSpilling(allowSpilling) , AggregatedColumns(nullptr) , VoidNode(ctx.NewCallable(node->Pos(), "Void", {})) , HaveDistinct(false) @@ -70,8 +71,9 @@ class TAggregateExpander { TExprNode::TPtr GenerateCondenseSwitch(const TExprNode::TPtr& keyExtractor); TExprNode::TPtr BuildFinalizeByKeyLambda(const TExprNode::TPtr& preprocessLambda, const TExprNode::TPtr& keyExtractor); TExprNode::TPtr GeneratePostAggregateInitPhase(); - TExprNode::TPtr GeneratePostAggregateSavePhase(); + TExprNode::TPtr GeneratePostAggregateFinishPhase(); TExprNode::TPtr GeneratePostAggregateMergePhase(); + TExprNode::TPtr GeneratePostAggregateSerializePhase(); std::function GetPartialAggArgExtractor(ui32 i, bool deserialize); TExprNode::TPtr GetFinalAggStateExtractor(ui32 i); @@ -98,6 +100,7 @@ class TAggregateExpander { bool ForceCompact; bool CompactForDistinct; bool UsePhases; + bool AllowSpilling; TStringBuf Suffix; TSessionWindowParams SessionWindowParams; diff --git a/ydb/library/yql/dq/common/dq_common.h b/ydb/library/yql/dq/common/dq_common.h index 557b3abaa3c1..fadaca37d9d4 100644 --- a/ydb/library/yql/dq/common/dq_common.h +++ b/ydb/library/yql/dq/common/dq_common.h @@ -89,6 +89,34 @@ enum class EHashJoinMode { GraceAndSelf /* "graceandself" */, }; +enum class EEnabledSpillingNodes : ui64 { + None = 0ULL /* None */, + GraceJoin = 1ULL /* "GraceJoin" */, + Aggregation = 2ULL /* "Aggregation" */, + All = ~0ULL /* "All" */, +}; + +class TSpillingSettings { +public: + TSpillingSettings() = default; + explicit TSpillingSettings(ui64 mask) : Mask(mask) {}; + + operator bool() const { + return Mask; + } + + bool IsGraceJoinSpillingEnabled() const { + return Mask & ui64(EEnabledSpillingNodes::GraceJoin); + } + + bool IsAggregationSpillingEnabled() const { + return Mask & ui64(EEnabledSpillingNodes::Aggregation); + } + +private: + const ui64 Mask = 0; +}; + } // namespace NYql::NDq IOutputStream& operator<<(IOutputStream& stream, const NYql::NDq::TTxId& txId); diff --git a/ydb/library/yql/dq/opt/dq_opt_log.cpp b/ydb/library/yql/dq/opt/dq_opt_log.cpp index 467fec490c9a..a0ad02b81860 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_log.cpp @@ -17,12 +17,13 @@ using namespace NYql::NNodes; namespace NYql::NDq { TExprBase DqRewriteAggregate(TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool compactForDistinct, - bool usePhases, const bool useFinalizeByKey) + bool usePhases, const bool useFinalizeByKey, const bool allowSpilling) { if (!node.Maybe()) { return node; } - TAggregateExpander aggExpander(!typesCtx.IsBlockEngineEnabled() && !useFinalizeByKey, useFinalizeByKey, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases); + TAggregateExpander aggExpander(!typesCtx.IsBlockEngineEnabled() && !useFinalizeByKey, + useFinalizeByKey, node.Ptr(), ctx, typesCtx, false, compactForDistinct, usePhases, allowSpilling); auto result = aggExpander.ExpandAggregate(); YQL_ENSURE(result); diff --git a/ydb/library/yql/dq/opt/dq_opt_log.h b/ydb/library/yql/dq/opt/dq_opt_log.h index 56b33ebe5fe5..fecdcc07c486 100644 --- a/ydb/library/yql/dq/opt/dq_opt_log.h +++ b/ydb/library/yql/dq/opt/dq_opt_log.h @@ -19,7 +19,8 @@ namespace NYql { namespace NYql::NDq { -NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, bool compactForDistinct, bool usePhases, const bool useFinalizeByKey); +NNodes::TExprBase DqRewriteAggregate(NNodes::TExprBase node, TExprContext& ctx, TTypeAnnotationContext& typesCtx, + bool compactForDistinct, bool usePhases, const bool useFinalizeByKey, const bool allowSpilling); NNodes::TExprBase DqRewriteTakeSortToTopSort(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parents); diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index b62bed91b36f..97a8d67d3ed8 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -1058,6 +1058,90 @@ TExprBase DqPushCombineToStage(TExprBase node, TExprContext& ctx, IOptimizationC return result.Cast(); } +TExprBase DqPushCombineWithSpillingToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, + const TParentsMap& parentsMap, bool allowStageMultiUsage) +{ + if (!node.Maybe().Input().Maybe()) { + return node; + } + + auto combine = node.Cast(); + auto dqUnion = combine.Input().Cast(); + if (!IsSingleConsumerConnection(dqUnion, parentsMap, allowStageMultiUsage)) { + return node; + } + + if (!IsDqCompletePureExpr(combine.PreMapLambda()) || + !IsDqCompletePureExpr(combine.KeySelectorLambda()) || + !IsDqCompletePureExpr(combine.InitHandlerLambda()) || + !IsDqCompletePureExpr(combine.UpdateHandlerLambda()) || + !IsDqCompletePureExpr(combine.FinishHandlerLambda()) || + !IsDqCompletePureExpr(combine.LoadHandlerLambda())) + { + return node; + } + + if (auto connToPushableStage = DqBuildPushableStage(dqUnion, ctx)) { + return TExprBase(ctx.ChangeChild(*node.Raw(), TCoCombineByKey::idx_Input, std::move(connToPushableStage))); + } + + auto lambda = Build(ctx, combine.Pos()) + .Args({"stream"}) + .Body() + .Input("stream") + .PreMapLambda(ctx.DeepCopyLambda(combine.PreMapLambda().Ref())) + .KeySelectorLambda(ctx.DeepCopyLambda(combine.KeySelectorLambda().Ref())) + .InitHandlerLambda(ctx.DeepCopyLambda(combine.InitHandlerLambda().Ref())) + .UpdateHandlerLambda(ctx.DeepCopyLambda(combine.UpdateHandlerLambda().Ref())) + .FinishHandlerLambda(ctx.DeepCopyLambda(combine.FinishHandlerLambda().Ref())) + .LoadHandlerLambda(ctx.DeepCopyLambda(combine.LoadHandlerLambda().Ref())) + .Build() + .Done(); + + if (HasContextFuncs(*lambda.Ptr())) { + lambda = Build(ctx, combine.Pos()) + .Args({ TStringBuf("stream") }) + .Body() + .Input() + .Apply(lambda) + .With(0, TStringBuf("stream")) + .Build() + .Name() + .Value("Agg") + .Build() + .Build() + .Done(); + } + + if (IsDqDependsOnStage(combine.PreMapLambda(), dqUnion.Output().Stage()) || + IsDqDependsOnStage(combine.KeySelectorLambda(), dqUnion.Output().Stage()) || + IsDqDependsOnStage(combine.InitHandlerLambda(), dqUnion.Output().Stage()) || + IsDqDependsOnStage(combine.UpdateHandlerLambda(), dqUnion.Output().Stage()) || + IsDqDependsOnStage(combine.FinishHandlerLambda(), dqUnion.Output().Stage()) || + IsDqDependsOnStage(combine.LoadHandlerLambda(), dqUnion.Output().Stage())) + { + return Build(ctx, combine.Pos()) + .Output() + .Stage() + .Inputs() + .Add(dqUnion) + .Build() + .Program(lambda) + .Settings(TDqStageSettings().BuildNode(ctx, node.Pos())) + .Build() + .Index().Build("0") + .Build() + .Done(); + } + + auto result = DqPushLambdaToStageUnionAll(dqUnion, lambda, {}, ctx, optCtx); + if (!result) { + return node; + } + + return result.Cast(); +} + NNodes::TExprBase DqPushAggregateCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage) { @@ -1240,6 +1324,57 @@ TExprBase DqBuildFinalizeByKeyStage(TExprBase node, TExprContext& ctx, .Done(); } +TExprBase DqBuildFinalizeByKeyWithSpillingStage(TExprBase node, TExprContext& ctx, + const TParentsMap& parentsMap, bool allowStageMultiUsage) +{ + auto finalizeInput = node.Maybe().Input(); + if (!finalizeInput.Maybe()) { + return node; + } + + auto finalize = node.Cast(); + auto dqUnion = finalize.Input().Cast(); + + if (!IsSingleConsumerConnection(dqUnion, parentsMap, allowStageMultiUsage)) { + return node; + } + + auto keyLambda = finalize.KeySelectorLambda(); + + TVector inputArgs; + TVector inputConns; + + inputConns.push_back(dqUnion); + + auto finalizeStage = Build(ctx, node.Pos()) + .Inputs() + .Add(inputConns) + .Build() + .Program() + .Args({ "input" }) + .Body() + .Input() + .Input("input") + .PreMapLambda(finalize.PreMapLambda()) + .KeySelectorLambda(finalize.KeySelectorLambda()) + .InitHandlerLambda(finalize.InitHandlerLambda()) + .UpdateHandlerLambda(finalize.UpdateHandlerLambda()) + .FinishHandlerLambda(finalize.FinishHandlerLambda()) + .SaveHandlerLambda(finalize.SaveHandlerLambda()) + .Build() + .Build() + .Build() + .Settings(TDqStageSettings().BuildNode(ctx, node.Pos())) + .Done(); + + return Build(ctx, node.Pos()) + .Output() + .Stage(finalizeStage) + .Index().Build("0") + .Build() + .Done(); +} + /* * Optimizer rule which handles a switch to scalar expression context for aggregation results. * This switch happens for full aggregations, such as @code select sum(column) from table @endcode). diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h index b2912ad24334..58e26de2a3ce 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.h +++ b/ydb/library/yql/dq/opt/dq_opt_phy.h @@ -41,6 +41,9 @@ NNodes::TExprBase DqBuildFlatmapStage(NNodes::TExprBase node, TExprContext& ctx, NNodes::TExprBase DqPushCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage = true); +NNodes::TExprBase DqPushCombineWithSpillingToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, + const TParentsMap& parentsMap, bool allowStageMultiUsage = true); + NNodes::TExprBase DqPushAggregateCombineToStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage = true); @@ -56,6 +59,9 @@ NNodes::TExprBase DqBuildShuffleStage(NNodes::TExprBase node, TExprContext& ctx, NNodes::TExprBase DqBuildFinalizeByKeyStage(NNodes::TExprBase node, TExprContext& ctx, const TParentsMap& parentsMap, bool allowStageMultiUsage = true); +NNodes::TExprBase DqBuildFinalizeByKeyWithSpillingStage(NNodes::TExprBase node, TExprContext& ctx, + const TParentsMap& parentsMap, bool allowStageMultiUsage = true); + NNodes::TExprBase DqBuildAggregationResultStage(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx); diff --git a/ydb/library/yql/dq/tasks/dq_task_program.h b/ydb/library/yql/dq/tasks/dq_task_program.h index 80e7e645edbc..cd11a301d040 100644 --- a/ydb/library/yql/dq/tasks/dq_task_program.h +++ b/ydb/library/yql/dq/tasks/dq_task_program.h @@ -11,23 +11,6 @@ namespace NYql::NDq { -class TSpillingSettings { -public: - TSpillingSettings() = default; - explicit TSpillingSettings(ui64 mask) : Mask(mask) {}; - - operator bool() const { - return Mask; - } - - bool IsGraceJoinSpillingEnabled() const { - return Mask & ui64(TDqConfiguration::EEnabledSpillingNodes::GraceJoin); - } - -private: - const ui64 Mask = 0; -}; - const TStructExprType* CollectParameters(NNodes::TCoLambda program, TExprContext& ctx); TString BuildProgram(NNodes::TCoLambda program, const TStructExprType& paramsType, diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 10e63f3d4891..2bb93217bf2f 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -4830,12 +4830,100 @@ TRuntimeNode TProgramBuilder::WideLastCombiner(TRuntimeNode flow, const TWideLam return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish); } -TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish) { - if constexpr (RuntimeVersion < 49U) { +TRuntimeNode TProgramBuilder::WideLastCombinerCommonWithSpilling(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish, const TBinaryWideLambda& serialize, const TBinaryWideLambda& deserialize) { + const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flow.GetStaticType())); + + TRuntimeNode::TList itemArgs; + itemArgs.reserve(wideComponents.size()); + + auto i = 0U; + std::generate_n(std::back_inserter(itemArgs), wideComponents.size(), [&](){ return Arg(wideComponents[i++]); }); + + const auto keys = extractor(itemArgs); + + TRuntimeNode::TList keyArgs; + keyArgs.reserve(keys.size()); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(keyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + + const auto first = init(keyArgs, itemArgs); + + TRuntimeNode::TList stateArgs; + stateArgs.reserve(first.size()); + std::transform(first.cbegin(), first.cend(), std::back_inserter(stateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + + const auto next = update(keyArgs, itemArgs, stateArgs); + MKQL_ENSURE(next.size() == first.size(), "Mismatch init and update state size."); + + TRuntimeNode::TList finishKeyArgs; + finishKeyArgs.reserve(keys.size()); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(finishKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + + TRuntimeNode::TList finishStateArgs; + finishStateArgs.reserve(next.size()); + std::transform(next.cbegin(), next.cend(), std::back_inserter(finishStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + + const auto output = finish(finishKeyArgs, finishStateArgs); + + TRuntimeNode::TList serializeKeyArgs; + serializeKeyArgs.reserve(keys.size()); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(serializeKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + + TRuntimeNode::TList serializeStateArgs; + serializeStateArgs.reserve(next.size()); + std::transform(next.cbegin(), next.cend(), std::back_inserter(serializeStateArgs), [&](TRuntimeNode state){ return Arg(state.GetStaticType()); } ); + + const auto serializeOutput = serialize(serializeKeyArgs, serializeStateArgs); + + TRuntimeNode::TList deserializeKeyArgs; + deserializeKeyArgs.reserve(keys.size()); + std::transform(keys.cbegin(), keys.cend(), std::back_inserter(deserializeKeyArgs), [&](TRuntimeNode key){ return Arg(key.GetStaticType()); } ); + + TRuntimeNode::TList deserializeStateArgs; + deserializeStateArgs.reserve(serializeOutput.size()); + std::transform(serializeOutput.cbegin(), serializeOutput.cend(), std::back_inserter(deserializeStateArgs), + [&](TRuntimeNode arg){ + TType* actualType = arg.GetStaticType(); + if (actualType->IsOptional()) { + actualType = AS_TYPE(TOptionalType, actualType)->GetItemType(); + } + return Arg(actualType); + } + ); + + const auto deserializeOutput = deserialize(deserializeKeyArgs, deserializeStateArgs); + + std::vector tupleItems; + tupleItems.reserve(output.size()); + std::transform(output.cbegin(), output.cend(), std::back_inserter(tupleItems), std::bind(&TRuntimeNode::GetStaticType, std::placeholders::_1)); + + TCallableBuilder callableBuilder(Env, funcName, NewFlowType(NewMultiType(tupleItems))); + callableBuilder.Add(flow); + callableBuilder.Add(NewDataLiteral(ui32(keyArgs.size()))); + callableBuilder.Add(NewDataLiteral(ui32(stateArgs.size()))); + std::for_each(itemArgs.cbegin(), itemArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keys.cbegin(), keys.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(keyArgs.cbegin(), keyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(first.cbegin(), first.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(stateArgs.cbegin(), stateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(next.cbegin(), next.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishKeyArgs.cbegin(), finishKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(finishStateArgs.cbegin(), finishStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(output.cbegin(), output.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(serializeKeyArgs.cbegin(), serializeKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(serializeStateArgs.cbegin(), serializeStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(serializeOutput.cbegin(), serializeOutput.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(deserializeKeyArgs.cbegin(), deserializeKeyArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(deserializeStateArgs.cbegin(), deserializeStateArgs.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + std::for_each(deserializeOutput.cbegin(), deserializeOutput.cend(), std::bind(&TCallableBuilder::Add, std::ref(callableBuilder), std::placeholders::_1)); + return TRuntimeNode(callableBuilder.Build(), false); +} + +TRuntimeNode TProgramBuilder::WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& extractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish, const TBinaryWideLambda& serialize, const TBinaryWideLambda& deserialize) { + if constexpr (RuntimeVersion < 51U) { THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; } - return WideLastCombinerCommon(__func__, flow, extractor, init, update, finish); + return WideLastCombinerCommonWithSpilling(__func__, flow, extractor, init, update, finish, serialize, deserialize); } TRuntimeNode TProgramBuilder::WideCondense1(TRuntimeNode flow, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& update, bool useCtx) { @@ -4937,6 +5025,66 @@ TRuntimeNode TProgramBuilder::CombineCore(TRuntimeNode stream, return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::CombineCoreWithSpilling(TRuntimeNode stream, + const TUnaryLambda& keyExtractor, + const TBinaryLambda& init, + const TTernaryLambda& update, + const TBinaryLambda& finish, + const TUnaryLambda& load, + ui64 memLimit) +{ + if constexpr (RuntimeVersion < 51U) { + THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__; + } + + const bool isStream = stream.GetStaticType()->IsStream(); + const auto itemType = isStream ? AS_TYPE(TStreamType, stream)->GetItemType() : AS_TYPE(TFlowType, stream)->GetItemType(); + + const auto itemArg = Arg(itemType); + + const auto key = keyExtractor(itemArg); + const auto keyType = key.GetStaticType(); + + const auto keyArg = Arg(keyType); + + const auto stateInit = init(keyArg, itemArg); + const auto stateType = stateInit.GetStaticType(); + + const auto stateArg = Arg(stateType); + + const auto stateUpdate = update(keyArg, itemArg, stateArg); + const auto finishItem = finish(keyArg, stateArg); + + const auto finishType = finishItem.GetStaticType(); + MKQL_ENSURE(finishType->IsList() || finishType->IsStream() || finishType->IsOptional(), "Expected list, stream or optional"); + + const auto stateLoad = load(stateArg); + + TType* retItemType = nullptr; + if (finishType->IsOptional()) { + retItemType = AS_TYPE(TOptionalType, finishType)->GetItemType(); + } else if (finishType->IsList()) { + retItemType = AS_TYPE(TListType, finishType)->GetItemType(); + } else if (finishType->IsStream()) { + retItemType = AS_TYPE(TStreamType, finishType)->GetItemType(); + } + + const auto resultStreamType = isStream ? NewStreamType(retItemType) : NewFlowType(retItemType); + TCallableBuilder callableBuilder(Env, __func__, resultStreamType); + callableBuilder.Add(stream); + callableBuilder.Add(itemArg); + callableBuilder.Add(key); + callableBuilder.Add(keyArg); + callableBuilder.Add(stateInit); + callableBuilder.Add(stateArg); + callableBuilder.Add(stateUpdate); + callableBuilder.Add(finishItem); + callableBuilder.Add(stateLoad); + callableBuilder.Add(NewDataLiteral(memLimit)); + + return TRuntimeNode(callableBuilder.Build(), false); +} + TRuntimeNode TProgramBuilder::GroupingCore(TRuntimeNode stream, const TBinaryLambda& groupSwitch, const TUnaryLambda& keyExtractor, diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index fe6b977a2720..3d7344848c7a 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -408,8 +408,9 @@ class TProgramBuilder : public TTypeBuilder { TRuntimeNode WideCombiner(TRuntimeNode flow, i64 memLimit, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish); TRuntimeNode WideLastCombinerCommon(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish); + TRuntimeNode WideLastCombinerCommonWithSpilling(const TStringBuf& funcName, TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish, const TBinaryWideLambda& serialize, const TBinaryWideLambda& deserialize); TRuntimeNode WideLastCombiner(TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish); - TRuntimeNode WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish); + TRuntimeNode WideLastCombinerWithSpilling(TRuntimeNode flow, const TWideLambda& keyExtractor, const TBinaryWideLambda& init, const TTernaryWideLambda& update, const TBinaryWideLambda& finish, const TBinaryWideLambda& serialize, const TBinaryWideLambda& deserialize); TRuntimeNode WideCondense1(TRuntimeNode stream, const TWideLambda& init, const TWideSwitchLambda& switcher, const TBinaryWideLambda& handler, bool useCtx = false); TRuntimeNode WideTop(TRuntimeNode flow, TRuntimeNode count, const std::vector>& keys); @@ -469,6 +470,13 @@ class TProgramBuilder : public TTypeBuilder { const TTernaryLambda& update, const TBinaryLambda& finish, ui64 memLimit); + TRuntimeNode CombineCoreWithSpilling(TRuntimeNode stream, + const TUnaryLambda& keyExtractor, + const TBinaryLambda& init, + const TTernaryLambda& update, + const TBinaryLambda& finish, + const TUnaryLambda& load, + ui64 memLimit); TRuntimeNode GroupingCore(TRuntimeNode stream, const TBinaryLambda& groupSwitch, const TUnaryLambda& keyExtractor, diff --git a/ydb/library/yql/minikql/mkql_runtime_version.h b/ydb/library/yql/minikql/mkql_runtime_version.h index 22072157e87f..4d416ac7089d 100644 --- a/ydb/library/yql/minikql/mkql_runtime_version.h +++ b/ydb/library/yql/minikql/mkql_runtime_version.h @@ -24,7 +24,7 @@ namespace NMiniKQL { // 1. Bump this version every time incompatible runtime nodes are introduced. // 2. Make sure you provide runtime node generation for previous runtime versions. #ifndef MKQL_RUNTIME_VERSION -#define MKQL_RUNTIME_VERSION 50U +#define MKQL_RUNTIME_VERSION 51U #endif // History: diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index d1aa870f24eb..a4144fe05cb6 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -754,18 +754,49 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return MkqlBuildWideLambda(node.Tail(), ctx, keys); }; - bool isStatePersistable = true; - // Traverse through childs skipping input and limit children - for (size_t i = 2U; i < node.ChildrenSize(); ++i) { - isStatePersistable = isStatePersistable && node.Child(i)->GetTypeAnn()->IsPersistable(); + if (withLimit) { + return ctx.ProgramBuilder.WideCombiner(flow, memLimit, keyExtractor, init, update, finish); } + return ctx.ProgramBuilder.WideLastCombiner(flow, keyExtractor, init, update, finish); + }); + + AddCallable("WideCombinerWithSpilling", [](const TExprNode& node, TMkqlBuildContext& ctx) { + const auto flow = MkqlBuildExpr(node.Head(), ctx); + i64 memLimit = 0LL; + const bool withLimit = TryFromString(node.Child(1U)->Content(), memLimit); + + const auto keyExtractor = [&](TRuntimeNode::TList items) { + return MkqlBuildWideLambda(*node.Child(2U), ctx, items); + }; + const auto init = [&](TRuntimeNode::TList keys, TRuntimeNode::TList items) { + keys.insert(keys.cend(), items.cbegin(), items.cend()); + return MkqlBuildWideLambda(*node.Child(3U), ctx, keys); + }; + const auto update = [&](TRuntimeNode::TList keys, TRuntimeNode::TList items, TRuntimeNode::TList state) { + keys.insert(keys.cend(), items.cbegin(), items.cend()); + keys.insert(keys.cend(), state.cbegin(), state.cend()); + return MkqlBuildWideLambda(*node.Child(4U), ctx, keys); + }; + const auto finish = [&](TRuntimeNode::TList keys, TRuntimeNode::TList state) { + keys.insert(keys.cend(), state.cbegin(), state.cend()); + return MkqlBuildWideLambda(*node.Child(5U), ctx, keys); + }; + const auto serialize = [&](TRuntimeNode::TList keys, TRuntimeNode::TList state) { + keys.insert(keys.cend(), state.cbegin(), state.cend()); + return MkqlBuildWideLambda(*node.Child(6U), ctx, keys); + }; + const auto deserialize = [&](TRuntimeNode::TList keys, TRuntimeNode::TList state) { + keys.insert(keys.cend(), state.cbegin(), state.cend()); + return MkqlBuildWideLambda(*node.Child(7U), ctx, keys); + }; + if (withLimit) { return ctx.ProgramBuilder.WideCombiner(flow, memLimit, keyExtractor, init, update, finish); } - if (isStatePersistable && RuntimeVersion >= 49U) { - return ctx.ProgramBuilder.WideLastCombinerWithSpilling(flow, keyExtractor, init, update, finish); + if (RuntimeVersion >= 51U) { + return ctx.ProgramBuilder.WideLastCombinerWithSpilling(flow, keyExtractor, init, update, finish, serialize, deserialize); } return ctx.ProgramBuilder.WideLastCombiner(flow, keyExtractor, init, update, finish); }); @@ -1814,6 +1845,28 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return ctx.ProgramBuilder.CombineCore(stream, keyExtractor, init, update, finish, memLimit); }); + AddCallable("CombineCoreWithSpilling", [](const TExprNode& node, TMkqlBuildContext& ctx) { + NNodes::TCoCombineCoreWithSpilling core(&node); + + const auto stream = MkqlBuildExpr(core.Input().Ref(), ctx); + const auto memLimit = FromString(core.MemLimit().Cast().Value()); + + const auto keyExtractor = [&](TRuntimeNode item) { + return MkqlBuildLambda(core.KeyExtractor().Ref(), ctx, {item}); + }; + const auto init = [&](TRuntimeNode key, TRuntimeNode item) { + return MkqlBuildLambda(core.InitHandler().Ref(), ctx, {key, item}); + }; + const auto update = [&](TRuntimeNode key, TRuntimeNode item, TRuntimeNode state) { + return MkqlBuildLambda(core.UpdateHandler().Ref(), ctx, {key, item, state}); + }; + const auto finish = [&](TRuntimeNode key, TRuntimeNode state) { + return MkqlBuildLambda(core.FinishHandler().Ref(), ctx, {key, state}); + }; + + return ctx.ProgramBuilder.CombineCore(stream, keyExtractor, init, update, finish, memLimit); + }); + AddCallable("GroupingCore", [](const TExprNode& node, TMkqlBuildContext& ctx) { NNodes::TCoGroupingCore core(&node); @@ -2051,6 +2104,10 @@ TMkqlCommonCallableCompiler::TShared::TShared() { return CombineByKeyImpl(node, ctx); }); + AddCallable("CombineByKeyWithSpilling", [](const TExprNode& node, TMkqlBuildContext& ctx) { + return CombineByKeyImpl(node, ctx); + }); + AddCallable("Enumerate", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto arg = MkqlBuildExpr(node.Head(), ctx); TRuntimeNode start; diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp index 24f16bd28312..8714c7624808 100644 --- a/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp +++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.cpp @@ -110,7 +110,7 @@ TDqConfiguration::TDqConfiguration() { if (s.empty()) { throw yexception() << "Empty value item"; } - auto value = FromString(s); + auto value = FromString(s); res |= ui64(value); } return res; diff --git a/ydb/library/yql/providers/dq/common/yql_dq_settings.h b/ydb/library/yql/providers/dq/common/yql_dq_settings.h index 8fe9163d9037..e2ab658ee633 100644 --- a/ydb/library/yql/providers/dq/common/yql_dq_settings.h +++ b/ydb/library/yql/providers/dq/common/yql_dq_settings.h @@ -27,13 +27,7 @@ struct TDqSettings { Disable /* "disable" */, File /* "file" */, }; - - enum class EEnabledSpillingNodes : ui64 { - None = 0ULL /* None */, - GraceJoin = 1ULL /* "GraceJoin" */, - All = ~0ULL /* "All" */, - }; - + struct TDefault { static constexpr ui32 MaxTasksPerStage = 20U; static constexpr ui32 MaxTasksPerOperation = 70U; diff --git a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp index a25065a8854e..754770fca15e 100644 --- a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp @@ -137,7 +137,10 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { bool syncActor = Config->ComputeActorType.Get() != "async"; return NHopping::RewriteAsHoppingWindow(node, ctx, input.Cast(), analyticsHopping, lateArrivalDelay, defaultWatermarksMode, syncActor); } else { - return DqRewriteAggregate(node, ctx, TypesCtx, true, Config->UseAggPhases.Get().GetOrElse(false), Config->UseFinalizeByKey.Get().GetOrElse(false)); + NDq::TSpillingSettings spillingSettings(Config->GetEnabledSpillingNodes()); + return DqRewriteAggregate(node, ctx, TypesCtx, true, + Config->UseAggPhases.Get().GetOrElse(false), Config->UseFinalizeByKey.Get().GetOrElse(false), + spillingSettings.IsAggregationSpillingEnabled()); } } return node; @@ -183,7 +186,7 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { return node; } - TDqLookupSourceWrap lookupSourceWrap = right.Maybe() + TDqLookupSourceWrap lookupSourceWrap = right.Maybe() ? LookupSourceFromSource(right.Cast(), ctx) : LookupSourceFromRead(right.Cast(), ctx) ; diff --git a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp index 0eba2f3be213..e26abc04c915 100644 --- a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp @@ -32,9 +32,11 @@ class TDqsPhysicalOptProposalTransformer : public TOptimizeTransformerBase { AddHandler(0, &TCoExtractMembers::Match, HNDL(PushExtractMembersToStage)); AddHandler(0, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage)); AddHandler(0, &TCoCombineByKey::Match, HNDL(PushCombineToStage)); + AddHandler(0, &TCoCombineByKeyWithSpilling::Match, HNDL(PushCombineWithSpillingToStage)); AddHandler(0, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage)); AddHandler(0, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage)); AddHandler(0, &TCoFinalizeByKey::Match, HNDL(BuildFinalizeByKeyStage)); + AddHandler(0, &TCoFinalizeByKeyWithSpilling::Match, HNDL(BuildFinalizeByKeyWithSpillingStage)); AddHandler(0, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage)); AddHandler(0, &TCoAsList::Match, HNDL(BuildAggregationResultStage)); AddHandler(0, &TCoTopSort::Match, HNDL(BuildTopSortStage)); @@ -67,9 +69,11 @@ class TDqsPhysicalOptProposalTransformer : public TOptimizeTransformerBase { AddHandler(1, &TCoExtractMembers::Match, HNDL(PushExtractMembersToStage)); AddHandler(1, &TCoFlatMapBase::Match, HNDL(BuildFlatmapStage)); AddHandler(1, &TCoCombineByKey::Match, HNDL(PushCombineToStage)); + AddHandler(1, &TCoCombineByKeyWithSpilling::Match, HNDL(PushCombineWithSpillingToStage)); AddHandler(1, &TCoPartitionsByKeys::Match, HNDL(BuildPartitionsStage)); AddHandler(1, &TCoShuffleByKeys::Match, HNDL(BuildShuffleStage)); AddHandler(1, &TCoFinalizeByKey::Match, HNDL(BuildFinalizeByKeyStage)); + AddHandler(1, &TCoFinalizeByKeyWithSpilling::Match, HNDL(BuildFinalizeByKeyWithSpillingStage)); AddHandler(1, &TCoPartitionByKey::Match, HNDL(BuildPartitionStage)); AddHandler(1, &TCoTopSort::Match, HNDL(BuildTopSortStage)); AddHandler(1, &TCoSort::Match, HNDL(BuildSortStage)); @@ -140,6 +144,11 @@ class TDqsPhysicalOptProposalTransformer : public TOptimizeTransformerBase { return DqPushCombineToStage(node, ctx, optCtx, *getParents(), IsGlobal); } + template + TMaybeNode PushCombineWithSpillingToStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TGetParents& getParents) { + return DqPushCombineWithSpillingToStage(node, ctx, optCtx, *getParents(), IsGlobal); + } + template TMaybeNode BuildPartitionsStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TGetParents& getParents) { return DqBuildPartitionsStage(node, ctx, optCtx, *getParents(), IsGlobal); @@ -160,6 +169,11 @@ class TDqsPhysicalOptProposalTransformer : public TOptimizeTransformerBase { return DqBuildFinalizeByKeyStage(node, ctx, *getParents(), IsGlobal); } + template + TMaybeNode BuildFinalizeByKeyWithSpillingStage(TExprBase node, TExprContext& ctx, const TGetParents& getParents) { + return DqBuildFinalizeByKeyWithSpillingStage(node, ctx, *getParents(), IsGlobal); + } + TMaybeNode BuildAggregationResultStage(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx) { return DqBuildAggregationResultStage(node, ctx, optCtx); }