diff --git a/src/otx/algorithms/common/adapters/mmcv/hooks/logger_hook.py b/src/otx/algorithms/common/adapters/mmcv/hooks/logger_hook.py index 051e8102950..acbddd846a8 100644 --- a/src/otx/algorithms/common/adapters/mmcv/hooks/logger_hook.py +++ b/src/otx/algorithms/common/adapters/mmcv/hooks/logger_hook.py @@ -1,4 +1,8 @@ """Logger hooks.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from collections import defaultdict from typing import Any, Dict, Optional @@ -29,6 +33,19 @@ def __repr__(self): points.append(f"({x},{y})") return "curve[" + ",".join(points) + "]" + _TAGS_TO_SKIP = ( + "accuracy_top-1", + "current_iters", + "decode.acc_seg", + "decode.loss_ce_ignore", + ) + + _TAGS_TO_RENAME = { + "train/time": "train/time (sec/iter)", + "train/data_time": "train/data_time (sec/iter)", + "val/accuracy": "val/accuracy (%)", + } + def __init__( self, curves: Optional[Dict[Any, Curve]] = None, @@ -43,12 +60,13 @@ def __init__( @master_only def log(self, runner: BaseRunner): """Log function for OTXLoggerHook.""" - tags = self.get_loggable_tags(runner, allow_text=False, tags_to_skip=()) + tags = self.get_loggable_tags(runner, allow_text=False, tags_to_skip=self._TAGS_TO_SKIP) if runner.max_epochs is not None: normalized_iter = self.get_iter(runner) / runner.max_iters * runner.max_epochs else: normalized_iter = self.get_iter(runner) for tag, value in tags.items(): + tag = self._TAGS_TO_RENAME.get(tag, tag) curve = self.curves[tag] # Remove duplicates. if len(curve.x) > 0 and curve.x[-1] == normalized_iter: @@ -57,6 +75,11 @@ def log(self, runner: BaseRunner): curve.x.append(normalized_iter) curve.y.append(value) + def before_run(self, runner: BaseRunner): + """Called before_run in OTXLoggerHook.""" + super().before_run(runner) + self.curves.clear() + def after_train_epoch(self, runner: BaseRunner): """Called after_train_epoch in OTXLoggerHook.""" # Iteration counter is increased right after the last iteration in the epoch,