Skip to content

Commit

Permalink
[FEAT] Add unpivot/melt (#2204)
Browse files Browse the repository at this point in the history
Adds the unpivot dataframe operation
  • Loading branch information
kevinzwang authored May 6, 2024
1 parent a77bee7 commit c57aaad
Show file tree
Hide file tree
Showing 31 changed files with 1,069 additions and 218 deletions.
6 changes: 6 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,9 @@ class PyMicroPartition:
self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool
) -> PyMicroPartition: ...
def explode(self, to_explode: list[PyExpr]) -> PyMicroPartition: ...
def unpivot(
self, ids: list[PyExpr], values: list[PyExpr], variable_name: str, value_name: str
) -> PyMicroPartition: ...
def head(self, num: int) -> PyMicroPartition: ...
def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyMicroPartition: ...
def sample_by_size(self, size: int, with_replacement: bool, seed: int | None) -> PyMicroPartition: ...
Expand Down Expand Up @@ -1310,6 +1313,9 @@ class LogicalPlanBuilder:
def filter(self, predicate: PyExpr) -> LogicalPlanBuilder: ...
def limit(self, limit: int, eager: bool) -> LogicalPlanBuilder: ...
def explode(self, to_explode: list[PyExpr]) -> LogicalPlanBuilder: ...
def unpivot(
self, ids: list[PyExpr], values: list[PyExpr], variable_name: str, value_name: str
) -> LogicalPlanBuilder: ...
def sort(self, sort_by: list[PyExpr], descending: list[bool]) -> LogicalPlanBuilder: ...
def hash_repartition(
self,
Expand Down
165 changes: 133 additions & 32 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

ColumnInputType = Union[Expression, str]

ColumnInputOrListType = Union[ColumnInputType, List[ColumnInputType]]
ManyColumnsInputType = Union[ColumnInputType, Iterable[ColumnInputType]]


class DataFrame:
Expand Down Expand Up @@ -529,28 +529,27 @@ def write_iceberg(self, table: "IcebergTable", mode: str = "append") -> "DataFra
###

def __column_input_to_expression(self, columns: Iterable[ColumnInputType]) -> List[Expression]:
# TODO(Kevin): remove this method and use _column_inputs_to_expressions
return [col(c) if isinstance(c, str) else c for c in columns]

def _inputs_to_expressions(self, inputs: Tuple[ColumnInputOrListType, ...]) -> List[Expression]:
def _is_column_input(self, x: Any) -> bool:
return isinstance(x, str) or isinstance(x, Expression)

def _column_inputs_to_expressions(self, columns: ManyColumnsInputType) -> List[Expression]:
"""
Inputs to dataframe operations can be passed in as individual arguments or a list.
In addition, they may be strings, Expressions, or tuples (deprecated).
Inputs to dataframe operations can be passed in as individual arguments or an iterable.
In addition, they may be strings or Expressions.
This method normalizes the inputs to a list of Expressions.
"""
cols = inputs[0] if (len(inputs) == 1 and isinstance(inputs[0], list)) else inputs

exprs = []
for c in cols:
if isinstance(c, str):
exprs.append(col(c))
elif isinstance(c, Expression):
exprs.append(c)
elif isinstance(c, tuple):
exprs.append(self._agg_tuple_to_expression(c))
else:
raise ValueError(f"Unknown column type: {type(c)}")
column_iter: Iterable[ColumnInputType] = [columns] if self._is_column_input(columns) else columns # type: ignore
return [col(c) if isinstance(c, str) else c for c in column_iter]

def _wildcard_inputs_to_expressions(self, columns: Tuple[ManyColumnsInputType, ...]) -> List[Expression]:
"""Handles wildcard argument column inputs"""

return exprs
column_input: Iterable[ColumnInputType] = columns[0] if len(columns) == 1 else columns # type: ignore
return self._column_inputs_to_expressions(column_input)

def __getitem__(self, item: Union[slice, int, str, Iterable[Union[str, int]]]) -> Union[Expression, "DataFrame"]:
"""Gets a column from the DataFrame as an Expression (``df["mycol"]``)"""
Expand Down Expand Up @@ -1095,17 +1094,95 @@ def explode(self, *columns: ColumnInputType) -> "DataFrame":
builder = self._builder.explode(parsed_exprs)
return DataFrame(builder)

@DataframePublicAPI
def unpivot(
self,
ids: ManyColumnsInputType,
values: ManyColumnsInputType = [],
variable_name: str = "variable",
value_name: str = "value",
) -> "DataFrame":
"""Unpivots a DataFrame from wide to long format.
Example:
>>> df = daft.from_pydict({
... "year": [2020, 2021, 2022],
... "Jan": [10, 30, 50],
... "Feb": [20, 40, 60],
... })
>>> df
╭───────┬───────┬───────╮
│ year ┆ Jan ┆ Feb │
│ --- ┆ --- ┆ --- │
│ Int64 ┆ Int64 ┆ Int64 │
╞═══════╪═══════╪═══════╡
│ 2020 ┆ 10 ┆ 20 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2021 ┆ 30 ┆ 40 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2022 ┆ 50 ┆ 60 │
╰───────┴───────┴───────╯
(Showing first 3 of 3 rows)
>>> df = df.unpivot("year", ["Jan", "Feb"], variable_name="month", value_name="inventory")
>>> df = df.sort("year")
>>> df.show()
╭───────┬───────┬───────────╮
│ year ┆ month ┆ inventory │
│ --- ┆ --- ┆ --- │
│ Int64 ┆ Utf8 ┆ Int64 │
╞═══════╪═══════╪═══════════╡
│ 2020 ┆ Jan ┆ 10 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2020 ┆ Feb ┆ 20 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2021 ┆ Jan ┆ 30 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2021 ┆ Feb ┆ 40 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2022 ┆ Jan ┆ 50 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤
│ 2022 ┆ Feb ┆ 60 │
╰───────┴───────┴───────────╯
(Showing first 6 of 6 rows)
Args:
ids (ManyColumnsInputType): Columns to keep as identifiers
values (Optional[ManyColumnsInputType]): Columns to unpivot. If not specified, all columns except ids will be unpivoted.
variable_name (Optional[str]): Name of the variable column. Defaults to "variable".
value_name (Optional[str]): Name of the value column. Defaults to "value".
Returns:
DataFrame: Unpivoted DataFrame
See also:
`melt`
"""
ids_exprs = self._column_inputs_to_expressions(ids)
values_exprs = self._column_inputs_to_expressions(values)

builder = self._builder.unpivot(ids_exprs, values_exprs, variable_name, value_name)
return DataFrame(builder)

@DataframePublicAPI
def melt(
self,
ids: ManyColumnsInputType,
values: ManyColumnsInputType = [],
variable_name: str = "variable",
value_name: str = "value",
) -> "DataFrame":
"""Alias for unpivot
See also:
`unpivot`
"""
return self.unpivot(ids, values, variable_name, value_name)

def _agg(self, to_agg: List[Expression], group_by: Optional[ExpressionsProjection] = None) -> "DataFrame":
builder = self._builder.agg(to_agg, list(group_by) if group_by is not None else None)
return DataFrame(builder)

def _agg_tuple_to_expression(self, agg_tuple: Tuple[ColumnInputType, str]) -> Expression:
warnings.warn(
"Tuple arguments in aggregations is deprecated and will be removed "
"in Daft v0.3. Please use aggregation expressions instead.",
DeprecationWarning,
)

expr, op = agg_tuple

if isinstance(expr, str):
Expand All @@ -1130,26 +1207,50 @@ def _agg_tuple_to_expression(self, agg_tuple: Tuple[ColumnInputType, str]) -> Ex

raise NotImplementedError(f"Aggregation {op} is not implemented.")

def _agg_inputs_to_expressions(
self, to_agg: Tuple[Union[Expression, Iterable[Expression]], ...]
) -> List[Expression]:
def is_agg_column_input(x: Any) -> bool:
# aggs currently support Expression or tuple of (ColumnInputType, str) [deprecated]
if isinstance(x, Expression):
return True
if isinstance(x, tuple) and len(x) == 2:
tuple_type = list(map(type, x))
return tuple_type == [Expression, str] or tuple_type == [str, str]
return False

columns: Iterable[Expression] = to_agg[0] if len(to_agg) == 1 and not is_agg_column_input(to_agg[0]) else to_agg # type: ignore

if any(isinstance(col, tuple) for col in columns):
warnings.warn(
"Tuple arguments in aggregations is deprecated and will be removed "
"in Daft v0.3. Please use aggregation expressions instead.",
DeprecationWarning,
)
return [self._agg_tuple_to_expression(col) if isinstance(col, tuple) else col for col in columns] # type: ignore
else:
return list(columns)

def _apply_agg_fn(
self,
fn: Callable[[Expression], Expression],
cols: Tuple[ColumnInputOrListType, ...],
cols: Tuple[ManyColumnsInputType, ...],
group_by: Optional[ExpressionsProjection] = None,
) -> "DataFrame":
if len(cols) == 0:
warnings.warn("No columns specified; performing aggregation on all columns.")

groupby_name_set = set() if group_by is None else group_by.to_name_set()
cols = tuple(c for c in self.column_names if c not in groupby_name_set)
exprs = self._inputs_to_expressions(cols)
exprs = self._wildcard_inputs_to_expressions(cols)
return self._agg([fn(c) for c in exprs], group_by)

def _map_groups(self, udf: Expression, group_by: Optional[ExpressionsProjection] = None) -> "DataFrame":
builder = self._builder.map_groups(udf, list(group_by) if group_by is not None else None)
return DataFrame(builder)

@DataframePublicAPI
def sum(self, *cols: ColumnInputOrListType) -> "DataFrame":
def sum(self, *cols: ManyColumnsInputType) -> "DataFrame":
"""Performs a global sum on the DataFrame
Args:
Expand Down Expand Up @@ -1238,7 +1339,7 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
return self._apply_agg_fn(Expression.agg_concat, cols)

@DataframePublicAPI
def agg(self, *to_agg: ColumnInputOrListType) -> "DataFrame":
def agg(self, *to_agg: Union[Expression, Iterable[Expression]]) -> "DataFrame":
"""Perform aggregations on this DataFrame. Allows for mixed aggregations for multiple columns
Will return a single row that aggregated the entire DataFrame.
Expand All @@ -1257,10 +1358,10 @@ def agg(self, *to_agg: ColumnInputOrListType) -> "DataFrame":
Returns:
DataFrame: DataFrame with aggregated results
"""
return self._agg(self._inputs_to_expressions(to_agg), group_by=None)
return self._agg(self._agg_inputs_to_expressions(to_agg), group_by=None)

@DataframePublicAPI
def groupby(self, *group_by: ColumnInputOrListType) -> "GroupedDataFrame":
def groupby(self, *group_by: ManyColumnsInputType) -> "GroupedDataFrame":
"""Performs a GroupBy on the DataFrame for aggregation
Args:
Expand All @@ -1269,7 +1370,7 @@ def groupby(self, *group_by: ColumnInputOrListType) -> "GroupedDataFrame":
Returns:
GroupedDataFrame: DataFrame to Aggregate
"""
return GroupedDataFrame(self, ExpressionsProjection(self._inputs_to_expressions(group_by)))
return GroupedDataFrame(self, ExpressionsProjection(self._wildcard_inputs_to_expressions(group_by)))

@DataframePublicAPI
def pivot(
Expand Down Expand Up @@ -1809,7 +1910,7 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
"""
return self.df._apply_agg_fn(Expression.agg_concat, cols, self.group_by)

def agg(self, *to_agg: ColumnInputOrListType) -> "DataFrame":
def agg(self, *to_agg: Union[Expression, Iterable[Expression]]) -> "DataFrame":
"""Perform aggregations on this GroupedDataFrame. Allows for mixed aggregations.
Example:
Expand All @@ -1821,12 +1922,12 @@ def agg(self, *to_agg: ColumnInputOrListType) -> "DataFrame":
>>> )
Args:
*to_agg (Expression): aggregation expressions
*to_agg (Union[Expression, Iterable[Expression]]): aggregation expressions
Returns:
DataFrame: DataFrame with grouped aggregations
"""
return self.df._agg(self.df._inputs_to_expressions(to_agg), group_by=self.group_by)
return self.df._agg(self.df._agg_inputs_to_expressions(to_agg), group_by=self.group_by)

def map_groups(self, udf: Expression) -> "DataFrame":
"""Apply a user-defined function to each group. The name of the resultant column will default to the name of the first input column.
Expand Down
24 changes: 24 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,30 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
]


@dataclass(frozen=True)
class Unpivot(SingleOutputInstruction):
ids: ExpressionsProjection
values: ExpressionsProjection
variable_name: str
value_name: str

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._unpivot(inputs)

def _unpivot(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[input] = inputs
return [input.unpivot(self.ids, self.values, self.variable_name, self.value_name)]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
[input_meta] = input_metadatas
return [
PartialPartitionMetadata(
num_rows=None if input_meta.num_rows is None else input_meta.num_rows * len(self.values),
size_bytes=None,
)
]


@dataclass(frozen=True)
class HashJoin(SingleOutputInstruction):
left_on: ExpressionsProjection
Expand Down
24 changes: 24 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ def explode(
)


def unpivot(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
ids: list[PyExpr],
values: list[PyExpr],
variable_name: str,
value_name: str,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
ids_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in ids])
values_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in values])

unpivot_step = execution_step.Unpivot(
ids=ids_projection,
values=values_projection,
variable_name=variable_name,
value_name=value_name,
)

return physical_plan.pipeline_instruction(
child_plan=input,
pipeable_instruction=unpivot_step,
resource_request=ResourceRequest(),
)


def local_aggregate(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
agg_exprs: list[PyExpr],
Expand Down
8 changes: 8 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def explode(self, explode_expressions: list[Expression]) -> LogicalPlanBuilder:
builder = self._builder.explode(explode_pyexprs)
return LogicalPlanBuilder(builder)

def unpivot(
self, ids: list[Expression], values: list[Expression], variable_name: str, value_name: str
) -> LogicalPlanBuilder:
ids_pyexprs = [expr._expr for expr in ids]
values_pyexprs = [expr._expr for expr in values]
builder = self._builder.unpivot(ids_pyexprs, values_pyexprs, variable_name, value_name)
return LogicalPlanBuilder(builder)

def count(self) -> LogicalPlanBuilder:
# TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations.
first_col = col(self.schema().column_names()[0])
Expand Down
9 changes: 9 additions & 0 deletions daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,15 @@ def explode(self, columns: ExpressionsProjection) -> MicroPartition:
to_explode_pyexprs = [e._expr for e in columns]
return MicroPartition._from_pymicropartition(self._micropartition.explode(to_explode_pyexprs))

def unpivot(
self, ids: ExpressionsProjection, values: ExpressionsProjection, variable_name: str, value_name: str
) -> MicroPartition:
ids_pyexprs = [e._expr for e in ids]
values_pyexprs = [e._expr for e in values]
return MicroPartition._from_pymicropartition(
self._micropartition.unpivot(ids_pyexprs, values_pyexprs, variable_name, value_name)
)

def hash_join(
self,
right: MicroPartition,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Manipulating Columns
DataFrame.pivot
DataFrame.exclude
DataFrame.explode
DataFrame.unpivot
DataFrame.melt

Filtering Rows
**************
Expand Down
1 change: 0 additions & 1 deletion src/daft-core/src/array/ops/arrow2/sort/primitive/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ mod tests {

use arrow2::array::ord;
use arrow2::array::Array;
use arrow2::array::PrimitiveArray;
use arrow2::datatypes::DataType;

fn test_sort_primitive_arrays<T>(
Expand Down
Loading

0 comments on commit c57aaad

Please sign in to comment.