diff --git a/otx/algorithms/action/configs/classification/configuration.yaml b/otx/algorithms/action/configs/classification/configuration.yaml index e9c63c445c5..5221eaaa03b 100644 --- a/otx/algorithms/action/configs/classification/configuration.yaml +++ b/otx/algorithms/action/configs/classification/configuration.yaml @@ -261,6 +261,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/action/configs/detection/configuration.yaml b/otx/algorithms/action/configs/detection/configuration.yaml index e9c63c445c5..5221eaaa03b 100644 --- a/otx/algorithms/action/configs/detection/configuration.yaml +++ b/otx/algorithms/action/configs/detection/configuration.yaml @@ -261,6 +261,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/classification/adapters/mmcls/data/pipelines.py b/otx/algorithms/classification/adapters/mmcls/data/pipelines.py index a13ede8a9bb..23543ddbc58 100644 --- a/otx/algorithms/classification/adapters/mmcls/data/pipelines.py +++ b/otx/algorithms/classification/adapters/mmcls/data/pipelines.py @@ -3,7 +3,6 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import copy -import tempfile from typing import Any, Dict, List import numpy as np @@ -13,59 +12,16 @@ from PIL import Image, ImageFilter from torchvision import transforms as T -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.utils.argument_checks import check_input_parameters_type -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - # TODO: refactoring to common modules # TODO: refactoring to Sphinx style. @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. - - Can do conversion to float 32 if needed. - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function of LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), - std=np.ones(num_channels, dtype=np.float32), - to_rgb=False, - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/algorithms/classification/configs/configuration.yaml b/otx/algorithms/classification/configs/configuration.yaml index 541a4ab3528..897c3f7e13f 100644 --- a/otx/algorithms/classification/configs/configuration.yaml +++ b/otx/algorithms/classification/configs/configuration.yaml @@ -354,5 +354,21 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true diff --git a/otx/algorithms/common/configs/training_base.py b/otx/algorithms/common/configs/training_base.py index c8e17af0392..1e99f5048ee 100644 --- a/otx/algorithms/common/configs/training_base.py +++ b/otx/algorithms/common/configs/training_base.py @@ -282,6 +282,16 @@ class BaseAlgoBackendParameters(ParameterGroup): visible_in_ui=True, ) + mem_cache_size = configurable_integer( + header="Size of memory pool for caching decoded data to load data faster", + description="Size of memory pool for caching decoded data to load data faster", + default_value=0, + min_value=0, + max_value=maxsize, + visible_in_ui=False, + affects_outcome_of=ModelLifecycle.TRAINING, + ) + @attrs class BaseTilingParameters(ParameterGroup): """BaseTilingParameters for OTX Algorithms.""" diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index db326b81573..a58fdf0e20f 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -45,6 +45,7 @@ from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask from otx.api.usecases.tasks.interfaces.unload_interface import IUnload from otx.api.utils.argument_checks import check_input_parameters_type +from otx.core.data import caching from otx.mpa.builder import build from otx.mpa.modules.hooks.cancel_interface_hook import CancelInterfaceHook from otx.mpa.stage import Stage @@ -324,6 +325,9 @@ def _initialize(self, options=None): # noqa: C901 dataloader_cfg["persistent_workers"] = False data_cfg[f"{subset}_dataloader"] = dataloader_cfg + # Update recipe with caching modules + self._update_caching_modules(data_cfg) + if self._data_cfg is not None: align_data_config_with_recipe(self._data_cfg, self._recipe_cfg) @@ -412,7 +416,6 @@ def _init_deploy_cfg(self) -> Union[Config, None]: deploy_cfg = MPAConfig.fromfile(deploy_cfg_path) def patch_input_preprocessing(deploy_cfg): - normalize_cfg = get_configs_by_pairs( self._recipe_cfg.data.test.pipeline, dict(type="Normalize"), @@ -620,3 +623,31 @@ def set_early_stopping_hook(self): update_or_add_custom_hook(self._recipe_cfg, early_stop_hook) else: remove_custom_hook(self._recipe_cfg, "LazyEarlyStoppingHook") + + def _update_caching_modules(self, data_cfg: Config) -> None: + def _find_max_num_workers(cfg: dict): + num_workers = [0] + for key, value in cfg.items(): + if key == "workers_per_gpu" and isinstance(value, int): + num_workers += [value] + elif isinstance(value, dict): + num_workers += [_find_max_num_workers(value)] + + return max(num_workers) + + def _get_mem_cache_size(): + if not hasattr(self.hyperparams.algo_backend, "mem_cache_size"): + return 0 + + return self.hyperparams.algo_backend.mem_cache_size + + max_num_workers = _find_max_num_workers(data_cfg) + mem_cache_size = _get_mem_cache_size() + + mode = "multiprocessing" if max_num_workers > 0 else "singleprocessing" + caching.MemCacheHandlerSingleton.create(mode, mem_cache_size) + + update_or_add_custom_hook( + self._recipe_cfg, + ConfigDict(type="MemCacheHook", priority="VERY_LOW"), + ) diff --git a/otx/algorithms/detection/adapters/mmdet/data/pipelines.py b/otx/algorithms/detection/adapters/mmdet/data/pipelines.py index 836282067c1..0b2850c054c 100644 --- a/otx/algorithms/detection/adapters/mmdet/data/pipelines.py +++ b/otx/algorithms/detection/adapters/mmdet/data/pipelines.py @@ -13,65 +13,21 @@ # See the License for the specific language governing permissions # and limitations under the License. import copy -import tempfile from typing import Any, Dict, Optional -import numpy as np from mmdet.datasets.builder import PIPELINES -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.entities.label import Domain from otx.api.utils.argument_checks import check_input_parameters_type from .dataset import get_annotation_mmdet_format -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - # pylint: disable=too-many-instance-attributes, too-many-arguments @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed. - - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), - std=np.ones(num_channels, dtype=np.float32), - to_rgb=False, - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/algorithms/detection/configs/detection/configuration.yaml b/otx/algorithms/detection/configs/detection/configuration.yaml index 77ce90478d5..c3ec7071df1 100644 --- a/otx/algorithms/detection/configs/detection/configuration.yaml +++ b/otx/algorithms/detection/configs/detection/configuration.yaml @@ -262,6 +262,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml index 51d7e9d3696..57693128302 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml +++ b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml @@ -262,6 +262,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py b/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py index e1cfdc4f2fa..9d0f0278954 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py +++ b/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. -import tempfile from copy import deepcopy from typing import Any, Dict, List @@ -24,55 +23,16 @@ from torchvision import transforms as T from torchvision.transforms import functional as F -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.utils.argument_checks import check_input_parameters_type from .dataset import get_annotation_mmseg_format -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - +# pylint: disable=too-many-instance-attributes, too-many-arguments @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed. - - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32), to_rgb=False - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/algorithms/segmentation/configs/configuration.yaml b/otx/algorithms/segmentation/configs/configuration.yaml index 689a77855c5..0da91d335ba 100644 --- a/otx/algorithms/segmentation/configs/configuration.yaml +++ b/otx/algorithms/segmentation/configs/configuration.yaml @@ -292,6 +292,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/api/entities/dataset_item.py b/otx/api/entities/dataset_item.py index 762db27f2b4..7976a03d2de 100644 --- a/otx/api/entities/dataset_item.py +++ b/otx/api/entities/dataset_item.py @@ -110,6 +110,8 @@ def __init__( if Rectangle.is_full_box(annotation.shape): roi = annotation break + if roi is None: + roi = Annotation(Rectangle.generate_full_box(), labels=[]) self.__roi = roi self.__metadata: List[MetadataItemEntity] = [] @@ -150,16 +152,13 @@ def __repr__(self): def roi(self) -> Annotation: """Region Of Interest.""" with self.__roi_lock: - if self.__roi is None: - requested_roi = Annotation(Rectangle.generate_full_box(), labels=[]) - self.__roi = requested_roi - else: - requested_roi = self.__roi - return requested_roi + return self.__roi @roi.setter def roi(self, roi: Optional[Annotation]): with self.__roi_lock: + if roi is None: + roi = Annotation(Rectangle.generate_full_box(), labels=[]) self.__roi = roi @property diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index 8e1662ebe83..bcc8e7ba878 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -33,6 +33,7 @@ from otx.cli.utils.io import read_binary, read_label_schema, save_model_data from otx.cli.utils.multi_gpu import MultiGPUManager from otx.cli.utils.parser import ( + MemSizeAction, add_hyper_parameters_sub_parser, get_parser_and_hprams_data, ) @@ -111,6 +112,16 @@ def get_args(): default=0, help="Total number of workers in a worker group.", ) + parser.add_argument( + "--mem-cache-size", + action=MemSizeAction, + dest="params.algo_backend.mem_cache_size", + type=str, + required=False, + help="Size of memory pool for caching decoded data to load data faster. " + "For example, you can use digits for bytes size (e.g. 1024) or a string with size units " + "(e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30).", + ) parser.add_argument( "--data", type=str, diff --git a/otx/cli/utils/parser.py b/otx/cli/utils/parser.py index c41cf16067e..6f6b40c474f 100644 --- a/otx/cli/utils/parser.py +++ b/otx/cli/utils/parser.py @@ -15,6 +15,7 @@ # and limitations under the License. import argparse +import re import sys from argparse import RawTextHelpFormatter from pathlib import Path @@ -24,6 +25,52 @@ from otx.cli.registry import find_and_parse_model_template +class MemSizeAction(argparse.Action): + """Parser add on to parse memory size string.""" + + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + expected_dest = "params.algo_backend.mem_cache_size" + if dest != expected_dest: + raise ValueError(f"dest should be {expected_dest}, but dest={dest}.") + super().__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + """Parse and set the attribute of namespace.""" + setattr(namespace, self.dest, self._parse_mem_size_str(values)) + + @staticmethod + def _parse_mem_size_str(mem_size: str) -> int: + assert isinstance(mem_size, str) + + match = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", mem_size.strip()) + + if match is None: + raise ValueError(f"Cannot parse {mem_size} string.") + + units = { + "": 1, + "B": 1, + "KB": 2**10, + "MB": 2**20, + "GB": 2**30, + "KIB": 10**3, + "MIB": 10**6, + "GIB": 10**9, + "K": 2**10, + "M": 2**20, + "G": 2**30, + } + + number, unit = int(match.group(1)), match.group(2).upper() + + if unit not in units: + raise ValueError(f"{mem_size} has disallowed unit ({unit}).") + + return number * units[unit] + + def gen_param_help(hyper_parameters: Dict) -> Dict: """Generates help for hyper parameters section.""" @@ -67,27 +114,38 @@ def gen_params_dict_from_args( ) -> Dict[str, dict]: """Generates hyper parameters dict from parsed command line arguments.""" + def _get_leaf_node(curr_dict: Dict[str, dict], curr_key: str): + split_key = curr_key.split(".") + node_key = split_key[0] + + if len(split_key) == 1: + # It is leaf node + return curr_dict, node_key + + # Dive deeper + curr_key = ".".join(split_key[1:]) + if node_key not in curr_dict: + curr_dict[node_key] = {} + return _get_leaf_node(curr_dict[node_key], curr_key) + + _prefix = "params." params_dict: Dict[str, dict] = {} for param_name in dir(args): - if not param_name.startswith("params."): + value = getattr(args, param_name) + + if not param_name.startswith(_prefix) or value is None: continue if override_param and param_name not in override_param: continue + # param_name.removeprefix(_prefix) + origin_key = param_name[len(_prefix) :] value_type = None - cur_dict = params_dict - split_param_name = param_name.split(".")[1:] - if type_hint: - origin_key = ".".join(split_param_name) - value_type = type_hint[origin_key].get("type", None) - for i, k in enumerate(split_param_name): - if k not in cur_dict: - cur_dict[k] = {} - if i < len(split_param_name) - 1: - cur_dict = cur_dict[k] - else: - value = getattr(args, param_name) - cur_dict[k] = {"value": value_type(value) if value_type else value} + if type_hint is not None: + value_type = type_hint.get(origin_key, {}).get("type", None) + + leaf_node_dict, node_key = _get_leaf_node(params_dict, origin_key) + leaf_node_dict[node_key] = {"value": value_type(value) if value_type else value} return params_dict diff --git a/otx/core/data/caching/__init__.py b/otx/core/data/caching/__init__.py new file mode 100644 index 00000000000..f604a62e843 --- /dev/null +++ b/otx/core/data/caching/__init__.py @@ -0,0 +1,9 @@ +"""Module for data caching.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .mem_cache_handler import MemCacheHandlerError, MemCacheHandlerSingleton +from .mem_cache_hook import MemCacheHook + +__all__ = ["MemCacheHandlerSingleton", "MemCacheHook", "MemCacheHandlerError"] diff --git a/otx/core/data/caching/mem_cache_handler.py b/otx/core/data/caching/mem_cache_handler.py new file mode 100644 index 00000000000..44cab53e051 --- /dev/null +++ b/otx/core/data/caching/mem_cache_handler.py @@ -0,0 +1,202 @@ +"""Memory cache handler implementations and singleton class to call them.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import ctypes as ct +import multiprocessing as mp +from multiprocessing.managers import DictProxy +from typing import Any, Dict, Optional, Union + +import numpy as np +from mmcv.runner import get_dist_info +from multiprocess.synchronize import Lock + +from otx.mpa.utils.logger import get_logger + +logger = get_logger() + + +class _DummyLock: + def __enter__(self, *args, **kwargs): + pass + + def __exit__(self, *args, **kwargs): + pass + + +class MemCacheHandlerBase: + """Base class for memory cache handler. + + It will be combined with LoadImageFromOTXDataset to store/retrieve the samples in memory. + """ + + def __init__(self, mem_size: int): + self._init_data_structs(mem_size) + + def _init_data_structs(self, mem_size: int): + self._arr = (ct.c_uint8 * mem_size)() + self._cur_page = ct.c_size_t(0) + self._cache_addr: Union[Dict, DictProxy] = {} + self._lock: Union[Lock, _DummyLock] = _DummyLock() + self._freeze = ct.c_bool(False) + + def __len__(self): + """Get the number of cached items.""" + return len(self._cache_addr) + + @property + def mem_size(self) -> int: + """Get the reserved memory pool size (bytes).""" + return len(self._arr) + + def get(self, key: Any) -> Optional[np.ndarray]: + """Try to look up the cached item with the given key. + + Args: + key (Any): A key for looking up the cached item + + Returns: + If succeed return np.ndarray, otherwise return None + """ + if self.mem_size == 0 or key not in self._cache_addr: + return None + + addr = self._cache_addr[key] + + offset, count, shape, strides = addr + + data = np.frombuffer(self._arr, dtype=np.uint8, count=count, offset=offset) + return np.lib.stride_tricks.as_strided(data, shape, strides) + + def put(self, key: Any, data: np.ndarray) -> Optional[int]: + """Try to store np.ndarray with a key to the reserved memory pool. + + Args: + key (Any): A key to store the cached item + data (np.ndarray): A data sample to store + + Returns: + Optional[int]: If succeed return the address of cached item in memory pool + """ + if self._freeze.value: + return None + + assert data.dtype == np.uint8 + + with self._lock: + new_page = self._cur_page.value + data.size + + if key in self._cache_addr or new_page > self.mem_size: + return None + + offset = ct.byref(self._arr, self._cur_page.value) + ct.memmove(offset, data.ctypes.data, data.size) + + self._cache_addr[key] = ( + self._cur_page.value, + data.size, + data.shape, + data.strides, + ) + self._cur_page.value = new_page + return new_page + + def __repr__(self): + """Representation for the current handler status.""" + perc = 100.0 * self._cur_page.value / self.mem_size if self.mem_size > 0 else 0.0 + return ( + f"{self.__class__.__name__} " + f"uses {self._cur_page.value} / {self.mem_size} ({perc:.1f}%) memory pool and " + f"store {len(self)} items." + ) + + def freeze(self): + """If frozen, it is impossible to store a new item anymore.""" + self._freeze.value = True + + def unfreeze(self): + """If unfrozen, it is possible to store a new item.""" + self._freeze.value = False + + +class MemCacheHandlerForSP(MemCacheHandlerBase): + """Memory caching handler for single processing. + + Use if PyTorch's DataLoader.num_workers == 0. + """ + + +class MemCacheHandlerForMP(MemCacheHandlerBase): + """Memory caching handler for multi processing. + + Use if PyTorch's DataLoader.num_workers > 0. + """ + + def _init_data_structs(self, mem_size: int): + self._arr = mp.Array(ct.c_uint8, mem_size, lock=False) + self._cur_page = mp.Value(ct.c_size_t, 0, lock=False) + + self._manager = mp.Manager() + self._cache_addr: DictProxy = self._manager.dict() + self._lock = mp.Lock() + self._freeze = mp.Value(ct.c_bool, False, lock=False) + + def __del__(self): + """When deleting, manager should also be shutdowned.""" + self._manager.shutdown() + + +class MemCacheHandlerError(Exception): + """Exception class for MemCacheHandler.""" + + +class MemCacheHandlerSingleton: + """A singleton class to create, delete and get MemCacheHandlerBase.""" + + instance: MemCacheHandlerBase + + @classmethod + def get(cls) -> MemCacheHandlerBase: + """Get the created MemCacheHandlerBase. + + If no one is created before, raise RuntimeError. + """ + if not hasattr(cls, "instance"): + cls_name = cls.__class__.__name__ + raise MemCacheHandlerError(f"Before calling {cls_name}.get(), you should call {cls_name}.create() first.") + + return cls.instance + + @classmethod + def create(cls, mode: str, mem_size: int) -> MemCacheHandlerBase: + """Create a new MemCacheHandlerBase instance. + + Args: + mode (str): There are two options: null, multiprocessing or singleprocessing. + mem_size (int): The size of memory pool (bytes). + """ + logger.info(f"Try to create a {mem_size} size memory pool.") + + _, world_size = get_dist_info() + if world_size > 1: + mem_size = mem_size // world_size + logger.info(f"Since world_size={world_size} > 1, each worker a {mem_size} size memory pool.") + + if mode == "null" or mem_size == 0: + cls.instance = MemCacheHandlerBase(mem_size=0) + cls.instance.freeze() + elif mode == "multiprocessing": + cls.instance = MemCacheHandlerForMP(mem_size) + elif mode == "singleprocessing": + cls.instance = MemCacheHandlerForSP(mem_size) + else: + raise MemCacheHandlerError(f"{mode} is unknown mode.") + + return cls.instance + + @classmethod + def delete(cls) -> None: + """Delete the existing MemCacheHandlerBase instance.""" + if hasattr(cls, "instance"): + del cls.instance diff --git a/otx/core/data/caching/mem_cache_hook.py b/otx/core/data/caching/mem_cache_hook.py new file mode 100644 index 00000000000..ecd48fa840a --- /dev/null +++ b/otx/core/data/caching/mem_cache_hook.py @@ -0,0 +1,33 @@ +"""Memory cache hook for logging and freezing MemCacheHandler.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mmcv.runner.hooks import HOOKS, Hook + +from .mem_cache_handler import MemCacheHandlerSingleton + + +@HOOKS.register_module() +class MemCacheHook(Hook): + """Memory cache hook for logging and freezing MemCacheHandler.""" + + def __init__(self) -> None: + self.handler = MemCacheHandlerSingleton.get() + # It is because the first evaluation comes at the very beginning of the training. + # We don't want to cache validation samples first. + self.handler.freeze() + + def before_epoch(self, runner): + """Before training, unfreeze the handler.""" + # We want to cache training samples first. + self.handler.unfreeze() + + def after_epoch(self, runner): + """After epoch. Log the handler statistics. + + To prevent it from skipping the validation samples, + this hook should have lower priority than CustomEvalHook. + """ + self.handler.freeze() + runner.logger.info(f"{self.handler}") diff --git a/otx/core/data/pipelines/__init__.py b/otx/core/data/pipelines/__init__.py new file mode 100644 index 00000000000..699c2577892 --- /dev/null +++ b/otx/core/data/pipelines/__init__.py @@ -0,0 +1,3 @@ +"""OTX data pipelines.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py new file mode 100644 index 00000000000..100c24c7e25 --- /dev/null +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -0,0 +1,89 @@ +"""Pipeline element that loads an image from a OTX Dataset on the fly.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from tempfile import TemporaryDirectory +from typing import Any, Dict, Tuple + +import numpy as np + +from otx.algorithms.common.utils.data import get_image +from otx.api.utils.argument_checks import check_input_parameters_type + +from ..caching import MemCacheHandlerError, MemCacheHandlerSingleton + +_CACHE_DIR = TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with + +# TODO: refactoring to common modules +# TODO: refactoring to Sphinx style. + + +class LoadImageFromOTXDataset: + """Pipeline element that loads an image from a OTX Dataset on the fly. + + Can do conversion to float 32 if needed. + Expected entries in the 'results' dict that should be passed to this pipeline element are: + results['dataset_item']: dataset_item from which to load the image + results['dataset_id']: id of the dataset to which the item belongs + results['index']: index of the item in the dataset + + :param to_float32: optional bool, True to convert images to fp32. defaults to False + """ + + @check_input_parameters_type() + def __init__(self, to_float32: bool = False): + self.to_float32 = to_float32 + try: + self.mem_cache_handler = MemCacheHandlerSingleton.get() + except MemCacheHandlerError: + # Create a null handler + MemCacheHandlerSingleton.create(mode="null", mem_size=0) + self.mem_cache_handler = MemCacheHandlerSingleton.get() + + @staticmethod + def _get_unique_key(results: Dict[str, Any]) -> Tuple: + # TODO: We should improve it by assigning an unique id to DatasetItemEntity. + # This is because there is a case which + # d_item.media.path is None, but d_item.media.data is not None + d_item = results["dataset_item"] + return d_item.media.path, d_item.roi.id + + @check_input_parameters_type() + def __call__(self, results: Dict[str, Any]): + """Callback function of LoadImageFromOTXDataset.""" + key = self._get_unique_key(results) + + img = self.mem_cache_handler.get(key) + + if img is None: + # Get image (possibly from cache) + img = get_image(results, _CACHE_DIR.name, to_float32=False) + self.mem_cache_handler.put(key, img) + + if self.to_float32: + img = img.astype(np.float32) + shape = img.shape + + if img.shape[0] != results["height"]: + results["height"] = img.shape[0] + + if img.shape[1] != results["width"]: + results["width"] = img.shape[1] + + filename = f"Dataset item index {results['index']}" + results["filename"] = filename + results["ori_filename"] = filename + results["img"] = img + results["img_shape"] = shape + results["ori_shape"] = shape + # Set initial values for default meta_keys + results["pad_shape"] = shape + num_channels = 1 if len(shape) < 3 else shape[2] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + results["img_fields"] = ["img"] + + return results diff --git a/otx/mpa/modules/hooks/eval_hook.py b/otx/mpa/modules/hooks/eval_hook.py index 920c5839a9b..eed4e743996 100644 --- a/otx/mpa/modules/hooks/eval_hook.py +++ b/otx/mpa/modules/hooks/eval_hook.py @@ -96,6 +96,7 @@ def single_gpu_test(model, data_loader): batch_size = data["img"].size(0) for _ in range(batch_size): prog_bar.update() + prog_bar.file.write("\n") return results diff --git a/tests/unit/cli/utils/test_parser.py b/tests/unit/cli/utils/test_parser.py index 8ad34f12b23..ece49aed58b 100644 --- a/tests/unit/cli/utils/test_parser.py +++ b/tests/unit/cli/utils/test_parser.py @@ -1,3 +1,7 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + from argparse import ArgumentParser, ArgumentTypeError from pathlib import Path @@ -5,6 +9,7 @@ from otx.cli.utils import parser as target_package from otx.cli.utils.parser import ( + MemSizeAction, add_hyper_parameters_sub_parser, gen_param_help, gen_params_dict_from_args, @@ -54,7 +59,6 @@ @e2e_pytest_unit def test_gen_param_help(): - param_help = gen_param_help(FAKE_HYPER_PARAMETERS) hp_type_map = { @@ -78,6 +82,7 @@ def mock_args(mocker): setattr(mock_args, "params.a.c", True) setattr(mock_args, "params.b", "fake") setattr(mock_args, "params.c", 10) + setattr(mock_args, "params.d", None) return mock_args @@ -93,6 +98,7 @@ def test_gen_params_dict_from_args(mock_args): assert param_dict["a"]["c"]["value"] is True assert param_dict["b"]["value"] == "fake" assert param_dict["c"]["value"] == 10 + assert "d" not in param_dict @e2e_pytest_unit @@ -206,3 +212,42 @@ def test_get_parser_and_hprams_data(mocker): assert hyper_parameters == {} assert params == ["params", "--left-args"] assert isinstance(parser, ArgumentParser) + + +@pytest.fixture +def fxt_argparse(): + parser = ArgumentParser() + parser.add_argument( + "--mem-cache-size", + dest="params.algo_backend.mem_cache_size", + action=MemSizeAction, + type=str, + required=False, + default=0, + ) + return parser + + +@pytest.mark.parametrize( + "mem_size_arg,expected", + [ + ("1561", 1561), + ("121k", 121 * (2**10)), + ("121kb", 121 * (2**10)), + ("121kib", 121 * (10**3)), + ("121m", 121 * (2**20)), + ("121mb", 121 * (2**20)), + ("121mib", 121 * (10**6)), + ("121g", 121 * (2**30)), + ("121gb", 121 * (2**30)), + ("121gib", 121 * (10**9)), + ("121as", None), + ("121dddd", None), + ], +) +def test_mem_size_parsing(fxt_argparse, mem_size_arg, expected): + try: + args = fxt_argparse.parse_args(["--mem-cache-size", mem_size_arg]) + assert getattr(args, "params.algo_backend.mem_cache_size") == expected + except ValueError: + assert expected is None diff --git a/tests/unit/core/__init__.py b/tests/unit/core/__init__.py new file mode 100644 index 00000000000..9c68be83ef0 --- /dev/null +++ b/tests/unit/core/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/core/data/__init__.py b/tests/unit/core/data/__init__.py index 2faffbe2b1f..9c68be83ef0 100644 --- a/tests/unit/core/data/__init__.py +++ b/tests/unit/core/data/__init__.py @@ -1,13 +1,3 @@ # Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. diff --git a/tests/unit/core/data/test_caching.py b/tests/unit/core/data/test_caching.py new file mode 100644 index 00000000000..fa7ae67dd2c --- /dev/null +++ b/tests/unit/core/data/test_caching.py @@ -0,0 +1,145 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import string +from unittest.mock import patch + +import numpy as np +import pytest +from torch.utils.data import DataLoader, Dataset + +from otx.api.entities.annotation import AnnotationSceneEntity, AnnotationSceneKind +from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.image import Image +from otx.core.data.caching import MemCacheHandlerSingleton +from otx.core.data.pipelines.load_image_from_otx_dataset import LoadImageFromOTXDataset + + +@pytest.fixture +def fxt_data_list(): + np.random.seed(3003) + + num_data = 10 + h = w = key_len = 16 + + data_list = [] + for _ in range(num_data): + data = np.random.randint(0, 256, size=[h, w, 3], dtype=np.uint8) + key = "".join( + [string.ascii_lowercase[i] for i in np.random.randint(0, len(string.ascii_lowercase), size=[key_len])] + ) + data_list += [(key, data)] + + return data_list + + +@pytest.fixture +def fxt_caching_dataset_cls(fxt_data_list: list): + class CachingDataset(Dataset): + def __init__(self) -> None: + super().__init__() + self.d_items = [ + DatasetItemEntity( + media=Image(data=data), + annotation_scene=AnnotationSceneEntity(annotations=[], kind=AnnotationSceneKind.ANNOTATION), + ) + for _, data in fxt_data_list + ] + self.load = LoadImageFromOTXDataset() + + def __len__(self): + return len(self.d_items) + + def __getitem__(self, index): + d_item = self.d_items[index] + + results = { + "dataset_item": d_item, + "height": d_item.media.numpy.shape[0], + "width": d_item.media.numpy.shape[1], + "index": index, + } + + results = self.load(results) + return results["img"] + + yield CachingDataset + + +def get_data_list_size(data_list): + size = 0 + for _, data in data_list: + size += data.size + return size + + +class TestMemCacheHandler: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_fully_caching(self, mode, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) + MemCacheHandlerSingleton.create(mode, mem_size) + handler = MemCacheHandlerSingleton.get() + + for key, data in fxt_data_list: + assert handler.put(key, data) > 0 + + for key, data in fxt_data_list: + get_data = handler.get(key) + + assert np.array_equal(get_data, data) + + # Fully cached + assert len(handler) == len(fxt_data_list) + + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_unfully_caching(self, mode, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) // 2 + MemCacheHandlerSingleton.create(mode, mem_size) + handler = MemCacheHandlerSingleton.get() + + for idx, (key, data) in enumerate(fxt_data_list): + if idx < len(fxt_data_list) // 2: + assert handler.put(key, data) > 0 + else: + assert handler.put(key, data) is None + + for idx, (key, data) in enumerate(fxt_data_list): + get_data = handler.get(key) + + if idx < len(fxt_data_list) // 2: + assert np.array_equal(get_data, data) + else: + assert get_data is None + + # Unfully (half) cached + assert len(handler) == len(fxt_data_list) // 2 + + +class TestLoadImageFromFileWithCache: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) + MemCacheHandlerSingleton.create(mode, mem_size) + + dataset = fxt_caching_dataset_cls() + + with patch( + "otx.core.data.pipelines.load_image_from_otx_dataset.get_image", + side_effect=[data for _, data in fxt_data_list], + ) as mock: + for _ in DataLoader(dataset): + continue + + # This initial round requires all data samples to be read from disk. + assert mock.call_count == len(dataset) + + with patch( + "otx.core.data.pipelines.load_image_from_otx_dataset.get_image", + side_effect=[data for _, data in fxt_data_list], + ) as mock: + for _ in DataLoader(dataset): + continue + + # The second round requires no read. + assert mock.call_count == 0