Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add in-memory caching in dataloader #1694

Merged
merged 22 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7a20152
Add MemCacheHandler
vinnamkim Feb 13, 2023
d78f00c
Attach MemCacheHandler to otx.train
vinnamkim Feb 14, 2023
bc367ab
Refactor and fix unit tests
vinnamkim Feb 14, 2023
06d3eaa
Update configuration.yaml for other tasks
vinnamkim Feb 14, 2023
93cb940
Update QUICK_START_GUIDE.md
vinnamkim Feb 14, 2023
ce1515f
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Feb 14, 2023
ff59879
Fix comments
vinnamkim Feb 14, 2023
7cd7166
Rollback DatasetItemEntity.roi
vinnamkim Feb 14, 2023
90f8c99
Fix DatasetItemEntity.roi getter and setter
vinnamkim Feb 14, 2023
b153531
Add yaml recipes to package_data
vinnamkim Feb 15, 2023
6116934
Fix Codacy error
vinnamkim Feb 15, 2023
837b1c4
Change find path from otx/recipes to otx
vinnamkim Feb 15, 2023
e4d4331
Merge branch 'vinnamki/fix-add-yaml-recipe' into vinnamki/add-mem-cac…
vinnamkim Feb 15, 2023
325af51
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Feb 15, 2023
30c9f0d
Fix MemSizeAction to follow the existing hyperparams parsing rule
vinnamkim Feb 16, 2023
ca7681f
Fix for anomaly and action
vinnamkim Feb 16, 2023
984b7dd
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Feb 16, 2023
cf89189
Fix unit test
vinnamkim Feb 16, 2023
0df913d
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Feb 17, 2023
210bac2
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Feb 17, 2023
43f32a4
Merge remote-tracking branch 'origin/develop' into vinnamki/add-mem-c…
vinnamkim Mar 6, 2023
8549684
Clean up some code
vinnamkim Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions otx/algorithms/action/configs/classification/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ algo_backend:
value: INCREMENTAL
visible_in_ui: True
warning: null
mem_cache_size:
affects_outcome_of: TRAINING
default_value: 0
description: Size of memory pool for caching decoded data to load data faster (bytes).
editable: true
header: Size of memory pool
max_value: 9223372036854775807
min_value: 0
type: INTEGER
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
type: CONFIGURABLE_PARAMETERS
Expand Down
16 changes: 16 additions & 0 deletions otx/algorithms/action/configs/detection/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ algo_backend:
value: INCREMENTAL
visible_in_ui: True
warning: null
mem_cache_size:
affects_outcome_of: TRAINING
default_value: 0
description: Size of memory pool for caching decoded data to load data faster (bytes).
editable: true
header: Size of memory pool
max_value: 9223372036854775807
min_value: 0
type: INTEGER
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
type: CONFIGURABLE_PARAMETERS
Expand Down
50 changes: 3 additions & 47 deletions otx/algorithms/classification/adapters/mmcls/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import copy
import tempfile
from typing import Any, Dict, List

import numpy as np
Expand All @@ -13,59 +12,16 @@
from PIL import Image, ImageFilter
from torchvision import transforms as T

from otx.algorithms.common.utils.data import get_image
import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
from otx.api.utils.argument_checks import check_input_parameters_type

_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with

# TODO: refactoring to common modules
# TODO: refactoring to Sphinx style.


@PIPELINES.register_module()
class LoadImageFromOTXDataset:
"""Pipeline element that loads an image from a OTX Dataset on the fly.

Can do conversion to float 32 if needed.
Expected entries in the 'results' dict that should be passed to this pipeline element are:
results['dataset_item']: dataset_item from which to load the image
results['dataset_id']: id of the dataset to which the item belongs
results['index']: index of the item in the dataset

:param to_float32: optional bool, True to convert images to fp32. defaults to False
"""

@check_input_parameters_type()
def __init__(self, to_float32: bool = False):
self.to_float32 = to_float32

@check_input_parameters_type()
def __call__(self, results: Dict[str, Any]):
"""Callback function of LoadImageFromOTXDataset."""
# Get image (possibly from cache)
img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32)
shape = img.shape

assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}"
assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}"

filename = f"Dataset item index {results['index']}"
results["filename"] = filename
results["ori_filename"] = filename
results["img"] = img
results["img_shape"] = shape
results["ori_shape"] = shape
# Set initial values for default meta_keys
results["pad_shape"] = shape
num_channels = 1 if len(shape) < 3 else shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False,
)
results["img_fields"] = ["img"]

return results
class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
"""Pipeline element that loads an image from a OTX Dataset on the fly."""


@PIPELINES.register_module()
Expand Down
16 changes: 16 additions & 0 deletions otx/algorithms/classification/configs/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -354,5 +354,21 @@ algo_backend:
value: INCREMENTAL
visible_in_ui: True
warning: null
mem_cache_size:
affects_outcome_of: TRAINING
default_value: 0
description: Size of memory pool for caching decoded data to load data faster (bytes).
editable: true
header: Size of memory pool
max_value: 9223372036854775807
min_value: 0
type: INTEGER
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
10 changes: 10 additions & 0 deletions otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ class BaseAlgoBackendParameters(ParameterGroup):
visible_in_ui=True,
)

mem_cache_size = configurable_integer(
header="Size of memory pool for caching decoded data to load data faster",
description="Size of memory pool for caching decoded data to load data faster",
default_value=0,
min_value=0,
max_value=maxsize,
visible_in_ui=False,
affects_outcome_of=ModelLifecycle.TRAINING,
)

@attrs
class BaseTilingParameters(ParameterGroup):
"""BaseTilingParameters for OTX Algorithms."""
Expand Down
33 changes: 32 additions & 1 deletion otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask
from otx.api.usecases.tasks.interfaces.unload_interface import IUnload
from otx.api.utils.argument_checks import check_input_parameters_type
from otx.core.data import caching
from otx.mpa.builder import build
from otx.mpa.modules.hooks.cancel_interface_hook import CancelInterfaceHook
from otx.mpa.stage import Stage
Expand Down Expand Up @@ -324,6 +325,9 @@ def _initialize(self, options=None): # noqa: C901
dataloader_cfg["persistent_workers"] = False
data_cfg[f"{subset}_dataloader"] = dataloader_cfg

# Update recipe with caching modules
self._update_caching_modules(data_cfg)

if self._data_cfg is not None:
align_data_config_with_recipe(self._data_cfg, self._recipe_cfg)

Expand Down Expand Up @@ -412,7 +416,6 @@ def _init_deploy_cfg(self) -> Union[Config, None]:
deploy_cfg = MPAConfig.fromfile(deploy_cfg_path)

def patch_input_preprocessing(deploy_cfg):

normalize_cfg = get_configs_by_pairs(
self._recipe_cfg.data.test.pipeline,
dict(type="Normalize"),
Expand Down Expand Up @@ -620,3 +623,31 @@ def set_early_stopping_hook(self):
update_or_add_custom_hook(self._recipe_cfg, early_stop_hook)
else:
remove_custom_hook(self._recipe_cfg, "LazyEarlyStoppingHook")

def _update_caching_modules(self, data_cfg: Config) -> None:
def _find_max_num_workers(cfg: dict):
num_workers = [0]
for key, value in cfg.items():
if key == "workers_per_gpu" and isinstance(value, int):
num_workers += [value]
elif isinstance(value, dict):
num_workers += [_find_max_num_workers(value)]

return max(num_workers)

def _get_mem_cache_size():
if not hasattr(self.hyperparams.algo_backend, "mem_cache_size"):
return 0

return self.hyperparams.algo_backend.mem_cache_size

max_num_workers = _find_max_num_workers(data_cfg)
mem_cache_size = _get_mem_cache_size()

mode = "multiprocessing" if max_num_workers > 0 else "singleprocessing"
caching.MemCacheHandlerSingleton.create(mode, mem_cache_size)

update_or_add_custom_hook(
self._recipe_cfg,
ConfigDict(type="MemCacheHook", priority="VERY_LOW"),
)
50 changes: 3 additions & 47 deletions otx/algorithms/detection/adapters/mmdet/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,21 @@
# See the License for the specific language governing permissions
# and limitations under the License.
import copy
import tempfile
from typing import Any, Dict, Optional

import numpy as np
from mmdet.datasets.builder import PIPELINES

from otx.algorithms.common.utils.data import get_image
import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
from otx.api.entities.label import Domain
from otx.api.utils.argument_checks import check_input_parameters_type

from .dataset import get_annotation_mmdet_format

_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with


# pylint: disable=too-many-instance-attributes, too-many-arguments
@PIPELINES.register_module()
class LoadImageFromOTXDataset:
"""Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed.

Expected entries in the 'results' dict that should be passed to this pipeline element are:
results['dataset_item']: dataset_item from which to load the image
results['dataset_id']: id of the dataset to which the item belongs
results['index']: index of the item in the dataset

:param to_float32: optional bool, True to convert images to fp32. defaults to False
"""

@check_input_parameters_type()
def __init__(self, to_float32: bool = False):
self.to_float32 = to_float32

@check_input_parameters_type()
def __call__(self, results: Dict[str, Any]):
"""Callback function LoadImageFromOTXDataset."""
# Get image (possibly from cache)
img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32)
shape = img.shape

assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}"
assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}"

filename = f"Dataset item index {results['index']}"
results["filename"] = filename
results["ori_filename"] = filename
results["img"] = img
results["img_shape"] = shape
results["ori_shape"] = shape
# Set initial values for default meta_keys
results["pad_shape"] = shape
num_channels = 1 if len(shape) < 3 else shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False,
)
results["img_fields"] = ["img"]

return results
class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
"""Pipeline element that loads an image from a OTX Dataset on the fly."""


@PIPELINES.register_module()
Expand Down
16 changes: 16 additions & 0 deletions otx/algorithms/detection/configs/detection/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,22 @@ algo_backend:
value: INCREMENTAL
visible_in_ui: True
warning: null
mem_cache_size:
affects_outcome_of: TRAINING
default_value: 0
description: Size of memory pool for caching decoded data to load data faster (bytes).
editable: true
header: Size of memory pool
max_value: 9223372036854775807
min_value: 0
type: INTEGER
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
type: CONFIGURABLE_PARAMETERS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,22 @@ algo_backend:
value: INCREMENTAL
visible_in_ui: True
warning: null
mem_cache_size:
affects_outcome_of: TRAINING
default_value: 0
description: Size of memory pool for caching decoded data to load data faster (bytes).
editable: true
header: Size of memory pool
max_value: 9223372036854775807
min_value: 0
type: INTEGER
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
type: CONFIGURABLE_PARAMETERS
Expand Down
48 changes: 4 additions & 44 deletions otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import tempfile
from copy import deepcopy
from typing import Any, Dict, List

Expand All @@ -24,55 +23,16 @@
from torchvision import transforms as T
from torchvision.transforms import functional as F

from otx.algorithms.common.utils.data import get_image
import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
from otx.api.utils.argument_checks import check_input_parameters_type

from .dataset import get_annotation_mmseg_format

_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with


# pylint: disable=too-many-instance-attributes, too-many-arguments
@PIPELINES.register_module()
class LoadImageFromOTXDataset:
"""Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed.

Expected entries in the 'results' dict that should be passed to this pipeline element are:
results['dataset_item']: dataset_item from which to load the image
results['dataset_id']: id of the dataset to which the item belongs
results['index']: index of the item in the dataset

:param to_float32: optional bool, True to convert images to fp32. defaults to False
"""

@check_input_parameters_type()
def __init__(self, to_float32: bool = False):
self.to_float32 = to_float32

@check_input_parameters_type()
def __call__(self, results: Dict[str, Any]):
"""Callback function LoadImageFromOTXDataset."""
# Get image (possibly from cache)
img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32)
shape = img.shape

assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}"
assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}"

filename = f"Dataset item index {results['index']}"
results["filename"] = filename
results["ori_filename"] = filename
results["img"] = img
results["img_shape"] = shape
results["ori_shape"] = shape
# Set initial values for default meta_keys
results["pad_shape"] = shape
num_channels = 1 if len(shape) < 3 else shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32), to_rgb=False
)
results["img_fields"] = ["img"]

return results
class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
"""Pipeline element that loads an image from a OTX Dataset on the fly."""


@PIPELINES.register_module()
Expand Down
Loading