Skip to content

Commit

Permalink
check for pyarrow data types (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jan 24, 2024
1 parent bbfd2cb commit ebf72a9
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 74 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: pip install ./

- name: Run tests
run: nbdev_test --do_print --timing --flags 'matplotlib polars scipy'
run: nbdev_test --do_print --timing --flags 'matplotlib polars pyarrow scipy'

windows-tests:
runs-on: windows-latest
Expand All @@ -60,7 +60,7 @@ jobs:
run: pip install ".[dev]"

- name: Run tests
run: nbdev_test --do_print --timing --flags 'matplotlib polars scipy'
run: nbdev_test --do_print --timing --flags 'matplotlib polars pyarrow scipy'

minimal-tests:
runs-on: ${{ matrix.os }}
Expand Down
8 changes: 3 additions & 5 deletions nbs/processing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"\n",
"from utilsforecast.compat import DataFrame, Series, pl, pl_DataFrame, pl_Series\n",
"from utilsforecast.validation import (\n",
" _get_np_dtype,\n",
" _is_dt_dtype,\n",
" _is_int_dtype,\n",
" ensure_shallow_copy,\n",
Expand Down Expand Up @@ -774,12 +773,11 @@
" if isinstance(times, (pd.Series, pd.Index)):\n",
" if isinstance(freq, str):\n",
" freq = pd.tseries.frequencies.to_offset(freq)\n",
" times_dtype = _get_np_dtype(times)\n",
" ints = _is_int_dtype(times_dtype) and isinstance(freq, int)\n",
" dts = _is_dt_dtype(times_dtype) and isinstance(freq, BaseOffset)\n",
" ints = _is_int_dtype(times) and isinstance(freq, int)\n",
" dts = _is_dt_dtype(times) and isinstance(freq, BaseOffset)\n",
" if not ints and not dts:\n",
" raise ValueError(\n",
" f\"Cannot offset times with data type: '{times_dtype}' \"\n",
" f\"Cannot offset times with data type: '{times.dtype}' \"\n",
" f\"using a frequency of type: '{type(freq)}'.\"\n",
" )\n",
" out = times + n * freq\n",
Expand Down
84 changes: 47 additions & 37 deletions nbs/validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"import re\n",
"from typing import Optional, Union\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from utilsforecast.compat import DataFrame, Series, pl_DataFrame, pl_Series, pl"
Expand Down Expand Up @@ -63,60 +62,73 @@
{
"cell_type": "code",
"execution_count": null,
"id": "98e45c3f-e81a-4ae8-832b-ec9531e46eb9",
"id": "70eb41a2-869c-451f-8608-27aae7d7ef73",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def _get_np_dtype(s: Union[Series, pd.Index]) -> type:\n",
" if isinstance(s, (pd.Series, pd.Index)):\n",
" dtype = s.dtype.type\n",
"def _is_int_dtype(s: Union[pd.Index, Series]) -> bool:\n",
" if isinstance(s, (pd.Index, pd.Series)):\n",
" out = pd.api.types.is_integer_dtype(s.dtype)\n",
" else:\n",
" try:\n",
" out = s.dtype.is_integer()\n",
" except AttributeError:\n",
" out = s.is_integer()\n",
" return out\n",
"\n",
"def _is_dt_dtype(s: Union[pd.Index, Series]) -> bool:\n",
" if isinstance(s, (pd.Index, pd.Series)):\n",
" out = pd.api.types.is_datetime64_any_dtype(s.dtype)\n",
" else:\n",
" dtype = s.head(1).to_numpy().dtype.type\n",
" return dtype"
" out = s.dtype in (pl.Date, pl.Datetime)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea940c0a-5c6b-4c69-a0fd-ceea8ab0cc69",
"id": "b73338a3-b935-483b-942b-84a439ce6aab",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"test_eq(_get_np_dtype(pd.Index([1.0, 2.0], dtype=np.int64)), np.int64)\n",
"test_eq(_get_np_dtype(pd.Index([1.0, 2.0])), np.float64)\n",
"test_eq(_get_np_dtype(pd.to_datetime(['2000-01-01'])), np.datetime64)\n",
"test_eq(_get_np_dtype(pd.to_datetime(['2000-01-01']).to_series()), np.datetime64)"
"assert _is_int_dtype(pd.Series([1, 2]))\n",
"assert _is_int_dtype(pd.Index([1, 2], dtype='uint8'))\n",
"assert not _is_int_dtype(pd.Series([1.0]))\n",
"assert _is_dt_dtype(pd.to_datetime(['2000-01-01']))\n",
"assert _is_dt_dtype(pd.to_datetime(['2000-01-01'], utc=True))\n",
"assert _is_dt_dtype(pd.to_datetime(['2000-01-01']).astype('datetime64[s]'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "266de0cf-2f1e-4165-8e6d-d8c68d4b03a1",
"id": "80f090c9-c293-47df-a95a-af530737201e",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| polars\n",
"test_eq(_get_np_dtype(pl.Series([1, 2], dtype=pl.Int64)), np.int64)\n",
"test_eq(_get_np_dtype(pl.Series([1.0, 2.0])), np.float64)\n",
"test_eq(_get_np_dtype(pl.Series([datetime.datetime(2000, 1, 1)])), np.datetime64)"
"#| pyarrow\n",
"assert _is_int_dtype(pd.Series([1, 2], dtype='int32[pyarrow]'))\n",
"assert _is_dt_dtype(pd.to_datetime(['2000-01-01']).astype('timestamp[ns][pyarrow]'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70eb41a2-869c-451f-8608-27aae7d7ef73",
"id": "242c213e-f5c6-400b-b905-605a9fa44fe6",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def _is_int_dtype(dtype: type) -> bool:\n",
" return np.issubdtype(dtype, np.integer)\n",
"\n",
"def _is_dt_dtype(dtype: type) -> bool:\n",
" return np.issubdtype(dtype, np.datetime64)"
"#| hide\n",
"#| polars\n",
"assert _is_int_dtype(pl.Series([1, 2]))\n",
"assert _is_int_dtype(pl.Series([1, 2], dtype=pl.UInt8))\n",
"assert not _is_int_dtype(pl.Series([1.0]))\n",
"assert _is_dt_dtype(pl.Series([datetime.date(2000, 1, 1)]))\n",
"assert _is_dt_dtype(pl.Series([datetime.datetime(2000, 1, 1)]))\n",
"assert _is_dt_dtype(pl.Series([datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)]))"
]
},
{
Expand All @@ -128,8 +140,7 @@
"source": [
"#| exporti\n",
"def _is_dt_or_int(s: Series) -> bool:\n",
" dtype = _get_np_dtype(s)\n",
" return _is_dt_dtype(dtype) or _is_int_dtype(dtype)"
" return _is_dt_dtype(s) or _is_int_dtype(s)"
]
},
{
Expand Down Expand Up @@ -270,15 +281,15 @@
"\n",
" # time col\n",
" if not _is_dt_or_int(df[time_col]):\n",
" times_dtype = df[time_col].head(1).to_numpy().dtype\n",
" times_dtype = df[time_col].dtype\n",
" raise ValueError(f\"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'.\")\n",
"\n",
" # target col\n",
" if target_col is None:\n",
" return None\n",
" target = df[target_col]\n",
" if isinstance(target, pd.Series):\n",
" is_numeric = np.issubdtype(target.dtype.type, np.number)\n",
" is_numeric = pd.api.types.is_numeric_dtype(target.dtype)\n",
" else:\n",
" try:\n",
" is_numeric = target.dtype.is_numeric()\n",
Expand Down Expand Up @@ -310,14 +321,14 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L66){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L75){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### validate_format\n",
"\n",
"> validate_format\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe.f\n",
"> rame.DataFrame], id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y')\n",
"> time_col:str='ds', target_col:Optional[str]='y')\n",
"\n",
"Ensure DataFrame has expected format.\n",
"\n",
Expand All @@ -326,20 +337,20 @@
"| df | Union | | DataFrame with time series in long format. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestamp. |\n",
"| target_col | str | y | Column that contains the target. |\n",
"| target_col | Optional | y | Column that contains the target. |\n",
"| **Returns** | **None** | | |"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L66){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/utilsforecast/blob/main/utilsforecast/validation.py#L75){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### validate_format\n",
"\n",
"> validate_format\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe.f\n",
"> rame.DataFrame], id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y')\n",
"> time_col:str='ds', target_col:Optional[str]='y')\n",
"\n",
"Ensure DataFrame has expected format.\n",
"\n",
Expand All @@ -348,7 +359,7 @@
"| df | Union | | DataFrame with time series in long format. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestamp. |\n",
"| target_col | str | y | Column that contains the target. |\n",
"| target_col | Optional | y | Column that contains the target. |\n",
"| **Returns** | **None** | | |"
]
},
Expand Down Expand Up @@ -407,13 +418,12 @@
" times: Series,\n",
" freq: Union[str, int],\n",
") -> None:\n",
" time_dtype = times.head(1).to_numpy().dtype\n",
" if _is_int_dtype(time_dtype) and not isinstance(freq, int):\n",
" if _is_int_dtype(times) and not isinstance(freq, int):\n",
" raise ValueError(\n",
" \"Time column contains integers but the specified frequency is not an integer. \"\n",
" \"Please provide a valid integer, e.g. `freq=1`\"\n",
" )\n",
" if _is_dt_dtype(time_dtype) and isinstance(freq, int):\n",
" if _is_dt_dtype(times) and isinstance(freq, int):\n",
" raise ValueError(\n",
" \"Time column contains timestamps but the specified frequency is an integer. \"\n",
" \"Please provide a valid pandas or polars offset, e.g. `freq='D'` or `freq='1d'`.\"\n",
Expand Down
4 changes: 2 additions & 2 deletions settings.ini
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
[DEFAULT]
repo = utilsforecast
lib_name = utilsforecast
version = 0.0.25
version = 0.0.26
min_python = 3.8
license = apache2
black_formatting = True
doc_path = _docs
lib_path = utilsforecast
nbs_path = nbs
recursive = True
tst_flags = matplotlib polars scipy
tst_flags = matplotlib polars pyarrow scipy
put_version_in_init = True
branch = main
custom_sidebar = True
Expand Down
2 changes: 1 addition & 1 deletion utilsforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.25"
__version__ = "0.0.26"
4 changes: 1 addition & 3 deletions utilsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@
'utilsforecast/target_transforms.py'),
'utilsforecast.target_transforms._transform': ( 'target_transforms.html#_transform',
'utilsforecast/target_transforms.py')},
'utilsforecast.validation': { 'utilsforecast.validation._get_np_dtype': ( 'validation.html#_get_np_dtype',
'utilsforecast/validation.py'),
'utilsforecast.validation._is_dt_dtype': ( 'validation.html#_is_dt_dtype',
'utilsforecast.validation': { 'utilsforecast.validation._is_dt_dtype': ( 'validation.html#_is_dt_dtype',
'utilsforecast/validation.py'),
'utilsforecast.validation._is_dt_or_int': ( 'validation.html#_is_dt_or_int',
'utilsforecast/validation.py'),
Expand Down
8 changes: 3 additions & 5 deletions utilsforecast/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from .compat import DataFrame, Series, pl, pl_DataFrame, pl_Series
from utilsforecast.validation import (
_get_np_dtype,
_is_dt_dtype,
_is_int_dtype,
ensure_shallow_copy,
Expand Down Expand Up @@ -343,12 +342,11 @@ def offset_times(
if isinstance(times, (pd.Series, pd.Index)):
if isinstance(freq, str):
freq = pd.tseries.frequencies.to_offset(freq)
times_dtype = _get_np_dtype(times)
ints = _is_int_dtype(times_dtype) and isinstance(freq, int)
dts = _is_dt_dtype(times_dtype) and isinstance(freq, BaseOffset)
ints = _is_int_dtype(times) and isinstance(freq, int)
dts = _is_dt_dtype(times) and isinstance(freq, BaseOffset)
if not ints and not dts:
raise ValueError(
f"Cannot offset times with data type: '{times_dtype}' "
f"Cannot offset times with data type: '{times.dtype}' "
f"using a frequency of type: '{type(freq)}'."
)
out = times + n * freq
Expand Down
38 changes: 19 additions & 19 deletions utilsforecast/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@
import re
from typing import Optional, Union

import numpy as np
import pandas as pd

from .compat import DataFrame, Series, pl_DataFrame, pl_Series, pl

# %% ../nbs/validation.ipynb 5
def _get_np_dtype(s: Union[Series, pd.Index]) -> type:
if isinstance(s, (pd.Series, pd.Index)):
dtype = s.dtype.type
def _is_int_dtype(s: Union[pd.Index, Series]) -> bool:
if isinstance(s, (pd.Index, pd.Series)):
out = pd.api.types.is_integer_dtype(s.dtype)
else:
dtype = s.head(1).to_numpy().dtype.type
return dtype

# %% ../nbs/validation.ipynb 8
def _is_int_dtype(dtype: type) -> bool:
return np.issubdtype(dtype, np.integer)
try:
out = s.dtype.is_integer()
except AttributeError:
out = s.is_integer()
return out


def _is_dt_dtype(dtype: type) -> bool:
return np.issubdtype(dtype, np.datetime64)
def _is_dt_dtype(s: Union[pd.Index, Series]) -> bool:
if isinstance(s, (pd.Index, pd.Series)):
out = pd.api.types.is_datetime64_any_dtype(s.dtype)
else:
out = s.dtype in (pl.Date, pl.Datetime)
return out

# %% ../nbs/validation.ipynb 9
def _is_dt_or_int(s: Series) -> bool:
dtype = _get_np_dtype(s)
return _is_dt_dtype(dtype) or _is_int_dtype(dtype)
return _is_dt_dtype(s) or _is_int_dtype(s)

# %% ../nbs/validation.ipynb 10
def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -109,7 +110,7 @@ def validate_format(

# time col
if not _is_dt_or_int(df[time_col]):
times_dtype = df[time_col].head(1).to_numpy().dtype
times_dtype = df[time_col].dtype
raise ValueError(
f"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'."
)
Expand All @@ -119,7 +120,7 @@ def validate_format(
return None
target = df[target_col]
if isinstance(target, pd.Series):
is_numeric = np.issubdtype(target.dtype.type, np.number)
is_numeric = pd.api.types.is_numeric_dtype(target.dtype)
else:
try:
is_numeric = target.dtype.is_numeric()
Expand All @@ -135,13 +136,12 @@ def validate_freq(
times: Series,
freq: Union[str, int],
) -> None:
time_dtype = times.head(1).to_numpy().dtype
if _is_int_dtype(time_dtype) and not isinstance(freq, int):
if _is_int_dtype(times) and not isinstance(freq, int):
raise ValueError(
"Time column contains integers but the specified frequency is not an integer. "
"Please provide a valid integer, e.g. `freq=1`"
)
if _is_dt_dtype(time_dtype) and isinstance(freq, int):
if _is_dt_dtype(times) and isinstance(freq, int):
raise ValueError(
"Time column contains timestamps but the specified frequency is an integer. "
"Please provide a valid pandas or polars offset, e.g. `freq='D'` or `freq='1d'`."
Expand Down

0 comments on commit ebf72a9

Please sign in to comment.