diff --git a/build.sbt b/build.sbt index 151af7a67c2..8068bbab8d0 100644 --- a/build.sbt +++ b/build.sbt @@ -68,6 +68,7 @@ val hadoopVersion = "3.3.4" val scalaTestVersion = "3.2.15" val scalaTestVersionForConnectors = "3.0.8" val parquet4sVersion = "1.9.4" +val icu4jVersion = "75.1" // Versions for Hive 3 val hadoopVersionForHive3 = "3.1.0" @@ -656,6 +657,7 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) "org.apache.hadoop" % "hadoop-client-runtime" % hadoopVersion, "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5", "org.apache.parquet" % "parquet-hadoop" % "1.12.3", + "com.ibm.icu" % "icu4j" % icu4jVersion, "org.scalatest" %% "scalatest" % scalaTestVersion % "test", "junit" % "junit" % "4.13.2" % "test", diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet.crc b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet.crc new file mode 100644 index 00000000000..bbc0e0b700b Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet.crc differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet.crc b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet.crc new file mode 100644 index 00000000000..c45a2045d3e Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/.part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet.crc differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000000.json.crc b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 00000000000..145dec7c2d7 Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000000.json.crc differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000001.json.crc b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000001.json.crc new file mode 100644 index 00000000000..e2b9b3bbaa1 Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000001.json.crc differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000002.json.crc b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000002.json.crc new file mode 100644 index 00000000000..c333c466fcb Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/.00000000000000000002.json.crc differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000000.json b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000000.json new file mode 100644 index 00000000000..7e7d0d3945b --- /dev/null +++ b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"commitInfo":{"timestamp":1725838990599,"operation":"CREATE TABLE","operationParameters":{"partitionBy":"[]","clusterBy":"[]","description":null,"isManaged":"true","properties":"{}"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Apache-Spark/3.5.2 Delta-Lake/3.2.0","txnId":"0c4e0d4d-0f13-4726-ad61-f10192ab81e2"}} +{"metaData":{"id":"9a7918b4-42a5-4b47-bf27-0e8a7289d654","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"c1\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1725838990522}} +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000001.json b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000001.json new file mode 100644 index 00000000000..b2f7b82b1da --- /dev/null +++ b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1725839016423,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"440"},"engineInfo":"Apache-Spark/3.5.2 Delta-Lake/3.2.0","txnId":"94355dc1-e083-4da0-9934-716b093eaf3a"}} +{"add":{"path":"part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet","partitionValues":{},"size":440,"modificationTime":1725839016378,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c1\":\"a\"},\"maxValues\":{\"c1\":\"a\"},\"nullCount\":{\"c1\":0}}"}} diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000002.json b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000002.json new file mode 100644 index 00000000000..6a21f61d9c0 --- /dev/null +++ b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/_delta_log/00000000000000000002.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1725839020852,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":1,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"440"},"engineInfo":"Apache-Spark/3.5.2 Delta-Lake/3.2.0","txnId":"0b39a4b8-878c-4092-a46e-c2bccfa9aaf3"}} +{"add":{"path":"part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet","partitionValues":{},"size":440,"modificationTime":1725839020847,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"c1\":\"A\"},\"maxValues\":{\"c1\":\"A\"},\"nullCount\":{\"c1\":0}}"}} diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet new file mode 100644 index 00000000000..663e9680f6d Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-b839639e-0620-4d9c-baf8-20206fc2b063-c000.snappy.parquet differ diff --git a/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet new file mode 100644 index 00000000000..f5c9efba0da Binary files /dev/null and b/connectors/golden-tables/src/main/resources/golden/data_skipping_basic_stats_collated_predicate/part-00000-fdd2f0a3-75ba-4b6d-85a3-03173742c909-c000.snappy.parquet differ diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollatedPredicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollatedPredicate.java new file mode 100644 index 00000000000..adc46f2568b --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollatedPredicate.java @@ -0,0 +1,22 @@ +package io.delta.kernel.expressions; + +public class CollatedPredicate extends Predicate { + public CollatedPredicate(String name, Expression left, Expression right, CollationIdentifier collationIdentifier) { + super(name, left, right); + this.collationIdentifier = collationIdentifier; + } + + public CollationIdentifier getCollationIdentifier() { + return collationIdentifier; + } + + private final CollationIdentifier collationIdentifier; + + @Override + public String toString() { + if (BINARY_OPERATORS.contains(name)) { + return String.format("(%s %s %s [%s])", children.get(0), name, children.get(1), collationIdentifier); + } + return super.toString(); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollationIdentifier.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollationIdentifier.java new file mode 100644 index 00000000000..98c37e66cd2 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/CollationIdentifier.java @@ -0,0 +1,102 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.expressions; + +import java.util.Optional; + +public class CollationIdentifier { + public static final String PROVIDER_SPARK = "SPARK"; + public static final String PROVIDER_ICU = "ICU"; + + public static final String ICU_COLLATOR_VERSION = "75.1"; + + public static final String DEFAULT_COLLATION_NAME = "UTF8_BINARY"; + + public static final CollationIdentifier DEFAULT_COLLATION_IDENTIFIER = + new CollationIdentifier(PROVIDER_SPARK, DEFAULT_COLLATION_NAME); + + private final String provider; + private final String name; + private final Optional version; + + public CollationIdentifier(String provider, String collationName) { + this.provider = provider.toUpperCase(); + this.name = collationName.toUpperCase(); + this.version = Optional.empty(); + } + + public CollationIdentifier(String provider, String collationName, Optional version) { + this.provider = provider.toUpperCase(); + this.name = collationName.toUpperCase(); + if (version.isPresent()) { + this.version = Optional.of(version.get().toUpperCase()); + } else { + this.version = Optional.empty(); + } + } + + public String toStringWithoutVersion() { + return String.format("%s.%s", provider, name); + } + + public String getProvider() { + return provider; + } + + public String getName() { + return name; + } + + // Returns Optional.empty() + public Optional getVersion() { + return version; + } + + public static CollationIdentifier fromString(String identifier) { + long numDots = identifier.chars().filter(ch -> ch == '.').count(); + if (numDots == 0) { + throw new IllegalArgumentException( + String.format("Invalid collation identifier: %s", identifier)); + } else if (numDots == 1) { + String[] parts = identifier.split("\\."); + return new CollationIdentifier(parts[0], parts[1]); + } else { + String[] parts = identifier.split("\\.", 3); + return new CollationIdentifier(parts[0], parts[1], Optional.of(parts[2])); + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof CollationIdentifier)) { + return false; + } + + CollationIdentifier other = (CollationIdentifier) o; + return this.provider.equals(other.provider) + && this.name.equals(other.name) + && this.version.equals(other.version); + } + + @Override + public String toString() { + if (version.isPresent()) { + return String.format("%s.%s.%s", provider, name, version.get()); + } else { + return String.format("%s.%s", provider, name); + } + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index 2e2966e8614..d806a7b9bf9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -129,6 +129,6 @@ public String toString() { return super.toString(); } - private static final Set BINARY_OPERATORS = + protected static final Set BINARY_OPERATORS = Stream.of("<", "<=", ">", ">=", "=", "AND", "OR").collect(Collectors.toSet()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 96e93cd02e6..40913cb49c3 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -255,9 +255,12 @@ public void close() throws IOException { private Optional getDataSkippingFilter() { return getDataFilters() .flatMap( - dataFilters -> - DataSkippingUtils.constructDataSkippingFilter( - dataFilters, metadata.getDataSchema())); + dataFilters -> { + dataFilters = DataSkippingUtils.omitCollatedPredicateFromDataSkippingFilter(dataFilters); + return DataSkippingUtils.constructDataSkippingFilter( + dataFilters, metadata.getDataSchema()); + } + ); } private CloseableIterator applyDataSkipping( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java index d93bb710286..d00ba159610 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.internal.skipping; +import static io.delta.kernel.internal.DeltaErrors.timestampAfterLatestCommit; import static io.delta.kernel.internal.DeltaErrors.wrapEngineException; import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_ORDINAL; import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_STATS_ORDINAL; @@ -75,6 +76,153 @@ public static Optional constructDataSkippingFilter( return constructDataSkippingFilter(dataFilters, schemaHelper); } + public static Predicate omitCollatedPredicateFromDataSkippingFilter(Predicate dataFilters) { + return omitCollatedPredicateFromDataSkippingFilter(dataFilters, false)._1; + } + + /** + * TODO + * + * @param dataFilters + * @param isNotPropagated + * @return + */ + private static Tuple2 omitCollatedPredicateFromDataSkippingFilter(Predicate dataFilters, + boolean isNotPropagated) { + if (dataFilters instanceof CollatedPredicate) { + return new Tuple2<>(AlwaysTrue.ALWAYS_TRUE, true); + } + + String predicateName = dataFilters.getName().toUpperCase(Locale.ROOT); + if (isNotPropagated && REVERSE_PREDICATE.containsKey(predicateName)) { + predicateName = REVERSE_PREDICATE.get(predicateName); + } + + switch (predicateName) { + case "AND": + Predicate leftPredicate = asPredicate(getLeft(dataFilters)); + Predicate rightPredicate = asPredicate(getRight(dataFilters)); + Tuple2 leftResult = omitCollatedPredicateFromDataSkippingFilter(leftPredicate, isNotPropagated); + Tuple2 rightResult = omitCollatedPredicateFromDataSkippingFilter(rightPredicate, isNotPropagated); + boolean hasCollatedPredicate = leftResult._2 || rightResult._2; + Predicate resultingPredicate = AlwaysTrue.ALWAYS_TRUE; + if (leftResult._1 != AlwaysTrue.ALWAYS_TRUE) { + resultingPredicate = leftResult._1; + } + if (rightResult._1 != AlwaysTrue.ALWAYS_TRUE) { + if (resultingPredicate == AlwaysTrue.ALWAYS_TRUE) { + resultingPredicate = rightResult._1; + } else { + resultingPredicate = new And(leftResult._1, rightResult._1); + } + } + return new Tuple2<>(resultingPredicate, hasCollatedPredicate); + case "OR": + leftPredicate = asPredicate(getLeft(dataFilters)); + rightPredicate = asPredicate(getRight(dataFilters)); + leftResult = omitCollatedPredicateFromDataSkippingFilter(leftPredicate, isNotPropagated); + rightResult = omitCollatedPredicateFromDataSkippingFilter(rightPredicate, isNotPropagated); + hasCollatedPredicate = leftResult._2 || rightResult._2; + if (leftResult._1 == AlwaysTrue.ALWAYS_TRUE || rightResult._1 == AlwaysTrue.ALWAYS_TRUE) { + resultingPredicate = AlwaysTrue.ALWAYS_TRUE; + } else { + resultingPredicate = new Or(leftResult._1, rightResult._1); + } + return new Tuple2<>(resultingPredicate, hasCollatedPredicate); + case "=": + Expression left = getLeft(dataFilters); + Expression right = getRight(dataFilters); + hasCollatedPredicate = false; + if (left instanceof Predicate) { + leftResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) left, isNotPropagated); + left = leftResult._1; + hasCollatedPredicate |= leftResult._2; + } + if (right instanceof Predicate) { + rightResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) right, isNotPropagated); + right = rightResult._1; + hasCollatedPredicate |= rightResult._2; + } + if (hasCollatedPredicate) { + return new Tuple2<>(AlwaysTrue.ALWAYS_TRUE, true); + } + if (isNotPropagated) { + return new Tuple2<>( + new Or(new Predicate("<", left, right), + new Predicate("<", right, left)), + false); + } else { + return new Tuple2<>(new Predicate("=", left, right), false); + } + case ">": + case ">=": + left = getLeft(dataFilters); + right = getRight(dataFilters); + boolean hasCollatedPredicateOnLeft = false; + if (left instanceof Predicate) { + leftResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) left, isNotPropagated); + left = leftResult._1; + hasCollatedPredicateOnLeft = leftResult._2; + } + boolean hasCollatedPredicateOnRight = false; + if (right instanceof Predicate) { + rightResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) right, isNotPropagated); + right = rightResult._1; + hasCollatedPredicateOnRight = rightResult._2; + } + if (hasCollatedPredicateOnRight) { + return new Tuple2<>(AlwaysTrue.ALWAYS_TRUE, true); + } + return new Tuple2<>(new Predicate(predicateName, left, right), + hasCollatedPredicateOnLeft); + case "<": + case "<=": + left = getLeft(dataFilters); + right = getRight(dataFilters); + hasCollatedPredicateOnLeft = false; + if (left instanceof Predicate) { + leftResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) left, isNotPropagated); + left = leftResult._1; + hasCollatedPredicateOnLeft = leftResult._2; + } + hasCollatedPredicateOnRight = false; + if (right instanceof Predicate) { + rightResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) right, isNotPropagated); + right = rightResult._1; + hasCollatedPredicateOnRight = rightResult._2; + } + if (hasCollatedPredicateOnLeft) { + return new Tuple2<>(AlwaysTrue.ALWAYS_TRUE, true); + } + return new Tuple2<>(new Predicate(predicateName, left, right), + hasCollatedPredicateOnRight); + case "IS_NULL": + Expression child = getUnaryChild(dataFilters); + if (!(child instanceof Predicate)) { + return new Tuple2<>(dataFilters, false); + } + Tuple2 childResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) child, isNotPropagated); + if (childResult._2) { + return new Tuple2<>(AlwaysTrue.ALWAYS_TRUE, true); + } else { + return new Tuple2<>(childResult._1, false); + } + case "IS_NOT_NULL": + child = getUnaryChild(dataFilters); + if (!(child instanceof Predicate)) { + return new Tuple2<>(dataFilters, false); + } + childResult = omitCollatedPredicateFromDataSkippingFilter((Predicate) child, isNotPropagated); + return new Tuple2<>(childResult._1, childResult._2); + case "NOT": + Predicate childPredicate = asPredicate(getUnaryChild(dataFilters)); + return omitCollatedPredicateFromDataSkippingFilter(childPredicate, !isNotPropagated); + } + + throw new IllegalArgumentException( + String.format("Invalid predicate name: %s.", predicateName)); + } + ////////////////////////////////////////////////////////////////////////////////// // Helper functions ////////////////////////////////////////////////////////////////////////////////// @@ -334,20 +482,25 @@ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column)); } - private static final Map REVERSE_COMPARATORS = + private static final Map REVERSE_PREDICATE = new HashMap() { { + put("AND", "OR"); + put("OR", "AND"); + put("IS_NULL", "IS_NOT_NULL"); + put("IS_NOT_NULL", "IS_NULL"); + put("NOT", "NOT"); put("=", "="); - put("<", ">"); - put("<=", ">="); - put(">", "<"); - put(">=", "<="); + put("<", ">="); + put("<=", ">"); + put(">", "<="); + put(">=", "<"); } }; private static Predicate reverseComparatorFilter(Predicate predicate) { return new Predicate( - REVERSE_COMPARATORS.get(predicate.getName().toUpperCase(Locale.ROOT)), + REVERSE_PREDICATE.get(predicate.getName().toUpperCase(Locale.ROOT)), getRight(predicate), getLeft(predicate)); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/DataTypeJsonSerDe.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/DataTypeJsonSerDe.java index ab79487f898..c5d86ab70a0 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/DataTypeJsonSerDe.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/DataTypeJsonSerDe.java @@ -89,7 +89,8 @@ public static String serializeDataType(DataType dataType) { */ public static StructType deserializeStructType(String structTypeJson) { try { - DataType parsedType = parseDataType(OBJECT_MAPPER.reader().readTree(structTypeJson)); + DataType parsedType = + parseDataType(OBJECT_MAPPER.reader().readTree(structTypeJson), "", new HashMap<>()); if (parsedType instanceof StructType) { return (StructType) parsedType; } else { @@ -132,21 +133,25 @@ public static StructType deserializeStructType(String structTypeJson) { * } * */ - static DataType parseDataType(JsonNode json) { + static DataType parseDataType( + JsonNode json, String fieldPath, HashMap collationMap) { switch (json.getNodeType()) { case STRING: // simple types are stored as just a string - return nameToType(json.textValue()); + return nameToType(json.textValue(), fieldPath, collationMap); case OBJECT: // complex types (array, map, or struct are stored as JSON objects) String type = getStringField(json, "type"); switch (type) { case "struct": + assertValidTypeForCollations(fieldPath, "struct", collationMap); return parseStructType(json); case "array": - return parseArrayType(json); + assertValidTypeForCollations(fieldPath, "array", collationMap); + return parseArrayType(json, fieldPath, collationMap); case "map": - return parseMapType(json); + assertValidTypeForCollations(fieldPath, "map", collationMap); + return parseMapType(json, fieldPath, collationMap); // No default case here; fall through to the following error when no match } default: @@ -160,12 +165,14 @@ static DataType parseDataType(JsonNode json) { * Parses an array * type */ - private static ArrayType parseArrayType(JsonNode json) { + private static ArrayType parseArrayType( + JsonNode json, String fieldPath, HashMap collationMap) { checkArgument( json.isObject() && json.size() == 3, String.format("Expected JSON object with 3 fields for array data type but got:\n%s", json)); boolean containsNull = getBooleanField(json, "containsNull"); - DataType dataType = parseDataType(getNonNullField(json, "elementType")); + DataType dataType = + parseDataType(getNonNullField(json, "elementType"), fieldPath + "element", collationMap); return new ArrayType(dataType, containsNull); } @@ -173,13 +180,16 @@ private static ArrayType parseArrayType(JsonNode json) { * Parses an map type * */ - private static MapType parseMapType(JsonNode json) { + private static MapType parseMapType( + JsonNode json, String fieldPath, HashMap collationMap) { checkArgument( json.isObject() && json.size() == 4, String.format("Expected JSON object with 4 fields for map data type but got:\n%s", json)); boolean valueContainsNull = getBooleanField(json, "valueContainsNull"); - DataType keyType = parseDataType(getNonNullField(json, "keyType")); - DataType valueType = parseDataType(getNonNullField(json, "valueType")); + DataType keyType = + parseDataType(getNonNullField(json, "keyType"), fieldPath + "key", collationMap); + DataType valueType = + parseDataType(getNonNullField(json, "valueType"), fieldPath + "value", collationMap); return new MapType(keyType, valueType, valueContainsNull); } @@ -211,14 +221,23 @@ private static StructType parseStructType(JsonNode json) { private static StructField parseStructField(JsonNode json) { Preconditions.checkArgument(json.isObject(), "Expected JSON object for struct field"); String name = getStringField(json, "name"); - DataType type = parseDataType(getNonNullField(json, "type")); + FieldMetadata metadata = parseFieldMetadata(json.get("metadata"), false); + DataType type = + parseDataType(getNonNullField(json, "type"), name, getCollationsMap(json.get("metadata"))); boolean nullable = getBooleanField(json, "nullable"); - FieldMetadata metadata = parseFieldMetadata(json.get("metadata")); return new StructField(name, type, nullable, metadata); } /** Parses an {@link FieldMetadata}. */ private static FieldMetadata parseFieldMetadata(JsonNode json) { + return parseFieldMetadata(json, true); + } + + /** + * Parses a {@link FieldMetadata}, optionally including collation metadata, depending on + * `includeCollationMetadata`. + */ + private static FieldMetadata parseFieldMetadata(JsonNode json, boolean includeCollationMetadata) { if (json == null || json.isNull()) { return FieldMetadata.empty(); } @@ -231,6 +250,10 @@ private static FieldMetadata parseFieldMetadata(JsonNode json) { JsonNode value = entry.getValue(); String key = entry.getKey(); + if (!includeCollationMetadata && key.equals(DataType.COLLATIONS_METADATA_KEY)) { + continue; + } + if (value.isNull()) { builder.putNull(key); } else if (value.isIntegralNumber()) { // covers both int and long @@ -298,8 +321,13 @@ private static List buildList(JsonNode json, Function access private static Pattern FIXED_DECIMAL_PATTERN = Pattern.compile(FIXED_DECIMAL_REGEX); /** Parses primitive string type names to a {@link DataType} */ - private static DataType nameToType(String name) { + private static DataType nameToType( + String name, String fieldPath, HashMap collationMap) { if (BasePrimitiveType.isPrimitiveType(name)) { + if (collationMap.containsKey(fieldPath)) { + assertValidTypeForCollations(fieldPath, name, collationMap); + return stringTypeWithCollation(collationMap.get(fieldPath)); + } return BasePrimitiveType.createPrimitive(name); } else if (name.equals("decimal")) { return DecimalType.USER_DEFAULT; @@ -341,6 +369,37 @@ private static String getStringField(JsonNode rootNode, String fieldName) { return node.textValue(); // double check this only works for string values! and isTextual()! } + private static StringType stringTypeWithCollation(String collationName) { + return new StringType(collationName); + } + + private static void assertValidTypeForCollations( + String fieldPath, String fieldType, Map collationMap) { + if (collationMap.containsKey(fieldPath) && !fieldType.equals("string")) { + throw new IllegalArgumentException(String.format("Invalid collation path \"%s\"", fieldPath)); + } + } + + private static HashMap getCollationsMap(JsonNode fieldMetadata) { + if (fieldMetadata == null || !fieldMetadata.has(DataType.COLLATIONS_METADATA_KEY)) { + return new HashMap<>(); + } + HashMap collationsMap = new HashMap<>(); + FieldMetadata collationFieldMetadata = + parseFieldMetadata(fieldMetadata.get(DataType.COLLATIONS_METADATA_KEY)); + for (Map.Entry collationField : + collationFieldMetadata.getEntries().entrySet()) { + String fieldPath = collationField.getKey(); + Object collationName = collationField.getValue(); + if (!(collationName instanceof String)) { + throw new IllegalArgumentException( + String.format("Invalid collation name: %s.", collationName)); + } + collationsMap.put(fieldPath, (String) collationName); + } + return collationsMap; + } + private static boolean getBooleanField(JsonNode rootNode, String fieldName) { JsonNode node = getNonNullField(rootNode, fieldName); Preconditions.checkArgument( @@ -414,7 +473,7 @@ private static void writeStructField(JsonGenerator gen, StructField field) throw writeDataType(gen, field.getDataType()); gen.writeBooleanField("nullable", field.isNullable()); gen.writeFieldName("metadata"); - writeFieldMetadata(gen, field.getMetadata()); + writeFieldMetadata(gen, field.getSerializationMetadata()); gen.writeEndObject(); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java index 38fec7144b5..54aad10cd41 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java @@ -25,6 +25,7 @@ */ @Evolving public abstract class DataType { + public static final String COLLATIONS_METADATA_KEY = "__COLLATIONS"; /** * Are the data types same? The metadata or column names could be different. diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java index 08b5bbd1df7..bad4cf0fb6d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java @@ -16,6 +16,7 @@ package io.delta.kernel.types; import io.delta.kernel.annotation.Evolving; +import io.delta.kernel.expressions.CollationIdentifier; /** * The data type representing {@code string} type values. @@ -24,9 +25,22 @@ */ @Evolving public class StringType extends BasePrimitiveType { - public static final StringType STRING = new StringType(); + public static final StringType STRING = + new StringType(CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER); - private StringType() { + private final CollationIdentifier collationIdentifier; + + public StringType(CollationIdentifier collationIdentifier) { + super("string"); + this.collationIdentifier = collationIdentifier; + } + + public StringType(String collationName) { super("string"); + this.collationIdentifier = CollationIdentifier.fromString(collationName); + } + + public CollationIdentifier getCollationIdentifier() { + return collationIdentifier; } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java index 5e9885c7ba1..6fffcd478a2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java @@ -17,6 +17,10 @@ package io.delta.kernel.types; import io.delta.kernel.annotation.Evolving; +import io.delta.kernel.expressions.CollationIdentifier; +import io.delta.kernel.internal.util.Tuple2; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; /** @@ -102,6 +106,45 @@ public String toString() { "StructField(name=%s,type=%s,nullable=%s,metadata=%s)", name, dataType, nullable, metadata); } + public FieldMetadata getSerializationMetadata() { + List> nestedCollatedFields = getNestedCollatedFields(dataType, name); + if (nestedCollatedFields.isEmpty()) { + return metadata; + } + + FieldMetadata.Builder metadataBuilder = new FieldMetadata.Builder(); + for (Tuple2 nestedField : nestedCollatedFields) { + metadataBuilder.putString(nestedField._1, nestedField._2); + } + return new FieldMetadata.Builder() + .fromMetadata(metadata) + .putFieldMetadata(DataType.COLLATIONS_METADATA_KEY, metadataBuilder.build()) + .build(); + } + + private List> getNestedCollatedFields(DataType parent, String path) { + List> nestedCollatedFields = new ArrayList<>(); + if (parent instanceof StringType) { + StringType stringType = (StringType) parent; + if (!stringType + .getCollationIdentifier() + .equals(CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER)) { + nestedCollatedFields.add( + new Tuple2<>( + path, ((StringType) parent).getCollationIdentifier().toStringWithoutVersion())); + } + } else if (parent instanceof MapType) { + nestedCollatedFields.addAll( + getNestedCollatedFields(((MapType) parent).getKeyType(), path + ".key")); + nestedCollatedFields.addAll( + getNestedCollatedFields(((MapType) parent).getValueType(), path + ".value")); + } else if (parent instanceof ArrayType) { + nestedCollatedFields.addAll( + getNestedCollatedFields(((ArrayType) parent).getElementType(), path + ".element")); + } + return nestedCollatedFields; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala index 6085b81dedb..7e839e12d90 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala @@ -17,8 +17,10 @@ package io.delta.kernel.internal.types import com.fasterxml.jackson.databind.ObjectMapper import io.delta.kernel.types._ +import io.delta.kernel.types.DataType.COLLATIONS_METADATA_KEY import org.scalatest.funsuite.AnyFunSuite +import java.util.HashMap import scala.reflect.ClassTag class DataTypeJsonSerDeSuite extends AnyFunSuite { @@ -28,7 +30,7 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite { private val objectMapper = new ObjectMapper() private def parse(json: String): DataType = { - DataTypeJsonSerDe.parseDataType(objectMapper.readTree(json)) + DataTypeJsonSerDe.parseDataType(objectMapper.readTree(json), "", new HashMap()) } private def serialize(dataType: DataType): String = { @@ -130,6 +132,14 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite { } } + test("parseDataType: types with collated strings") { + SAMPLE_JSON_TO_TYPES_WITH_COLLATION + .foreach { + case(json, structType) => + assert(parse(json) == structType) + } + } + test("serialize/deserialize: special characters for column name") { val json = structTypeJson(Seq( structFieldJson("@_! *c", "\"string\"", true) @@ -245,6 +255,112 @@ object DataTypeJsonSerDeSuite { ) ) + val SAMPLE_JSON_TO_TYPES_WITH_COLLATION = Seq( + ( + structTypeJson(Seq( + structFieldJson("a1", "\"string\"", true, + metadataJson = Some(s"""{"$COLLATIONS_METADATA_KEY" : {"a1" : "ICU.UNICODE"}}""")), + structFieldJson("a2", "\"integer\"", false), + structFieldJson("a3", "\"string\"", false, + metadataJson = Some(s"""{"$COLLATIONS_METADATA_KEY" : {"a3" : "KERNEL.UTF8_LCASE"}}""")), + structFieldJson("a4", "\"string\"", true))), + new StructType() + .add("a1", new StringType("ICU.UNICODE"), true) + .add("a2", IntegerType.INTEGER, false) + .add("a3", new StringType("KERNEL.UTF8_LCASE"), false) + .add("a4", StringType.STRING, true) + ), + ( + structTypeJson(Seq( + structFieldJson("a1", structTypeJson(Seq( + structFieldJson("b1", "\"string\"", true, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"b1" : "ICU.UNICODE"}}""".stripMargin)))), true), + structFieldJson("a2", structTypeJson(Seq( + structFieldJson("b1", arrayTypeJson("\"string\"", false), true, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"b1.element" : "KERNEL.UTF8_LCASE"}}""".stripMargin)), + structFieldJson("b2", mapTypeJson("\"string\"", "\"string\"", true), false, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"b2.key" : "ICU.UNICODE_CI", + | "b2.value" : "KERNEL.UTF8_LCASE"}}""".stripMargin)), + structFieldJson("b3", arrayTypeJson("\"string\"", false), true), + structFieldJson("b4", mapTypeJson("\"string\"", "\"string\"", false), false))), true), + structFieldJson("a3", structTypeJson(Seq( + structFieldJson("b1", "\"string\"", false), + structFieldJson("b2", arrayTypeJson("\"integer\"", false), true))), false, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"b1" : "KERNEL.UTF8_LCASE"}}""".stripMargin)))), + new StructType() + .add("a1", new StructType() + .add("b1", new StringType("ICU.UNICODE")), true) + .add("a2", new StructType() + .add("b1", new ArrayType(new StringType("KERNEL.UTF8_LCASE"), false)) + .add("b2", new MapType( + new StringType("ICU.UNICODE_CI"), new StringType("KERNEL.UTF8_LCASE"), true), false) + .add("b3", new ArrayType(StringType.STRING, false)) + .add("b4", new MapType( + StringType.STRING, StringType.STRING, false), false), true) + .add("a3", new StructType() + .add("b1", StringType.STRING, false) + .add("b2", new ArrayType(IntegerType.INTEGER, false), true), false) + ), + ( + structTypeJson(Seq( + structFieldJson("a1", "\"string\"", true), + structFieldJson("a2", structTypeJson(Seq( + structFieldJson("b1", mapTypeJson( + arrayTypeJson(arrayTypeJson("\"string\"", true), true), + structTypeJson(Seq( + structFieldJson("c1", "\"string\"", false, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"c1" : "KERNEL.UTF8_LCASE"}}""".stripMargin)), + structFieldJson("c2", "\"string\"", true, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {\"c1\" : \"ICU.UNICODE\"}}""".stripMargin)), + structFieldJson("c3", "\"string\"", true))), true), true), + structFieldJson("b2", "\"long\"", true))), true, metadataJson = + Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"b1.key.element.element" : \"KERNEL.UTF8_LCASE\"}}""".stripMargin)), + structFieldJson("a3", arrayTypeJson( + mapTypeJson( + "\"string\"", + structTypeJson(Seq( + structFieldJson("b1", "\"string\"", false, metadataJson = + Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"a3.element.key" : "ICU.UNICODE_CI"}}""".stripMargin)))), + false), false), true, + metadataJson = Some( + s"""{"$COLLATIONS_METADATA_KEY" + | : {"a3.element.key" : "ICU.UNICODE_CI"}}""".stripMargin)))), + new StructType() + .add("a1", StringType.STRING, true) + .add("a2", new StructType() + .add("b1", new MapType( + new ArrayType( + new ArrayType( + new StringType("KERNEL.UTF8_LCASE"), true), true), + new StructType() + .add("c1", new StringType("KERNEL.UTF8_LCASE"), false) + .add("c2", new StringType("ICU.UNICODE"), true) + .add("c3", StringType.STRING), true)) + .add("b2", LongType.LONG), true) + .add("a3", new ArrayType( + new MapType( + new StringType("ICU.UNICODE_CI"), + new StructType() + .add("b1", new StringType("KERNEL.UTF8_LCASE"), false), false), false), true) + ) + ) + def arrayTypeJson(elementJson: String, containsNull: Boolean): String = { s""" |{ diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala index c1c86d4e8ef..c9275fad0c7 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala @@ -16,9 +16,9 @@ package io.delta.kernel.internal.util import scala.collection.JavaConverters._ - -import io.delta.kernel.expressions.Column +import io.delta.kernel.expressions.{AlwaysTrue, And, CollatedPredicate, CollationIdentifier, Column, Literal, Or, Predicate} import io.delta.kernel.internal.skipping.DataSkippingUtils +import io.delta.kernel.internal.skipping.DataSkippingUtils.omitCollatedPredicateFromDataSkippingFilter import io.delta.kernel.types.IntegerType.INTEGER import io.delta.kernel.types.{DataType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite @@ -55,6 +55,381 @@ class DataSkippingUtilsSuite extends AnyFunSuite { s"expected=$expectedSchema\nfound=$prunedSchema") } + test("omitCollatedPredicateFromDataSkippingFilter - AND, OR") { + Seq( + // (starting predicate, resulting predicate) + ( + new And( + new Predicate("=", new Column("c1"), new Column("c2")), + new Predicate(">", Literal.ofString("a"), new Column("c1")) + ), + new And( + new Predicate("=", new Column("c1"), new Column("c2")), + new Predicate(">", Literal.ofString("a"), new Column("c1")) + ) + ), + ( + new And( + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER), + new Predicate("=", new Column("c1"), Literal.ofString("a")) + ), + new Predicate("=", new Column("c1"), Literal.ofString("a")) + ), + ( + new And( + new Predicate("=", new Column("c1"), Literal.ofString("a")), + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + new Predicate("=", new Column("c1"), Literal.ofString("a")) + ), + ( + new And( + new CollatedPredicate("=", new Column("c1"), new Column("c2"), + CollationIdentifier.fromString("ICU.UNICODE")), + new CollatedPredicate(">", new Column("c2"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Or( + new Predicate("=", new Column("c1"), new Column("c2")), + new Predicate(">", Literal.ofString("a"), new Column("c1")) + ), + new Or( + new Predicate("=", new Column("c1"), new Column("c2")), + new Predicate(">", Literal.ofString("a"), new Column("c1")) + ) + ), + ( + new Or( + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER), + new Predicate("=", new Column("c1"), Literal.ofString("a")) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Or( + new Predicate("=", new Column("c1"), Literal.ofString("a")), + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Or( + new CollatedPredicate("=", new Column("c1"), new Column("c2"), + CollationIdentifier.fromString("ICU.UNICODE")), + new CollatedPredicate(">", new Column("c2"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + AlwaysTrue.ALWAYS_TRUE + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + } + + test("omitCollatedPredicateFromDataSkippingFilter - =") { + Seq( + // (starting predicate, resulting predicate) + ( + new Predicate("=", Literal.ofString("a"), new Column("c1")), + new Predicate("=", Literal.ofString("a"), new Column("c1")), + ), + ( + new Predicate("=", + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER), + Literal.ofBoolean(true) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Predicate("=", + Literal.ofBoolean(true), + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Predicate("=", + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER), + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Predicate("=", + new And( + new CollatedPredicate("=", new Column("c1"), Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER), + new Predicate("=", new Column("a"), Literal.ofString("a")) + ), + Literal.ofBoolean(true) + ), + AlwaysTrue.ALWAYS_TRUE + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + } + + test("omitCollatedPredicateFromDataSkippingFilter - >, >=") { + Seq(">", ">=") + .foreach { + name => + Seq( + // (starting predicate, resulting predicate) + ( + new Predicate(name, + Literal.ofString("a"), + new Column("c1") + ), + new Predicate(name, + Literal.ofString("a"), + new Column("c1") + ) + ), + ( + new Predicate(name, + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ), + Literal.ofBoolean(true) + ), + new Predicate(name, + AlwaysTrue.ALWAYS_TRUE, + Literal.ofBoolean(true) + ) + ), + ( + new Predicate(name, + Literal.ofBoolean(true), + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ) + ), + AlwaysTrue.ALWAYS_TRUE + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + } + } + + test("omitCollatedPredicateFromDataSkippingFilter - <, <=") { + Seq("<", "<=") + .foreach { + name => + Seq( + // (starting predicate, resulting predicate) + ( + new Predicate(name, + Literal.ofString("a"), + new Column("c1") + ), + new Predicate(name, + Literal.ofString("a"), + new Column("c1") + ) + ), + ( + new Predicate(name, + Literal.ofBoolean(true), + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ) + ), + new Predicate(name, + Literal.ofBoolean(true), + AlwaysTrue.ALWAYS_TRUE + ) + ), + ( + new Predicate(name, + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ), + Literal.ofBoolean(true) + ), + AlwaysTrue.ALWAYS_TRUE + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + } + } + + test("omitCollatedPredicateFromDataSkippingFilter - NOT") { + Seq( + // (starting predicate, resulting predicate) + ( + new Predicate("NOT", + new And( + new Predicate("=", + new Column("c1"), Literal.ofString("a") + ), + new Predicate(">", + new Column("c2"), new Column("c1") + ) + ) + ), + new Or( + new Or( + new Predicate("<", + new Column("c1"), Literal.ofString("a") + ), + new Predicate("<", + Literal.ofString("a"), new Column("c1")) + ), + new Predicate("<=", + new Column("c2"), new Column("c1") + ) + ) + ), + ( + new Predicate("NOT", + new Or( + new Predicate("=", + new Column("c1"), Literal.ofString("a") + ), + new Predicate(">", + new Column("c2"), new Column("c1") + ) + ) + ), + new And( + new Or( + new Predicate("<", + new Column("c1"), Literal.ofString("a") + ), + new Predicate("<", + Literal.ofString("a"), new Column("c1")) + ), + new Predicate("<=", + new Column("c2"), new Column("c1") + ) + ) + ), + ( + new Predicate("NOT", + new Predicate("=", + Literal.ofString("a"), + new Column("c1")) + ), + new Or( + new Predicate("<", + Literal.ofString("a"), + new Column("c1") + ), + new Predicate("<", + new Column("c1"), + Literal.ofString("a")) + ) + ), + ( + new Predicate("NOT", + new Predicate("=", + new CollatedPredicate("<", + new Column("c1"), + Literal.ofString("a"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ), + Literal.ofBoolean(true) + ) + ), + AlwaysTrue.ALWAYS_TRUE + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + + val reverseSign = Map( + ">" -> "<=", + ">=" -> "<", + "<" -> ">=", + "<=" -> ">" + ) + + // check >, >= + Seq(">", ">=") + .foreach { + name => + Seq( + ( + new Predicate("NOT", + new Predicate(name, + Literal.ofString("a"), + new Column("c1") + ) + ), + new Predicate(reverseSign(name), + Literal.ofString("a"), + new Column("c1") + ) + ), + ( + new Predicate("NOT", + new Predicate(name, + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ), + Literal.ofBoolean(true) + ) + ), + AlwaysTrue.ALWAYS_TRUE + ), + ( + new Predicate("NOT", + new Predicate(name, + Literal.ofBoolean(true), + new CollatedPredicate("=", + Literal.ofString("a"), + new Column("c1"), + CollationIdentifier.DEFAULT_COLLATION_IDENTIFIER + ) + ) + ), + new Predicate(reverseSign(name), + Literal.ofBoolean(true), + AlwaysTrue.ALWAYS_TRUE + ) + ) + ).foreach { + case(startingPredicate, resultingPredicate) => + assert(omitCollatedPredicateFromDataSkippingFilter(startingPredicate).toString + === resultingPredicate.toString) + } + } + } + test("pruneStatsSchema - multiple basic cases one level of nesting") { val nestedField = new StructField( "nested", diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructTypeSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructTypeSuite.scala new file mode 100644 index 00000000000..a81df626950 --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructTypeSuite.scala @@ -0,0 +1,142 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.types + +import org.scalatest.funsuite.AnyFunSuite + +class StructTypeSuite extends AnyFunSuite { + + test("check toJson method with primitive types") { + val structType = new StructType() + .add("c1", BinaryType.BINARY, true) + .add("c2", BooleanType.BOOLEAN, false) + .add("c3", ByteType.BYTE, false) + .add("c4", DateType.DATE, true) + .add("c5", DecimalType.USER_DEFAULT, false) + .add("c6", DoubleType.DOUBLE, false) + .add("c7", FloatType.FLOAT, false) + .add("c8", IntegerType.INTEGER, true) + .add("c9", LongType.LONG, true) + .add("c10", ShortType.SHORT, true) + .add("c11", StringType.STRING, true) + .add("c12", TimestampNTZType.TIMESTAMP_NTZ, false) + .add("c13", TimestampType.TIMESTAMP, false) + .add("c14", VariantType.VARIANT, false) + val toJson = + """{"type":"struct","fields":[{"name":"c1","type" + |:"binary","nullable":true,"metadata":{}},{"name" + |:"c2","type":"boolean","nullable":false,"metadata" + |:{}},{"name":"c3","type":"byte","nullable":false, + |"metadata":{}},{"name":"c4","type":"date","nullable" + |:true,"metadata":{}},{"name":"c5","type":"decimal(10,0)" + |,"nullable":false,"metadata":{}},{"name":"c6","type": + |"double","nullable":false,"metadata":{}},{"name":"c7", + |"type":"float","nullable":false,"metadata":{}},{"name": + |"c8","type":"integer","nullable":true,"metadata":{}}, + |{"name":"c9","type":"long","nullable":true,"metadata":{}} + |,{"name":"c10","type":"short","nullable":true,"metadata" + |:{}},{"name":"c11","type":"string","nullable":true,"metadata": + |{}},{"name":"c12","type":"timestamp_ntz","nullable":false, + |"metadata":{}},{"name":"c13","type":"timestamp","nullable":false, + |"metadata":{}},{"name":"c14","type":"variant","nullable":false, + |"metadata":{}}]}""".stripMargin.replaceAll("\n", "") + + assert(structType.toJson == toJson) + } + + test("check toJson method with complex types") { + val structType = new StructType() + .add("a1", StringType.STRING, true) + .add("a2", new StructType() + .add("b1", new MapType( + new ArrayType( + new ArrayType(StringType.STRING, true), true), + new StructType() + .add("c1", StringType.STRING, false) + .add("c2", StringType.STRING, true), true)) + .add("b2", LongType.LONG), true) + .add("a3", new ArrayType( + new MapType( + StringType.STRING, + new StructType() + .add("b1", DateType.DATE, false), false), false), true) + val toJson = + """{"type":"struct","fields":[{"name":"a1","type":"string", + |"nullable":true,"metadata":{}},{"name":"a2","type":{"type" + |:"struct","fields":[{"name":"b1","type":{"type":"map", + |"keyType":{"type":"array","elementType":{"type":"array", + |"elementType":"string","containsNull":true},"containsNull": + |true},"valueType":{"type":"struct","fields":[{"name":"c1", + |"type":"string","nullable":false,"metadata":{}},{"name": + |"c2","type":"string","nullable":true,"metadata":{}}]}, + |"valueContainsNull":true},"nullable":true,"metadata":{}}, + |{"name":"b2","type":"long","nullable":true,"metadata":{}}]}, + |"nullable":true,"metadata":{}},{"name":"a3","type":{"type": + |"array","elementType":{"type":"map","keyType":"string", + |"valueType":{"type":"struct","fields":[{"name":"b1","type": + |"date","nullable":false,"metadata":{}}]},"valueContainsNull": + |false},"containsNull":false},"nullable":true,"metadata":{}}]}""" + .stripMargin.replaceAll("\n", "") + + assert(structType.toJson == toJson) + } + + test("check toJson method with complex types and collated strings") { + val structType = new StructType() + .add("a1", StringType.STRING, true) + .add("a2", new StructType() + .add("b1", new MapType( + new ArrayType( + new ArrayType( + new StringType("KERNEL.UTF8_LCASE"), true), true), + new StructType() + .add("c1", new StringType("KERNEL.UTF8_LCASE"), false) + .add("c2", new StringType("ICU.UNICODE"), true) + .add("c3", StringType.STRING), true)) + .add("b2", LongType.LONG), true) + .add("a3", new ArrayType( + new MapType( + new StringType("ICU.UNICODE_CI"), + new StructType() + .add("b1", new StringType("KERNEL.UTF8_LCASE"), false), false), false), true) + val toJson = + """{"type":"struct","fields":[{"name":"a1","type": + |"string","nullable":true,"metadata":{}},{"name": + |"a2","type":{"type":"struct","fields":[{"name": + |"b1","type":{"type":"map","keyType":{"type":"array", + |"elementType":{"type":"array","elementType":"string", + |"containsNull":true},"containsNull":true},"valueType": + |{"type":"struct","fields":[{"name":"c1","type":"string", + |"nullable":false,"metadata":{"__COLLATIONS": + |{"c1":"KERNEL.UTF8_LCASE"}}},{"name":"c2","type":"string", + |"nullable":true,"metadata":{"__COLLATIONS": + |{"c2":"ICU.UNICODE"}}},{"name":"c3","type":"string", + |"nullable":true,"metadata":{}}]},"valueContainsNull": + |true},"nullable":true,"metadata":{"__COLLATIONS": + |{"b1.key.element.element":"KERNEL.UTF8_LCASE"}}}, + |{"name":"b2","type":"long","nullable":true,"metadata": + |{}}]},"nullable":true,"metadata":{}},{"name":"a3", + |"type":{"type":"array","elementType":{"type":"map", + |"keyType":"string","valueType":{"type":"struct","fields": + |[{"name":"b1","type":"string","nullable":false,"metadata": + |{"__COLLATIONS":{"b1":"KERNEL.UTF8_LCASE"}}}]},"valueContainsNull": + |false},"containsNull":false},"nullable":true,"metadata": + |{"__COLLATIONS":{"a3.element.key":"ICU.UNICODE_CI"}}}]}""" + .stripMargin.replaceAll("\n", "") + + assert(structType.toJson == toJson) + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/CollationFactory.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/CollationFactory.java new file mode 100644 index 00000000000..f2daf276d38 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/CollationFactory.java @@ -0,0 +1,254 @@ +package io.delta.kernel.defaults.internal.expressions; + +import com.ibm.icu.text.Collator; +import com.ibm.icu.util.ULocale; +import io.delta.kernel.expressions.CollationIdentifier; +import io.delta.kernel.internal.util.Tuple2; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; + +import static io.delta.kernel.defaults.internal.expressions.CollationFactory.Collation.DEFAULT_COLLATION; +import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.STRING_COMPARATOR; +import static io.delta.kernel.expressions.CollationIdentifier.*; + +public class CollationFactory { + private static final Map collationMap = new ConcurrentHashMap<>(); + + public static Collation fetchCollation(String collationName) { + if (collationName.startsWith("UTF8")) { + return fetchCollation(new CollationIdentifier(PROVIDER_SPARK, collationName)); + } else { + return fetchCollation(new CollationIdentifier(PROVIDER_ICU, collationName)); + } + } + + public static Collation fetchCollation(CollationIdentifier collationIdentifier) { + if (collationIdentifier.equals(DEFAULT_COLLATION_IDENTIFIER)) { + return DEFAULT_COLLATION; + } else if (collationMap.containsKey(collationIdentifier)) { + return collationMap.get(collationIdentifier); + } else { + Collation collation; + if (collationIdentifier.getProvider().equals(PROVIDER_SPARK)) { + collation = UTF8CollationFactory.fetchCollation(collationIdentifier); + } else if (collationIdentifier.getProvider().equals(PROVIDER_ICU)) { + collation = ICUCollationFactory.fetchCollation(collationIdentifier); + } else { + throw new IllegalArgumentException(String.format("Invalid collation provider: %s.", collationIdentifier.getProvider())); + } + collationMap.put(collationIdentifier, collation); + return collation; + } + } + + private static class UTF8CollationFactory { + private static Collation fetchCollation(CollationIdentifier collationIdentifier) { + if (collationIdentifier.equals(DEFAULT_COLLATION_IDENTIFIER)) { + return DEFAULT_COLLATION; + } else { + // TODO UTF8_LCASE + throw new IllegalArgumentException(String.format("Invalid collation identifier: %s.", collationIdentifier)); + } + } + } + + private static class ICUCollationFactory { + /** + * Bit 17 in collation ID having value 0 for case-sensitive and 1 for case-insensitive + * collation. + */ + private enum CaseSensitivity { + CS, CI + } + + /** + * Bit 16 in collation ID having value 0 for accent-sensitive and 1 for accent-insensitive + * collation. + */ + private enum AccentSensitivity { + AS, AI + } + + /** + * Mapping of locale names to corresponding `ULocale` instance. + */ + private static final Map ICULocaleMap = new HashMap<>(); + + private static Collation fetchCollation(CollationIdentifier collationIdentifier) { + if (collationIdentifier.getVersion().isPresent() && + !collationIdentifier.getVersion().get().equals(ICU_COLLATOR_VERSION)) { + throw new IllegalArgumentException(String.format("Invalid collation version: %s.", collationIdentifier.getVersion().get())); + } + + String locale = getICULocale(collationIdentifier); + + Tuple2 caseAndAccentSensitivity = getICUCaseAndAccentSensitivity(collationIdentifier, locale); + CaseSensitivity caseSensitivity = caseAndAccentSensitivity._1; + AccentSensitivity accentSensitivity = caseAndAccentSensitivity._2; + + Collator collator = getICUCollator(locale, caseSensitivity, accentSensitivity); + + return new Collation( + collationIdentifier, + collator::compare); + } + + private static String collationName(String locale, CaseSensitivity caseSensitivity, AccentSensitivity accentSensitivity) { + StringBuilder builder = new StringBuilder(); + builder.append(locale); + if (caseSensitivity != CaseSensitivity.CS) { + builder.append('_'); + builder.append(caseSensitivity.toString()); + } + if (accentSensitivity != AccentSensitivity.AS) { + builder.append('_'); + builder.append(accentSensitivity.toString()); + } + return builder.toString(); + } + + private static String getICULocale(CollationIdentifier collationIdentifier) { + String collationName = collationIdentifier.getName(); + String collationNameUpperCase = collationIdentifier.getName().toUpperCase(); + + // Search for the longest locale match because specifiers are designed to be different from + // script tag and country code, meaning the only valid locale name match can be the longest + // one. + int lastPos = -1; + for (int i = 1; i <= collationNameUpperCase.length(); i++) { + String localeName = collationNameUpperCase.substring(0, i); + if (ICULocaleMap.containsKey(localeName)) { + lastPos = i; + } + } + if (lastPos == -1) { + throw new IllegalArgumentException(String.format("Invalid collation name: %s.", collationIdentifier.toStringWithoutVersion())); + } else { + return collationName.substring(0, lastPos); + } + } + + private static Tuple2 getICUCaseAndAccentSensitivity(CollationIdentifier collationIdentifier, String locale) { + String collationName = collationIdentifier.getName(); + + // Try all combinations of AS/AI and CS/CI. + CaseSensitivity caseSensitivity; + AccentSensitivity accentSensitivity; + if (collationName.equals(locale) || + collationName.equals(locale + "_AS") || + collationName.equals(locale + "_CS") || + collationName.equals(locale + "_AS_CS") || + collationName.equals(locale + "_CS_AS") + ) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_CI") || + collationName.equals(locale + "_AS_CI") || + collationName.equals(locale + "_CI_AS")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_AI") || + collationName.equals(locale + "_CS_AI") || + collationName.equals(locale + "_AI_CS")) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AI; + } else if (collationName.equals(locale + "_AI_CI") || + collationName.equals(locale + "_CI_AI")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AI; + } else { + throw new IllegalArgumentException(String.format("Invalid collation name: %s.", collationIdentifier.toStringWithoutVersion())); + } + + return new Tuple2<>(caseSensitivity, accentSensitivity); + } + + private static Collator getICUCollator(String locale, CaseSensitivity caseSensitivity, AccentSensitivity accentSensitivity) { + ULocale.Builder builder = new ULocale.Builder(); + builder.setLocale(ICULocaleMap.get(locale)); + // Compute unicode locale keyword for all combinations of case/accent sensitivity. + if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level3"); + } else if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AI) { + builder + .setUnicodeLocaleKeyword("ks", "level1") + .setUnicodeLocaleKeyword("kc", "true"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level2"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AI) { + builder.setUnicodeLocaleKeyword("ks", "level1"); + } + ULocale resultLocale = builder.build(); + Collator collator = Collator.getInstance(resultLocale); + // Freeze ICU collator to ensure thread safety. + collator.freeze(); + return collator; + } + + static { + ICULocaleMap.put("UNICODE", ULocale.ROOT); + // ICU-implemented `ULocale`s which have corresponding `Collator` installed. + ULocale[] locales = Collator.getAvailableULocales(); + // Build locale names in format: language["_" optional script]["_" optional country code]. + // Examples: en, en_USA, sr_Cyrl_SRB + for (ULocale locale : locales) { + // Skip variants. + if (locale.getVariant().isEmpty()) { + String language = locale.getLanguage(); + // Require non-empty language as first component of locale name. + assert (!language.isEmpty()); + StringBuilder builder = new StringBuilder(language); + // Script tag. + String script = locale.getScript(); + if (!script.isEmpty()) { + builder.append('_'); + builder.append(script); + } + // 3-letter country code. + String country = locale.getISO3Country(); + if (!country.isEmpty()) { + builder.append('_'); + builder.append(country); + } + String localeName = builder.toString(); + // Verify locale names are unique. + assert (!ICULocaleMap.containsKey(localeName.toUpperCase())); + ICULocaleMap.put(localeName.toUpperCase(), locale); + } + } + } + } + + public static class Collation { + + public static Collation DEFAULT_COLLATION = new Collation(DEFAULT_COLLATION_IDENTIFIER, STRING_COMPARATOR); + + public Collation(CollationIdentifier collationIdentifier, Comparator collationComparator) { + this.identifier = collationIdentifier; + this.comparator = collationComparator; + this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; + } + + public CollationIdentifier getCollationIdentifier() { + return identifier; + } + + public Comparator getComparator() { + return comparator; + } + + public BiFunction getEqualsFunction() { + return equalsFunction; + } + + private final CollationIdentifier identifier; + private final Comparator comparator; + private final BiFunction equalsFunction; + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 2fd009b79cb..7617ad22cd1 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -329,7 +329,11 @@ private Expression transformBinaryComparator(Predicate predicate) { throw unsupportedExpressionException(predicate, msg); } } - return new Predicate(predicate.getName(), left, right); + if (predicate instanceof CollatedPredicate) { + return new CollatedPredicate(predicate.getName(), left, right, ((CollatedPredicate) predicate).getCollationIdentifier()); + } else { + return new Predicate(predicate.getName(), left, right); + } } } @@ -427,35 +431,19 @@ ColumnVector visitComparator(Predicate predicate) { PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate); switch (predicate.getName()) { case "=": - return comparatorVector( - argResults.leftResult, - argResults.rightResult, - (compareResult) -> (compareResult == 0)); case ">": - return comparatorVector( - argResults.leftResult, - argResults.rightResult, - (compareResult) -> (compareResult > 0)); case ">=": - return comparatorVector( - argResults.leftResult, - argResults.rightResult, - (compareResult) -> (compareResult >= 0)); case "<": - return comparatorVector( - argResults.leftResult, - argResults.rightResult, - (compareResult) -> (compareResult < 0)); case "<=": return comparatorVector( argResults.leftResult, argResults.rightResult, - (compareResult) -> (compareResult <= 0)); + predicate); case "IS NOT DISTINCT FROM": return nullSafeComparatorVector( argResults.leftResult, argResults.rightResult, - (compareResult) -> (compareResult == 0)); + predicate); default: // We should never reach this based on the ExpressionVisitor throw new IllegalStateException( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index bf7c98fdc84..76b3247e946 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -20,7 +20,9 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.expressions.CollatedPredicate; import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.internal.util.Utils; import io.delta.kernel.types.*; import java.math.BigDecimal; @@ -113,19 +115,41 @@ public boolean getBoolean(int rowId) { }; } + /** + * Matches predicate name to comparison function. + */ + static IntPredicate getBooleanPredicate(Predicate predicate) { + switch(predicate.getName()) { + case "=": + case "IS NOT DISTINCT FROM": + return x -> x == 0; + case "<": + return x -> x < 0; + case "<=": + return x -> x <= 0; + case ">": + return x -> x > 0; + case ">=": + return x -> x >= 0; + default: + throw new IllegalArgumentException(String.format("Unsupported predicate '%s'", predicate.getName())); + } + } + /** * Utility method for getting value comparator * * @param left * @param right - * @param booleanComparator + * @param predicate * @return */ static IntPredicate getComparator( - ColumnVector left, ColumnVector right, IntPredicate booleanComparator) { + ColumnVector left, ColumnVector right, Predicate predicate) { checkArgument( left.getSize() == right.getSize(), "Left and right operand have different vector sizes."); + IntPredicate booleanComparator = getBooleanPredicate(predicate); DataType dataType = left.getDataType(); IntPredicate vectorValueComparator; if (dataType instanceof BooleanType) { @@ -162,10 +186,18 @@ static IntPredicate getComparator( booleanComparator.test( BIGDECIMAL_COMPARATOR.compare(left.getDecimal(rowId), right.getDecimal(rowId))); } else if (dataType instanceof StringType) { - vectorValueComparator = - rowId -> - booleanComparator.test( - STRING_COMPARATOR.compare(left.getString(rowId), right.getString(rowId))); + if (predicate instanceof CollatedPredicate) { + vectorValueComparator = + rowId -> + booleanComparator.test( + CollationFactory.fetchCollation(((CollatedPredicate) predicate).getCollationIdentifier()) + .getComparator().compare(left.getString(rowId), right.getString(rowId))); + } else { + vectorValueComparator = + rowId -> + booleanComparator.test( + STRING_COMPARATOR.compare(left.getString(rowId), right.getString(rowId))); + } } else if (dataType instanceof BinaryType) { vectorValueComparator = rowId -> @@ -185,8 +217,8 @@ static IntPredicate getComparator( *

Only primitive data types are supported. */ static ColumnVector comparatorVector( - ColumnVector left, ColumnVector right, IntPredicate booleanComparator) { - IntPredicate vectorValueComparator = getComparator(left, right, booleanComparator); + ColumnVector left, ColumnVector right, Predicate predicate) { + IntPredicate vectorValueComparator = getComparator(left, right, predicate); return new ColumnVector() { @@ -227,8 +259,8 @@ public boolean getBoolean(int rowId) { *

Only primitive data types are supported. */ static ColumnVector nullSafeComparatorVector( - ColumnVector left, ColumnVector right, IntPredicate booleanComparator) { - IntPredicate vectorValueComparator = getComparator(left, right, booleanComparator); + ColumnVector left, ColumnVector right, Predicate predicate) { + IntPredicate vectorValueComparator = getComparator(left, right, predicate); return new ColumnVector() { @Override public DataType getDataType() { diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index cf67d87dd23..0cb29ced8a2 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -98,7 +98,11 @@ private R visitScalarExpression(ScalarExpression expression) { case ">": case ">=": case "IS NOT DISTINCT FROM": - return visitComparator(new Predicate(name, children)); + if (expression instanceof Predicate) { + return visitComparator((Predicate) expression); + } else { + return visitComparator(new Predicate(name, children)); + } case "ELEMENT_AT": return visitElementAt(expression); case "NOT": diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala index 2ce7d08d94e..2bd261fdd7e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala @@ -18,17 +18,19 @@ package io.delta.kernel.defaults import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel.Operation.{CREATE_TABLE, WRITE} import io.delta.kernel._ -import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch, Row} +import io.delta.kernel.data.{ColumnVector, ColumnarBatch, FilteredColumnarBatch, Row} +import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch +import io.delta.kernel.defaults.internal.expressions.DefaultExpressionEvaluator import io.delta.kernel.defaults.internal.parquet.ParquetSuiteBase import io.delta.kernel.defaults.utils.TestRow import io.delta.kernel.engine.Engine import io.delta.kernel.exceptions._ -import io.delta.kernel.expressions.Literal +import io.delta.kernel.expressions.{CollatedPredicate, CollationIdentifier, Column, Expression, Literal, ScalarExpression} import io.delta.kernel.expressions.Literal._ import io.delta.kernel.internal.checkpoints.CheckpointerSuite.selectSingleElement import io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames import io.delta.kernel.internal.{SnapshotImpl, TableConfig} -import io.delta.kernel.internal.util.ColumnMapping +import io.delta.kernel.internal.util.{ColumnMapping, VectorUtils} import io.delta.kernel.types.DateType.DATE import io.delta.kernel.types.DoubleType.DOUBLE import io.delta.kernel.types.IntegerType.INTEGER diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/CollationFactorySuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/CollationFactorySuite.scala new file mode 100644 index 00000000000..f1aecb74e7e --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/CollationFactorySuite.scala @@ -0,0 +1,77 @@ +package io.delta.kernel.defaults.internal.expressions + +import io.delta.kernel.defaults.internal.expressions.CollationFactory.fetchCollation +import io.delta.kernel.expressions.CollationIdentifier +import org.scalatest.funsuite.AnyFunSuite + +import java.util.function.BiFunction + +class CollationFactorySuite extends AnyFunSuite { + test("basic ICU collator checks") { + // scalastyle:off nonascii + Seq( + CollationTestCase("UNICODE_CI", "a", "A", true), + CollationTestCase("UNICODE_CI", "a", "å", false), + CollationTestCase("UNICODE_CI", "a", "Å", false), + CollationTestCase("UNICODE_AI", "a", "A", false), + CollationTestCase("UNICODE_AI", "a", "å", true), + CollationTestCase("UNICODE_AI", "a", "Å", false), + CollationTestCase("UNICODE_CI_AI", "a", "A", true), + CollationTestCase("UNICODE_CI_AI", "a", "å", true), + CollationTestCase("UNICODE_CI_AI", "a", "Å", true) + ).foreach { + testCase => + assert(testCase.equalsFunction(testCase.s1, testCase.s2) == testCase.expectedResult) + } + + Seq( + CollationTestCase("en", "a", "A", -1), + CollationTestCase("en_CI", "a", "A", 0), + CollationTestCase("en_AI", "a", "å", 0), + CollationTestCase("sv", "Kypper", "Köpfe", -1), + CollationTestCase("de", "Kypper", "Köpfe", 1) + ).foreach { + testCase => + assert( + Integer.signum(testCase.compare(testCase.s1, testCase.s2)) == testCase.expectedResult) + } + // scalastyle:on nonascii + } + + test("collation aware compare") { + Seq( + CollationTestCase("UTF8_BINARY", "aaa", "aaa", 0), + CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), + CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), + CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UNICODE", "aaa", "aaa", 0), + CollationTestCase("UNICODE", "aaa", "AAA", -1), + CollationTestCase("UNICODE", "aaa", "bbb", -1), + CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), + CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1) + ).foreach { + testCase => + assert( + Integer.signum(testCase.compare(testCase.s1, testCase.s2)) == testCase.expectedResult) + } + } + + case class CollationTestCase[R](collationName: String, + s1: String, + s2: String, + expectedResult: R) { + val provider = + if (collationName.startsWith("UTF8")) { + CollationIdentifier.PROVIDER_SPARK + } else { + CollationIdentifier.PROVIDER_ICU + } + val fullCollationName = String.format("%s.%s", provider, collationName) + val collationIdentifier = CollationIdentifier.fromString(fullCollationName) + val collation = fetchCollation(collationIdentifier) + val equalsFunction: BiFunction[String, String, java.lang.Boolean] = collation.getEqualsFunction + val compare: (String, String) => Int = collation.getComparator.compare + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index bb1cc6b6809..8b750c6316f 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -773,6 +773,111 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa ofNull(DoubleType.DOUBLE) ) + test("evaluate expression: compare collated strings") { + // scalastyle:off nonascii + // TODO: add UTF8_LCASE when supported + Seq( + "UTF8_BINARY", + "UNICODE", + "UNICODE_CI" + ).foreach { + collationName => + // Empty strings. + assertCompare("", "", collationName, 0) + assertCompare("a", "", collationName, 1) + assertCompare("", "a", collationName, -1) + // Basic tests. + assertCompare("a", "a", collationName, 0) + assertCompare("a", "b", collationName, -1) + assertCompare("b", "a", collationName, 1) + assertCompare("A", "A", collationName, 0) + assertCompare("A", "B", collationName, -1) + assertCompare("B", "A", collationName, 1) + assertCompare("aa", "a", collationName, 1) + assertCompare("b", "bb", collationName, -1) + assertCompare("abc", "a", collationName, 1) + assertCompare("abc", "b", collationName, -1) + assertCompare("abc", "ab", collationName, 1) + assertCompare("abc", "abc", collationName, 0) + assertCompare("aaaa", "aaa", collationName, 1) + assertCompare("hello", "world", collationName, -1) + assertCompare("Spark", "Spark", collationName, 0) + assertCompare("ü", "ü", collationName, 0) + assertCompare("ü", "", collationName, 1) + assertCompare("", "ü", collationName, -1) + assertCompare("äü", "äü", collationName, 0) + assertCompare("äxx", "äx", collationName, 1) + assertCompare("a", "ä", collationName, -1) + } + + // Advanced tests. + assertCompare("äü", "bü", "UTF8_BINARY", 1) + assertCompare("bxx", "bü", "UTF8_BINARY", -1) + assertCompare("äü", "bü", "UNICODE", -1) + assertCompare("bxx", "bü", "UNICODE", 1) + assertCompare("äü", "bü", "UNICODE_CI", -1) + assertCompare("bxx", "bü", "UNICODE_CI", 1) + // Case variation. + assertCompare("AbCd", "aBcD", "UTF8_BINARY", -1) + assertCompare("AbcD", "aBCd", "UNICODE", 1) + assertCompare("abcd", "ABCD", "UNICODE_CI", 0) + // Accent variation. + assertCompare("aBćD", "ABĆD", "UTF8_BINARY", 1) + assertCompare("äBCd", "ÄBCD", "UNICODE", -1) + assertCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0) + // One-to-many case mapping (e.g. Turkish dotted I). + assertCompare("i\u0307", "İ", "UTF8_BINARY", -1) + assertCompare("İ", "i\u0307", "UTF8_BINARY", 1) + assertCompare("i\u0307", "İ", "UNICODE", -1) + assertCompare("İ", "i\u0307", "UNICODE", 1) + assertCompare("i\u0307", "İ", "UNICODE_CI", 0) + assertCompare("İ", "i\u0307", "UNICODE_CI", 0) + assertCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0) + assertCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0) + assertCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0) + assertCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0) + // Conditional case mapping (e.g. Greek sigmas). + assertCompare("ς", "σ", "UTF8_BINARY", -1) + assertCompare("ς", "Σ", "UTF8_BINARY", 1) + assertCompare("σ", "Σ", "UTF8_BINARY", 1) + assertCompare("ς", "σ", "UNICODE", 1) + assertCompare("ς", "Σ", "UNICODE", 1) + assertCompare("σ", "Σ", "UNICODE", -1) + assertCompare("ς", "σ", "UNICODE_CI", 0) + assertCompare("ς", "Σ", "UNICODE_CI", 0) + assertCompare("σ", "Σ", "UNICODE_CI", 0) + // Surrogate pairs. + assertCompare("a🙃b🙃c", "aaaaa", "UTF8_BINARY", 1) + assertCompare("a🙃b🙃c", "aaaaa", "UNICODE", -1) // != UTF8_BINARY + assertCompare("a🙃b🙃c", "aaaaa", "UNICODE_CI", -1) // != UTF8_LCASE + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UTF8_BINARY", 0) + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UNICODE", 0) + assertCompare("a🙃b🙃c", "a🙃b🙃c", "UNICODE_CI", 0) + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UTF8_BINARY", -1) + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UNICODE", -1) + assertCompare("a🙃b🙃c", "a🙃b🙃d", "UNICODE_CI", -1) + // scalastyle:on nonascii + + // Maximum code point. + val maxCodePoint = Character.MAX_CODE_POINT + val maxCodePointStr = new String(Character.toChars(maxCodePoint)) + var i = 0 + while (i < maxCodePoint && Character.isValidCodePoint(i)) { + assertCompare(new String(Character.toChars(i)), maxCodePointStr, "UTF8_BINARY", -1) + + i += 1 + } + // Minimum code point.// Minimum code point. + val minCodePoint = Character.MIN_CODE_POINT + val minCodePointStr = new String(Character.toChars(minCodePoint)) + i = minCodePoint + 1 + while (i <= maxCodePoint && Character.isValidCodePoint(i)) { + assertCompare(new String(Character.toChars(i)), minCodePointStr, "UTF8_BINARY", 1) + + i += 1 + } + } + test("evaluate expression: comparators `byte` with other implicit types") { // Mapping of comparator to expected results for: // (byte, short), (byte, int), (byte, long), (byte, float), (byte, double) @@ -1021,9 +1126,42 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa new DefaultExpressionEvaluator(inputSchema, expression, outputType) } + private def assertCompare(s1: String, s2: String, collationName: String, expResult: Int): Unit = { + val l1 = ofString(s1) + val l2 = ofString(s2) + if (expResult == -1) { + testComparator("<", l1, l2, true, Some(collationName)) + } else if (expResult == 1) { + testComparator("<", l2, l1, true, Some(collationName)) + } else if (expResult == 0) { + testComparator("=", l1, l2, true, Some(collationName)) + } else { + throw new IllegalArgumentException( + String.format("Invalid expected result: %s.", expResult.toString)) + } + } + private def testComparator( - comparator: String, left: Expression, right: Expression, expResult: BooleanJ): Unit = { - val expression = new Predicate(comparator, left, right) + comparator: String, + left: Expression, + right: Expression, + expResult: BooleanJ, + collationName: Option[String] = Option.empty): Unit = { + val expression = + if (collationName.isEmpty) { + new Predicate(comparator, left, right) + } else { + val collationIdentifier = + if (collationName.get.startsWith("UTF8")) { + CollationIdentifier.fromString( + String.format("%s.%s", CollationIdentifier.PROVIDER_SPARK, collationName.get)) + } else { + CollationIdentifier.fromString( + String.format("%s.%s", CollationIdentifier.PROVIDER_ICU, collationName.get)) + } + new CollatedPredicate( + comparator, left, right, collationIdentifier) + } val batch = zeroColumnBatch(rowCount = 1) val outputVector = evaluator(batch.getSchema, expression, BooleanType.BOOLEAN).eval(batch) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultPredicateEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultPredicateEvaluatorSuite.scala index a820c420bb7..82a625262ac 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultPredicateEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultPredicateEvaluatorSuite.scala @@ -18,10 +18,9 @@ package io.delta.kernel.defaults.internal.expressions import java.lang.{Boolean => BooleanJ} import java.util.Optional import java.util.Optional.empty - -import io.delta.kernel.data.{ColumnarBatch, ColumnVector} +import io.delta.kernel.data.{ColumnVector, ColumnarBatch} import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.expressions.{Column, Literal} +import io.delta.kernel.expressions.{CollatedPredicate, CollationIdentifier, Column, Literal} import io.delta.kernel.types.{BooleanType, StructType} import org.scalatest.funsuite.AnyFunSuite