diff --git a/.github/workflows/build-docs.yaml b/.github/workflows/build-docs.yaml index 38c820e8..5b61bf30 100644 --- a/.github/workflows/build-docs.yaml +++ b/.github/workflows/build-docs.yaml @@ -22,7 +22,7 @@ jobs: repository: Nixtla/docs ref: scripts path: docs-scripts - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: cache: "pip" python-version: "3.10" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d11b5f7e..645f5879 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: ${{ matrix.python-version }} @@ -50,7 +50,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: ${{ matrix.python-version }} @@ -72,7 +72,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/deploy-readme.yaml b/.github/workflows/deploy-readme.yaml index 8a1e3757..cfd0b749 100644 --- a/.github/workflows/deploy-readme.yaml +++ b/.github/workflows/deploy-readme.yaml @@ -25,7 +25,7 @@ jobs: persist-credentials: false - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: "3.10" diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 0e95cccf..a6d30a74 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -14,7 +14,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: '3.10' diff --git a/.github/workflows/models-performance.yaml b/.github/workflows/models-performance.yaml index 4a131ef2..9a41a7b5 100644 --- a/.github/workflows/models-performance.yaml +++ b/.github/workflows/models-performance.yaml @@ -26,7 +26,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: "3.10" diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 1fa437fc..98e2330c 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: '3.10' @@ -24,16 +24,5 @@ jobs: run: python -m build - name: Publish nixtla package - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0 + uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 # v1.10.0 - - name: Build nixtlats package - run: > - rm -rf build dist && - mv nixtla nixtlats && - sed -i 's/name="nixtla"/name="nixtlats"/g' setup.py && - find nixtlats -type f -name '*.py' -exec sed -i 's/from nixtla/from nixtlats/g' {} + && - echo -e 'import warnings\nwarnings.warn("This package is deprecated, please install nixtla instead.", category=FutureWarning)' >> nixtlats/__init__.py && - python -m build - - - name: Publish nixtlats package - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0 diff --git a/.github/workflows/test-publish.yml b/.github/workflows/test-publish.yml index 8339f9d2..7c6240a1 100644 --- a/.github/workflows/test-publish.yml +++ b/.github/workflows/test-publish.yml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0 with: python-version: '3.10' @@ -23,20 +23,7 @@ jobs: run: python -m build - name: Publish nixtla package - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0 + uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 # v1.10.0 with: repository-url: https://test.pypi.org/legacy/ - - name: Build nixtlats package - run: > - rm -rf build dist && - mv nixtla nixtlats && - sed -i 's/name="nixtla"/name="nixtlats"/g' setup.py && - find nixtlats -type f -name '*.py' -exec sed -i 's/from nixtla/from nixtlats/g' {} + && - echo -e 'import warnings\nwarnings.warn("This package is deprecated, please install nixtla instead.", category=FutureWarning)' >> nixtlats/__init__.py && - python -m build - - - name: Publish nixtlats package - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0 - with: - repository-url: https://test.pypi.org/legacy/ diff --git a/.gitignore b/.gitignore index 6c9486e5..6943307a 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ nbs/.last_checked longhorizon data *.rda +nbs/_extensions diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77f8571d..83a3ab3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,16 @@ repos: + - repo: local + hooks: + - id: nbdev_clean + name: Clean notebooks + entry: sh -c 'nbdev_clean && nbdev_clean --fname nbs/src --clear_all' + language: system + - repo: https://github.com/fastai/nbdev rev: 2.2.10 hooks: - - id: nbdev_clean - id: nbdev_export + - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.1 hooks: diff --git a/nbs/date_features.ipynb b/nbs/date_features.ipynb deleted file mode 100644 index 2b07251d..00000000 --- a/nbs/date_features.ipynb +++ /dev/null @@ -1,642 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "6d7d1bce-e447-4702-baf5-2bfb8d112635", - "metadata": {}, - "outputs": [], - "source": [ - "#| default_exp date_features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "713b6f07-67b5-4fcc-bbd8-5aa3107a4463", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide \n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "markdown", - "id": "bf0edfb1-a7b3-4d6e-acbd-f41758af1779", - "metadata": {}, - "source": [ - "# Date Features " - ] - }, - { - "cell_type": "markdown", - "id": "5980c8e1-0416-4e2b-a335-08c7a809916e", - "metadata": {}, - "source": [ - "Useful classes to generate date features and add them to `TimeGPT`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b22a584a-1591-46bf-860e-20f0a519e7c3", - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "from typing import Dict, List\n", - "\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5f2f081-c3ff-4ba8-a7d0-b1dc992f5e09", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide\n", - "from nbdev.showdoc import show_doc" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02b6979b-b97a-41eb-9bf1-8e068da1ebcb", - "metadata": {}, - "outputs": [], - "source": [ - "#| exporti\n", - "def _transform_dict_holidays(dict_holidays_dates):\n", - " dict_holidays = {}\n", - " for key, value in dict_holidays_dates.items():\n", - " if value not in dict_holidays:\n", - " dict_holidays[value] = []\n", - " dict_holidays[value].append(key)\n", - " return dict_holidays" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d442cd19-c0e3-4aac-a720-07203fef41c6", - "metadata": {}, - "outputs": [], - "source": [ - "#| exporti\n", - "def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):\n", - " years = dates.year.unique().tolist()\n", - " total_holidays = dict()\n", - " for cat in categories:\n", - " if cat not in supported_categories:\n", - " raise Exception(f'Holidays for {cat} not available, please remove it.')\n", - " dict_holidays = _transform_dict_holidays(holiday_extractor(cat, years=years))\n", - " for key, val in dict_holidays.items():\n", - " total_holidays[f'{cat}_{key}'] = [int(ds.date() in val) for ds in dates]\n", - " return pd.DataFrame(total_holidays, index=dates)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8bedda61-b017-430d-85ce-8a369b8e1bb4", - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "class CountryHolidays:\n", - " \"\"\"Given a list of countries, returns a dataframe with holidays for each country.\"\"\"\n", - " \n", - " def __init__(self, countries: List[str]):\n", - " self.countries = countries\n", - " \n", - " def __call__(self, dates: pd.DatetimeIndex):\n", - " try:\n", - " from holidays.utils import country_holidays\n", - " from holidays.utils import list_supported_countries\n", - " except ModuleNotFoundError:\n", - " raise Exception(\n", - " 'You have to install additional libraries to use holidays, '\n", - " 'please install them using `pip install \"nixtla[date_extras]\"`'\n", - " )\n", - " return _get_holidays_df(dates, self.countries, country_holidays, list_supported_countries())\n", - " \n", - " def __name__(self):\n", - " return 'CountryHolidays'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb22ad7a-fc3b-4ce1-aa2e-16f3aaf2c66d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "#### CountryHolidays\n", - "\n", - "> CountryHolidays (countries:List[str])\n", - "\n", - "Given a list of countries, returns a dataframe with holidays for each country." - ], - "text/plain": [ - "---\n", - "\n", - "#### CountryHolidays\n", - "\n", - "> CountryHolidays (countries:List[str])\n", - "\n", - "Given a list of countries, returns a dataframe with holidays for each country." - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_doc(CountryHolidays, title_level=4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bcc7c4d1-663e-4835-8e95-5ca756cd35ce", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
US_New Year's DayUS_Martin Luther King Jr. DayUS_Washington's BirthdayUS_Memorial DayUS_Independence DayUS_Labor DayUS_Columbus DayUS_Veterans DayUS_Veterans Day (Observed)US_Thanksgiving...MX_Día de la Independencia [Independence Day]MX_Día de la Independencia [Independence Day] (Observed)MX_Día de la Revolución [Revolution Day] (Observed)MX_Día de la Revolución [Revolution Day]MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government]MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] (Observed)MX_Navidad [Christmas]MX_Día de la Constitución [Constitution Day]MX_Año Nuevo [New Year's Day] (Observed)MX_Día del Trabajo [Labour Day] (Observed)
2018-09-030000010000...0000000000
2018-09-040000000000...0000000000
2018-09-050000000000...0000000000
2018-09-060000000000...0000000000
2018-09-070000000000...0000000000
\n", - "

5 rows × 31 columns

\n", - "
" - ], - "text/plain": [ - " US_New Year's Day US_Martin Luther King Jr. Day \\\n", - "2018-09-03 0 0 \n", - "2018-09-04 0 0 \n", - "2018-09-05 0 0 \n", - "2018-09-06 0 0 \n", - "2018-09-07 0 0 \n", - "\n", - " US_Washington's Birthday US_Memorial Day US_Independence Day \\\n", - "2018-09-03 0 0 0 \n", - "2018-09-04 0 0 0 \n", - "2018-09-05 0 0 0 \n", - "2018-09-06 0 0 0 \n", - "2018-09-07 0 0 0 \n", - "\n", - " US_Labor Day US_Columbus Day US_Veterans Day \\\n", - "2018-09-03 1 0 0 \n", - "2018-09-04 0 0 0 \n", - "2018-09-05 0 0 0 \n", - "2018-09-06 0 0 0 \n", - "2018-09-07 0 0 0 \n", - "\n", - " US_Veterans Day (Observed) US_Thanksgiving ... \\\n", - "2018-09-03 0 0 ... \n", - "2018-09-04 0 0 ... \n", - "2018-09-05 0 0 ... \n", - "2018-09-06 0 0 ... \n", - "2018-09-07 0 0 ... \n", - "\n", - " MX_Día de la Independencia [Independence Day] \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Día de la Independencia [Independence Day] (Observed) \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Día de la Revolución [Revolution Day] (Observed) \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Día de la Revolución [Revolution Day] \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] (Observed) \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Navidad [Christmas] \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Día de la Constitución [Constitution Day] \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Año Nuevo [New Year's Day] (Observed) \\\n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - " MX_Día del Trabajo [Labour Day] (Observed) \n", - "2018-09-03 0 \n", - "2018-09-04 0 \n", - "2018-09-05 0 \n", - "2018-09-06 0 \n", - "2018-09-07 0 \n", - "\n", - "[5 rows x 31 columns]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c_holidays = CountryHolidays(countries=['US', 'MX'])\n", - "periods = 365 * 5\n", - "dates = pd.date_range(end='2023-09-01', periods=periods)\n", - "holidays_df = c_holidays(dates)\n", - "holidays_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c548b898-2f7e-4b11-8426-dd50ed68931c", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide\n", - "# test shape of holidays_df\n", - "assert len(holidays_df) == periods" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4aee4fca-68f0-4817-80c0-697b6d04120a", - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "class SpecialDates:\n", - " \"\"\"Given a dictionary of categories and dates, returns a dataframe with the special dates.\"\"\"\n", - " \n", - " def __init__(self, special_dates: Dict[str, List[str]]):\n", - " self.special_dates = special_dates\n", - " \n", - " def __call__(self, dates: pd.DatetimeIndex):\n", - " total_special_dates = dict()\n", - " for key, val in self.special_dates.items():\n", - " date_vals = [ds.date() for ds in pd.to_datetime(val)]\n", - " total_special_dates[key] = [int(ds.date() in date_vals) for ds in dates]\n", - " return pd.DataFrame(total_special_dates, index=dates)\n", - " \n", - " def __name__(self):\n", - " return 'SpecialDates'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ef5f615-2d57-42bc-ace2-e3e77d82656d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "#### SpecialDates\n", - "\n", - "> SpecialDates (special_dates:Dict[str,List[str]])\n", - "\n", - "Given a dictionary of categories and dates, returns a dataframe with the special dates." - ], - "text/plain": [ - "---\n", - "\n", - "#### SpecialDates\n", - "\n", - "> SpecialDates (special_dates:Dict[str,List[str]])\n", - "\n", - "Given a dictionary of categories and dates, returns a dataframe with the special dates." - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_doc(SpecialDates, title_level=4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f967cda6-4958-4e03-93c9-d881ec3c2548", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Important DatesVery Important Dates
2018-09-0300
2018-09-0400
2018-09-0500
2018-09-0600
2018-09-0700
\n", - "
" - ], - "text/plain": [ - " Important Dates Very Important Dates\n", - "2018-09-03 0 0\n", - "2018-09-04 0 0\n", - "2018-09-05 0 0\n", - "2018-09-06 0 0\n", - "2018-09-07 0 0" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "special_dates = SpecialDates(\n", - " special_dates={\n", - " 'Important Dates': ['2021-02-26', '2020-02-26'],\n", - " 'Very Important Dates': ['2021-01-26', '2020-01-26', '2019-01-26']\n", - " }\n", - ")\n", - "periods = 365 * 5\n", - "dates = pd.date_range(end='2023-09-01', periods=periods)\n", - "holidays_df = special_dates(dates)\n", - "holidays_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "de9f4656-d808-490f-8b1a-621396c3e7ed", - "metadata": {}, - "outputs": [], - "source": [ - "#| hide\n", - "# test shape of holidays_df\n", - "assert len(holidays_df) == periods\n", - "assert holidays_df.sum().sum() == 5" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "python3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/docs/getting-started/21_polars_quickstart.ipynb b/nbs/docs/getting-started/21_polars_quickstart.ipynb index ab4cf865..6b695762 100644 --- a/nbs/docs/getting-started/21_polars_quickstart.ipynb +++ b/nbs/docs/getting-started/21_polars_quickstart.ipynb @@ -51,7 +51,7 @@ "id": "6ecd9d32-9178-4768-bffa-d70c93c98311", "metadata": {}, "source": [ - "# TimeGPT Quickstart\n", + "# TimeGPT Quickstart (Polars)\n", "\n", "> TimeGPT is a production ready, generative pretrained transformer for time series. It's capable of accurately predicting various domains such as retail, electricity, finance, and IoT with just a few lines of code 🚀." ] diff --git a/nbs/src/date_features.ipynb b/nbs/src/date_features.ipynb new file mode 100644 index 00000000..be836ebb --- /dev/null +++ b/nbs/src/date_features.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6d7d1bce-e447-4702-baf5-2bfb8d112635", + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp date_features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "713b6f07-67b5-4fcc-bbd8-5aa3107a4463", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide \n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "bf0edfb1-a7b3-4d6e-acbd-f41758af1779", + "metadata": {}, + "source": [ + "# Date Features " + ] + }, + { + "cell_type": "markdown", + "id": "5980c8e1-0416-4e2b-a335-08c7a809916e", + "metadata": {}, + "source": [ + "Useful classes to generate date features and add them to `TimeGPT`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b22a584a-1591-46bf-860e-20f0a519e7c3", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from typing import Dict, List\n", + "\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5f2f081-c3ff-4ba8-a7d0-b1dc992f5e09", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from nbdev.showdoc import show_doc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02b6979b-b97a-41eb-9bf1-8e068da1ebcb", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "def _transform_dict_holidays(dict_holidays_dates):\n", + " dict_holidays = {}\n", + " for key, value in dict_holidays_dates.items():\n", + " if value not in dict_holidays:\n", + " dict_holidays[value] = []\n", + " dict_holidays[value].append(key)\n", + " return dict_holidays" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d442cd19-c0e3-4aac-a720-07203fef41c6", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):\n", + " years = dates.year.unique().tolist()\n", + " total_holidays = dict()\n", + " for cat in categories:\n", + " if cat not in supported_categories:\n", + " raise Exception(f'Holidays for {cat} not available, please remove it.')\n", + " dict_holidays = _transform_dict_holidays(holiday_extractor(cat, years=years))\n", + " for key, val in dict_holidays.items():\n", + " total_holidays[f'{cat}_{key}'] = [int(ds.date() in val) for ds in dates]\n", + " return pd.DataFrame(total_holidays, index=dates)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bedda61-b017-430d-85ce-8a369b8e1bb4", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class CountryHolidays:\n", + " \"\"\"Given a list of countries, returns a dataframe with holidays for each country.\"\"\"\n", + " \n", + " def __init__(self, countries: List[str]):\n", + " self.countries = countries\n", + " \n", + " def __call__(self, dates: pd.DatetimeIndex):\n", + " try:\n", + " from holidays.utils import country_holidays\n", + " from holidays.utils import list_supported_countries\n", + " except ModuleNotFoundError:\n", + " raise Exception(\n", + " 'You have to install additional libraries to use holidays, '\n", + " 'please install them using `pip install \"nixtla[date_extras]\"`'\n", + " )\n", + " return _get_holidays_df(dates, self.countries, country_holidays, list_supported_countries())\n", + " \n", + " def __name__(self):\n", + " return 'CountryHolidays'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb22ad7a-fc3b-4ce1-aa2e-16f3aaf2c66d", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(CountryHolidays, title_level=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcc7c4d1-663e-4835-8e95-5ca756cd35ce", + "metadata": {}, + "outputs": [], + "source": [ + "c_holidays = CountryHolidays(countries=['US', 'MX'])\n", + "periods = 365 * 5\n", + "dates = pd.date_range(end='2023-09-01', periods=periods)\n", + "holidays_df = c_holidays(dates)\n", + "holidays_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c548b898-2f7e-4b11-8426-dd50ed68931c", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test shape of holidays_df\n", + "assert len(holidays_df) == periods" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4aee4fca-68f0-4817-80c0-697b6d04120a", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class SpecialDates:\n", + " \"\"\"Given a dictionary of categories and dates, returns a dataframe with the special dates.\"\"\"\n", + " \n", + " def __init__(self, special_dates: Dict[str, List[str]]):\n", + " self.special_dates = special_dates\n", + " \n", + " def __call__(self, dates: pd.DatetimeIndex):\n", + " total_special_dates = dict()\n", + " for key, val in self.special_dates.items():\n", + " date_vals = [ds.date() for ds in pd.to_datetime(val)]\n", + " total_special_dates[key] = [int(ds.date() in date_vals) for ds in dates]\n", + " return pd.DataFrame(total_special_dates, index=dates)\n", + " \n", + " def __name__(self):\n", + " return 'SpecialDates'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ef5f615-2d57-42bc-ace2-e3e77d82656d", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(SpecialDates, title_level=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f967cda6-4958-4e03-93c9-d881ec3c2548", + "metadata": {}, + "outputs": [], + "source": [ + "special_dates = SpecialDates(\n", + " special_dates={\n", + " 'Important Dates': ['2021-02-26', '2020-02-26'],\n", + " 'Very Important Dates': ['2021-01-26', '2020-01-26', '2019-01-26']\n", + " }\n", + ")\n", + "periods = 365 * 5\n", + "dates = pd.date_range(end='2023-09-01', periods=periods)\n", + "holidays_df = special_dates(dates)\n", + "holidays_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de9f4656-d808-490f-8b1a-621396c3e7ed", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test shape of holidays_df\n", + "assert len(holidays_df) == periods\n", + "assert holidays_df.sum().sum() == 5" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/nixtla_client.ipynb b/nbs/src/nixtla_client.ipynb similarity index 99% rename from nbs/nixtla_client.ipynb rename to nbs/src/nixtla_client.ipynb index 3d4dfc9e..12da0ccf 100644 --- a/nbs/nixtla_client.ipynb +++ b/nbs/src/nixtla_client.ipynb @@ -670,7 +670,16 @@ " )\n", " self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}\n", " if 'ai.azure' in base_url:\n", - " self.supported_models = ['azureai', 'timegpt-1-long-horizon']\n", + " from packaging.version import Version\n", + "\n", + " import nixtla\n", + "\n", + " if Version(nixtla.__version__) > Version(\"0.5.2\"):\n", + " raise NotImplementedError(\n", + " \"This version doesn't support Azure endpoints, please install \"\n", + " \"an earlier version with: `pip install 'nixtla<=0.5.2'`\"\n", + " )\n", + " self.supported_models = ['azureai']\n", " else:\n", " self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n", "\n", @@ -694,7 +703,13 @@ " ensure_contiguous_arrays(payload)\n", " content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)\n", " resp = client.post(url=endpoint, content=content)\n", - " resp_body = orjson.loads(resp.content)\n", + " try:\n", + " resp_body = orjson.loads(resp.content)\n", + " except orjson.JSONDecodeError:\n", + " raise ApiError(\n", + " status_code=resp.status_code,\n", + " body=f'Could not parse JSON: {resp.content}',\n", + " )\n", " if resp.status_code != 200:\n", " raise ApiError(status_code=resp.status_code, body=resp_body)\n", " if 'data' in resp_body:\n", diff --git a/nbs/utils.ipynb b/nbs/src/utils.ipynb similarity index 100% rename from nbs/utils.ipynb rename to nbs/src/utils.ipynb diff --git a/nixtla/_modidx.py b/nixtla/_modidx.py index 91ed8adf..18d5ef19 100644 --- a/nixtla/_modidx.py +++ b/nixtla/_modidx.py @@ -15,98 +15,101 @@ 'nixtla.core.pydantic_utilities': {}, 'nixtla.core.remove_none_from_dict': {}, 'nixtla.core.request_options': {}, - 'nixtla.date_features': { 'nixtla.date_features.CountryHolidays': ( 'date_features.html#countryholidays', + 'nixtla.date_features': { 'nixtla.date_features.CountryHolidays': ( 'src/date_features.html#countryholidays', 'nixtla/date_features.py'), - 'nixtla.date_features.CountryHolidays.__call__': ( 'date_features.html#countryholidays.__call__', + 'nixtla.date_features.CountryHolidays.__call__': ( 'src/date_features.html#countryholidays.__call__', 'nixtla/date_features.py'), - 'nixtla.date_features.CountryHolidays.__init__': ( 'date_features.html#countryholidays.__init__', + 'nixtla.date_features.CountryHolidays.__init__': ( 'src/date_features.html#countryholidays.__init__', 'nixtla/date_features.py'), - 'nixtla.date_features.CountryHolidays.__name__': ( 'date_features.html#countryholidays.__name__', + 'nixtla.date_features.CountryHolidays.__name__': ( 'src/date_features.html#countryholidays.__name__', 'nixtla/date_features.py'), - 'nixtla.date_features.SpecialDates': ('date_features.html#specialdates', 'nixtla/date_features.py'), - 'nixtla.date_features.SpecialDates.__call__': ( 'date_features.html#specialdates.__call__', + 'nixtla.date_features.SpecialDates': ( 'src/date_features.html#specialdates', + 'nixtla/date_features.py'), + 'nixtla.date_features.SpecialDates.__call__': ( 'src/date_features.html#specialdates.__call__', 'nixtla/date_features.py'), - 'nixtla.date_features.SpecialDates.__init__': ( 'date_features.html#specialdates.__init__', + 'nixtla.date_features.SpecialDates.__init__': ( 'src/date_features.html#specialdates.__init__', 'nixtla/date_features.py'), - 'nixtla.date_features.SpecialDates.__name__': ( 'date_features.html#specialdates.__name__', + 'nixtla.date_features.SpecialDates.__name__': ( 'src/date_features.html#specialdates.__name__', 'nixtla/date_features.py'), - 'nixtla.date_features._get_holidays_df': ( 'date_features.html#_get_holidays_df', + 'nixtla.date_features._get_holidays_df': ( 'src/date_features.html#_get_holidays_df', 'nixtla/date_features.py'), - 'nixtla.date_features._transform_dict_holidays': ( 'date_features.html#_transform_dict_holidays', + 'nixtla.date_features._transform_dict_holidays': ( 'src/date_features.html#_transform_dict_holidays', 'nixtla/date_features.py')}, 'nixtla.errors.unprocessable_entity_error': {}, - 'nixtla.nixtla_client': { 'nixtla.nixtla_client.NixtlaClient': ('nixtla_client.html#nixtlaclient', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.__init__': ( 'nixtla_client.html#nixtlaclient.__init__', + 'nixtla.nixtla_client': { 'nixtla.nixtla_client.NixtlaClient': ( 'src/nixtla_client.html#nixtlaclient', + 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client.NixtlaClient.__init__': ( 'src/nixtla_client.html#nixtlaclient.__init__', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._distributed_cross_validation': ( 'nixtla_client.html#nixtlaclient._distributed_cross_validation', + 'nixtla.nixtla_client.NixtlaClient._distributed_cross_validation': ( 'src/nixtla_client.html#nixtlaclient._distributed_cross_validation', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._distributed_detect_anomalies': ( 'nixtla_client.html#nixtlaclient._distributed_detect_anomalies', + 'nixtla.nixtla_client.NixtlaClient._distributed_detect_anomalies': ( 'src/nixtla_client.html#nixtlaclient._distributed_detect_anomalies', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._distributed_forecast': ( 'nixtla_client.html#nixtlaclient._distributed_forecast', + 'nixtla.nixtla_client.NixtlaClient._distributed_forecast': ( 'src/nixtla_client.html#nixtlaclient._distributed_forecast', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._get_model_params': ( 'nixtla_client.html#nixtlaclient._get_model_params', + 'nixtla.nixtla_client.NixtlaClient._get_model_params': ( 'src/nixtla_client.html#nixtlaclient._get_model_params', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._make_partitioned_requests': ( 'nixtla_client.html#nixtlaclient._make_partitioned_requests', + 'nixtla.nixtla_client.NixtlaClient._make_partitioned_requests': ( 'src/nixtla_client.html#nixtlaclient._make_partitioned_requests', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._make_request': ( 'nixtla_client.html#nixtlaclient._make_request', + 'nixtla.nixtla_client.NixtlaClient._make_request': ( 'src/nixtla_client.html#nixtlaclient._make_request', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._make_request_with_retries': ( 'nixtla_client.html#nixtlaclient._make_request_with_retries', + 'nixtla.nixtla_client.NixtlaClient._make_request_with_retries': ( 'src/nixtla_client.html#nixtlaclient._make_request_with_retries', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._maybe_assign_feature_contributions': ( 'nixtla_client.html#nixtlaclient._maybe_assign_feature_contributions', + 'nixtla.nixtla_client.NixtlaClient._maybe_assign_feature_contributions': ( 'src/nixtla_client.html#nixtlaclient._maybe_assign_feature_contributions', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._maybe_assign_weights': ( 'nixtla_client.html#nixtlaclient._maybe_assign_weights', + 'nixtla.nixtla_client.NixtlaClient._maybe_assign_weights': ( 'src/nixtla_client.html#nixtlaclient._maybe_assign_weights', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient._run_validations': ( 'nixtla_client.html#nixtlaclient._run_validations', + 'nixtla.nixtla_client.NixtlaClient._run_validations': ( 'src/nixtla_client.html#nixtlaclient._run_validations', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.cross_validation': ( 'nixtla_client.html#nixtlaclient.cross_validation', + 'nixtla.nixtla_client.NixtlaClient.cross_validation': ( 'src/nixtla_client.html#nixtlaclient.cross_validation', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.detect_anomalies': ( 'nixtla_client.html#nixtlaclient.detect_anomalies', + 'nixtla.nixtla_client.NixtlaClient.detect_anomalies': ( 'src/nixtla_client.html#nixtlaclient.detect_anomalies', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.forecast': ( 'nixtla_client.html#nixtlaclient.forecast', + 'nixtla.nixtla_client.NixtlaClient.forecast': ( 'src/nixtla_client.html#nixtlaclient.forecast', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.plot': ( 'nixtla_client.html#nixtlaclient.plot', + 'nixtla.nixtla_client.NixtlaClient.plot': ( 'src/nixtla_client.html#nixtlaclient.plot', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client.NixtlaClient.validate_api_key': ( 'nixtla_client.html#nixtlaclient.validate_api_key', + 'nixtla.nixtla_client.NixtlaClient.validate_api_key': ( 'src/nixtla_client.html#nixtlaclient.validate_api_key', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._array_tails': ('nixtla_client.html#_array_tails', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._cross_validation_wrapper': ( 'nixtla_client.html#_cross_validation_wrapper', + 'nixtla.nixtla_client._array_tails': ( 'src/nixtla_client.html#_array_tails', + 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client._cross_validation_wrapper': ( 'src/nixtla_client.html#_cross_validation_wrapper', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._detect_anomalies_wrapper': ( 'nixtla_client.html#_detect_anomalies_wrapper', + 'nixtla.nixtla_client._detect_anomalies_wrapper': ( 'src/nixtla_client.html#_detect_anomalies_wrapper', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._distributed_setup': ( 'nixtla_client.html#_distributed_setup', + 'nixtla.nixtla_client._distributed_setup': ( 'src/nixtla_client.html#_distributed_setup', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._forecast_payload_to_in_sample': ( 'nixtla_client.html#_forecast_payload_to_in_sample', + 'nixtla.nixtla_client._forecast_payload_to_in_sample': ( 'src/nixtla_client.html#_forecast_payload_to_in_sample', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._forecast_wrapper': ( 'nixtla_client.html#_forecast_wrapper', + 'nixtla.nixtla_client._forecast_wrapper': ( 'src/nixtla_client.html#_forecast_wrapper', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._get_schema': ('nixtla_client.html#_get_schema', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._maybe_add_date_features': ( 'nixtla_client.html#_maybe_add_date_features', + 'nixtla.nixtla_client._get_schema': ('src/nixtla_client.html#_get_schema', 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client._maybe_add_date_features': ( 'src/nixtla_client.html#_maybe_add_date_features', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._maybe_add_intervals': ( 'nixtla_client.html#_maybe_add_intervals', + 'nixtla.nixtla_client._maybe_add_intervals': ( 'src/nixtla_client.html#_maybe_add_intervals', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._maybe_convert_level_to_quantiles': ( 'nixtla_client.html#_maybe_convert_level_to_quantiles', + 'nixtla.nixtla_client._maybe_convert_level_to_quantiles': ( 'src/nixtla_client.html#_maybe_convert_level_to_quantiles', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._maybe_drop_id': ( 'nixtla_client.html#_maybe_drop_id', + 'nixtla.nixtla_client._maybe_drop_id': ( 'src/nixtla_client.html#_maybe_drop_id', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._maybe_infer_freq': ( 'nixtla_client.html#_maybe_infer_freq', + 'nixtla.nixtla_client._maybe_infer_freq': ( 'src/nixtla_client.html#_maybe_infer_freq', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._parse_in_sample_output': ( 'nixtla_client.html#_parse_in_sample_output', + 'nixtla.nixtla_client._parse_in_sample_output': ( 'src/nixtla_client.html#_parse_in_sample_output', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._partition_series': ( 'nixtla_client.html#_partition_series', + 'nixtla.nixtla_client._partition_series': ( 'src/nixtla_client.html#_partition_series', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._prepare_level_and_quantiles': ( 'nixtla_client.html#_prepare_level_and_quantiles', + 'nixtla.nixtla_client._prepare_level_and_quantiles': ( 'src/nixtla_client.html#_prepare_level_and_quantiles', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._preprocess': ('nixtla_client.html#_preprocess', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._restrict_input_samples': ( 'nixtla_client.html#_restrict_input_samples', + 'nixtla.nixtla_client._preprocess': ('src/nixtla_client.html#_preprocess', 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client._restrict_input_samples': ( 'src/nixtla_client.html#_restrict_input_samples', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._retry_strategy': ( 'nixtla_client.html#_retry_strategy', + 'nixtla.nixtla_client._retry_strategy': ( 'src/nixtla_client.html#_retry_strategy', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._standardize_freq': ( 'nixtla_client.html#_standardize_freq', + 'nixtla.nixtla_client._standardize_freq': ( 'src/nixtla_client.html#_standardize_freq', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._tail': ('nixtla_client.html#_tail', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._validate_exog': ( 'nixtla_client.html#_validate_exog', + 'nixtla.nixtla_client._tail': ('src/nixtla_client.html#_tail', 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client._validate_exog': ( 'src/nixtla_client.html#_validate_exog', 'nixtla/nixtla_client.py'), - 'nixtla.nixtla_client._validate_input_size': ( 'nixtla_client.html#_validate_input_size', + 'nixtla.nixtla_client._validate_input_size': ( 'src/nixtla_client.html#_validate_input_size', 'nixtla/nixtla_client.py')}, 'nixtla.types.anomaly_detection_output': {}, 'nixtla.types.cross_validation_input_finetune_loss': {}, @@ -142,5 +145,5 @@ 'nixtla.types.single_series_insample_forecast_level_item': {}, 'nixtla.types.validation_error': {}, 'nixtla.types.validation_error_loc_item': {}, - 'nixtla.utils': { 'nixtla.utils.colab_badge': ('utils.html#colab_badge', 'nixtla/utils.py'), - 'nixtla.utils.in_colab': ('utils.html#in_colab', 'nixtla/utils.py')}}} + 'nixtla.utils': { 'nixtla.utils.colab_badge': ('src/utils.html#colab_badge', 'nixtla/utils.py'), + 'nixtla.utils.in_colab': ('src/utils.html#in_colab', 'nixtla/utils.py')}}} diff --git a/nixtla/date_features.py b/nixtla/date_features.py index d3e4025c..7e8f7ef7 100644 --- a/nixtla/date_features.py +++ b/nixtla/date_features.py @@ -1,14 +1,14 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/date_features.ipynb. +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/date_features.ipynb. # %% auto 0 __all__ = ['CountryHolidays', 'SpecialDates'] -# %% ../nbs/date_features.ipynb 4 +# %% ../nbs/src/date_features.ipynb 4 from typing import Dict, List import pandas as pd -# %% ../nbs/date_features.ipynb 6 +# %% ../nbs/src/date_features.ipynb 6 def _transform_dict_holidays(dict_holidays_dates): dict_holidays = {} for key, value in dict_holidays_dates.items(): @@ -17,7 +17,7 @@ def _transform_dict_holidays(dict_holidays_dates): dict_holidays[value].append(key) return dict_holidays -# %% ../nbs/date_features.ipynb 7 +# %% ../nbs/src/date_features.ipynb 7 def _get_holidays_df(dates, categories, holiday_extractor, supported_categories): years = dates.year.unique().tolist() total_holidays = dict() @@ -29,7 +29,7 @@ def _get_holidays_df(dates, categories, holiday_extractor, supported_categories) total_holidays[f"{cat}_{key}"] = [int(ds.date() in val) for ds in dates] return pd.DataFrame(total_holidays, index=dates) -# %% ../nbs/date_features.ipynb 8 +# %% ../nbs/src/date_features.ipynb 8 class CountryHolidays: """Given a list of countries, returns a dataframe with holidays for each country.""" @@ -52,7 +52,7 @@ def __call__(self, dates: pd.DatetimeIndex): def __name__(self): return "CountryHolidays" -# %% ../nbs/date_features.ipynb 12 +# %% ../nbs/src/date_features.ipynb 12 class SpecialDates: """Given a dictionary of categories and dates, returns a dataframe with the special dates.""" diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index bf753fec..14317ff7 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -1,9 +1,9 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/nixtla_client.ipynb. +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/nixtla_client.ipynb. # %% auto 0 __all__ = ['NixtlaClient'] -# %% ../nbs/nixtla_client.ipynb 3 +# %% ../nbs/src/nixtla_client.ipynb 3 import logging import math import os @@ -78,7 +78,7 @@ from .core.api_error import ApiError -# %% ../nbs/nixtla_client.ipynb 4 +# %% ../nbs/src/nixtla_client.ipynb 4 AnyDFType = TypeVar( "AnyDFType", "DaskDataFrame", @@ -97,7 +97,7 @@ logging.getLogger("httpx").setLevel(logging.ERROR) logger = logging.getLogger(__name__) -# %% ../nbs/nixtla_client.ipynb 7 +# %% ../nbs/src/nixtla_client.ipynb 7 _Loss = Literal["default", "mae", "mse", "rmse", "mape", "smape"] _Model = Literal["azureai", "timegpt-1", "timegpt-1-long-horizon"] @@ -540,7 +540,7 @@ def _restrict_input_samples(level, input_size, model_horizon, h) -> int: new_input_size = input_size return new_input_size -# %% ../nbs/nixtla_client.ipynb 8 +# %% ../nbs/src/nixtla_client.ipynb 8 class NixtlaClient: def __init__( @@ -601,7 +601,16 @@ def __init__( ) self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {} if "ai.azure" in base_url: - self.supported_models = ["azureai", "timegpt-1-long-horizon"] + from packaging.version import Version + + import nixtla + + if Version(nixtla.__version__) > Version("0.5.2"): + raise NotImplementedError( + "This version doesn't support Azure endpoints, please install " + "an earlier version with: `pip install 'nixtla<=0.5.2'`" + ) + self.supported_models = ["azureai"] else: self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"] @@ -629,7 +638,13 @@ def ensure_contiguous_arrays(d: Dict[str, Any]) -> None: ensure_contiguous_arrays(payload) content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY) resp = client.post(url=endpoint, content=content) - resp_body = orjson.loads(resp.content) + try: + resp_body = orjson.loads(resp.content) + except orjson.JSONDecodeError: + raise ApiError( + status_code=resp.status_code, + body=f"Could not parse JSON: {resp.content}", + ) if resp.status_code != 200: raise ApiError(status_code=resp.status_code, body=resp_body) if "data" in resp_body: @@ -1574,7 +1589,7 @@ def plot( ax=ax, ) -# %% ../nbs/nixtla_client.ipynb 52 +# %% ../nbs/src/nixtla_client.ipynb 52 def _forecast_wrapper( df: pd.DataFrame, client: NixtlaClient, diff --git a/nixtla/utils.py b/nixtla/utils.py index 5ea29421..eb283302 100644 --- a/nixtla/utils.py +++ b/nixtla/utils.py @@ -1,9 +1,9 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb. +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/utils.ipynb. # %% auto 0 __all__ = [] -# %% ../nbs/utils.ipynb 3 +# %% ../nbs/src/utils.ipynb 3 def colab_badge(path: str): from IPython.display import Markdown, display @@ -13,9 +13,9 @@ def colab_badge(path: str): badge_md = f"[![]({badge_svg})]({nb_url})" display(Markdown(badge_md)) -# %% ../nbs/utils.ipynb 4 +# %% ../nbs/src/utils.ipynb 4 import sys -# %% ../nbs/utils.ipynb 5 +# %% ../nbs/src/utils.ipynb 5 def in_colab(): return "google.colab" in sys.modules diff --git a/setup.py b/setup.py index 688f985b..5f1c8168 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "fastcore", "httpx", "orjson", + "packaging", "pandas", "pydantic", "tenacity",