Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the simple UDAF interface with function-level states #9167

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

liujiayi771
Copy link
Contributor

@liujiayi771 liujiayi771 commented Mar 20, 2024

The simple function interface for UDAFs currently doesn't allow function-level
states. Function-level states are variables hold by a UDAF instance that are
typically computed once and used at every row when adding inputs to accumulators
or extracting values from accumulators. With the traditional vector function
interface, UDAF authors use function-level states by simply defining them as
data members of the function class. On the contrary, the simple function
interface doesn’t expose any data member of the function class to the
author-defined accumulator logics.

Below are a few examples where function-level states are necessary and useful.

Decimal functions usually need to know the precision and scale of the result
Decimal type when extract values from accumulators. The result type is currently
not exposed to author-defined accumulator type in the simple function interface.
Some functions perform heavy computation on a constant argument before
processing all input rows, such as approx_most_frequent, and store the
computation result in a function-level state. With the current simple function
interface, the UDAF author would have to do the computation on the constant
argument at every row. To enable function-level states, we can extend
SimpleAggregateAdapter to allow the UDAF author to define a FunctionState struct
in the function class and let SimpleAggregateAdapter hold an instance of
FunctionState. The author then implement a void initialize(FunctionState& state,
const TypePtr& resultType, const std::vector& constantInputs) in
their function class that assign values to the FunctionState instance. This
initialize() function will be called only once when the aggregation function
instance is created. Finally, all methods in the author-defined accumulator type
receive this FunctionState instance as a const argument.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 20, 2024
Copy link

netlify bot commented Mar 20, 2024

Deploy Preview for meta-velox canceled.

Name Link
🔨 Latest commit 28855c3
🔍 Latest deploy log https://app.netlify.com/sites/meta-velox/deploys/662b19359b97460008fb5354

@liujiayi771
Copy link
Contributor Author

liujiayi771 commented Mar 20, 2024

Hi @mbasmanova, @kagamiori, we need to use the FunctionState mentioned in #8711 for Spark decimal average. I have made some extensions based on the existing code examples in #8710, could you help review?

@liujiayi771
Copy link
Contributor Author

cc @rui-mo.

@liujiayi771
Copy link
Contributor Author

liujiayi771 commented Mar 28, 2024

Hi @mbasmanova @kagamiori, Could you help review? We need this change to continue to work on decimal avg and set_agg for spark.

@kagamiori
Copy link
Contributor

I'll take a look. Thank you for working on this!

Comment on lines 106 to 107
info.function->initialize(
aggregate.rawInputTypes, aggResultType, info.constantInputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, I'm starting reviewing this PR. I was busy with the VeloxCon and sorry for the delay. Could you remind me of the use case where you need the rawInputTypes in the function-level state? Also, I think aggResultType is already passed to info.function when it's created through Aggregate::create(). Do we still need aggResultType in the function-level state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kagamiori, thank you for taking the time to review this PR.

Could you remind me of the use case where you need the rawInputTypes in the function-level state?

In the Spark decimal avg, I need to calculate the decimal type of the sum in the intermediateType based on the rawInputType, and obtain the correct precision and scale. We have discussed in #9048 (comment) that it is incorrect to calculate based on the resultType. When the precision of the resultType > 34, we cannot infer the precision of the sum.

Do we still need aggResultType in the function-level state?

Can we access the resultType info from aggInfo in the writeIntermediateResult and writeFinalResult methods of SimpleAggregateAdapter's struct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for reminding me of the discussion in #9048. Could you also share a concrete example about how you use the change in this PR to solve the problem of the merge_extract function of decimal avg()? I'd like to understand it better.

For example, suppose you're calling companion functions of avg(decimal(30, 20)) -> row(decimal(38, 20), bigint) -> decimal(34, 24). Do you use the AggregateCompanionAdapter to generate the companion functions, and what rawInputTypes and outputType_ do you set in the AggregationNode of them?

Copy link
Contributor Author

@liujiayi771 liujiayi771 Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kagamiori. In fact, the problem of the companion function not being registered in this case has not been resolved. We still cannot use the aggregate companion function, and can only register a normal aggregate function.

However, we intend to use this normal aggregate function in Gluten. If we can obtain the rawInputType in the function state, we can implement the logic for Spark decimal avg based on the simple function interface. I already have a basic implementation, and the code is here.

The sumType calculated based on rawInputType is needed when computing the final result of decimal avg.

Copy link
Contributor

@kagamiori kagamiori Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for the clarification. Are you still going to use the avg_merge_extract function for decimal type and do you plan to implement it manually as an aggregation function separate from avg? I'm asking because you mentioned in #9048 (reply in thread) that you plan to pass the rawInputType of avg as the input type to the factory of avg_merge_extract. I'm concerned that this does not work with the current auto-generated factory code of merge_extract companion functions because the auto-generated code uses the input type to resolve the original result type of avg.

const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);

Copy link
Contributor Author

@liujiayi771 liujiayi771 Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kagamiori Are you saying that, for the merge_extract companion function, if we follow the implementation of this PR, the rawInputType received in its initialize method (used to initialize the function state of this aggregate function) is actually the intermediateType of the original aggregate function?

If I understand your concern correctly, my answer is that I plan to use the current auto-generated factory code of merge_extract companion function. However, I can still get the correct sumType from its rawInputType. Yes, for the merge_extract companion function, its rawInputType is indeed the intermediateType of the original aggregate function. However, this does not affect Spark decimal avg's logic to obtain its sumType; the sumType doesn't always need to be obtained from the original rawInputType. For the merge_extract companion function, its rawInputType->childAt(0) is the sumType. I can use the initialize method to determine whether the current initialization is for the merge_extract companion function based on whether rawInputType is a ROW.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for the explanation. I remember the intermediate type is ROW(DECIMAL(min(38, a_precision + 10), a_scale), bigint) while the result type is DECIMAL(min(38, a_precision + 4), min(38, a_scale + 4)). As we discussed in #9048, min(38, a_precision + 4) may not be inferable from min(38, a_precision + 10) if a_precision >= 28. The AggregateCompanionAdapter would require the result type be inferable from intermediate type. How would you plan to address this problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kagamiori. As we discussed before, it is not possible to resolve the resultType through intermediateType in Velox. But in the case of Spark, this can be resolved through a tricky workaround, and this tricky code will only be kept in our forked repo. In Velox, the avg function signature of decimal still does not support companion aggregate functions.

The signature for decimal is

exec::AggregateFunctionSignatureBuilder()
        .integerVariable("a_precision")
        .integerVariable("a_scale")
        .integerVariable("i_precision", "min(38, a_precision + 10)")
        .integerVariable("r_precision", "min(38, a_precision + 4)")
        .integerVariable("r_scale", "min(38, a_scale + 4)")
        .argumentType("DECIMAL(a_precision, a_scale)")
        .intermediateType("ROW(DECIMAL(i_precision, a_scale), bigint)")
        .returnType("DECIMAL(r_precision, r_scale)")
        .build()

This signature will not register companion aggregate functions, because isResultTypeResolvableGivenIntermediateType is false. So, it does not involve resolve resultType from intermediateType, and there is no issue.

But if we still want to use the auto-generated avg_merge_extract, this is where the tricky part comes in. We will only add this tricky approach in our forked repo and introduce a new signature.

exec::AggregateFunctionSignatureBuilder()
        .integerVariable("a_precision")
        .integerVariable("a_scale")
        .argumentType("DECIMAL(a_precision, a_scale)")
        .intermediateType("ROW(DECIMAL(a_precision, a_scale), BIGINT)")
        .returnType("DECIMAL(a_precision, a_scale)")
        .build()

This signature does not correspond to the actual situation. Its purpose is to register companion aggregate functions for decimal type. The type of this signature is the same as the one above, but the precision and scale of decimal are incorrect. However, this situation still works fine for Spark, because Spark does not use function type resolution logic in Velox, and we have removed the resolve check for merge_extract function in our repo. All types are defined by the Spark plan, not based on the signature resolution by Velox. We entirely rely on the types in the Spark plan to build the vector types used in operators and functions. You can understand that Spark itself already includes the type resolution functionality, and we simply build the Velox plan and execute it based on the types in the Spark plan.

We can think of a better way to resolve the types for the merge_extract function of the decimal avg later, but even if we don't register the companion functions, the decimal avg still needs the information in the function state to complete the final result calculation.

Copy link
Contributor

@kagamiori kagamiori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for putting together this PR! I left a few comments in the code. I haven't read the changes in the documentation and will review them after we sort out the code.

velox/exec/SimpleAggregateAdapter.h Outdated Show resolved Hide resolved
velox/exec/AggregateCompanionAdapter.cpp Outdated Show resolved Hide resolved
velox/exec/AggregateWindow.cpp Outdated Show resolved Hide resolved
velox/exec/AggregateCompanionAdapter.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@kagamiori kagamiori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for updating! Since this change adds the support for initialize() for companion functions and aggregate window functions, we should add unit tests for them too. Could you make a dummy aggregation function just for testing via the simple UDAF interface that has function-level state? Ideally, its initialize() performs differently depending on different steps. Then, let's add unit test to verify the correct behavior for the simple UDAF, its companion functions, and the window function.

Comment on lines 76 to 82
void AggregateCompanionFunctionBase::initialize(
core::AggregationNode::Step step,
const std::vector<TypePtr>& rawInputType,
const facebook::velox::TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs) {
fn_->initialize(step, rawInputType, resultType, constantInputs);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this method to every derived class since they need to pass different step to the main function's initialize() method.

velox/exec/SimpleAggregateAdapter.h Outdated Show resolved Hide resolved
Comment on lines 489 to 491
using InputType = Row<int64_t>; // Input vector type wrapped in Row.
using IntermediateType = int64_t; // Intermediate result type.
using OutputType = int64_t; // Output vector type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the unit test!

Could we make the input, intermediate, and result types all be different, so that in initialize(), we can check that when step is partial, rawInputType is always the same as InputType, and when step is intermediate, rawInputType is always the same as IntermediateType? Also in initialize(), we can check that result type is expected.

Also, could we let this UDAF receive two arguments and make one of them constant literal in the aggregation function calls, e.g., call something like simple_function_state_agg(c0, 1) in unit tests. Then, let's check in initialize for the partial and main aggregation function that constantInputs is set correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimizations have been made to the test cases.

@liujiayi771 liujiayi771 force-pushed the func-state branch 2 times, most recently from 7a15313 to 44560cc Compare April 17, 2024 14:45
velox/exec/tests/SimpleAggregateAdapterTest.cpp Outdated Show resolved Hide resolved
@@ -471,5 +483,324 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

// A testing aggregation function that uses the function state.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for iterating on the unit test! I find the unit test still a bit hard to understand. Could we simplify it a bit by separating the checks for different purposes?

We need to test the function-level states from three aspects:

  1. function-level states work correctly with the simple UDAF interface.
  2. companion functions receive correct function-level states and work properly.
  3. the UDAF with function-level states work correctly in window operations.

So I propose that we construct the unit tests as follows:

  1. We don't need to check the function state when processing every row. So we can just move the checks into the UDAF's initialize() method.
  2. We can make the FunctionStateTestAggregate class templated by a bool testCompanion. Inside initialize(), if testCompanion is true, we check that rawInputType is always {bigint, bigint} when step is partial, and it's always {row<bigint, double>} when step is intermediate. Plus, the constantInput is set correctly when step is partial. On the other hand, if testCompanion is false, we check that rawInputType is always {bigint, bigint} and constantInput is set correctly for all steps.
  3. Let the addInput() method not use the argument increment but the constant value in the function state instead.
  4. We can register a UDAF simple_function_state_agg_main with testCompanion being false. With this UDAF, we can use the testAggregations API to check that results are correct. This test would tell us that function-level state works properly with simple UDAFs. (You don't need testIncrementalAggregation and testStreaming since those focus more on individual accumulator design. You can turn them off by calling disableTestStreaming() and disableTestIncremental() so that you don't need to modify AggregationTestBase.cpp.)
  5. We can register another UDAF simple_function_state_agg_companion with testCompanion being true. With this UDAF, we use testAggregationsWithCompanion() to check the results are correct.
  6. Let's also use the testWindowFunction() API to test that the main UDAF works correctly in window operations.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kagamiori Thank you very much for providing the detailed suggestion. I will refactor the current test cases based on these suggestions.

Copy link
Contributor Author

@liujiayi771 liujiayi771 Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kagamiori If we use testAggregationsWithCompanion to test companion aggregate functions, we will encounter a scenario where simple_function_state_agg_companion_partial's actual step is kFinal. In this case, the step passed in initialize method is kPartial and the constantInputs is {nullptr}. However, when the step is kPartial, we would expect constantInputs to be {nullptr, 1}. So, I chose to use single aggregation to test companion aggregate functions before.

Should we add a companionStep to differentiate the actual step in function state? Or can we just check the single aggregation of companion functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, that's good finding. Let's pass an additional companionStep to the initialize() method to make it clear what that aggregation step is doing.

@liujiayi771 liujiayi771 force-pushed the func-state branch 3 times, most recently from 69e3739 to 569ac66 Compare April 23, 2024 15:23
Copy link
Contributor

@kagamiori kagamiori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, thank you for refactoring the unit test! It looks much better now. The code looks mostly good to me. I left a few comments.

Comment on lines 132 to 137
// Initialize the function-level state of the simple function interface for
// UDAF.
// @param step The aggregation step.
// @param rawInputType The raw input type of the UDAF.
// @param resultType The result type of the UDAF.
// @param constantInputs Optional constant inputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the comment to include companionStep as well and explain what constant inputs mean. resultType is the result type of the current aggregation step, not necessarily the entire UDAF.

rawInputType,
resultType,
constantInputs,
core::AggregationNode::Step::kIntermediate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also override the initialize() method of MergeExtractFunction to pass kFinal as the companion step.

rawInputTypes,
outputType,
constantInputs,
core::AggregationNode::Step::kIntermediate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pass kFinal as the companion step too.

@@ -374,6 +391,7 @@ class AggregateWindowFunction : public exec::WindowFunction {
std::vector<TypePtr> argTypes_;
std::vector<column_index_t> argIndices_;
std::vector<VectorPtr> argVectors_;
std::vector<VectorPtr> constantInputs_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment explaining what constantInputs_ is.

Comment on lines 529 to 533
VELOX_CHECK(std::equal(
rawInputTypes.begin(),
rawInputTypes.end(),
expectedRawInputTypes.begin(),
expectedRawInputTypes.end()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would VELOX_CHECK(rawInputTypes == expectedRawInputTypes) work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this can work.

const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_LE(argTypes.size(), 2, "{} takes 2 argument", name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "{} takes at most 2 arguments"

SimpleAggregateAdapter<FunctionStateTestAggregate<testCompanion>>>(
resultType);
},
true /*registerCompanionFunctions*/,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we only need to register companion functions when testCompanion is true.

@@ -471,5 +484,233 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

// A testing aggregation function that uses the function state.
template <bool testCompanion>
class FunctionStateTestAggregate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment about what this UDAF does, e.g., the checks for expectations in initialize and the weighted average it calculates for input values.

WindowTestBase::testWindowFunction(
{inputVectors},
"simple_function_state_agg_main(c0, 1)",
{"partition by c0"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not partition by c0 since this makes the operation trivial. You can add an addition column to the input and partition by that.

Copy link
Contributor

@kagamiori kagamiori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except a few nits.

@mbasmanova, could you also help take a look since this PR adds an initialize() method in the Aggregate class?

Comment on lines 155 to 156
// If UDAF does not require the use of FunctionState, it is necessary
// to declare an empty FunctionState struct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Define a struct for function-level states. Even if the aggregation function doesn't use function-level states, it is still necessary to define an empty FunctionState struct.

TypePtr resultType;
};

// Optional. Used only when the UDAF needs to use FunctionState.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add: This method is called once when the aggregation function is created.

@@ -169,6 +187,15 @@ function's argument type(s) wrapped in a Row<> even if the function only takes
one argument. This is needed for the SimpleAggregateAdapter to parse input
types for arbitrary aggregation functions properly.

A FunctionState struct needs to be declared in the simple aggregation function
class, it is used to hold the function-level variables that are typically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ...in the simple aggregation function class. FunctionState is initialized once when the aggregation function is created and used at every row when adding inputs to accumulators or extracting values from accumulators...

extracting values from accumulators. For example, if the UDAF needs to get the
result type or the raw input type of the aggregaiton function, the author can
hold them in the FunctionState struct, and initialize them in the initialize()
method. If the UDAF does not require the use ofFunctionState, it is necessary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's use aggregation function instead of UDAF in this documentation to be consistent with the existing content.

ofFunctionState --> of FunctionState

Copy link
Contributor

@mbasmanova mbasmanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liujiayi771 Would it be possible to update the PR description to provide more context? Perhaps, describe a specific use case and how it requires these changes and how it will be using them.

@@ -152,13 +152,33 @@ A simple aggregation function is implemented as a class as the following.
using IntermediateType = Array<Generic<T1>>;
using OutputType = Array<Generic<T1>>;

// Define a struct for function-level states. Even if the aggregation function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make 'FunctionState' struct optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to pass FunctionState as an input parameter for methods like addInput, combine, etc. If it's not defined, It will result in a compilation failure.

const std::vector<TypePtr>& rawInputTypes,
const TypePtr& resultType,
const std::vector<VectorPtr>& constantInputs,
std::optional<core::AggregationNode::Step> companionStep) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document all the parameters? What is 'companionStep'? It seems strange that we have such a parameter as function implementations should be agnostic to whether they are used as "regular" or a "companion" function.

CC: @kagamiori

Copy link
Contributor Author

@liujiayi771 liujiayi771 Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added document for all parameters in the definition of initialize in Aggregate.h. Should I also add these comments in the rst documentation?

For the partial companion function, we need to know that its companion step is kPartial, while the agg function itself includes kPartial, kFinal, or kSingle steps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liujiayi771 Thank you for updating PR description. It sounds like we want to allow stateful simple aggregate functions. This makes sense, but I wonder if we can make it work similar to Simple Function API for scalar functions. There, function author defines a struct with call-once initialize and call-per-row call methods. The author is then free to add member variables to hold state and initialize it however they want from 'initialize'. Would it make sense to follow this pattern for aggregate functions as well?

The result type is currently not exposed to author-defined accumulator type in the simple function interface.

Can we expose this?

Some functions perform heavy computation on a constant argument before
processing all input rows, such as approx_most_frequent, and store the
computation result in a function-level state.

Would you clarify what is "heavy computation" done by approx_most_frequent to help readers understand a bit better?

function-level state.

I wonder if a more accurate term would be "per-instance state". There is only one function, foo, but there are many instances of 'foo' and each has its own state, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mbasmanova,

but I wonder if we can make it work similar to Simple Function API for scalar functions. There, function author defines a struct with call-once initialize and call-per-row call methods. The author is then free to add member variables to hold state and initialize it however they want from 'initialize'.

The SimpleAggregateAdapter currently doesn't hold an instance of the user-defined simple UDAF class (i.e., it only creates instances of the AccumulatorType struct inside the UDAF class). We can change SimpleAggregateAdapter to hold an instance of the UDAF class if we want to allow authors to freely access member variables in it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change SimpleAggregateAdapter to hold an instance of the UDAF class if we want to allow authors to freely access member variables in it.

This would be nice. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kagamiori, I will try to understand this method and see how it can be modified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liujiayi771, sorry for the delay. Here is a code pointer of how the simple scalar function interface calls the initialize() method:

return (*fn_).initialize(inputTypes, config, values...);

The fn_ here is a std::unique_ptr of the UDF class. Because this UDF instance is created, the author can add function-level states as member variables in the UDF class and the UDF's initialize() and call() member methods can access them directly. An example UDF that uses initialize() is below.

FOLLY_ALWAYS_INLINE void initialize(

What @mbasmanova suggested is that we can do it similar in the SimpleAggregateAdapter so that the UDAF authors doesn't have to keep function-level states. Specifically, below is what I’m thinking:

  1. The UDAF author doesn’t define a FunctionState struct, but rather add function-level variables as data members in the UDAF class (outside of its AccumulatorType struct).
  2. The UDAF class has an initialize() method that receives the aggregation step, the types, and the constantInput, and assigns values to the data members in the UDAF class.
  3. The AccumulatorType struct has a data member that is a pointer to the UDAF class. This would allow member methods inside AccumulatorType to access data members in the UDAF class. This UDAF-pointer can be set in SimpleAggregateAdapter::initializeNewGroupsInternal().

I’ll try to make a prototype to see if this works. Let’s discuss and review this design before coding in #8711 first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reply @kagamiori. Let's discuss in #8711 further once the prototype has been validated for feasibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kagamiori. Do you have time to make a prototype for this design. cc @rui-mo.

@@ -129,6 +129,27 @@ class Aggregate {
rowSizeOffset);
}

// Initialize the function-level state of the simple function interface for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... this API should not be aware of Simple Function Interface... looks like there might be some leak in the design.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mbasmanova, thank you for the feedback. I saw that for simple scalar functions, we call initialize() in the constructor of SimpleFunctionAdapter. It can do this because ExprCompiler passes constantInputs to the function factory. The Aggregate::create() and aggregation function factory currently do not receive constantInputs as an argument. What about we pass constantInputs to them and move the call of initialize() into the constructor of SimpleAggregateAdapter? We'll have to pass constantInputs to all aggregation function factories though, since we cannot tell simple UDAFs from regular UDAFs apart in the function registry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Aggregate::create() and aggregation function factory currently do not receive constantInputs as an argument. What about we pass constantInputs to them and move the call of initialize() into the constructor of SimpleAggregateAdapter?

That sounds good. Thanks.

@liujiayi771
Copy link
Contributor Author

Would it be possible to update the PR description to provide more context? Perhaps, describe a specific use case and how it requires these changes and how it will be using them.

I think the explanation in the #8711 is already very detailed. I will copy the content into the description of this PR.

CC: @kagamiori

@liujiayi771
Copy link
Contributor Author

@kagamiori Gentle ping. Do you have time to make a prototype for this design?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants