diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala index 3447c91d861..e50654255ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala @@ -37,6 +37,8 @@ case class GpuJsonToStructs( extends GpuUnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + lazy val emptyRowStr = constructEmptyRow(schema) + private def constructEmptyRow(schema: DataType): String = { schema match { case struct: StructType if struct.fields.nonEmpty => @@ -45,70 +47,49 @@ case class GpuJsonToStructs( throw new IllegalArgumentException(s"$other is not supported as a top level type") } } - lazy val emptyRowStr = constructEmptyRow(schema) - private def cleanAndConcat(input: cudf.ColumnVector): (cudf.ColumnVector, cudf.ColumnVector) = { - withResource(cudf.Scalar.fromString(emptyRowStr)) { emptyRow => + val stripped = if (input.getData == null) { + input.incRefCount + } else { + withResource(cudf.Scalar.fromString(" ")) { space => + input.strip(space) + } + } - val stripped = if (input.getData == null) { - input.incRefCount - } else { - withResource(cudf.Scalar.fromString(" ")) { space => - input.strip(space) + withResource(stripped) { stripped => + val isEmpty = withResource(stripped.getCharLengths) { lengths => + withResource(cudf.Scalar.fromInt(0)) { zero => + lengths.lessOrEqualTo(zero) } } - - withResource(stripped) { stripped => - val isNullOrEmptyInput = withResource(input.isNull) { isNull => - val isEmpty = withResource(stripped.getCharLengths) { lengths => - withResource(cudf.Scalar.fromInt(0)) { zero => - lengths.lessOrEqualTo(zero) - } - } - withResource(isEmpty) { isEmpty => - isNull.binaryOp(cudf.BinaryOp.NULL_LOGICAL_OR, isEmpty, cudf.DType.BOOL8) - } + val isNullOrEmptyInput = withResource(isEmpty) { _ => + withResource(input.isNull) { isNull => + isNull.binaryOp(cudf.BinaryOp.NULL_LOGICAL_OR, isEmpty, cudf.DType.BOOL8) } - closeOnExcept(isNullOrEmptyInput) { _ => + } + closeOnExcept(isNullOrEmptyInput) { _ => + withResource(cudf.Scalar.fromString(emptyRowStr)) { emptyRow => withResource(isNullOrEmptyInput.ifElse(emptyRow, stripped)) { nullsReplaced => val isLiteralNull = withResource(Scalar.fromString("null")) { literalNull => nullsReplaced.equalTo(literalNull) } withResource(isLiteralNull) { _ => withResource(isLiteralNull.ifElse(emptyRow, nullsReplaced)) { cleaned => - withResource(cudf.Scalar.fromString("\n")) { lineSep => - withResource(cudf.Scalar.fromString("\r")) { returnSep => - withResource(cleaned.stringContains(lineSep)) { inputHas => - withResource(inputHas.any()) { anyLineSep => - if (anyLineSep.isValid && anyLineSep.getBoolean) { - throw new IllegalArgumentException( - "We cannot currently support parsing " + - "JSON that contains a line separator in it") - } - } - } - withResource(cleaned.stringContains(returnSep)) { inputHas => - withResource(inputHas.any()) { anyReturnSep => - if (anyReturnSep.isValid && anyReturnSep.getBoolean) { - throw new IllegalArgumentException( - "We cannot currently support parsing " + - "JSON that contains a carriage return in it") - } - } - } - } - - // if the last entry in a column is incomplete or invalid, then cuDF - // will drop the row rather than replace with null if there is no newline, so we - // add a newline here to prevent that - val joined = withResource(cleaned.joinStrings(lineSep, emptyRow)) { joined => - withResource(ColumnVector.fromStrings("\n")) { newline => - ColumnVector.stringConcatenate(Array[ColumnView](joined, newline)) - } + checkForNewline(cleaned, "\n", "line separator") + checkForNewline(cleaned, "\r", "carriage return") + // if the last entry in a column is incomplete or invalid, then cuDF + // will drop the row rather than replace with null if there is no newline, so we + // add a newline here to prevent that + val joined = withResource(cudf.Scalar.fromString("\n")) { lineSep => + cleaned.joinStrings(lineSep, emptyRow) + } + val concat = withResource(joined) { _ => + withResource(ColumnVector.fromStrings("\n")) { newline => + ColumnVector.stringConcatenate(Array[ColumnView](joined, newline)) } - - (isNullOrEmptyInput, joined) } + + (isNullOrEmptyInput, concat) } } } @@ -117,6 +98,19 @@ case class GpuJsonToStructs( } } + private def checkForNewline(cleaned: ColumnVector, newlineStr: String, name: String): Unit = { + withResource(cudf.Scalar.fromString(newlineStr)) { newline => + withResource(cleaned.stringContains(newline)) { hasNewline => + withResource(hasNewline.any()) { anyNewline => + if (anyNewline.isValid && anyNewline.getBoolean) { + throw new IllegalArgumentException( + s"We cannot currently support parsing JSON that contains a $name in it") + } + } + } + } + } + // Process a sequence of field names. If there are duplicated field names, we only keep the field // name with the largest index in the sequence, for others, replace the field names with null. // Example: