Skip to content

Commit

Permalink
Add progress callback interface to HPO (#2889)
Browse files Browse the repository at this point in the history
* add progress callback as HPO argument

* deal with edge case
  • Loading branch information
eunwoosh authored Feb 7, 2024
1 parent fa3c86c commit ab01197
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import os
import re
import shutil
import time
from copy import deepcopy
from enum import Enum
from functools import partial
from inspect import isclass
from math import floor
from pathlib import Path
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Union

import torch
Expand All @@ -30,6 +32,7 @@
from otx.cli.utils.io import read_model, save_model_data
from otx.core.data.adapter import get_dataset_adapter
from otx.hpo import HyperBand, TrialStatus, run_hpo_loop
from otx.hpo.hpo_base import HpoBase
from otx.utils.logger import get_logger

logger = get_logger()
Expand Down Expand Up @@ -382,6 +385,7 @@ class HpoRunner:
val_dataset_size (int): validation dataset size
hpo_workdir (Union[str, Path]): work directory for HPO
hpo_time_ratio (int, optional): time ratio to use for HPO compared to training time. Defaults to 4.
progress_updater_callback (Optional[Callable[[Union[int, float]], None]]): callback to update progress
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -393,6 +397,7 @@ def __init__(
val_dataset_size: int,
hpo_workdir: Union[str, Path],
hpo_time_ratio: int = 4,
progress_updater_callback: Optional[Callable[[Union[int, float]], None]] = None,
):
if train_dataset_size <= 0:
raise ValueError(f"train_dataset_size should be bigger than 0. Your value is {train_dataset_size}")
Expand All @@ -409,6 +414,7 @@ def __init__(
self._val_dataset_size = val_dataset_size
self._fixed_hp: Dict[str, Any] = {}
self._initial_weight_name = "initial_weight.pth"
self._progress_updater_callback = progress_updater_callback

self._align_batch_size_search_space_to_dataset_size()

Expand Down Expand Up @@ -466,6 +472,11 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Di
"""
self._environment.save_initial_weight(self._get_initial_model_weight_path())
hpo_algo = self._get_hpo_algo()

if self._update_hpo_progress is not None:
progress_updater_thread = Thread(target=self._update_hpo_progress, args=[hpo_algo], daemon=True)
progress_updater_thread.start()

resource_type = "gpu" if torch.cuda.is_available() else "cpu"
run_hpo_loop(
hpo_algo,
Expand Down Expand Up @@ -543,9 +554,27 @@ def _get_default_hyper_parameters(self):
def _get_initial_model_weight_path(self):
return self._hpo_workdir / self._initial_weight_name

def _update_hpo_progress(self, hpo_algo: HpoBase):
"""Function for a thread to report a HPO progress regularly.
Args:
hpo_algo (HpoBase): HPO algorithm class
"""

while True:
if hpo_algo.is_done():
break
self._progress_updater_callback(hpo_algo.get_progress() * 100)
time.sleep(1)


def run_hpo(
hpo_time_ratio: int, output: Path, environment: TaskEnvironment, dataset: DatasetEntity, data_roots: Dict[str, Dict]
hpo_time_ratio: int,
output: Path,
environment: TaskEnvironment,
dataset: DatasetEntity,
data_roots: Dict[str, Dict],
progress_updater_callback: Optional[Callable[[Union[int, float]], None]] = None,
) -> Optional[TaskEnvironment]:
"""Run HPO and load optimized hyper parameter and best HPO model weight.
Expand All @@ -555,6 +584,7 @@ def run_hpo(
environment (TaskEnvironment): otx task environment
dataset (DatasetEntity): dataset to use for training
data_roots (Dict[str, Dict]): dataset path of each dataset type
progress_updater_callback (Optional[Callable[[Union[int, float]], None]]): callback to update progress
"""
task_type = environment.model_template.task_type
if not _check_hpo_enabled_task(task_type):
Expand All @@ -575,6 +605,7 @@ def run_hpo(
len(dataset.get_subset(Subset.VALIDATION)),
hpo_save_path,
hpo_time_ratio,
progress_updater_callback,
)

logger.info("started hyper-parameter optimization")
Expand Down

0 comments on commit ab01197

Please sign in to comment.