Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progress callback interface to HPO #2889

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import os
import re
import shutil
import time
from copy import deepcopy
from enum import Enum
from functools import partial
from inspect import isclass
from math import floor
from pathlib import Path
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Union

import torch
Expand All @@ -30,6 +32,7 @@
from otx.cli.utils.io import read_model, save_model_data
from otx.core.data.adapter import get_dataset_adapter
from otx.hpo import HyperBand, TrialStatus, run_hpo_loop
from otx.hpo.hpo_base import HpoBase
from otx.utils.logger import get_logger

logger = get_logger()
Expand Down Expand Up @@ -382,6 +385,7 @@ class HpoRunner:
val_dataset_size (int): validation dataset size
hpo_workdir (Union[str, Path]): work directory for HPO
hpo_time_ratio (int, optional): time ratio to use for HPO compared to training time. Defaults to 4.
progress_updater_callback (Optional[Callable[[Union[int, float]], None]]): callback to update progress
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -393,6 +397,7 @@ def __init__(
val_dataset_size: int,
hpo_workdir: Union[str, Path],
hpo_time_ratio: int = 4,
progress_updater_callback: Optional[Callable[[Union[int, float]], None]] = None,
):
if train_dataset_size <= 0:
raise ValueError(f"train_dataset_size should be bigger than 0. Your value is {train_dataset_size}")
Expand All @@ -409,6 +414,7 @@ def __init__(
self._val_dataset_size = val_dataset_size
self._fixed_hp: Dict[str, Any] = {}
self._initial_weight_name = "initial_weight.pth"
self._progress_updater_callback = progress_updater_callback

self._align_batch_size_search_space_to_dataset_size()

Expand Down Expand Up @@ -466,6 +472,10 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Di
"""
self._environment.save_initial_weight(self._get_initial_model_weight_path())
hpo_algo = self._get_hpo_algo()

progress_updater_thread = Thread(target=self._update_hpo_progress, args=[hpo_algo], daemon=True)
progress_updater_thread.start()

resource_type = "gpu" if torch.cuda.is_available() else "cpu"
run_hpo_loop(
hpo_algo,
Expand Down Expand Up @@ -543,9 +553,27 @@ def _get_default_hyper_parameters(self):
def _get_initial_model_weight_path(self):
return self._hpo_workdir / self._initial_weight_name

def _update_hpo_progress(self, hpo_algo: HpoBase):
"""Function for a thread to report a HPO progress regularly.

Args:
hpo_algo (HpoBase): HPO algorithm class
"""

while True:
if hpo_algo.is_done():
break
self._progress_updater_callback(hpo_algo.get_progress() * 100)
time.sleep(1)


def run_hpo(
hpo_time_ratio: int, output: Path, environment: TaskEnvironment, dataset: DatasetEntity, data_roots: Dict[str, Dict]
hpo_time_ratio: int,
output: Path,
environment: TaskEnvironment,
dataset: DatasetEntity,
data_roots: Dict[str, Dict],
progress_updater_callback: Optional[Callable[[Union[int, float]], None]] = None,
) -> Optional[TaskEnvironment]:
"""Run HPO and load optimized hyper parameter and best HPO model weight.

Expand All @@ -555,6 +583,7 @@ def run_hpo(
environment (TaskEnvironment): otx task environment
dataset (DatasetEntity): dataset to use for training
data_roots (Dict[str, Dict]): dataset path of each dataset type
progress_updater_callback (Optional[Callable[[Union[int, float]], None]]): callback to update progress
"""
task_type = environment.model_template.task_type
if not _check_hpo_enabled_task(task_type):
Expand All @@ -575,6 +604,7 @@ def run_hpo(
len(dataset.get_subset(Subset.VALIDATION)),
hpo_save_path,
hpo_time_ratio,
progress_updater_callback,
)

logger.info("started hyper-parameter optimization")
Expand Down
Loading