Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix CAST bool field to integer issue (#600)
Browse files Browse the repository at this point in the history
* Test painless script

* Support cast bool to numeric value

* Add IT for group by cast alias

* Add comment
  • Loading branch information
dai-chen authored Jul 22, 2020
1 parent e8f0539 commit 7372a44
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class SQLFunctionsIT extends SQLIntegTestCase {
@Override
protected void init() throws Exception {
loadIndex(Index.ACCOUNT);
loadIndex(Index.BANK);
loadIndex(Index.ONLINE);
loadIndex(Index.DATE);
}
Expand Down Expand Up @@ -369,6 +370,54 @@ public void castFieldToDatetimeWithGroupByJdbcFormatTest() {
rows("2019-09-25T02:04:13.469Z"));
}

@Test
public void castBoolFieldToNumericValueInSelectClause() {
JSONObject response =
executeJdbcRequest(
"SELECT "
+ " male, "
+ " CAST(male AS INT) AS cast_int, "
+ " CAST(male AS LONG) AS cast_long, "
+ " CAST(male AS FLOAT) AS cast_float, "
+ " CAST(male AS DOUBLE) AS cast_double "
+ "FROM " + TestsConstants.TEST_INDEX_BANK + " "
+ "WHERE account_number = 1 OR account_number = 13"
);

verifySchema(response,
schema("male", "boolean"),
schema("cast_int", "integer"),
schema("cast_long", "long"),
schema("cast_float", "float"),
schema("cast_double", "double")
);
verifyDataRows(response,
rows(true, 1, 1, 1, 1),
rows(false, 0, 0, 0, 0)
);
}

@Test
public void castBoolFieldToNumericValueWithGroupByAlias() {
JSONObject response =
executeJdbcRequest(
"SELECT "
+ "CAST(male AS INT) AS cast_int, "
+ "COUNT(*) "
+ "FROM " + TestsConstants.TEST_INDEX_BANK + " "
+ "GROUP BY cast_int"
);

verifySchema(response,
schema("cast_int", "cast_int", "double"), //Type is double due to query plan fail to infer
schema("COUNT(*)", "integer")
);
verifyDataRows(response,
rows("0", 3),
rows("1", 4)
);
}

@Test
public void castStatementInWhereClauseGreaterThanTest() {
JSONObject response = executeJdbcRequest("SELECT balance FROM " + TEST_INDEX_ACCOUNT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public static <T> void verifyOrder(JSONArray array, Matcher<T>... matchers) {
assertThat(objects, containsInRelativeOrder(matchers));
}

public static TypeSafeMatcher<JSONObject> schema(String expectedName,
String expectedType) {
return schema(expectedName, null, expectedType);
}

public static TypeSafeMatcher<JSONObject> schema(String expectedName, String expectedAlias,
String expectedType) {
return new TypeSafeMatcher<JSONObject>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,13 +973,10 @@ public String getCastScriptStatement(String name, String castType, List<KVValue>
String castFieldName = String.format("doc['%s'].value", paramers.get(0).toString());
switch (StringUtils.toUpper(castType)) {
case "INT":
return String.format("def %s = Double.parseDouble(%s.toString()).intValue()", name, castFieldName);
case "LONG":
return String.format("def %s = Double.parseDouble(%s.toString()).longValue()", name, castFieldName);
case "FLOAT":
return String.format("def %s = Double.parseDouble(%s.toString()).floatValue()", name, castFieldName);
case "DOUBLE":
return String.format("def %s = Double.parseDouble(%s.toString()).doubleValue()", name, castFieldName);
return getCastToNumericValueScript(name, castFieldName, StringUtils.toLower(castType));
case "STRING":
return String.format("def %s = %s.toString()", name, castFieldName);
case "DATETIME":
Expand All @@ -990,6 +987,14 @@ public String getCastScriptStatement(String name, String castType, List<KVValue>
}
}

private String getCastToNumericValueScript(String varName, String docValue, String targetType) {
String script =
"def %1$s = (%2$s instanceof boolean) "
+ "? (%2$s ? 1 : 0) "
+ ": Double.parseDouble(%2$s.toString()).%3$sValue()";
return StringUtils.format(script, varName, docValue, targetType);
}

/**
* Returns return type of script function. This is simple approach, that might be not the best solution in the long
* term. For example - for JDBC, if the column type in index is INTEGER, and the query is "select column+5", current
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.amazon.opendistroforelasticsearch.sql.legacy.executor.format.Schema;
import com.amazon.opendistroforelasticsearch.sql.legacy.utils.SQLFunctions;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import org.elasticsearch.common.collect.Tuple;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -39,11 +40,12 @@
import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class SQLFunctionsTest {

private SQLFunctions sqlFunctions;
private SQLFunctions sqlFunctions = new SQLFunctions();

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -96,4 +98,16 @@ public void testCastReturnType() {
final Schema.Type returnType = sqlFunctions.getScriptFunctionReturnType(field, resolvedType);
Assert.assertEquals(returnType, Schema.Type.INTEGER);
}

@Test
public void testCastIntStatementScript() throws SqlParseException {
assertEquals(
"def result = (doc['age'].value instanceof boolean) "
+ "? (doc['age'].value ? 1 : 0) "
+ ": Double.parseDouble(doc['age'].value.toString()).intValue()",
sqlFunctions.getCastScriptStatement(
"result", "int", Arrays.asList(new KVValue("age")))
);
}

}

0 comments on commit 7372a44

Please sign in to comment.