diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c5ac0bb8..34442fc4 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, List, TYPE_CHECKING +from typing import Any, Iterable, List, TYPE_CHECKING from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated from datafusion.plan import LogicalPlan, ExecutionPlan @@ -160,6 +160,53 @@ def with_column(self, name: str, expr: Expr) -> DataFrame: """ return DataFrame(self.df.with_column(name, expr.expr)) + def with_columns( + self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr + ) -> DataFrame: + """Add columns to the DataFrame. + + By passing expressions, iteratables of expressions, or named expressions. To + pass named expressions use the form name=Expr. + + Example usage: + + The following will add 4 columns labeled a, b, c, and d. + + df = df.with_columns( + lit(0).alias('a'), + [lit(1).alias('b'), lit(2).alias('c')], + d=lit(3) + ) + + Args: + *exprs: Name of the column to add. + **named_exprs: Expression to compute the column. + + Returns: + DataFrame with the new column. + """ + + def _simplify_expression( + *exprs: Expr | Iterable[Expr], **named_exprs: Expr + ) -> list[Expr]: + expr_list = [] + for expr in exprs: + if isinstance(expr, Expr): + expr_list.append(expr.expr) + elif isinstance(expr, Iterable): + for inner_expr in expr: + expr_list.append(inner_expr.expr) + else: + raise NotImplementedError + if named_exprs: + for alias, expr in named_exprs.items(): + expr_list.append(expr.alias(alias).expr) + return expr_list + + expressions = _simplify_expression(*exprs, **named_exprs) + + return DataFrame(self.df.with_columns(expressions)) + def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: r"""Rename one column by applying a new projection. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index e89c5715..55f93975 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -205,6 +205,37 @@ def test_with_column(df): assert result.column(2) == pa.array([5, 7, 9]) +def test_with_columns(df): + df = df.with_columns( + (column("a") + column("b")).alias("c"), + (column("a") + column("b")).alias("d"), + [ + (column("a") + column("b")).alias("e"), + (column("a") + column("b")).alias("f"), + ], + g=(column("a") + column("b")), + ) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.schema.field(0).name == "a" + assert result.schema.field(1).name == "b" + assert result.schema.field(2).name == "c" + assert result.schema.field(3).name == "d" + assert result.schema.field(4).name == "e" + assert result.schema.field(5).name == "f" + assert result.schema.field(6).name == "g" + + assert result.column(0) == pa.array([1, 2, 3]) + assert result.column(1) == pa.array([4, 5, 6]) + assert result.column(2) == pa.array([5, 7, 9]) + assert result.column(3) == pa.array([5, 7, 9]) + assert result.column(4) == pa.array([5, 7, 9]) + assert result.column(5) == pa.array([5, 7, 9]) + assert result.column(6) == pa.array([5, 7, 9]) + + def test_with_column_renamed(df): df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum") diff --git a/src/dataframe.rs b/src/dataframe.rs index e77ca842..81cdb0f6 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -180,6 +180,16 @@ impl PyDataFrame { Ok(Self::new(df)) } + fn with_columns(&self, exprs: Vec) -> PyResult { + let mut df = self.df.as_ref().clone(); + for expr in exprs { + let expr: Expr = expr.into(); + let name = format!("{}", expr.schema_name()); + df = df.with_column(name.as_str(), expr)? + } + Ok(Self::new(df)) + } + /// Rename one column by applying a new projection. This is a no-op if the column to be /// renamed does not exist. fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyResult {