-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
pycompat.py
140 lines (102 loc) · 4.27 KB
/
pycompat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import annotations
from importlib import import_module
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
from packaging.version import Version
from xarray.core.utils import is_duck_array, is_scalar, module_available
integer_types = (int, np.integer)
if TYPE_CHECKING:
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic
class DuckArrayModule:
"""
Solely for internal isinstance and version checks.
Motivated by having to only import pint when required (as pint currently imports xarray)
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
"""
module: ModuleType | None
version: Version
type: DuckArrayTypes
available: bool
def __init__(self, mod: ModType) -> None:
duck_array_module: ModuleType | None
duck_array_version: Version
duck_array_type: DuckArrayTypes
try:
duck_array_module = import_module(mod)
duck_array_version = Version(duck_array_module.__version__)
if mod == "dask":
duck_array_type = (import_module("dask.array").Array,)
elif mod == "pint":
duck_array_type = (duck_array_module.Quantity,)
elif mod == "cupy":
duck_array_type = (duck_array_module.ndarray,)
elif mod == "sparse":
duck_array_type = (duck_array_module.SparseArray,)
elif mod == "cubed":
duck_array_type = (duck_array_module.Array,)
# Not a duck array module, but using this system regardless, to get lazy imports
elif mod == "numbagg":
duck_array_type = ()
else:
raise NotImplementedError
except (ImportError, AttributeError): # pragma: no cover
duck_array_module = None
duck_array_version = Version("0.0.0")
duck_array_type = ()
self.module = duck_array_module
self.version = duck_array_version
self.type = duck_array_type
self.available = duck_array_module is not None
_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}
def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
if mod not in _cached_duck_array_modules:
duckmod = DuckArrayModule(mod)
_cached_duck_array_modules[mod] = duckmod
return duckmod
else:
return _cached_duck_array_modules[mod]
def array_type(mod: ModType) -> DuckArrayTypes:
"""Quick wrapper to get the array class of the module."""
return _get_cached_duck_array_module(mod).type
def mod_version(mod: ModType) -> Version:
"""Quick wrapper to get the version of the module."""
return _get_cached_duck_array_module(mod).version
def is_dask_collection(x):
if module_available("dask"):
from dask.base import is_dask_collection
return is_dask_collection(x)
return False
def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)
def is_chunked_array(x) -> bool:
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
def is_0d_dask_array(x):
return is_duck_dask_array(x) and is_scalar(x)
def to_numpy(data) -> np.ndarray:
from xarray.core.indexing import ExplicitlyIndexed
from xarray.core.parallelcompat import get_chunked_array_type
if isinstance(data, ExplicitlyIndexed):
data = data.get_duck_array()
# TODO first attempt to call .to_numpy() once some libraries implement it
if hasattr(data, "chunks"):
chunkmanager = get_chunked_array_type(data)
data, *_ = chunkmanager.compute(data)
if isinstance(data, array_type("cupy")):
data = data.get()
# pint has to be imported dynamically as pint imports xarray
if isinstance(data, array_type("pint")):
data = data.magnitude
if isinstance(data, array_type("sparse")):
data = data.todense()
data = np.asarray(data)
return data
def to_duck_array(data):
from xarray.core.indexing import ExplicitlyIndexed
if isinstance(data, ExplicitlyIndexed):
return data.get_duck_array()
elif is_duck_array(data):
return data
else:
return np.asarray(data)