Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PERF] Lazily import heavy modules to speed up import times #2826

Merged
merged 41 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
55b53a2
Defer Register of Super Type
samster25 Sep 11, 2024
67566e3
Add lazy import
desmondcheongzx Sep 10, 2024
b5a3900
Fix test
desmondcheongzx Sep 10, 2024
1a184d9
Add more lazy imports
desmondcheongzx Sep 11, 2024
701b4ae
Fix import order
desmondcheongzx Sep 11, 2024
d05d7e7
Fix tests
desmondcheongzx Sep 17, 2024
94a9779
Import pyarrow lazily
desmondcheongzx Sep 17, 2024
b48e3cb
Lazily register HTML visualization hooks
desmondcheongzx Sep 17, 2024
3faeb8e
Separate type checking imports from actual imports
desmondcheongzx Sep 17, 2024
6aa2828
Undo poc artifact
desmondcheongzx Sep 17, 2024
a8c5a46
Ban expensive imports with linter
desmondcheongzx Sep 17, 2024
e91afbc
Fix tests
desmondcheongzx Sep 17, 2024
ce44aab
Add flake8-type-checking
desmondcheongzx Sep 18, 2024
988e359
Move all lazy imports into dependencies module
desmondcheongzx Sep 18, 2024
f30b20e
Fix csv test
desmondcheongzx Sep 18, 2024
826c715
Attempt fix
desmondcheongzx Sep 18, 2024
d0a1fce
Test another fix
desmondcheongzx Sep 18, 2024
bddcce6
Fix test?
desmondcheongzx Sep 18, 2024
3052f24
Fix ray runner tests
desmondcheongzx Sep 18, 2024
ed8061c
Add action upterm
desmondcheongzx Sep 18, 2024
a583631
Fix upterm
desmondcheongzx Sep 18, 2024
3580b26
Fix upterm
desmondcheongzx Sep 18, 2024
55cace6
Set timeout
desmondcheongzx Sep 18, 2024
9d26e3b
Tweak
desmondcheongzx Sep 18, 2024
e41737e
Delay upterm
desmondcheongzx Sep 18, 2024
d300ce9
Remove upterm
desmondcheongzx Sep 18, 2024
1508747
Break up build and test
desmondcheongzx Sep 19, 2024
a71d072
Jiggle
desmondcheongzx Sep 19, 2024
4835c92
test
desmondcheongzx Sep 19, 2024
2d2e8d1
jiggle
desmondcheongzx Sep 19, 2024
81e76a6
jiggle
desmondcheongzx Sep 19, 2024
ebafdf7
getting
desmondcheongzx Sep 19, 2024
513c425
jiggle
desmondcheongzx Sep 19, 2024
463294a
test pyarrow version
desmondcheongzx Sep 19, 2024
f28f8e8
jiggle
desmondcheongzx Sep 19, 2024
7f692f6
modify
desmondcheongzx Sep 19, 2024
4496d1f
Fix
desmondcheongzx Sep 19, 2024
8e25dbc
Fix
desmondcheongzx Sep 19, 2024
f7be46d
Remove ssh
desmondcheongzx Sep 19, 2024
faaf4f1
Cleanup
desmondcheongzx Sep 19, 2024
f888e49
Import numpy via daft.dependencies for ray runner
desmondcheongzx Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions daft/.ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
extend = "../.ruff.toml"

[lint]
extend-select = [
"TID253", # banned-module-level-imports, derived from flake8-tidy-imports
"TCH" # flake8-type-checking
]

[lint.flake8-tidy-imports]
# Ban certain modules from being imported at module level, instead requiring
# that they're imported lazily (e.g., within a function definition,
# with daft.lazy_import.LazyImport, or with TYPE_CHECKING).
banned-module-level-imports = ["daft.unity_catalog", "fsspec", "numpy", "pandas", "PIL", "pyarrow", "ray", "xml"]
30 changes: 19 additions & 11 deletions daft/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys

import pyarrow as pa
from daft.dependencies import pa


def ensure_array(arr: pa.Array) -> pa.Array:
Expand Down Expand Up @@ -34,13 +34,21 @@ class _FixEmptyStructArrays:
Python layer before going through ffi into Rust.
"""

EMPTY_STRUCT_TYPE = pa.struct([])
SINGLE_FIELD_STRUCT_TYPE = pa.struct({"": pa.null()})
SINGLE_FIELD_STRUCT_VALUE = {"": None}
@staticmethod
def get_empty_struct_type():
return pa.struct([])

@staticmethod
def get_single_field_struct_type():
return pa.struct({"": pa.null()})

@staticmethod
def get_single_field_struct_value():
return {"": None}

def ensure_table(table: pa.Table) -> pa.Table:
empty_struct_fields = [
(i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE
(i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.get_empty_struct_type()
]
if not empty_struct_fields:
return table
Expand All @@ -49,19 +57,19 @@ def ensure_table(table: pa.Table) -> pa.Table:
return table

def ensure_chunked_array(arr: pa.ChunkedArray) -> pa.ChunkedArray:
if arr.type != _FixEmptyStructArrays.EMPTY_STRUCT_TYPE:
if arr.type != _FixEmptyStructArrays.get_empty_struct_type():
return arr
return pa.chunked_array([_FixEmptyStructArrays.ensure_array(chunk) for chunk in arr.chunks])

def ensure_array(arr: pa.Array) -> pa.Array:
"""Recursively converts empty struct arrays to single-field struct arrays"""
if arr.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE:
if arr.type == _FixEmptyStructArrays.get_empty_struct_type():
return pa.array(
[
_FixEmptyStructArrays.SINGLE_FIELD_STRUCT_VALUE if valid.as_py() else None
_FixEmptyStructArrays.get_single_field_struct_value() if valid.as_py() else None
for valid in arr.is_valid()
],
type=_FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE,
type=_FixEmptyStructArrays.get_single_field_struct_type(),
)

elif isinstance(arr, pa.StructArray):
Expand All @@ -77,10 +85,10 @@ def ensure_array(arr: pa.Array) -> pa.Array:

def remove_empty_struct_placeholders(arr: pa.Array):
"""Recursively removes the empty struct placeholders placed by _FixEmptyStructArrays.ensure_array"""
if arr.type == _FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE:
if arr.type == _FixEmptyStructArrays.get_single_field_struct_type():
return pa.array(
[{} if valid.as_py() else None for valid in arr.is_valid()],
type=_FixEmptyStructArrays.EMPTY_STRUCT_TYPE,
type=_FixEmptyStructArrays.get_empty_struct_type(),
)

elif isinstance(arr, pa.StructArray):
Expand Down
14 changes: 6 additions & 8 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Iterator

import pyarrow

from daft.dataframe.display import MermaidOptions
from daft.execution import physical_plan
from daft.io.scan import ScanOperator
Expand Down Expand Up @@ -994,7 +992,7 @@ class PyDataType:
def tensor(dtype: PyDataType, shape: tuple[int, ...] | None = None) -> PyDataType: ...
@staticmethod
def python() -> PyDataType: ...
def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pyarrow.DataType: ...
def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pa.DataType: ...
def is_numeric(self) -> builtins.bool: ...
def is_image(self) -> builtins.bool: ...
def is_fixed_shape_image(self) -> builtins.bool: ...
Expand Down Expand Up @@ -1271,11 +1269,11 @@ class PyCatalog:

class PySeries:
@staticmethod
def from_arrow(name: str, pyarrow_array: pyarrow.Array) -> PySeries: ...
def from_arrow(name: str, pyarrow_array: pa.Array) -> PySeries: ...
@staticmethod
def from_pylist(name: str, pylist: list[Any], pyobj: str) -> PySeries: ...
def to_pylist(self) -> list[Any]: ...
def to_arrow(self) -> pyarrow.Array: ...
def to_arrow(self) -> pa.Array: ...
def __abs__(self) -> PySeries: ...
def __add__(self, other: PySeries) -> PySeries: ...
def __sub__(self, other: PySeries) -> PySeries: ...
Expand Down Expand Up @@ -1456,10 +1454,10 @@ class PyTable:
def concat(tables: list[PyTable]) -> PyTable: ...
def slice(self, start: int, end: int) -> PyTable: ...
@staticmethod
def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyTable: ...
def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyTable: ...
@staticmethod
def from_pylist_series(dict: dict[str, PySeries]) -> PyTable: ...
def to_arrow_record_batch(self) -> pyarrow.RecordBatch: ...
def to_arrow_record_batch(self) -> pa.RecordBatch: ...
@staticmethod
def empty(schema: PySchema | None = None) -> PyTable: ...

Expand All @@ -1476,7 +1474,7 @@ class PyMicroPartition:
@staticmethod
def from_tables(tables: list[PyTable]) -> PyMicroPartition: ...
@staticmethod
def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyMicroPartition: ...
def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyMicroPartition: ...
@staticmethod
def concat(tables: list[PyMicroPartition]) -> PyMicroPartition: ...
def slice(self, start: int, end: int) -> PyMicroPartition: ...
Expand Down
4 changes: 3 additions & 1 deletion daft/dataframe/preview.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from daft.table import MicroPartition
if TYPE_CHECKING:
from daft.table import MicroPartition


@dataclass(frozen=True)
Expand Down
53 changes: 34 additions & 19 deletions daft/datatype.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import builtins
from typing import TYPE_CHECKING

import pyarrow as pa

from daft.context import get_context
from daft.daft import ImageMode, PyDataType, PyTimeUnit
from daft.dependencies import pa

if TYPE_CHECKING:
import builtins

import numpy as np


Expand Down Expand Up @@ -501,25 +501,40 @@ def __hash__(self) -> int:
return self._dtype.__hash__()


class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")
_EXT_TYPE_REGISTERED = False
_STATIC_DAFT_EXTENSION = None

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata
def _ensure_registered_super_ext_type():
global _EXT_TYPE_REGISTERED
global _STATIC_DAFT_EXTENSION
if not _EXT_TYPE_REGISTERED:

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)
class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)

_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True

pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
def get_super_ext_type():
_ensure_registered_super_ext_type()
return _STATIC_DAFT_EXTENSION
5 changes: 4 additions & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from collections.abc import Iterator
from typing import TYPE_CHECKING

from deltalake.table import DeltaTable

Expand All @@ -20,6 +20,9 @@
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema

if TYPE_CHECKING:
from collections.abc import Iterator

logger = logging.getLogger(__name__)


Expand Down
32 changes: 32 additions & 0 deletions daft/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import TYPE_CHECKING

from daft.lazy_import import LazyImport

if TYPE_CHECKING:
import xml.etree.ElementTree as ET

import fsspec
import numpy as np
import pandas as pd
import PIL.Image as pil_image
import pyarrow as pa
import pyarrow.csv as pacsv
import pyarrow.dataset as pads
import pyarrow.fs as pafs
import pyarrow.json as pajson
import pyarrow.parquet as pq
else:
ET = LazyImport("xml.etree.ElementTree")

fsspec = LazyImport("fsspec")
np = LazyImport("numpy")
pd = LazyImport("pandas")
pil_image = LazyImport("PIL.Image")
pa = LazyImport("pyarrow")
pacsv = LazyImport("pyarrow.csv")
pads = LazyImport("pyarrow.dataset")
pafs = LazyImport("pyarrow.fs")
pajson = LazyImport("pyarrow.json")
pq = LazyImport("pyarrow.parquet")

unity_catalog = LazyImport("daft.unity_catalog")
11 changes: 7 additions & 4 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations

import itertools
import pathlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Protocol

from daft.context import get_context
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask
from daft.daft import ResourceRequest
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import (
Boundaries,
MaterializedResult,
Expand All @@ -20,9 +17,15 @@
from daft.table import MicroPartition, table_io

if TYPE_CHECKING:
import pathlib

from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.daft import FileFormat, IOConfig, JoinType, ScanTask
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema


ID_GEN = itertools.count()

Expand Down
10 changes: 5 additions & 5 deletions daft/execution/native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
NativeExecutor as _NativeExecutor,
)
from daft.daft import PyDaftExecutionConfig
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
)
from daft.table import MicroPartition

if TYPE_CHECKING:
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
)
from daft.runners.pyrunner import PyMaterializedResult


Expand Down
9 changes: 6 additions & 3 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import itertools
import logging
import math
import pathlib
from collections import deque
from typing import (
TYPE_CHECKING,
Expand All @@ -30,7 +29,7 @@
)

from daft.context import get_context
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
from daft.daft import ResourceRequest
from daft.execution import execution_step
from daft.execution.execution_step import (
Instruction,
Expand All @@ -41,7 +40,6 @@
SingleOutputPartitionTask,
)
from daft.expressions import ExpressionsProjection
from daft.logical.schema import Schema
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
Expand All @@ -53,9 +51,14 @@
T = TypeVar("T")

if TYPE_CHECKING:
import pathlib

from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.daft import FileFormat, IOConfig, JoinType
from daft.logical.schema import Schema


# A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks.
InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]]
Expand Down
3 changes: 2 additions & 1 deletion daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import PartitionT
from daft.table import MicroPartition

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.table import MicroPartition


def scan_with_tasks(
scan_tasks: list[ScanTask],
Expand Down
Loading
Loading