-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add examples for wandb sweeps (#317)
- Loading branch information
Showing
6 changed files
with
130 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,4 +123,5 @@ config.env | |
.devcontainer | ||
/docs/source/api/ | ||
tmp | ||
wandb | ||
wandb | ||
!examples/wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Using WandB with ETNA library | ||
|
||
## Colab example | ||
|
||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EBSqqBPaYgLWCRdpC5vMy9RiLBsCEd7I?usp=sharing) | ||
|
||
![](assets/etna-wandb.png) | ||
|
||
[Sweep Dashboard](https://wandb.ai/martins0n/wandb-etna-sweep/sweeps/c7e0r8sq/overview?workspace=user-martins0n) | ||
|
||
## Steps to start | ||
|
||
- Define your pipeline and hyperparameters in `pipeline.yaml`, in example we will optimize number of iterations `iterations` and `learning-rate` | ||
|
||
- Define WandB sweeps config `sweep.yaml` and push it to cloud: | ||
|
||
```bash | ||
WANDB_PROJECT=<project_name> WandB sweep sweep.yaml | ||
``` | ||
|
||
- You may change `dataloader` function and add additional parameters for WandB logger like tags for example in `run.py` | ||
|
||
- Run WandB agent for hyperparameters optimization start: | ||
|
||
```bash | ||
wandb agent <user_name>/<project_name>/<sweep_id> | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
backtest: | ||
n_folds: 3 | ||
n_jobs: 1 | ||
metrics: | ||
- _target_: etna.metrics.MAE | ||
- _target_: etna.metrics.MSE | ||
- _target_: etna.metrics.MAPE | ||
- _target_: etna.metrics.SMAPE | ||
- _target_: etna.metrics.R2 | ||
pipeline: | ||
_target_: etna.pipeline.Pipeline | ||
horizon: 10 | ||
model: | ||
_target_: etna.models.CatBoostModelMultiSegment | ||
iterations: ${iterations} | ||
learning_rate: ${learning-rate} | ||
transforms: | ||
- _target_: etna.transforms.SegmentEncoderTransform | ||
- _target_: etna.transforms.LagTransform | ||
in_column: target | ||
lags: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] | ||
iterations: null | ||
learning-rate: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
Example of using WandB with ETNA library. | ||
Current script could be used for sweeps and simple validation runs. | ||
""" | ||
|
||
import argparse | ||
import random | ||
from typing import Any, Dict | ||
|
||
import hydra_slayer | ||
import numpy as np | ||
from etna.datasets import TSDataset, generate_ar_df | ||
from etna.loggers import WandbLogger, tslogger | ||
from etna.pipeline import Pipeline | ||
from omegaconf import OmegaConf | ||
|
||
SEED = 11 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
|
||
# Default config loading | ||
config = OmegaConf.load("pipeline.yaml") | ||
|
||
|
||
# Define arguments for WandB sweep parameters | ||
args = argparse.ArgumentParser() | ||
args.add_argument("--iterations", type=int) | ||
args.add_argument("--learning-rate", type=float) | ||
for key, value in vars(args.parse_args()).items(): | ||
if value: | ||
config[key] = value | ||
|
||
# Config for Pipeline and backtesting pipeline | ||
config = OmegaConf.to_container(config, resolve=True) | ||
pipeline = config["pipeline"] | ||
backtest = config["backtest"] | ||
|
||
|
||
# Define WandbLogger and passing it to global library logger | ||
# It will not log child processes in case of `spawn` (OSX or Windows) | ||
wblogger = WandbLogger(project="test-run", config=pipeline) | ||
tslogger.add(wblogger) | ||
|
||
|
||
def dataloader() -> TSDataset: | ||
df = generate_ar_df(periods=300, start_time="2021-01-02", n_segments=10) | ||
df = TSDataset.to_dataset(df) | ||
ts = TSDataset(df=df, freq="1D") | ||
return ts | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
ts = dataloader() | ||
|
||
pipeline: Pipeline = hydra_slayer.get_from_params(**pipeline) | ||
|
||
backtest_configs: Dict[str, Any] = hydra_slayer.get_from_params(**backtest) | ||
|
||
metrics, forecast, info = pipeline.backtest(ts=ts, **backtest_configs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
program: | ||
run.py | ||
method: bayes | ||
parameters: | ||
learning-rate: | ||
min: 0.0001 | ||
max: 0.1 | ||
iterations: | ||
distribution: int_uniform | ||
min: 2 | ||
max: 30 | ||
metric: | ||
name: MAE_median | ||
goal: minimize | ||
command: | ||
- python | ||
- run.py | ||
- ${args} |