Skip to content

Commit

Permalink
change Normalize, Log1p, CropData, SqueezeData to inplace transform t…
Browse files Browse the repository at this point in the history
…o avoid copying data for efficiency
  • Loading branch information
HydrogenSulfate committed Dec 20, 2023
1 parent fd47ef5 commit b6a512c
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions ppsci/data/process/transform/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __call__(self, input_dict, label_dict, weight_dict):
class Normalize:
"""Normalize data class.
NOTE: This transform will modify the input data dict inplace.
Args:
mean (Union[np.ndarray, Tuple[float, ...]]): Mean of training dataset.
std (Union[np.ndarray, Tuple[float, ...]]): Standard Deviation of training dataset.
Expand All @@ -96,20 +98,20 @@ def __init__(
self.apply_keys = apply_keys

def __call__(self, input_item, label_item, weight_item):
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item_copy.items():
input_item_copy[key] = (value - self.mean) / self.std
for key, value in input_item.items():
input_item[key] = (value - self.mean) / self.std
if "label" in self.apply_keys:
for key, value in label_item_copy.items():
label_item_copy[key] = (value - self.mean) / self.std
return input_item_copy, label_item_copy, weight_item
for key, value in label_item.items():
label_item[key] = (value - self.mean) / self.std
return input_item, label_item, weight_item


class Log1p:
"""Calculates the natural logarithm of one plus the data, element-wise.
NOTE: This transform will modify the input data dict inplace.
Args:
scale (float, optional): Scale data. Defaults to 1.0.
apply_keys (Tuple[str, ...], optional): Which data is the log1p method applied to. Defaults to ("input", "label").
Expand All @@ -132,22 +134,22 @@ def __init__(
self.apply_keys = apply_keys

def __call__(self, input_item, label_item, weight_item):
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item_copy.items():
input_item_copy[key] = np.log1p(value / self.scale)
for key, value in input_item.items():
input_item[key] = np.log1p(value / self.scale)
if "label" in self.apply_keys:
for key, value in label_item_copy.items():
label_item_copy[key] = np.log1p(value / self.scale)
return input_item_copy, label_item_copy, weight_item
for key, value in label_item.items():
label_item[key] = np.log1p(value / self.scale)
return input_item, label_item, weight_item


class CropData:
"""Crop data class.
This class is used to crop data based on a specified bounding box.
NOTE: This transform will modify the input data dict inplace.
Args:
xmin (Tuple[int, ...]): Bottom left corner point, [x0, y0].
xmax (Tuple[int, ...]): Top right corner point, [x1, y1].
Expand Down Expand Up @@ -182,24 +184,24 @@ def __init__(
self.apply_keys = apply_keys

def __call__(self, input_item, label_item, weight_item):
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item_copy.items():
input_item_copy[key] = value[
for key, value in input_item.items():
input_item[key] = value[
:, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1]
]
if "label" in self.apply_keys:
for key, value in label_item_copy.items():
label_item_copy[key] = value[
for key, value in label_item.items():
label_item[key] = value[
:, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1]
]
return input_item_copy, label_item_copy, weight_item
return input_item, label_item, weight_item


class SqueezeData:
"""Squeeze data class.
NOTE: This transform will modify the input data dict inplace.
Args:
apply_keys (Tuple[str, ...], optional): Which data is the squeeze method applied to. Defaults to ("input", "label").
Expand All @@ -216,27 +218,25 @@ def __init__(self, apply_keys: Tuple[str, ...] = ("input", "label")):
self.apply_keys = apply_keys

def __call__(self, input_item, label_item, weight_item):
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item_copy.items():
for key, value in input_item.items():
if value.ndim == 4:
B, C, H, W = value.shape
input_item_copy[key] = value.reshape((B * C, H, W))
input_item[key] = value.reshape((B * C, H, W))
if value.ndim != 3:
raise ValueError(
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
)
if "label" in self.apply_keys:
for key, value in label_item_copy.items():
for key, value in label_item.items():
if value.ndim == 4:
B, C, H, W = value.shape
label_item_copy[key] = value.reshape((B * C, H, W))
label_item[key] = value.reshape((B * C, H, W))
if value.ndim != 3:
raise ValueError(
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
)
return input_item_copy, label_item_copy, weight_item
return input_item, label_item, weight_item


class FunctionalTransform:
Expand Down

0 comments on commit b6a512c

Please sign in to comment.