-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/master' into issue-323
# Conflicts: # CHANGELOG.md
- Loading branch information
Showing
13 changed files
with
448 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,4 +123,5 @@ config.env | |
.devcontainer | ||
/docs/source/api/ | ||
tmp | ||
wandb | ||
wandb | ||
!examples/wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import math | ||
from typing import Optional | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.transforms.base import Transform | ||
|
||
|
||
class FourierTransform(Transform): | ||
"""Adds fourier features to the dataset.""" | ||
|
||
def __init__( | ||
self, | ||
period: float, | ||
order: Optional[int] = None, | ||
mods: Optional[Sequence[int]] = None, | ||
out_column: Optional[str] = None, | ||
): | ||
"""Create instance of FourierTransform. | ||
Parameters | ||
---------- | ||
period: | ||
the period of the seasonality to capture in frequency units of time series, it should be >= 2 | ||
order: | ||
upper order of Fourier components to include, it should be >= 1 and <= ceil(period/2)) | ||
mods: | ||
alternative and precise way of defining which harmonics will be used, | ||
for example `mods=[1, 3, 4]` means that sin of the first order | ||
and sin and cos of the second order will be used, | ||
mods should be >= 1 and < period | ||
out_column: | ||
if set, name of added column, the final name will be '{out_columnt}_{mod}', | ||
don't forget to add 'regressor_' prefix | ||
if don't set, name will be 'regressor_{repr}', repr will represent class that creates exactly this column | ||
Raises | ||
------ | ||
ValueError: | ||
if period < 2 | ||
ValueError: | ||
if both or none of order, mods is set | ||
ValueError: | ||
if order is < 1 or > ceil(period/2) | ||
ValueError: | ||
if at least one mod is < 1 or >= period | ||
Notes | ||
----- | ||
To understand how transform works we recommend: https://otexts.com/fpp2/useful-predictors.html#fourier-series | ||
* Parameter `period` is responsible for the seasonality we want to capture. | ||
* Parameters `order` and `mods` define which harmonics will be used. | ||
Parameter `order` is a more user-friendly version of `mods`. | ||
For example, `order=2` can be represented as `mods=[1, 2, 3, 4]` if `period` > 4 and | ||
as `mods=[1, 2, 3]` if 3 <= `period` <= 4. | ||
""" | ||
if period < 2: | ||
raise ValueError("Period should be at least 2") | ||
self.period = period | ||
self.mods: Sequence[int] | ||
|
||
if order is not None and mods is None: | ||
if order < 1 or order > math.ceil(period / 2): | ||
raise ValueError("Order should be within [1, ceil(period/2)] range") | ||
self.mods = [mod for mod in range(1, 2 * order + 1) if mod < period] | ||
elif mods is not None and order is None: | ||
if min(mods) < 1 or max(mods) >= period: | ||
raise ValueError("Every mod should be within [1, int(period)) range") | ||
self.mods = mods | ||
else: | ||
raise ValueError("There should be exactly one option set: order or mods") | ||
|
||
self.out_column = out_column | ||
|
||
def fit(self, df: pd.DataFrame) -> "FourierTransform": | ||
"""Fit method does nothing and is kept for compatibility. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with data. | ||
Returns | ||
------- | ||
result: FourierTransform | ||
""" | ||
return self | ||
|
||
def _get_column_name(self, mod: int) -> str: | ||
if self.out_column is None: | ||
return f"regressor_{FourierTransform(period=self.period, mods=[mod]).__repr__()}" | ||
else: | ||
return f"{self.out_column}_{mod}" | ||
|
||
@staticmethod | ||
def _construct_answer(df: pd.DataFrame, features: pd.DataFrame) -> pd.DataFrame: | ||
dataframes = [] | ||
for seg in df.columns.get_level_values("segment").unique(): | ||
tmp = df[seg].join(features) | ||
_idx = tmp.columns.to_frame() | ||
_idx.insert(0, "segment", seg) | ||
tmp.columns = pd.MultiIndex.from_frame(_idx) | ||
dataframes.append(tmp) | ||
|
||
result = pd.concat(dataframes, axis=1).sort_index(axis=1) | ||
result.columns.names = ["segment", "feature"] | ||
return result | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Add harmonics to the dataset. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe with data to transform. | ||
Returns | ||
------- | ||
result: pd.Dataframe | ||
transformed dataframe | ||
""" | ||
features = pd.DataFrame(index=df.index) | ||
elapsed = np.arange(features.shape[0]) / self.period | ||
|
||
for mod in self.mods: | ||
order = (mod + 1) // 2 | ||
is_cos = mod % 2 == 0 | ||
|
||
features[self._get_column_name(mod)] = np.sin(2 * np.pi * order * elapsed + np.pi / 2 * is_cos) | ||
|
||
return self._construct_answer(df, features) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Using WandB with ETNA library | ||
|
||
## Colab example | ||
|
||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EBSqqBPaYgLWCRdpC5vMy9RiLBsCEd7I?usp=sharing) | ||
|
||
![](assets/etna-wandb.png) | ||
|
||
[Sweep Dashboard](https://wandb.ai/martins0n/wandb-etna-sweep/sweeps/c7e0r8sq/overview?workspace=user-martins0n) | ||
|
||
## Steps to start | ||
|
||
- Define your pipeline and hyperparameters in `pipeline.yaml`, in example we will optimize number of iterations `iterations` and `learning-rate` | ||
|
||
- Define WandB sweeps config `sweep.yaml` and push it to cloud: | ||
|
||
```bash | ||
WANDB_PROJECT=<project_name> WandB sweep sweep.yaml | ||
``` | ||
|
||
- You may change `dataloader` function and add additional parameters for WandB logger like tags for example in `run.py` | ||
|
||
- Run WandB agent for hyperparameters optimization start: | ||
|
||
```bash | ||
wandb agent <user_name>/<project_name>/<sweep_id> | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
backtest: | ||
n_folds: 3 | ||
n_jobs: 1 | ||
metrics: | ||
- _target_: etna.metrics.MAE | ||
- _target_: etna.metrics.MSE | ||
- _target_: etna.metrics.MAPE | ||
- _target_: etna.metrics.SMAPE | ||
- _target_: etna.metrics.R2 | ||
pipeline: | ||
_target_: etna.pipeline.Pipeline | ||
horizon: 10 | ||
model: | ||
_target_: etna.models.CatBoostModelMultiSegment | ||
iterations: ${iterations} | ||
learning_rate: ${learning-rate} | ||
transforms: | ||
- _target_: etna.transforms.SegmentEncoderTransform | ||
- _target_: etna.transforms.LagTransform | ||
in_column: target | ||
lags: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] | ||
iterations: null | ||
learning-rate: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
Example of using WandB with ETNA library. | ||
Current script could be used for sweeps and simple validation runs. | ||
""" | ||
|
||
import argparse | ||
import random | ||
from typing import Any, Dict | ||
|
||
import hydra_slayer | ||
import numpy as np | ||
from etna.datasets import TSDataset, generate_ar_df | ||
from etna.loggers import WandbLogger, tslogger | ||
from etna.pipeline import Pipeline | ||
from omegaconf import OmegaConf | ||
|
||
SEED = 11 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
|
||
# Default config loading | ||
config = OmegaConf.load("pipeline.yaml") | ||
|
||
|
||
# Define arguments for WandB sweep parameters | ||
args = argparse.ArgumentParser() | ||
args.add_argument("--iterations", type=int) | ||
args.add_argument("--learning-rate", type=float) | ||
for key, value in vars(args.parse_args()).items(): | ||
if value: | ||
config[key] = value | ||
|
||
# Config for Pipeline and backtesting pipeline | ||
config = OmegaConf.to_container(config, resolve=True) | ||
pipeline = config["pipeline"] | ||
backtest = config["backtest"] | ||
|
||
|
||
# Define WandbLogger and passing it to global library logger | ||
# It will not log child processes in case of `spawn` (OSX or Windows) | ||
wblogger = WandbLogger(project="test-run", config=pipeline) | ||
tslogger.add(wblogger) | ||
|
||
|
||
def dataloader() -> TSDataset: | ||
df = generate_ar_df(periods=300, start_time="2021-01-02", n_segments=10) | ||
df = TSDataset.to_dataset(df) | ||
ts = TSDataset(df=df, freq="1D") | ||
return ts | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
ts = dataloader() | ||
|
||
pipeline: Pipeline = hydra_slayer.get_from_params(**pipeline) | ||
|
||
backtest_configs: Dict[str, Any] = hydra_slayer.get_from_params(**backtest) | ||
|
||
metrics, forecast, info = pipeline.backtest(ts=ts, **backtest_configs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
program: | ||
run.py | ||
method: bayes | ||
parameters: | ||
learning-rate: | ||
min: 0.0001 | ||
max: 0.1 | ||
iterations: | ||
distribution: int_uniform | ||
min: 2 | ||
max: 30 | ||
metric: | ||
name: MAE_median | ||
goal: minimize | ||
command: | ||
- python | ||
- run.py | ||
- ${args} |
Oops, something went wrong.