Skip to content

Commit

Permalink
[SPARK-49863][SQL] Fix NormalizeFloatingNumbers to preserve nullabili…
Browse files Browse the repository at this point in the history
…ty of nested structs

### What changes were proposed in this pull request?

- Fixes a bug in `NormalizeFloatingNumbers` to respect the `nullable` attribute of nested expressions when normalizing.

### Why are the changes needed?

- Without the fix, there would be a degradation in the nullability of the expression post normalization.
- For example, for an expression like: `namedStruct("struct", namedStruct("double", <DoubleType-field>)) ` with the following data type:

```
StructType(StructField("struct", StructType(StructField("double", DoubleType, true, {})), false, {}))
```

after normalizing we would have ended up with the dataType:
```
StructType(StructField("struct", StructType(StructField("double", DoubleType, true, {})), true, {}))
```

Note, the change in the `nullable` attribute of the "double" StructField from `false` to `true`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- Added unit test.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#48331 from nikhilsheoran-db/SPARK-49863-fix.

Authored-by: Nikhil Sheoran <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nikhilsheoran-db authored and cloud-fan committed Oct 9, 2024
1 parent c1f18a0 commit 5e27eec
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,17 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i)))
}
val struct = CreateNamedStruct(fields.flatten.toImmutableArraySeq)
KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
// For nested structs (and other complex types), this branch is called again with either a
// `GetStructField` or a `NamedLambdaVariable` expression. Even if the field for which this
// has been recursively called might have `nullable = false`, directly creating an `If`
// predicate would end up creating an expression with `nullable = true` (as the trueBranch is
// nullable). Hence, use the `expr.nullable` to create an `If` predicate only when the column
// is nullable.
if (expr.nullable) {
KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
} else {
KnownFloatingPointNormalized(struct)
}

case _ if expr.dataType.isInstanceOf[ArrayType] =>
val ArrayType(et, containsNull) = expr.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,13 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {

comparePlans(doubleOptimized, correctAnswer)
}

test("SPARK-49863: NormalizeFloatingNumbers preserves nullability for nested struct") {
val relation = LocalRelation($"a".double, $"b".string)
val nestedExpr = namedStruct("struct", namedStruct("double", relation.output.head))
.as("nestedExpr").toAttribute
val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
assert(nestedExpr.dataType == normalizedExpr.dataType)
}
}

0 comments on commit 5e27eec

Please sign in to comment.