Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-26474][hive] Fold exprNode to fix the issue of failing to call some hive udf required constant parameters with implicit constant passed #18975

Merged
merged 4 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2307,7 +2307,8 @@ private RelNode genSelectLogicalPlan(
} else {
// Case when this is an expression
HiveParserTypeCheckCtx typeCheckCtx =
new HiveParserTypeCheckCtx(inputRR, frameworkConfig, cluster);
new HiveParserTypeCheckCtx(
inputRR, true, true, frameworkConfig, cluster);
// We allow stateful functions in the SELECT list (but nowhere else)
typeCheckCtx.setAllowStatefulFunctions(true);
if (!qbp.getDestToGroupBy().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
import org.apache.hadoop.hive.common.type.Decimal128;
import org.apache.hadoop.hive.common.type.HiveChar;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.HiveIntervalDayTime;
import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth;
import org.apache.hadoop.hive.common.type.HiveVarchar;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
Expand Down Expand Up @@ -456,9 +458,21 @@ public static RexNode convertConstant(ExprNodeConstantDesc literal, RelOptCluste
default:
if (hiveShim.isIntervalYearMonthType(hiveTypeCategory)) {
// Calcite year-month literal value is months as BigDecimal
BigDecimal totalMonths =
BigDecimal.valueOf(
((HiveParserIntervalYearMonth) value).getTotalMonths());
BigDecimal totalMonths;
if (value instanceof HiveParserIntervalYearMonth) {
totalMonths =
BigDecimal.valueOf(
((HiveParserIntervalYearMonth) value).getTotalMonths());
} else if (value instanceof HiveIntervalYearMonth) {
totalMonths =
BigDecimal.valueOf(
((HiveIntervalYearMonth) value).getTotalMonths());
} else {
throw new SemanticException(
String.format(
"Unexpected class %s for Hive's interval day time type",
value.getClass().getName()));
}
calciteLiteral =
rexBuilder.makeIntervalLiteral(
totalMonths,
Expand All @@ -467,12 +481,30 @@ public static RexNode convertConstant(ExprNodeConstantDesc literal, RelOptCluste
} else if (hiveShim.isIntervalDayTimeType(hiveTypeCategory)) {
// Calcite day-time interval is millis value as BigDecimal
// Seconds converted to millis
BigDecimal secsValueBd =
BigDecimal.valueOf(
((HiveParserIntervalDayTime) value).getTotalSeconds() * 1000);
BigDecimal secsValueBd;
// Nanos converted to millis
BigDecimal nanosValueBd;
if (value instanceof HiveParserIntervalDayTime) {
secsValueBd =
BigDecimal.valueOf(
((HiveParserIntervalDayTime) value).getTotalSeconds()
* 1000);
nanosValueBd =
BigDecimal.valueOf(
((HiveParserIntervalDayTime) value).getNanos(), 6);
} else if (value instanceof HiveIntervalDayTime) {
secsValueBd =
BigDecimal.valueOf(
((HiveIntervalDayTime) value).getTotalSeconds() * 1000);
nanosValueBd =
BigDecimal.valueOf(((HiveIntervalDayTime) value).getNanos(), 6);
} else {
throw new SemanticException(
String.format(
"Unexpected class %s for Hive's interval day time type.",
value.getClass().getName()));
}
// Nanos converted to millis
BigDecimal nanosValueBd =
BigDecimal.valueOf(((HiveParserIntervalDayTime) value).getNanos(), 6);
calciteLiteral =
rexBuilder.makeIntervalLiteral(
secsValueBd.add(nanosValueBd),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFNvl;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen;
Expand Down Expand Up @@ -1257,11 +1258,31 @@ protected ExprNodeDesc getXpathOrFuncExprNodeDesc(
if (FunctionRegistry.isOpPositive(desc)) {
assert (desc.getChildren().size() == 1);
desc = desc.getChildren().get(0);
} else if (getGenericUDFClassFromExprDesc(desc) == GenericUDFOPNegative.class) {
// UDFOPNegative should always be folded.
assert (desc.getChildren().size() == 1);
ExprNodeDesc input = desc.getChildren().get(0);
if (input instanceof ExprNodeConstantDesc
&& desc instanceof ExprNodeGenericFuncDesc) {
ExprNodeDesc constantExpr =
ConstantPropagateProcFactory.foldExpr((ExprNodeGenericFuncDesc) desc);
if (constantExpr != null) {
desc = constantExpr;
}
}
}
assert (desc != null);
return desc;
}

private Class<? extends GenericUDF> getGenericUDFClassFromExprDesc(ExprNodeDesc desc) {
if (!(desc instanceof ExprNodeGenericFuncDesc)) {
return null;
}
ExprNodeGenericFuncDesc genericFuncDesc = (ExprNodeGenericFuncDesc) desc;
return genericFuncDesc.getGenericUDF().getClass();
}

// try to create an ExprNodeDesc with a SqlOperator
private ExprNodeDesc convertSqlOperator(
String funcText, List<ExprNodeDesc> children, HiveParserTypeCheckCtx ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ public void testCastTimeStampToDecimal() throws Exception {
timestamp))
.collect());
assertThat(results.toString())
.isEqualTo(String.format("[+I[%s]]", expectTimeStampDecimal.toFormatString(8)));
.isEqualTo(String.format("[+I[%s]]", expectTimeStampDecimal));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason for this change? IIUC, the toFormatString(8) is on purpose because it is cast to decimal(30,8).

Copy link
Contributor Author

@luoyuxia luoyuxia Aug 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's. But there's a special case when comes to cast constant and constant fold is enabled. Actually, the current behavior is same to Hive.
I try with the following sql in Hive:

hive> select cast(cast('2012-12-19 11:12:19.1234567' as timestamp) as decimal(30,8));
1355915539.1234567

hive> insert into t2 values('2012-12-19 11:12:19.1234567')

hive> select  cast(c2 as decimal(30, 8)) from t2;
1355915539.12345670

hive > insert into t1 select * from t2;

hive > select * from t1;
1355915539.12345670

The plan for the sql in hive select cast(cast('2012-12-19 11:12:19.1234567' as timestamp) as decimal(30,8)) is:

STAGE PLANS:
  Stage: Stage-0
    Fetch Operator
      limit: -1
      Processor Tree:
        TableScan
          alias: _dummy_table
          Row Limit Per Split: 1
          Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: COMPLETE
          Select Operator
            expressions: 1355915539.1234567 (type: decimal(30,8))
            outputColumnNames: _col0
            Statistics: Num rows: 1 Data size: 112 Basic stats: COMPLETE Column stats: COMPLETE
            ListSink

The plan for select cast(c1 as decimal(30, 8)) from t1 is :

STAGE DEPENDENCIES:
  Stage-0 is a root stage

STAGE PLANS:
  Stage: Stage-0
    Fetch Operator
      limit: -1
      Processor Tree:
        TableScan
          alias: t1
          Statistics: Num rows: 1 Data size: 112 Basic stats: COMPLETE Column stats: NONE
          Select Operator
            expressions: CAST( c1 AS decimal(30,8)) (type: decimal(30,8))
            outputColumnNames: _col0
            Statistics: Num rows: 1 Data size: 112 Basic stats: COMPLETE Column stats: NONE
            ListSink

The reason I found is the HiveDecimalConverter used to convert data in Hive's GenericUDFToDecimal function actually won't padding zero for 2012-12-19 11:12:19.1234567, althogh the type is decimal(30,8).
Then, the first sql will select a constant 1355915539.1234567.
But for the second sql, a further padding will be done which will result 1355915539.12345670.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation.


// test insert timestamp type to decimal type directly
tableEnv.executeSql("create table t1 (c1 DECIMAL(38,6))");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- SORT_QUERY_RESULTS

select bround(55.0, -1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any test which the second parameter is positive?

Copy link
Contributor Author

@luoyuxia luoyuxia Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add a query select bround(55.0, +1) to cover the positive number with + sign.
For the positive number without + sign, I think the query select percentile_approx(x, 0.5) from foo in udaf.q has covered it.


[+I[60]]

select bround(55.0, +1);

[+I[55]]

select round(123.45, -2);

[+I[100]]

select sha2('ABC', cast(null as int));

[+I[null]]