Skip to content

Commit

Permalink
[SPARK-49894][PYTHON][CONNECT] Refine the string representation of co…
Browse files Browse the repository at this point in the history
…lumn field operations

### What changes were proposed in this pull request?
Refine the string representation of column field operations: `GetField`, `WithField`, and `DropFields`

### Why are the changes needed?
make the string representations consistent between pyspark classic and connect

### Does this PR introduce _any_ user-facing change?
yes

before
```
In [1]: from pyspark.sql import functions as sf

In [2]: c = sf.col("c")

In [3]: c.x
Out[3]: Column<'UnresolvedExtractValue(c, x)'>
```

after
```
In [1]: from pyspark.sql import functions as sf

In [2]: c = sf.col("c")

In [3]: c.x
Out[3]: Column<'c['x']'>
```

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48369 from zhengruifeng/py_connect_col_str.

Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng and HyukjinKwon committed Oct 8, 2024
1 parent d8aca18 commit c6b09c0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return expr

def __repr__(self) -> str:
return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})"
return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})"


class DropField(Expression):
Expand All @@ -833,7 +833,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return expr

def __repr__(self) -> str:
return f"DropField({self._structExpr}, {self._fieldName})"
return f"drop_field({self._structExpr}, {self._fieldName})"


class UnresolvedExtractValue(Expression):
Expand All @@ -857,7 +857,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return expr

def __repr__(self) -> str:
return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})"
return f"{self._child}['{self._extraction}']"


class UnresolvedRegex(Expression):
Expand Down
71 changes: 71 additions & 0 deletions python/pyspark/sql/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,77 @@ def test_expr_str_representation(self):
when_cond = sf.when(expression, sf.lit(None))
self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>")

def test_col_field_ops_representation(self):
# SPARK-49894: Test string representation of columns
c = sf.col("c")

# getField
self.assertEqual(str(c.x), "Column<'c['x']'>")
self.assertEqual(str(c.x.y), "Column<'c['x']['y']'>")
self.assertEqual(str(c.x.y.z), "Column<'c['x']['y']['z']'>")

self.assertEqual(str(c["x"]), "Column<'c['x']'>")
self.assertEqual(str(c["x"]["y"]), "Column<'c['x']['y']'>")
self.assertEqual(str(c["x"]["y"]["z"]), "Column<'c['x']['y']['z']'>")

self.assertEqual(str(c.getField("x")), "Column<'c['x']'>")
self.assertEqual(
str(c.getField("x").getField("y")),
"Column<'c['x']['y']'>",
)
self.assertEqual(
str(c.getField("x").getField("y").getField("z")),
"Column<'c['x']['y']['z']'>",
)

self.assertEqual(str(c.getItem("x")), "Column<'c['x']'>")
self.assertEqual(
str(c.getItem("x").getItem("y")),
"Column<'c['x']['y']'>",
)
self.assertEqual(
str(c.getItem("x").getItem("y").getItem("z")),
"Column<'c['x']['y']['z']'>",
)

self.assertEqual(
str(c.x["y"].getItem("z")),
"Column<'c['x']['y']['z']'>",
)
self.assertEqual(
str(c["x"].getField("y").getItem("z")),
"Column<'c['x']['y']['z']'>",
)
self.assertEqual(
str(c.getField("x").getItem("y").z),
"Column<'c['x']['y']['z']'>",
)
self.assertEqual(
str(c["x"].y.getField("z")),
"Column<'c['x']['y']['z']'>",
)

# WithField
self.assertEqual(
str(c.withField("x", sf.col("y"))),
"Column<'update_field(c, x, y)'>",
)
self.assertEqual(
str(c.withField("x", sf.col("y")).withField("x", sf.col("z"))),
"Column<'update_field(update_field(c, x, y), x, z)'>",
)

# DropFields
self.assertEqual(str(c.dropFields("x")), "Column<'drop_field(c, x)'>")
self.assertEqual(
str(c.dropFields("x", "y")),
"Column<'drop_field(drop_field(c, x), y)'>",
)
self.assertEqual(
str(c.dropFields("x", "y", "z")),
"Column<'drop_field(drop_field(drop_field(c, x), y), z)'>",
)

def test_lit_time_representation(self):
dt = datetime.date(2021, 3, 4)
self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>")
Expand Down

0 comments on commit c6b09c0

Please sign in to comment.