Skip to content

Commit

Permalink
Upgrade datafusion (apache#867)
Browse files Browse the repository at this point in the history
* update dependencies

* update get_logical_plan signature

* remove row_number() function

row_number was converted to a UDF in datafusion v42 apache/datafusion#12030
This specific functionality needs to be added back in.

* remove unneeded dependency

* fix pyo3 warnings

Implicit defaults for trailing optional arguments have been deprecated
in pyo3 v0.22.0 PyO3/pyo3#4078

* update object_store dependency

* change PyExpr -> PySortExpr

* comment out key.extract::<&PyTuple>() condition statement

* change more instances of PyExpr > PySortExpr

* update function signatures to use _bound versions

* remove clone

* Working through some of the sort requirement changes

* remove unused import

* expr.display_name is deprecated, used format!() + schema_name() instead

* expr.canonical_name() is deprecated, use format!() expr instead

* remove comment

* fix tuple extraction in dataframe.__getitem__()

* remove unneeded import

* Add docstring comments to SortExpr python class

* change extract() to downcast()

Co-authored-by: Michael J Ward <[email protected]>

* deprecate Expr::display_name

Ref: apache/datafusion#11797

* fix lint errors

* update datafusion commit hash

* fix type in cargo file for arrow features

* upgrade to datafusion 42

* cleanup

---------

Co-authored-by: Tim Saucer <[email protected]>
Co-authored-by: Michael J Ward <[email protected]>
Co-authored-by: Michael-J-Ward <[email protected]>
  • Loading branch information
4 people authored Sep 17, 2024
1 parent 02d4453 commit 6c8bf5f
Show file tree
Hide file tree
Showing 22 changed files with 710 additions and 595 deletions.
784 changes: 390 additions & 394 deletions Cargo.lock

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.8"
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "52", feature = ["pyarrow"] }
datafusion = { version = "41.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "41.0.0", optional = true }
prost = "0.12" # keep in line with `datafusion-substrait`
prost-types = "0.12" # keep in line with `datafusion-substrait`
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "42.0.0", optional = true }
prost = "0.13" # keep in line with `datafusion-substrait`
prost-types = "0.13" # keep in line with `datafusion-substrait`
uuid = { version = "1.9", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
async-trait = "0.1"
futures = "0.3"
object_store = { version = "0.10.1", features = ["aws", "gcp", "azure"] }
object_store = { version = "0.11.0", features = ["aws", "gcp", "azure"] }
parking_lot = "0.12"
regex-syntax = "0.8"
syn = "2.0.68"
url = "2"

[build-dependencies]
pyo3-build-config = "0.21"
pyo3-build-config = "0.22"

[lib]
name = "datafusion_python"
Expand All @@ -62,4 +62,3 @@ crate-type = ["cdylib", "rlib"]
[profile.release]
lto = true
codegen-units = 1

13 changes: 8 additions & 5 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from datafusion._internal import AggregateUDF
from datafusion.catalog import Catalog, Table
from datafusion.dataframe import DataFrame
from datafusion.expr import Expr
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import ScalarUDF

Expand Down Expand Up @@ -466,7 +466,7 @@ def register_listing_table(
table_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".parquet",
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
) -> None:
"""Register multiple files as a single table.
Expand All @@ -484,15 +484,18 @@ def register_listing_table(
"""
if table_partition_cols is None:
table_partition_cols = []
if file_sort_order is not None:
file_sort_order = [[x.expr for x in xs] for xs in file_sort_order]
file_sort_order_raw = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
else None
)
self.ctx.register_listing_table(
name,
str(path),
table_partition_cols,
file_extension,
schema,
file_sort_order,
file_sort_order_raw,
)

def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
Expand Down
8 changes: 4 additions & 4 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import Callable

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr
from datafusion.expr import Expr, SortExpr, sort_or_default
from datafusion._internal import (
LogicalPlan,
ExecutionPlan,
Expand Down Expand Up @@ -199,7 +199,7 @@ def aggregate(
aggs = [e.expr for e in aggs]
return DataFrame(self.df.aggregate(group_by, aggs))

def sort(self, *exprs: Expr) -> DataFrame:
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
"""Sort the DataFrame by the specified sorting expressions.
Note that any expression can be turned into a sort expression by
Expand All @@ -211,8 +211,8 @@ def sort(self, *exprs: Expr) -> DataFrame:
Returns:
DataFrame after sorting.
"""
exprs = [expr.expr for expr in exprs]
return DataFrame(self.df.sort(*exprs))
exprs_raw = [sort_or_default(expr) for expr in exprs]
return DataFrame(self.df.sort(*exprs_raw))

def limit(self, count: int, offset: int = 0) -> DataFrame:
"""Return a new :py:class:`DataFrame` with a limited number of rows.
Expand Down
83 changes: 70 additions & 13 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@

from __future__ import annotations

from ._internal import (
expr as expr_internal,
LogicalPlan,
functions as functions_internal,
)
from datafusion.common import NullTreatment, RexType, DataTypeMap
from typing import Any, Optional, Type

import pyarrow as pa
from datafusion.common import DataTypeMap, NullTreatment, RexType
from typing_extensions import deprecated

from ._internal import LogicalPlan
from ._internal import expr as expr_internal
from ._internal import functions as functions_internal

# The following are imported from the internal representation. We may choose to
# give these all proper wrappers, or to simply leave as is. These were added
Expand Down Expand Up @@ -84,7 +85,6 @@
ScalarVariable = expr_internal.ScalarVariable
SimilarTo = expr_internal.SimilarTo
Sort = expr_internal.Sort
SortExpr = expr_internal.SortExpr
Subquery = expr_internal.Subquery
SubqueryAlias = expr_internal.SubqueryAlias
TableScan = expr_internal.TableScan
Expand Down Expand Up @@ -159,6 +159,27 @@
]


def expr_list_to_raw_expr_list(
expr_list: Optional[list[Expr]],
) -> Optional[list[expr_internal.Expr]]:
"""Helper function to convert an optional list to raw expressions."""
return [e.expr for e in expr_list] if expr_list is not None else None


def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
return SortExpr(e.expr, True, True).raw_sort


def sort_list_to_raw_sort_list(
sort_list: Optional[list[Expr | SortExpr]],
) -> Optional[list[expr_internal.SortExpr]]:
"""Helper function to return an optional sort list to raw variant."""
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None


class Expr:
"""Expression object.
Expand All @@ -174,12 +195,22 @@ def to_variant(self) -> Any:
"""Convert this expression into a python object if possible."""
return self.expr.to_variant()

@deprecated(
"display_name() is deprecated. Use :py:meth:`~Expr.schema_name` instead"
)
def display_name(self) -> str:
"""Returns the name of this expression as it should appear in a schema.
This name will not include any CAST expressions.
"""
return self.expr.display_name()
return self.schema_name()

def schema_name(self) -> str:
"""Returns the name of this expression as it should appear in a schema.
This name will not include any CAST expressions.
"""
return self.expr.schema_name()

def canonical_name(self) -> str:
"""Returns a complete string representation of this expression."""
Expand Down Expand Up @@ -355,14 +386,14 @@ def alias(self, name: str) -> Expr:
"""Assign a name to the expression."""
return Expr(self.expr.alias(name))

def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr:
def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
"""Creates a sort :py:class:`Expr` from an existing :py:class:`Expr`.
Args:
ascending: If true, sort in ascending order.
nulls_first: Return null values first.
"""
return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first))
return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first)

def is_null(self) -> Expr:
"""Returns ``True`` if this expression is null."""
Expand Down Expand Up @@ -455,14 +486,14 @@ def column_name(self, plan: LogicalPlan) -> str:
"""Compute the output column name based on the provided logical plan."""
return self.expr.column_name(plan)

def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder:
"""Set the ordering for a window or aggregate function.
This function will create an :py:class:`ExprFuncBuilder` that can be used to
set parameters for either window or aggregate functions. If used on any other
type of expression, an error will be generated when ``build()`` is called.
"""
return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs)))
return ExprFuncBuilder(self.expr.order_by([sort_or_default(e) for e in exprs]))

def filter(self, filter: Expr) -> ExprFuncBuilder:
"""Filter an aggregate function.
Expand Down Expand Up @@ -522,7 +553,9 @@ def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
Values given in ``exprs`` must be sort expressions. You can convert any other
expression to a sort expression using `.sort()`.
"""
return ExprFuncBuilder(self.builder.order_by(list(e.expr for e in exprs)))
return ExprFuncBuilder(
self.builder.order_by([sort_or_default(e) for e in exprs])
)

def filter(self, filter: Expr) -> ExprFuncBuilder:
"""Filter values during aggregation."""
Expand Down Expand Up @@ -659,3 +692,27 @@ def end(self) -> Expr:
Any non-matching cases will end in a `null` value.
"""
return Expr(self.case_builder.end())


class SortExpr:
"""Used to specify sorting on either a DataFrame or function."""

def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None:
"""This constructor should not be called by the end user."""
self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first)

def expr(self) -> Expr:
"""Return the raw expr backing the SortExpr."""
return Expr(self.raw_sort.expr())

def ascending(self) -> bool:
"""Return ascending property."""
return self.raw_sort.ascending()

def nulls_first(self) -> bool:
"""Return nulls_first property."""
return self.raw_sort.nulls_first()

def __repr__(self) -> str:
"""Generate a string representation of this expression."""
return self.raw_sort.__repr__()
Loading

0 comments on commit 6c8bf5f

Please sign in to comment.