diff --git a/.github/workflows/build-docs.yaml b/.github/workflows/build-docs.yaml
index 38c820e8..5b61bf30 100644
--- a/.github/workflows/build-docs.yaml
+++ b/.github/workflows/build-docs.yaml
@@ -22,7 +22,7 @@ jobs:
repository: Nixtla/docs
ref: scripts
path: docs-scripts
- - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
cache: "pip"
python-version: "3.10"
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index d11b5f7e..645f5879 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -29,7 +29,7 @@ jobs:
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: ${{ matrix.python-version }}
@@ -50,7 +50,7 @@ jobs:
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: ${{ matrix.python-version }}
@@ -72,7 +72,7 @@ jobs:
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: ${{ matrix.python-version }}
diff --git a/.github/workflows/deploy-readme.yaml b/.github/workflows/deploy-readme.yaml
index 8a1e3757..cfd0b749 100644
--- a/.github/workflows/deploy-readme.yaml
+++ b/.github/workflows/deploy-readme.yaml
@@ -25,7 +25,7 @@ jobs:
persist-credentials: false
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: "3.10"
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index 0e95cccf..a6d30a74 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -14,7 +14,7 @@ jobs:
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: '3.10'
diff --git a/.github/workflows/models-performance.yaml b/.github/workflows/models-performance.yaml
index 4a131ef2..9a41a7b5 100644
--- a/.github/workflows/models-performance.yaml
+++ b/.github/workflows/models-performance.yaml
@@ -26,7 +26,7 @@ jobs:
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: "3.10"
diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
index 1fa437fc..98e2330c 100644
--- a/.github/workflows/python-publish.yml
+++ b/.github/workflows/python-publish.yml
@@ -13,7 +13,7 @@ jobs:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up Python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: '3.10'
@@ -24,16 +24,5 @@ jobs:
run: python -m build
- name: Publish nixtla package
- uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0
+ uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 # v1.10.0
- - name: Build nixtlats package
- run: >
- rm -rf build dist &&
- mv nixtla nixtlats &&
- sed -i 's/name="nixtla"/name="nixtlats"/g' setup.py &&
- find nixtlats -type f -name '*.py' -exec sed -i 's/from nixtla/from nixtlats/g' {} + &&
- echo -e 'import warnings\nwarnings.warn("This package is deprecated, please install nixtla instead.", category=FutureWarning)' >> nixtlats/__init__.py &&
- python -m build
-
- - name: Publish nixtlats package
- uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0
diff --git a/.github/workflows/test-publish.yml b/.github/workflows/test-publish.yml
index 8339f9d2..7c6240a1 100644
--- a/.github/workflows/test-publish.yml
+++ b/.github/workflows/test-publish.yml
@@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- name: Set up Python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # 5.1.1
+ uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
python-version: '3.10'
@@ -23,20 +23,7 @@ jobs:
run: python -m build
- name: Publish nixtla package
- uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0
+ uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 # v1.10.0
with:
repository-url: https://test.pypi.org/legacy/
- - name: Build nixtlats package
- run: >
- rm -rf build dist &&
- mv nixtla nixtlats &&
- sed -i 's/name="nixtla"/name="nixtlats"/g' setup.py &&
- find nixtlats -type f -name '*.py' -exec sed -i 's/from nixtla/from nixtlats/g' {} + &&
- echo -e 'import warnings\nwarnings.warn("This package is deprecated, please install nixtla instead.", category=FutureWarning)' >> nixtlats/__init__.py &&
- python -m build
-
- - name: Publish nixtlats package
- uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 # v1.9.0
- with:
- repository-url: https://test.pypi.org/legacy/
diff --git a/.gitignore b/.gitignore
index 6c9486e5..6943307a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,3 +28,4 @@ nbs/.last_checked
longhorizon
data
*.rda
+nbs/_extensions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 77f8571d..83a3ab3c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,9 +1,16 @@
repos:
+ - repo: local
+ hooks:
+ - id: nbdev_clean
+ name: Clean notebooks
+ entry: sh -c 'nbdev_clean && nbdev_clean --fname nbs/src --clear_all'
+ language: system
+
- repo: https://github.com/fastai/nbdev
rev: 2.2.10
hooks:
- - id: nbdev_clean
- id: nbdev_export
+
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
diff --git a/nbs/date_features.ipynb b/nbs/date_features.ipynb
deleted file mode 100644
index 2b07251d..00000000
--- a/nbs/date_features.ipynb
+++ /dev/null
@@ -1,642 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6d7d1bce-e447-4702-baf5-2bfb8d112635",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| default_exp date_features"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "713b6f07-67b5-4fcc-bbd8-5aa3107a4463",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| hide \n",
- "%load_ext autoreload\n",
- "%autoreload 2"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "bf0edfb1-a7b3-4d6e-acbd-f41758af1779",
- "metadata": {},
- "source": [
- "# Date Features "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5980c8e1-0416-4e2b-a335-08c7a809916e",
- "metadata": {},
- "source": [
- "Useful classes to generate date features and add them to `TimeGPT`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b22a584a-1591-46bf-860e-20f0a519e7c3",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| export\n",
- "from typing import Dict, List\n",
- "\n",
- "import pandas as pd"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d5f2f081-c3ff-4ba8-a7d0-b1dc992f5e09",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| hide\n",
- "from nbdev.showdoc import show_doc"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "02b6979b-b97a-41eb-9bf1-8e068da1ebcb",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| exporti\n",
- "def _transform_dict_holidays(dict_holidays_dates):\n",
- " dict_holidays = {}\n",
- " for key, value in dict_holidays_dates.items():\n",
- " if value not in dict_holidays:\n",
- " dict_holidays[value] = []\n",
- " dict_holidays[value].append(key)\n",
- " return dict_holidays"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d442cd19-c0e3-4aac-a720-07203fef41c6",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| exporti\n",
- "def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):\n",
- " years = dates.year.unique().tolist()\n",
- " total_holidays = dict()\n",
- " for cat in categories:\n",
- " if cat not in supported_categories:\n",
- " raise Exception(f'Holidays for {cat} not available, please remove it.')\n",
- " dict_holidays = _transform_dict_holidays(holiday_extractor(cat, years=years))\n",
- " for key, val in dict_holidays.items():\n",
- " total_holidays[f'{cat}_{key}'] = [int(ds.date() in val) for ds in dates]\n",
- " return pd.DataFrame(total_holidays, index=dates)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8bedda61-b017-430d-85ce-8a369b8e1bb4",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| export\n",
- "class CountryHolidays:\n",
- " \"\"\"Given a list of countries, returns a dataframe with holidays for each country.\"\"\"\n",
- " \n",
- " def __init__(self, countries: List[str]):\n",
- " self.countries = countries\n",
- " \n",
- " def __call__(self, dates: pd.DatetimeIndex):\n",
- " try:\n",
- " from holidays.utils import country_holidays\n",
- " from holidays.utils import list_supported_countries\n",
- " except ModuleNotFoundError:\n",
- " raise Exception(\n",
- " 'You have to install additional libraries to use holidays, '\n",
- " 'please install them using `pip install \"nixtla[date_extras]\"`'\n",
- " )\n",
- " return _get_holidays_df(dates, self.countries, country_holidays, list_supported_countries())\n",
- " \n",
- " def __name__(self):\n",
- " return 'CountryHolidays'"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bb22ad7a-fc3b-4ce1-aa2e-16f3aaf2c66d",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/markdown": [
- "---\n",
- "\n",
- "#### CountryHolidays\n",
- "\n",
- "> CountryHolidays (countries:List[str])\n",
- "\n",
- "Given a list of countries, returns a dataframe with holidays for each country."
- ],
- "text/plain": [
- "---\n",
- "\n",
- "#### CountryHolidays\n",
- "\n",
- "> CountryHolidays (countries:List[str])\n",
- "\n",
- "Given a list of countries, returns a dataframe with holidays for each country."
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "show_doc(CountryHolidays, title_level=4)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bcc7c4d1-663e-4835-8e95-5ca756cd35ce",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " US_New Year's Day | \n",
- " US_Martin Luther King Jr. Day | \n",
- " US_Washington's Birthday | \n",
- " US_Memorial Day | \n",
- " US_Independence Day | \n",
- " US_Labor Day | \n",
- " US_Columbus Day | \n",
- " US_Veterans Day | \n",
- " US_Veterans Day (Observed) | \n",
- " US_Thanksgiving | \n",
- " ... | \n",
- " MX_Día de la Independencia [Independence Day] | \n",
- " MX_Día de la Independencia [Independence Day] (Observed) | \n",
- " MX_Día de la Revolución [Revolution Day] (Observed) | \n",
- " MX_Día de la Revolución [Revolution Day] | \n",
- " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] | \n",
- " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] (Observed) | \n",
- " MX_Navidad [Christmas] | \n",
- " MX_Día de la Constitución [Constitution Day] | \n",
- " MX_Año Nuevo [New Year's Day] (Observed) | \n",
- " MX_Día del Trabajo [Labour Day] (Observed) | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 2018-09-03 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 1 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " ... | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-04 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " ... | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-05 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " ... | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-06 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " ... | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-07 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " ... | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- "
\n",
- "
5 rows × 31 columns
\n",
- "
"
- ],
- "text/plain": [
- " US_New Year's Day US_Martin Luther King Jr. Day \\\n",
- "2018-09-03 0 0 \n",
- "2018-09-04 0 0 \n",
- "2018-09-05 0 0 \n",
- "2018-09-06 0 0 \n",
- "2018-09-07 0 0 \n",
- "\n",
- " US_Washington's Birthday US_Memorial Day US_Independence Day \\\n",
- "2018-09-03 0 0 0 \n",
- "2018-09-04 0 0 0 \n",
- "2018-09-05 0 0 0 \n",
- "2018-09-06 0 0 0 \n",
- "2018-09-07 0 0 0 \n",
- "\n",
- " US_Labor Day US_Columbus Day US_Veterans Day \\\n",
- "2018-09-03 1 0 0 \n",
- "2018-09-04 0 0 0 \n",
- "2018-09-05 0 0 0 \n",
- "2018-09-06 0 0 0 \n",
- "2018-09-07 0 0 0 \n",
- "\n",
- " US_Veterans Day (Observed) US_Thanksgiving ... \\\n",
- "2018-09-03 0 0 ... \n",
- "2018-09-04 0 0 ... \n",
- "2018-09-05 0 0 ... \n",
- "2018-09-06 0 0 ... \n",
- "2018-09-07 0 0 ... \n",
- "\n",
- " MX_Día de la Independencia [Independence Day] \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Día de la Independencia [Independence Day] (Observed) \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Día de la Revolución [Revolution Day] (Observed) \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Día de la Revolución [Revolution Day] \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Transmisión del Poder Ejecutivo Federal [Change of Federal Government] (Observed) \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Navidad [Christmas] \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Día de la Constitución [Constitution Day] \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Año Nuevo [New Year's Day] (Observed) \\\n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- " MX_Día del Trabajo [Labour Day] (Observed) \n",
- "2018-09-03 0 \n",
- "2018-09-04 0 \n",
- "2018-09-05 0 \n",
- "2018-09-06 0 \n",
- "2018-09-07 0 \n",
- "\n",
- "[5 rows x 31 columns]"
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "c_holidays = CountryHolidays(countries=['US', 'MX'])\n",
- "periods = 365 * 5\n",
- "dates = pd.date_range(end='2023-09-01', periods=periods)\n",
- "holidays_df = c_holidays(dates)\n",
- "holidays_df.head()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c548b898-2f7e-4b11-8426-dd50ed68931c",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| hide\n",
- "# test shape of holidays_df\n",
- "assert len(holidays_df) == periods"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4aee4fca-68f0-4817-80c0-697b6d04120a",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| export\n",
- "class SpecialDates:\n",
- " \"\"\"Given a dictionary of categories and dates, returns a dataframe with the special dates.\"\"\"\n",
- " \n",
- " def __init__(self, special_dates: Dict[str, List[str]]):\n",
- " self.special_dates = special_dates\n",
- " \n",
- " def __call__(self, dates: pd.DatetimeIndex):\n",
- " total_special_dates = dict()\n",
- " for key, val in self.special_dates.items():\n",
- " date_vals = [ds.date() for ds in pd.to_datetime(val)]\n",
- " total_special_dates[key] = [int(ds.date() in date_vals) for ds in dates]\n",
- " return pd.DataFrame(total_special_dates, index=dates)\n",
- " \n",
- " def __name__(self):\n",
- " return 'SpecialDates'"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9ef5f615-2d57-42bc-ace2-e3e77d82656d",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/markdown": [
- "---\n",
- "\n",
- "#### SpecialDates\n",
- "\n",
- "> SpecialDates (special_dates:Dict[str,List[str]])\n",
- "\n",
- "Given a dictionary of categories and dates, returns a dataframe with the special dates."
- ],
- "text/plain": [
- "---\n",
- "\n",
- "#### SpecialDates\n",
- "\n",
- "> SpecialDates (special_dates:Dict[str,List[str]])\n",
- "\n",
- "Given a dictionary of categories and dates, returns a dataframe with the special dates."
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "show_doc(SpecialDates, title_level=4)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f967cda6-4958-4e03-93c9-d881ec3c2548",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Important Dates | \n",
- " Very Important Dates | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 2018-09-03 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-04 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-05 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-06 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 2018-09-07 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Important Dates Very Important Dates\n",
- "2018-09-03 0 0\n",
- "2018-09-04 0 0\n",
- "2018-09-05 0 0\n",
- "2018-09-06 0 0\n",
- "2018-09-07 0 0"
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "special_dates = SpecialDates(\n",
- " special_dates={\n",
- " 'Important Dates': ['2021-02-26', '2020-02-26'],\n",
- " 'Very Important Dates': ['2021-01-26', '2020-01-26', '2019-01-26']\n",
- " }\n",
- ")\n",
- "periods = 365 * 5\n",
- "dates = pd.date_range(end='2023-09-01', periods=periods)\n",
- "holidays_df = special_dates(dates)\n",
- "holidays_df.head()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "de9f4656-d808-490f-8b1a-621396c3e7ed",
- "metadata": {},
- "outputs": [],
- "source": [
- "#| hide\n",
- "# test shape of holidays_df\n",
- "assert len(holidays_df) == periods\n",
- "assert holidays_df.sum().sum() == 5"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "python3",
- "language": "python",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/nbs/docs/getting-started/21_polars_quickstart.ipynb b/nbs/docs/getting-started/21_polars_quickstart.ipynb
index ab4cf865..6b695762 100644
--- a/nbs/docs/getting-started/21_polars_quickstart.ipynb
+++ b/nbs/docs/getting-started/21_polars_quickstart.ipynb
@@ -51,7 +51,7 @@
"id": "6ecd9d32-9178-4768-bffa-d70c93c98311",
"metadata": {},
"source": [
- "# TimeGPT Quickstart\n",
+ "# TimeGPT Quickstart (Polars)\n",
"\n",
"> TimeGPT is a production ready, generative pretrained transformer for time series. It's capable of accurately predicting various domains such as retail, electricity, finance, and IoT with just a few lines of code 🚀."
]
diff --git a/nbs/src/date_features.ipynb b/nbs/src/date_features.ipynb
new file mode 100644
index 00000000..be836ebb
--- /dev/null
+++ b/nbs/src/date_features.ipynb
@@ -0,0 +1,244 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6d7d1bce-e447-4702-baf5-2bfb8d112635",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| default_exp date_features"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "713b6f07-67b5-4fcc-bbd8-5aa3107a4463",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide \n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf0edfb1-a7b3-4d6e-acbd-f41758af1779",
+ "metadata": {},
+ "source": [
+ "# Date Features "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5980c8e1-0416-4e2b-a335-08c7a809916e",
+ "metadata": {},
+ "source": [
+ "Useful classes to generate date features and add them to `TimeGPT`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b22a584a-1591-46bf-860e-20f0a519e7c3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "from typing import Dict, List\n",
+ "\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d5f2f081-c3ff-4ba8-a7d0-b1dc992f5e09",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "from nbdev.showdoc import show_doc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02b6979b-b97a-41eb-9bf1-8e068da1ebcb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| exporti\n",
+ "def _transform_dict_holidays(dict_holidays_dates):\n",
+ " dict_holidays = {}\n",
+ " for key, value in dict_holidays_dates.items():\n",
+ " if value not in dict_holidays:\n",
+ " dict_holidays[value] = []\n",
+ " dict_holidays[value].append(key)\n",
+ " return dict_holidays"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d442cd19-c0e3-4aac-a720-07203fef41c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| exporti\n",
+ "def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):\n",
+ " years = dates.year.unique().tolist()\n",
+ " total_holidays = dict()\n",
+ " for cat in categories:\n",
+ " if cat not in supported_categories:\n",
+ " raise Exception(f'Holidays for {cat} not available, please remove it.')\n",
+ " dict_holidays = _transform_dict_holidays(holiday_extractor(cat, years=years))\n",
+ " for key, val in dict_holidays.items():\n",
+ " total_holidays[f'{cat}_{key}'] = [int(ds.date() in val) for ds in dates]\n",
+ " return pd.DataFrame(total_holidays, index=dates)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8bedda61-b017-430d-85ce-8a369b8e1bb4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "class CountryHolidays:\n",
+ " \"\"\"Given a list of countries, returns a dataframe with holidays for each country.\"\"\"\n",
+ " \n",
+ " def __init__(self, countries: List[str]):\n",
+ " self.countries = countries\n",
+ " \n",
+ " def __call__(self, dates: pd.DatetimeIndex):\n",
+ " try:\n",
+ " from holidays.utils import country_holidays\n",
+ " from holidays.utils import list_supported_countries\n",
+ " except ModuleNotFoundError:\n",
+ " raise Exception(\n",
+ " 'You have to install additional libraries to use holidays, '\n",
+ " 'please install them using `pip install \"nixtla[date_extras]\"`'\n",
+ " )\n",
+ " return _get_holidays_df(dates, self.countries, country_holidays, list_supported_countries())\n",
+ " \n",
+ " def __name__(self):\n",
+ " return 'CountryHolidays'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bb22ad7a-fc3b-4ce1-aa2e-16f3aaf2c66d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(CountryHolidays, title_level=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bcc7c4d1-663e-4835-8e95-5ca756cd35ce",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "c_holidays = CountryHolidays(countries=['US', 'MX'])\n",
+ "periods = 365 * 5\n",
+ "dates = pd.date_range(end='2023-09-01', periods=periods)\n",
+ "holidays_df = c_holidays(dates)\n",
+ "holidays_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c548b898-2f7e-4b11-8426-dd50ed68931c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test shape of holidays_df\n",
+ "assert len(holidays_df) == periods"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4aee4fca-68f0-4817-80c0-697b6d04120a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "class SpecialDates:\n",
+ " \"\"\"Given a dictionary of categories and dates, returns a dataframe with the special dates.\"\"\"\n",
+ " \n",
+ " def __init__(self, special_dates: Dict[str, List[str]]):\n",
+ " self.special_dates = special_dates\n",
+ " \n",
+ " def __call__(self, dates: pd.DatetimeIndex):\n",
+ " total_special_dates = dict()\n",
+ " for key, val in self.special_dates.items():\n",
+ " date_vals = [ds.date() for ds in pd.to_datetime(val)]\n",
+ " total_special_dates[key] = [int(ds.date() in date_vals) for ds in dates]\n",
+ " return pd.DataFrame(total_special_dates, index=dates)\n",
+ " \n",
+ " def __name__(self):\n",
+ " return 'SpecialDates'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ef5f615-2d57-42bc-ace2-e3e77d82656d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(SpecialDates, title_level=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f967cda6-4958-4e03-93c9-d881ec3c2548",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "special_dates = SpecialDates(\n",
+ " special_dates={\n",
+ " 'Important Dates': ['2021-02-26', '2020-02-26'],\n",
+ " 'Very Important Dates': ['2021-01-26', '2020-01-26', '2019-01-26']\n",
+ " }\n",
+ ")\n",
+ "periods = 365 * 5\n",
+ "dates = pd.date_range(end='2023-09-01', periods=periods)\n",
+ "holidays_df = special_dates(dates)\n",
+ "holidays_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "de9f4656-d808-490f-8b1a-621396c3e7ed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test shape of holidays_df\n",
+ "assert len(holidays_df) == periods\n",
+ "assert holidays_df.sum().sum() == 5"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/nbs/nixtla_client.ipynb b/nbs/src/nixtla_client.ipynb
similarity index 99%
rename from nbs/nixtla_client.ipynb
rename to nbs/src/nixtla_client.ipynb
index 3d4dfc9e..12da0ccf 100644
--- a/nbs/nixtla_client.ipynb
+++ b/nbs/src/nixtla_client.ipynb
@@ -670,7 +670,16 @@
" )\n",
" self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}\n",
" if 'ai.azure' in base_url:\n",
- " self.supported_models = ['azureai', 'timegpt-1-long-horizon']\n",
+ " from packaging.version import Version\n",
+ "\n",
+ " import nixtla\n",
+ "\n",
+ " if Version(nixtla.__version__) > Version(\"0.5.2\"):\n",
+ " raise NotImplementedError(\n",
+ " \"This version doesn't support Azure endpoints, please install \"\n",
+ " \"an earlier version with: `pip install 'nixtla<=0.5.2'`\"\n",
+ " )\n",
+ " self.supported_models = ['azureai']\n",
" else:\n",
" self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
"\n",
@@ -694,7 +703,13 @@
" ensure_contiguous_arrays(payload)\n",
" content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)\n",
" resp = client.post(url=endpoint, content=content)\n",
- " resp_body = orjson.loads(resp.content)\n",
+ " try:\n",
+ " resp_body = orjson.loads(resp.content)\n",
+ " except orjson.JSONDecodeError:\n",
+ " raise ApiError(\n",
+ " status_code=resp.status_code,\n",
+ " body=f'Could not parse JSON: {resp.content}',\n",
+ " )\n",
" if resp.status_code != 200:\n",
" raise ApiError(status_code=resp.status_code, body=resp_body)\n",
" if 'data' in resp_body:\n",
diff --git a/nbs/utils.ipynb b/nbs/src/utils.ipynb
similarity index 100%
rename from nbs/utils.ipynb
rename to nbs/src/utils.ipynb
diff --git a/nixtla/_modidx.py b/nixtla/_modidx.py
index 91ed8adf..18d5ef19 100644
--- a/nixtla/_modidx.py
+++ b/nixtla/_modidx.py
@@ -15,98 +15,101 @@
'nixtla.core.pydantic_utilities': {},
'nixtla.core.remove_none_from_dict': {},
'nixtla.core.request_options': {},
- 'nixtla.date_features': { 'nixtla.date_features.CountryHolidays': ( 'date_features.html#countryholidays',
+ 'nixtla.date_features': { 'nixtla.date_features.CountryHolidays': ( 'src/date_features.html#countryholidays',
'nixtla/date_features.py'),
- 'nixtla.date_features.CountryHolidays.__call__': ( 'date_features.html#countryholidays.__call__',
+ 'nixtla.date_features.CountryHolidays.__call__': ( 'src/date_features.html#countryholidays.__call__',
'nixtla/date_features.py'),
- 'nixtla.date_features.CountryHolidays.__init__': ( 'date_features.html#countryholidays.__init__',
+ 'nixtla.date_features.CountryHolidays.__init__': ( 'src/date_features.html#countryholidays.__init__',
'nixtla/date_features.py'),
- 'nixtla.date_features.CountryHolidays.__name__': ( 'date_features.html#countryholidays.__name__',
+ 'nixtla.date_features.CountryHolidays.__name__': ( 'src/date_features.html#countryholidays.__name__',
'nixtla/date_features.py'),
- 'nixtla.date_features.SpecialDates': ('date_features.html#specialdates', 'nixtla/date_features.py'),
- 'nixtla.date_features.SpecialDates.__call__': ( 'date_features.html#specialdates.__call__',
+ 'nixtla.date_features.SpecialDates': ( 'src/date_features.html#specialdates',
+ 'nixtla/date_features.py'),
+ 'nixtla.date_features.SpecialDates.__call__': ( 'src/date_features.html#specialdates.__call__',
'nixtla/date_features.py'),
- 'nixtla.date_features.SpecialDates.__init__': ( 'date_features.html#specialdates.__init__',
+ 'nixtla.date_features.SpecialDates.__init__': ( 'src/date_features.html#specialdates.__init__',
'nixtla/date_features.py'),
- 'nixtla.date_features.SpecialDates.__name__': ( 'date_features.html#specialdates.__name__',
+ 'nixtla.date_features.SpecialDates.__name__': ( 'src/date_features.html#specialdates.__name__',
'nixtla/date_features.py'),
- 'nixtla.date_features._get_holidays_df': ( 'date_features.html#_get_holidays_df',
+ 'nixtla.date_features._get_holidays_df': ( 'src/date_features.html#_get_holidays_df',
'nixtla/date_features.py'),
- 'nixtla.date_features._transform_dict_holidays': ( 'date_features.html#_transform_dict_holidays',
+ 'nixtla.date_features._transform_dict_holidays': ( 'src/date_features.html#_transform_dict_holidays',
'nixtla/date_features.py')},
'nixtla.errors.unprocessable_entity_error': {},
- 'nixtla.nixtla_client': { 'nixtla.nixtla_client.NixtlaClient': ('nixtla_client.html#nixtlaclient', 'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.__init__': ( 'nixtla_client.html#nixtlaclient.__init__',
+ 'nixtla.nixtla_client': { 'nixtla.nixtla_client.NixtlaClient': ( 'src/nixtla_client.html#nixtlaclient',
+ 'nixtla/nixtla_client.py'),
+ 'nixtla.nixtla_client.NixtlaClient.__init__': ( 'src/nixtla_client.html#nixtlaclient.__init__',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._distributed_cross_validation': ( 'nixtla_client.html#nixtlaclient._distributed_cross_validation',
+ 'nixtla.nixtla_client.NixtlaClient._distributed_cross_validation': ( 'src/nixtla_client.html#nixtlaclient._distributed_cross_validation',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._distributed_detect_anomalies': ( 'nixtla_client.html#nixtlaclient._distributed_detect_anomalies',
+ 'nixtla.nixtla_client.NixtlaClient._distributed_detect_anomalies': ( 'src/nixtla_client.html#nixtlaclient._distributed_detect_anomalies',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._distributed_forecast': ( 'nixtla_client.html#nixtlaclient._distributed_forecast',
+ 'nixtla.nixtla_client.NixtlaClient._distributed_forecast': ( 'src/nixtla_client.html#nixtlaclient._distributed_forecast',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._get_model_params': ( 'nixtla_client.html#nixtlaclient._get_model_params',
+ 'nixtla.nixtla_client.NixtlaClient._get_model_params': ( 'src/nixtla_client.html#nixtlaclient._get_model_params',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._make_partitioned_requests': ( 'nixtla_client.html#nixtlaclient._make_partitioned_requests',
+ 'nixtla.nixtla_client.NixtlaClient._make_partitioned_requests': ( 'src/nixtla_client.html#nixtlaclient._make_partitioned_requests',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._make_request': ( 'nixtla_client.html#nixtlaclient._make_request',
+ 'nixtla.nixtla_client.NixtlaClient._make_request': ( 'src/nixtla_client.html#nixtlaclient._make_request',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._make_request_with_retries': ( 'nixtla_client.html#nixtlaclient._make_request_with_retries',
+ 'nixtla.nixtla_client.NixtlaClient._make_request_with_retries': ( 'src/nixtla_client.html#nixtlaclient._make_request_with_retries',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._maybe_assign_feature_contributions': ( 'nixtla_client.html#nixtlaclient._maybe_assign_feature_contributions',
+ 'nixtla.nixtla_client.NixtlaClient._maybe_assign_feature_contributions': ( 'src/nixtla_client.html#nixtlaclient._maybe_assign_feature_contributions',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._maybe_assign_weights': ( 'nixtla_client.html#nixtlaclient._maybe_assign_weights',
+ 'nixtla.nixtla_client.NixtlaClient._maybe_assign_weights': ( 'src/nixtla_client.html#nixtlaclient._maybe_assign_weights',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient._run_validations': ( 'nixtla_client.html#nixtlaclient._run_validations',
+ 'nixtla.nixtla_client.NixtlaClient._run_validations': ( 'src/nixtla_client.html#nixtlaclient._run_validations',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.cross_validation': ( 'nixtla_client.html#nixtlaclient.cross_validation',
+ 'nixtla.nixtla_client.NixtlaClient.cross_validation': ( 'src/nixtla_client.html#nixtlaclient.cross_validation',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.detect_anomalies': ( 'nixtla_client.html#nixtlaclient.detect_anomalies',
+ 'nixtla.nixtla_client.NixtlaClient.detect_anomalies': ( 'src/nixtla_client.html#nixtlaclient.detect_anomalies',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.forecast': ( 'nixtla_client.html#nixtlaclient.forecast',
+ 'nixtla.nixtla_client.NixtlaClient.forecast': ( 'src/nixtla_client.html#nixtlaclient.forecast',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.plot': ( 'nixtla_client.html#nixtlaclient.plot',
+ 'nixtla.nixtla_client.NixtlaClient.plot': ( 'src/nixtla_client.html#nixtlaclient.plot',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client.NixtlaClient.validate_api_key': ( 'nixtla_client.html#nixtlaclient.validate_api_key',
+ 'nixtla.nixtla_client.NixtlaClient.validate_api_key': ( 'src/nixtla_client.html#nixtlaclient.validate_api_key',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._array_tails': ('nixtla_client.html#_array_tails', 'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._cross_validation_wrapper': ( 'nixtla_client.html#_cross_validation_wrapper',
+ 'nixtla.nixtla_client._array_tails': ( 'src/nixtla_client.html#_array_tails',
+ 'nixtla/nixtla_client.py'),
+ 'nixtla.nixtla_client._cross_validation_wrapper': ( 'src/nixtla_client.html#_cross_validation_wrapper',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._detect_anomalies_wrapper': ( 'nixtla_client.html#_detect_anomalies_wrapper',
+ 'nixtla.nixtla_client._detect_anomalies_wrapper': ( 'src/nixtla_client.html#_detect_anomalies_wrapper',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._distributed_setup': ( 'nixtla_client.html#_distributed_setup',
+ 'nixtla.nixtla_client._distributed_setup': ( 'src/nixtla_client.html#_distributed_setup',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._forecast_payload_to_in_sample': ( 'nixtla_client.html#_forecast_payload_to_in_sample',
+ 'nixtla.nixtla_client._forecast_payload_to_in_sample': ( 'src/nixtla_client.html#_forecast_payload_to_in_sample',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._forecast_wrapper': ( 'nixtla_client.html#_forecast_wrapper',
+ 'nixtla.nixtla_client._forecast_wrapper': ( 'src/nixtla_client.html#_forecast_wrapper',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._get_schema': ('nixtla_client.html#_get_schema', 'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._maybe_add_date_features': ( 'nixtla_client.html#_maybe_add_date_features',
+ 'nixtla.nixtla_client._get_schema': ('src/nixtla_client.html#_get_schema', 'nixtla/nixtla_client.py'),
+ 'nixtla.nixtla_client._maybe_add_date_features': ( 'src/nixtla_client.html#_maybe_add_date_features',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._maybe_add_intervals': ( 'nixtla_client.html#_maybe_add_intervals',
+ 'nixtla.nixtla_client._maybe_add_intervals': ( 'src/nixtla_client.html#_maybe_add_intervals',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._maybe_convert_level_to_quantiles': ( 'nixtla_client.html#_maybe_convert_level_to_quantiles',
+ 'nixtla.nixtla_client._maybe_convert_level_to_quantiles': ( 'src/nixtla_client.html#_maybe_convert_level_to_quantiles',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._maybe_drop_id': ( 'nixtla_client.html#_maybe_drop_id',
+ 'nixtla.nixtla_client._maybe_drop_id': ( 'src/nixtla_client.html#_maybe_drop_id',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._maybe_infer_freq': ( 'nixtla_client.html#_maybe_infer_freq',
+ 'nixtla.nixtla_client._maybe_infer_freq': ( 'src/nixtla_client.html#_maybe_infer_freq',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._parse_in_sample_output': ( 'nixtla_client.html#_parse_in_sample_output',
+ 'nixtla.nixtla_client._parse_in_sample_output': ( 'src/nixtla_client.html#_parse_in_sample_output',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._partition_series': ( 'nixtla_client.html#_partition_series',
+ 'nixtla.nixtla_client._partition_series': ( 'src/nixtla_client.html#_partition_series',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._prepare_level_and_quantiles': ( 'nixtla_client.html#_prepare_level_and_quantiles',
+ 'nixtla.nixtla_client._prepare_level_and_quantiles': ( 'src/nixtla_client.html#_prepare_level_and_quantiles',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._preprocess': ('nixtla_client.html#_preprocess', 'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._restrict_input_samples': ( 'nixtla_client.html#_restrict_input_samples',
+ 'nixtla.nixtla_client._preprocess': ('src/nixtla_client.html#_preprocess', 'nixtla/nixtla_client.py'),
+ 'nixtla.nixtla_client._restrict_input_samples': ( 'src/nixtla_client.html#_restrict_input_samples',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._retry_strategy': ( 'nixtla_client.html#_retry_strategy',
+ 'nixtla.nixtla_client._retry_strategy': ( 'src/nixtla_client.html#_retry_strategy',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._standardize_freq': ( 'nixtla_client.html#_standardize_freq',
+ 'nixtla.nixtla_client._standardize_freq': ( 'src/nixtla_client.html#_standardize_freq',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._tail': ('nixtla_client.html#_tail', 'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._validate_exog': ( 'nixtla_client.html#_validate_exog',
+ 'nixtla.nixtla_client._tail': ('src/nixtla_client.html#_tail', 'nixtla/nixtla_client.py'),
+ 'nixtla.nixtla_client._validate_exog': ( 'src/nixtla_client.html#_validate_exog',
'nixtla/nixtla_client.py'),
- 'nixtla.nixtla_client._validate_input_size': ( 'nixtla_client.html#_validate_input_size',
+ 'nixtla.nixtla_client._validate_input_size': ( 'src/nixtla_client.html#_validate_input_size',
'nixtla/nixtla_client.py')},
'nixtla.types.anomaly_detection_output': {},
'nixtla.types.cross_validation_input_finetune_loss': {},
@@ -142,5 +145,5 @@
'nixtla.types.single_series_insample_forecast_level_item': {},
'nixtla.types.validation_error': {},
'nixtla.types.validation_error_loc_item': {},
- 'nixtla.utils': { 'nixtla.utils.colab_badge': ('utils.html#colab_badge', 'nixtla/utils.py'),
- 'nixtla.utils.in_colab': ('utils.html#in_colab', 'nixtla/utils.py')}}}
+ 'nixtla.utils': { 'nixtla.utils.colab_badge': ('src/utils.html#colab_badge', 'nixtla/utils.py'),
+ 'nixtla.utils.in_colab': ('src/utils.html#in_colab', 'nixtla/utils.py')}}}
diff --git a/nixtla/date_features.py b/nixtla/date_features.py
index d3e4025c..7e8f7ef7 100644
--- a/nixtla/date_features.py
+++ b/nixtla/date_features.py
@@ -1,14 +1,14 @@
-# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/date_features.ipynb.
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/date_features.ipynb.
# %% auto 0
__all__ = ['CountryHolidays', 'SpecialDates']
-# %% ../nbs/date_features.ipynb 4
+# %% ../nbs/src/date_features.ipynb 4
from typing import Dict, List
import pandas as pd
-# %% ../nbs/date_features.ipynb 6
+# %% ../nbs/src/date_features.ipynb 6
def _transform_dict_holidays(dict_holidays_dates):
dict_holidays = {}
for key, value in dict_holidays_dates.items():
@@ -17,7 +17,7 @@ def _transform_dict_holidays(dict_holidays_dates):
dict_holidays[value].append(key)
return dict_holidays
-# %% ../nbs/date_features.ipynb 7
+# %% ../nbs/src/date_features.ipynb 7
def _get_holidays_df(dates, categories, holiday_extractor, supported_categories):
years = dates.year.unique().tolist()
total_holidays = dict()
@@ -29,7 +29,7 @@ def _get_holidays_df(dates, categories, holiday_extractor, supported_categories)
total_holidays[f"{cat}_{key}"] = [int(ds.date() in val) for ds in dates]
return pd.DataFrame(total_holidays, index=dates)
-# %% ../nbs/date_features.ipynb 8
+# %% ../nbs/src/date_features.ipynb 8
class CountryHolidays:
"""Given a list of countries, returns a dataframe with holidays for each country."""
@@ -52,7 +52,7 @@ def __call__(self, dates: pd.DatetimeIndex):
def __name__(self):
return "CountryHolidays"
-# %% ../nbs/date_features.ipynb 12
+# %% ../nbs/src/date_features.ipynb 12
class SpecialDates:
"""Given a dictionary of categories and dates, returns a dataframe with the special dates."""
diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py
index bf753fec..14317ff7 100644
--- a/nixtla/nixtla_client.py
+++ b/nixtla/nixtla_client.py
@@ -1,9 +1,9 @@
-# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/nixtla_client.ipynb.
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/nixtla_client.ipynb.
# %% auto 0
__all__ = ['NixtlaClient']
-# %% ../nbs/nixtla_client.ipynb 3
+# %% ../nbs/src/nixtla_client.ipynb 3
import logging
import math
import os
@@ -78,7 +78,7 @@
from .core.api_error import ApiError
-# %% ../nbs/nixtla_client.ipynb 4
+# %% ../nbs/src/nixtla_client.ipynb 4
AnyDFType = TypeVar(
"AnyDFType",
"DaskDataFrame",
@@ -97,7 +97,7 @@
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
-# %% ../nbs/nixtla_client.ipynb 7
+# %% ../nbs/src/nixtla_client.ipynb 7
_Loss = Literal["default", "mae", "mse", "rmse", "mape", "smape"]
_Model = Literal["azureai", "timegpt-1", "timegpt-1-long-horizon"]
@@ -540,7 +540,7 @@ def _restrict_input_samples(level, input_size, model_horizon, h) -> int:
new_input_size = input_size
return new_input_size
-# %% ../nbs/nixtla_client.ipynb 8
+# %% ../nbs/src/nixtla_client.ipynb 8
class NixtlaClient:
def __init__(
@@ -601,7 +601,16 @@ def __init__(
)
self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}
if "ai.azure" in base_url:
- self.supported_models = ["azureai", "timegpt-1-long-horizon"]
+ from packaging.version import Version
+
+ import nixtla
+
+ if Version(nixtla.__version__) > Version("0.5.2"):
+ raise NotImplementedError(
+ "This version doesn't support Azure endpoints, please install "
+ "an earlier version with: `pip install 'nixtla<=0.5.2'`"
+ )
+ self.supported_models = ["azureai"]
else:
self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]
@@ -629,7 +638,13 @@ def ensure_contiguous_arrays(d: Dict[str, Any]) -> None:
ensure_contiguous_arrays(payload)
content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
resp = client.post(url=endpoint, content=content)
- resp_body = orjson.loads(resp.content)
+ try:
+ resp_body = orjson.loads(resp.content)
+ except orjson.JSONDecodeError:
+ raise ApiError(
+ status_code=resp.status_code,
+ body=f"Could not parse JSON: {resp.content}",
+ )
if resp.status_code != 200:
raise ApiError(status_code=resp.status_code, body=resp_body)
if "data" in resp_body:
@@ -1574,7 +1589,7 @@ def plot(
ax=ax,
)
-# %% ../nbs/nixtla_client.ipynb 52
+# %% ../nbs/src/nixtla_client.ipynb 52
def _forecast_wrapper(
df: pd.DataFrame,
client: NixtlaClient,
diff --git a/nixtla/utils.py b/nixtla/utils.py
index 5ea29421..eb283302 100644
--- a/nixtla/utils.py
+++ b/nixtla/utils.py
@@ -1,9 +1,9 @@
-# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb.
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/utils.ipynb.
# %% auto 0
__all__ = []
-# %% ../nbs/utils.ipynb 3
+# %% ../nbs/src/utils.ipynb 3
def colab_badge(path: str):
from IPython.display import Markdown, display
@@ -13,9 +13,9 @@ def colab_badge(path: str):
badge_md = f"[![]({badge_svg})]({nb_url})"
display(Markdown(badge_md))
-# %% ../nbs/utils.ipynb 4
+# %% ../nbs/src/utils.ipynb 4
import sys
-# %% ../nbs/utils.ipynb 5
+# %% ../nbs/src/utils.ipynb 5
def in_colab():
return "google.colab" in sys.modules
diff --git a/setup.py b/setup.py
index 688f985b..5f1c8168 100644
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,7 @@
"fastcore",
"httpx",
"orjson",
+ "packaging",
"pandas",
"pydantic",
"tenacity",