From b6a512c036101524b716d615f9e394baa4293bcd Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 20 Dec 2023 06:14:03 +0000 Subject: [PATCH] change Normalize, Log1p, CropData, SqueezeData to inplace transform to avoid copying data for efficiency --- ppsci/data/process/transform/preprocess.py | 56 +++++++++++----------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/ppsci/data/process/transform/preprocess.py b/ppsci/data/process/transform/preprocess.py index 5552dfe6f..25986cc29 100644 --- a/ppsci/data/process/transform/preprocess.py +++ b/ppsci/data/process/transform/preprocess.py @@ -71,6 +71,8 @@ def __call__(self, input_dict, label_dict, weight_dict): class Normalize: """Normalize data class. + NOTE: This transform will modify the input data dict inplace. + Args: mean (Union[np.ndarray, Tuple[float, ...]]): Mean of training dataset. std (Union[np.ndarray, Tuple[float, ...]]): Standard Deviation of training dataset. @@ -96,20 +98,20 @@ def __init__( self.apply_keys = apply_keys def __call__(self, input_item, label_item, weight_item): - input_item_copy = {**input_item} - label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item_copy.items(): - input_item_copy[key] = (value - self.mean) / self.std + for key, value in input_item.items(): + input_item[key] = (value - self.mean) / self.std if "label" in self.apply_keys: - for key, value in label_item_copy.items(): - label_item_copy[key] = (value - self.mean) / self.std - return input_item_copy, label_item_copy, weight_item + for key, value in label_item.items(): + label_item[key] = (value - self.mean) / self.std + return input_item, label_item, weight_item class Log1p: """Calculates the natural logarithm of one plus the data, element-wise. + NOTE: This transform will modify the input data dict inplace. + Args: scale (float, optional): Scale data. Defaults to 1.0. apply_keys (Tuple[str, ...], optional): Which data is the log1p method applied to. Defaults to ("input", "label"). @@ -132,15 +134,13 @@ def __init__( self.apply_keys = apply_keys def __call__(self, input_item, label_item, weight_item): - input_item_copy = {**input_item} - label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item_copy.items(): - input_item_copy[key] = np.log1p(value / self.scale) + for key, value in input_item.items(): + input_item[key] = np.log1p(value / self.scale) if "label" in self.apply_keys: - for key, value in label_item_copy.items(): - label_item_copy[key] = np.log1p(value / self.scale) - return input_item_copy, label_item_copy, weight_item + for key, value in label_item.items(): + label_item[key] = np.log1p(value / self.scale) + return input_item, label_item, weight_item class CropData: @@ -148,6 +148,8 @@ class CropData: This class is used to crop data based on a specified bounding box. + NOTE: This transform will modify the input data dict inplace. + Args: xmin (Tuple[int, ...]): Bottom left corner point, [x0, y0]. xmax (Tuple[int, ...]): Top right corner point, [x1, y1]. @@ -182,24 +184,24 @@ def __init__( self.apply_keys = apply_keys def __call__(self, input_item, label_item, weight_item): - input_item_copy = {**input_item} - label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item_copy.items(): - input_item_copy[key] = value[ + for key, value in input_item.items(): + input_item[key] = value[ :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] ] if "label" in self.apply_keys: - for key, value in label_item_copy.items(): - label_item_copy[key] = value[ + for key, value in label_item.items(): + label_item[key] = value[ :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] ] - return input_item_copy, label_item_copy, weight_item + return input_item, label_item, weight_item class SqueezeData: """Squeeze data class. + NOTE: This transform will modify the input data dict inplace. + Args: apply_keys (Tuple[str, ...], optional): Which data is the squeeze method applied to. Defaults to ("input", "label"). @@ -216,27 +218,25 @@ def __init__(self, apply_keys: Tuple[str, ...] = ("input", "label")): self.apply_keys = apply_keys def __call__(self, input_item, label_item, weight_item): - input_item_copy = {**input_item} - label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item_copy.items(): + for key, value in input_item.items(): if value.ndim == 4: B, C, H, W = value.shape - input_item_copy[key] = value.reshape((B * C, H, W)) + input_item[key] = value.reshape((B * C, H, W)) if value.ndim != 3: raise ValueError( f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" ) if "label" in self.apply_keys: - for key, value in label_item_copy.items(): + for key, value in label_item.items(): if value.ndim == 4: B, C, H, W = value.shape - label_item_copy[key] = value.reshape((B * C, H, W)) + label_item[key] = value.reshape((B * C, H, W)) if value.ndim != 3: raise ValueError( f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" ) - return input_item_copy, label_item_copy, weight_item + return input_item, label_item, weight_item class FunctionalTransform: