Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fixed interval type null/missing check failure (#1011)
Browse files Browse the repository at this point in the history
* fixed issue #991

* update

* addressed comments
  • Loading branch information
chloe-zh authored Jan 28, 2021
1 parent 959458b commit 4ec31b0
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.define;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.impl;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntervalValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
Expand Down Expand Up @@ -54,8 +55,8 @@ public void register(BuiltinFunctionRepository repository) {

private FunctionResolver interval() {
return define(BuiltinFunctionName.INTERVAL.getName(),
impl(IntervalClause::interval, INTERVAL, INTEGER, STRING),
impl(IntervalClause::interval, INTERVAL, LONG, STRING));
impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING),
impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING));
}

private ExprValue interval(ExprValue value, ExprValue unit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
package com.amazon.opendistroforelasticsearch.sql.expression.datetime;

import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.intervalValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.missingValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.nullValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTERVAL;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
Expand All @@ -39,6 +43,12 @@ public class IntervalClauseTest extends ExpressionTestBase {
@Mock
Environment<Expression, ExprValue> env;

@Mock
Expression nullRef;

@Mock
Expression missingRef;

@Test
public void microsecond() {
FunctionExpression expr = dsl.interval(DSL.literal(1), DSL.literal("microsecond"));
Expand Down Expand Up @@ -114,4 +124,22 @@ public void to_string() {
FunctionExpression expr = dsl.interval(DSL.literal(1), DSL.literal("day"));
assertEquals("interval(1, \"day\")", expr.toString());
}

@Test
public void null_value() {
when(nullRef.type()).thenReturn(INTEGER);
when(nullRef.valueOf(env)).thenReturn(nullValue());
FunctionExpression expr = dsl.interval(nullRef, DSL.literal("day"));
assertEquals(INTERVAL, expr.type());
assertEquals(nullValue(), expr.valueOf(env));
}

@Test
public void missing_value() {
when(missingRef.type()).thenReturn(INTEGER);
when(missingRef.valueOf(env)).thenReturn(missingValue());
FunctionExpression expr = dsl.interval(missingRef, DSL.literal("day"));
assertEquals(INTERVAL, expr.type());
assertEquals(missingValue(), expr.valueOf(env));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ public void testNullLiteralInFunction() {
rows(null, null));
}

@Test
public void testNullLiteralInInterval() {
verifyDataRows(
query("SELECT INTERVAL NULL DAY, INTERVAL 60 * 60 * 24 * (NULL - FLOOR(NULL)) SECOND"),
rows(null, null)
);
}

private JSONObject query(String sql) {
return new JSONObject(executeQuery(sql, "jdbc"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Alias;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Aggregation;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan;
Expand Down Expand Up @@ -125,36 +124,26 @@ private List<UnresolvedExpression> replaceGroupByItemIfAliasOrOrdinal() {
*/
private Optional<UnresolvedExpression> findNonAggregatedItemInSelect() {
return querySpec.getSelectItems().stream()
.filter(this::isNonAggregatedExpression)
.filter(this::isNonLiteralFunction)
.filter(this::isNonAggregateOrLiteralExpression)
.findFirst();
}

private boolean isAggregatorNotFoundAnywhere() {
return querySpec.getAggregators().isEmpty();
}

private boolean isNonLiteralFunction(UnresolvedExpression expr) {
// The base case for recursion
if (expr instanceof Literal) {
private boolean isNonAggregateOrLiteralExpression(UnresolvedExpression expr) {
if (expr instanceof AggregateFunction) {
return false;
}
if (expr instanceof Function) {
List<? extends Node> children = expr.getChild();
return children.stream().anyMatch(child ->
isNonLiteralFunction((UnresolvedExpression) child));
}
return true;
}

private boolean isNonAggregatedExpression(UnresolvedExpression expr) {
if (expr instanceof AggregateFunction) {
return false;
if (expr instanceof QualifiedName) {
return true;
}

List<? extends Node> children = expr.getChild();
return children.stream()
.allMatch(child -> isNonAggregatedExpression((UnresolvedExpression) child));
return children.stream().anyMatch(child ->
isNonAggregateOrLiteralExpression((UnresolvedExpression) child));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() {
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT INTERVAL 1 DAY FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT CAST(1 AS LONG) FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));

assertThat(
buildAggregation("SELECT CASE WHEN true THEN 1 ELSE 2 END FROM test HAVING AVG(age) > 30"),
allOf(
hasGroupByItems(),
hasAggregators(
alias("AVG(age)", aggregate("AVG", qualifiedName("age"))))));
}

@Test
Expand Down

0 comments on commit 4ec31b0

Please sign in to comment.