Skip to content

Commit

Permalink
Pass params to forecast cli (#671)
Browse files Browse the repository at this point in the history
* Pass params to forecast cli

* Upd CHANGELOG
  • Loading branch information
julia-shenshina authored May 11, 2022
1 parent d0ed655 commit 4c04584
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-

### Fixed
-
- Add missed `forecast_params` in forecast CLI method ([#671](https://github.com/tinkoff-ai/etna/pull/671))
-
-
-
Expand Down
23 changes: 16 additions & 7 deletions docs/source/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ Basic ``forecast`` usage:
.. code-block:: console
Usage: etna forecast [OPTIONS] CONFIG_PATH TARGET_PATH FREQ OUTPUT_PATH [EXOG_PATH]
[RAW_OUTPUT]
[FORECAST_CONFIG_PATH] [RAW_OUTPUT]
Command to make forecast with etna without coding.
Arguments:
CONFIG_PATH path to yaml config with desired pipeline [required]
TARGET_PATH path to csv with data to forecast [required]
FREQ frequency of timestamp in files in pandas format [required]
OUTPUT_PATH where to save forecast [required]
[EXOG_PATH] path to csv with exog data
[RAW_OUTPUT] by default we return only forecast without features [default: False]
CONFIG_PATH path to yaml config with desired pipeline [required]
TARGET_PATH path to csv with data to forecast [required]
FREQ frequency of timestamp in files in pandas format [required]
OUTPUT_PATH where to save forecast [required]
[EXOG_PATH] path to csv with exog data
[FORECAST_CONFIG_PATH] path to yaml config with forecast params
[RAW_OUTPUT] by default we return only forecast without features [default: False]
**How to create config?**

Expand All @@ -34,6 +35,14 @@ Example of pipeline's config:
in_column: target
- _target_: etna.transforms.SegmentEncoderTransform
Example of forecast params config:

.. code-block:: yaml
prediction_interval: true
quantiles: [0.025, 0.975]
n_folds: 3
**How to prepare data?**

Example of dataset with data to forecast:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
SOURCE_PATH = Path(os.path.dirname(__file__)) # noqa # docs source
PROJECT_PATH = SOURCE_PATH.joinpath("../..") # noqa # project root

COMMIT_SHORT_SHA = os.environ["CI_COMMIT_SHORT_SHA"]
WORKFLOW_NAME = os.environ["WORKFLOW_NAME"]
COMMIT_SHORT_SHA = os.environ.get("CI_COMMIT_SHORT_SHA", None)
WORKFLOW_NAME = os.environ.get("WORKFLOW_NAME", None)

sys.path.insert(0, str(PROJECT_PATH)) # noqa

Expand Down
16 changes: 13 additions & 3 deletions etna/commands/forecast_command.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Optional

import hydra_slayer
Expand All @@ -16,6 +18,7 @@ def forecast(
freq: str = typer.Argument(..., help="frequency of timestamp in files in pandas format"),
output_path: Path = typer.Argument(..., help="where to save forecast"),
exog_path: Optional[Path] = typer.Argument(None, help="path to csv with exog data"),
forecast_config_path: Optional[Path] = typer.Argument(None, help="path to yaml config with forecast params"),
raw_output: bool = typer.Argument(False, help="by default we return only forecast without features"),
):
"""Command to make forecast with etna without coding.
Expand Down Expand Up @@ -51,6 +54,11 @@ def forecast(
============= =========== =============== ===============
"""
pipeline_configs = OmegaConf.to_object(OmegaConf.load(config_path))
if forecast_config_path:
forecast_params_config = OmegaConf.to_object(OmegaConf.load(forecast_config_path))
else:
forecast_params_config = {}
forecast_params: Dict[str, Any] = hydra_slayer.get_from_params(**forecast_params_config)

df_timeseries = pd.read_csv(target_path, parse_dates=["timestamp"])

Expand All @@ -65,12 +73,14 @@ def forecast(

pipeline: Pipeline = hydra_slayer.get_from_params(**pipeline_configs)
pipeline.fit(tsdataset)
forecast = pipeline.forecast()
forecast = pipeline.forecast(**forecast_params)

flatten = forecast.to_pandas(flatten=True)
if raw_output:
(forecast.to_pandas(True).to_csv(output_path, index=False))
(flatten.to_csv(output_path, index=False))
else:
(forecast.to_pandas(True)[["timestamp", "segment", "target"]].to_csv(output_path, index=False))
quantile_columns = [column for column in flatten.columns if column.startswith("target_0.")]
(flatten[["timestamp", "segment", "target"] + quantile_columns].to_csv(output_path, index=False))


if __name__ == "__main__":
Expand Down
15 changes: 15 additions & 0 deletions tests/test_commands/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,18 @@ def base_timeseries_exog_path():
tmp.flush()
yield Path(tmp.name)
tmp.close()


@pytest.fixture
def base_forecast_omegaconf_path():
tmp = NamedTemporaryFile("w")
tmp.write(
"""
prediction_interval: true
quantiles: [0.025, 0.975]
n_folds: 3
"""
)
tmp.flush()
yield Path(tmp.name)
tmp.close()
22 changes: 22 additions & 0 deletions tests/test_commands/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,25 @@ def test_dummy_run(base_pipeline_yaml_path, base_timeseries_path):
run(["etna", "forecast", str(base_pipeline_yaml_path), str(base_timeseries_path), "D", str(tmp_output_path)])
df_output = pd.read_csv(tmp_output_path)
assert len(df_output) == 2 * 4


def test_run_with_predictive_intervals(
base_pipeline_yaml_path, base_timeseries_path, base_timeseries_exog_path, base_forecast_omegaconf_path
):
tmp_output = NamedTemporaryFile("w")
tmp_output_path = Path(tmp_output.name)
run(
[
"etna",
"forecast",
str(base_pipeline_yaml_path),
str(base_timeseries_path),
"D",
str(tmp_output_path),
str(base_timeseries_exog_path),
str(base_forecast_omegaconf_path),
]
)
df_output = pd.read_csv(tmp_output_path)
for q in [0.025, 0.975]:
assert f"target_{q}" in df_output.columns

1 comment on commit 4c04584

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.