Skip to content

Add table option to ConsoleLogger #544

Merged
merged 4 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change the way `Sklearn` models works with regressors ([#440](https://github.com/tinkoff-ai/etna/pull/440))
- Change the way `FeatureSelectionTransform` works with regressors, rename variables replacing the "regressor" to "feature" ([#522](https://github.com/tinkoff-ai/etna/pull/522))
-
-
- Add table option to ConsoleLogger ([#544](https://github.com/tinkoff-ai/etna/pull/544))
- Installation instruction ([#526](https://github.com/tinkoff-ai/etna/pull/526))
-
- Trainer kwargs for deep models ([#540](https://github.com/tinkoff-ai/etna/pull/540))
Expand Down
36 changes: 23 additions & 13 deletions etna/loggers/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
class ConsoleLogger(BaseLogger):
"""Log any events and metrics to stderr output. Uses loguru."""

def __init__(self):
"""Create instance of ConsoleLogger."""
def __init__(self, table: bool = True):
"""Create instance of ConsoleLogger.

Parameters
----------
table:
Indicator for writing tables to the console
"""
super().__init__()
if 0 in _logger._core.handlers:
self.table = table
try:
_logger.remove(0)
except ValueError:
pass
_logger.add(sink=sys.stderr)
self.logger = _logger.opt(depth=2, lazy=True, colors=True)

Expand All @@ -37,7 +46,7 @@ def log(self, msg: Union[str, Dict[str, Any]], **kwargs):
kwargs:
Parameters for changing additional info in log message
"""
self.logger.patch(lambda r: r.update(**kwargs)).info(msg)
self.logger.patch(lambda r: r.update(**kwargs)).info(msg) # type: ignore

def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
Expand All @@ -60,15 +69,16 @@ def log_backtest_metrics(
-----
The result of logging will be different for aggregate_metrics=True and aggregate_metrics=False
"""
for _, row in metrics_df.iterrows():
for metric in metrics_df.columns[1:-1]:
# case for aggregate_metrics=False
if "fold_number" in row:
msg = f'Fold {row["fold_number"]}:{row["segment"]}:{metric} = {row[metric]}'
# case for aggregate_metrics=True
else:
msg = f'Segment {row["segment"]}:{metric} = {row[metric]}'
self.logger.info(msg)
if self.table:
for _, row in metrics_df.iterrows():
for metric in metrics_df.columns[1:-1]:
# case for aggregate_metrics=False
if "fold_number" in row:
msg = f'Fold {row["fold_number"]}:{row["segment"]}:{metric} = {row[metric]}'
# case for aggregate_metrics=True
else:
msg = f'Segment {row["segment"]}:{metric} = {row[metric]}'
self.logger.info(msg)

@property
def pl_logger(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_loggers/test_console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def test_backtest_logging(example_tsds: TSDataset):
tslogger.remove(idx)


def test_backtest_logging_no_tables(example_tsds: TSDataset):
"""Check working of logging inside backtest with `table=False`."""
file = NamedTemporaryFile()
_logger.add(file.name)
idx = tslogger.add(ConsoleLogger(table=False))
metrics = [MAE(), MSE(), SMAPE()]
date_flags = DateFlagsTransform(day_number_in_week=True, day_number_in_month=True)
pipe = Pipeline(model=CatBoostModelMultiSegment(), horizon=10, transforms=[date_flags])
n_folds = 5
pipe.backtest(ts=example_tsds, metrics=metrics, n_jobs=1, n_folds=n_folds)
with open(file.name, "r") as in_file:
lines = in_file.readlines()
# remain lines only about backtest
lines = [line for line in lines if "backtest" in line]
assert len(lines) == 0
tslogger.remove(idx)


@pytest.mark.parametrize("model", [LinearPerSegmentModel(), LinearMultiSegmentModel()])
def test_model_logging(example_tsds, model):
"""Check working of logging in fit/forecast of model."""
Expand Down