-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from JeffersonLab/36-make-pandasstandardscaler…
…-for-hugs 36 make pandasstandardscaler for hugs
- Loading branch information
Showing
6 changed files
with
503 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
226 changes: 226 additions & 0 deletions
226
jlab_datascience_toolkit/data_prep/pandas_standard_scaler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from jlab_datascience_toolkit.core.jdst_data_prep import JDSTDataPrep | ||
from jlab_datascience_toolkit.utils.io import save_yaml_config, load_yaml_config | ||
from pathlib import Path | ||
import pandas as pd | ||
import numpy as np | ||
import logging | ||
import inspect | ||
import yaml | ||
import os | ||
|
||
prep_log = logging.getLogger("Prep Logger") | ||
|
||
|
||
def _fix_small_scales(scale, epsilon): | ||
"""Updates scale parameters below epsilon to 1 to prevent issues with small divisors | ||
Args: | ||
scale (array_like): Scale parameters to (potentially) fix | ||
epsilon (float): Smallest allowable value for scale parameters | ||
Returns: | ||
array_like: Updated scale parameters | ||
""" | ||
return np.where(scale < epsilon, 1, scale) | ||
|
||
|
||
class PandasStandardScaler(JDSTDataPrep): | ||
"""Module performs standard scaling on Pandas DataFrames. | ||
Intialization arguments: | ||
config: dict | ||
Optional configuration keys: | ||
axis: int = 0 | ||
Axis to perform scaling on. Accepts 0,1 or None. Defaults to 0. | ||
epsilon: float = 1e-7 | ||
Smallest allowable value for the standard deviation. Defaults to 1e-7. | ||
If smaller than epsilon, the output variance will not be modified. | ||
This avoids exploding small noise variance values. | ||
inplace: bool = False | ||
If True, operations modify the original DataFrame. Defaults to False. | ||
Attributes | ||
---------- | ||
name : str | ||
Name of the module | ||
config: dict | ||
Configuration information | ||
Methods | ||
------- | ||
get_info() | ||
Prints this docstring | ||
load(path) | ||
Loads this module (including fit scaler parameters) from `path` | ||
save(path) | ||
Saves this module (including fit scaler parameters) to `path` | ||
load_config(path) | ||
Loads a configuration file. Scaler parameters will be fit to new data. | ||
save_config(path) | ||
Calls `save(path)` | ||
run(data) | ||
Performs standard scaling on `data`. If the scaler has not been previously | ||
fit, the scaler parameters will be fit to `data`. Otherwise, the scaling | ||
will utilize mean and variance information from the most recent `fit()` call. | ||
fit(data) | ||
Sets scaler parameters for mean and variance based on `data` | ||
reverse(data) | ||
Performs inverse scaling on `data`. | ||
save_data(path) | ||
Does nothing. | ||
""" | ||
|
||
def __init__(self, config: dict = None, registry_config: dict = None): | ||
# Set default config | ||
self.config = dict(axis=0, epsilon=1e-7, inplace=False) | ||
|
||
if registry_config is not None: | ||
self.config.update(registry_config) | ||
if config is not None: | ||
self.config.update(config) | ||
|
||
self.setup() | ||
|
||
@property | ||
def name(self): | ||
return "PandasStandardScaler_v0" | ||
|
||
def setup(self): | ||
self.mean = None | ||
self.var = None | ||
self.scale = None | ||
self.n_samples = 0 | ||
|
||
def get_info(self): | ||
"""Prints this module's docstring.""" | ||
print(inspect.getdoc(self)) | ||
|
||
def save(self, path: str): | ||
"""Save entire module to a folder at `path` | ||
Args: | ||
path (str): Location to save the module. This path must not currently exist. | ||
""" | ||
os.makedirs(path) | ||
self.save_config(path) | ||
self.save_internal_state(path) | ||
|
||
def load(self, path: str): | ||
"""Load entire saved module from `path` | ||
Args: | ||
path (str): Directory to load module from. Should include a config.yaml | ||
and scaler_state.npz files. | ||
""" | ||
self.load_config(path) | ||
self.load_internal_state(path) | ||
|
||
def save_config(self, path: str, overwrite: bool = False): | ||
"""Save the module configuration to a folder at `path` | ||
Args: | ||
path (str): Location to save the module config yaml file | ||
overwrite (bool, optional): If True, overwrites file at path if it exists. | ||
Defaults to False. | ||
""" | ||
save_dir = Path(path) | ||
save_yaml_config(self.config, save_dir, overwrite) | ||
|
||
def load_config(self, path: str): | ||
"""Load the entire module state from `path` | ||
Args: | ||
path (str): Path to folder containing module files. | ||
""" | ||
base_path = Path(path) | ||
self.config.update(load_yaml_config(base_path)) | ||
self.setup() | ||
|
||
def save_internal_state(self, path: str): | ||
internal_state = dict( | ||
mean=self.mean, var=self.var, scale=self.scale, n_samples=self.n_samples | ||
) | ||
save_dir = Path(path) | ||
if not save_dir.exists(): | ||
os.makedirs(save_dir) | ||
np.savez(save_dir.joinpath("scaler_state.npz"), **internal_state) | ||
|
||
def load_internal_state(self, path: str): | ||
save_dir = Path(path) | ||
internal_state = np.load(save_dir.joinpath("scaler_state.npz")) | ||
self.mean = internal_state["mean"] | ||
self.var = internal_state["var"] | ||
self.scale = internal_state["scale"] | ||
self.n_samples = internal_state["n_samples"] | ||
|
||
def run(self, data: pd.DataFrame): | ||
if self.mean is None: | ||
prep_log.debug("Fitting new data on run()") | ||
self.fit(data) | ||
|
||
return self.transform(data) | ||
|
||
def reverse(self, data: pd.DataFrame): | ||
"""Performs inverse scaling on `data` | ||
Args: | ||
data (pd.DataFrame): Data to perform inverse scaling on. | ||
Returns: | ||
pd.DataFrame: Inverse scaled DataFrame | ||
""" | ||
return self.inverse_transform(data) | ||
|
||
def fit(self, data: pd.DataFrame): | ||
"""Sets internal scaler parameters based on the mean and variance of `data` | ||
Args: | ||
data (pd.DataFrame): DataFrame used to fit the scaler | ||
""" | ||
# Since we do not modify data here, we can avoid a copy using np.asarray | ||
data_view = np.asarray(data) | ||
self.mean = np.mean(data_view, axis=self.config["axis"]) | ||
self.var = np.var(data_view, axis=self.config["axis"]) | ||
self.scale = _fix_small_scales(np.sqrt(self.var), self.config["epsilon"]) | ||
self.n_samples = data.shape[0] | ||
|
||
def transform(self, data): | ||
if self.mean is None: | ||
raise RuntimeError() | ||
data_view = np.array(data, copy=not self.config["inplace"]) | ||
if self.config["axis"] is not None: | ||
data_rotated = np.rollaxis(data_view, self.config["axis"]) | ||
else: | ||
data_rotated = data_view | ||
data_rotated -= self.mean | ||
data_rotated /= self.scale | ||
|
||
if self.config["inplace"]: | ||
return | ||
|
||
output = data.copy() | ||
output.values[:] = data_view | ||
return output | ||
|
||
def inverse_transform(self, data): | ||
if self.mean is None: | ||
raise RuntimeError() | ||
data_view = np.array(data, copy=not self.config["inplace"]) | ||
if self.config["axis"] is not None: | ||
data_rotated = np.rollaxis(data_view, self.config["axis"]) | ||
else: | ||
data_rotated = data_view | ||
data_rotated *= self.scale | ||
data_rotated += self.mean | ||
|
||
if self.config["inplace"]: | ||
return | ||
|
||
output = data.copy() | ||
output.values[:] = data_view | ||
return output | ||
|
||
def save_data(self, data): | ||
super().save_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import yaml | ||
from pathlib import Path | ||
import tempfile | ||
import logging | ||
import sys | ||
|
||
io_log = logging.getLogger('io_log') | ||
|
||
def save_yaml_config(config: dict, path: str | Path, overwrite: bool = False): | ||
""" Saves configuration dictionary to a yaml file | ||
Args: | ||
config (dict): Dictionary to save | ||
path (str | Path): Location to save configuration. | ||
If `path` does not exist, it will be created. | ||
If `path` is a directory, the configuration will be saved to config.yaml | ||
If `path` is a filename, the configuration will be saved to that filename | ||
overwrite (bool, optional): If True, the passed configuration will overwrite any existing | ||
file with the same `path`. Defaults to False. | ||
Raises: | ||
FileExistsError: If `path` exists and `overwrite==False` a FileExistsError will be raised. | ||
""" | ||
path = Path(path) | ||
|
||
if path.is_dir(): | ||
io_log.info('path.is_dir() == True') | ||
path = path.joinpath('config.yaml') | ||
|
||
path.parent.mkdir(exist_ok=True) | ||
|
||
if path.exists() and not overwrite: | ||
io_log.error(f'File {path} exists without overwrite flag set') | ||
raise FileExistsError('File already exists. Set overwrite=True if you would like to overwrite it.') | ||
|
||
with open(path, 'w') as f: | ||
io_log.info(f'Writing config to {path}') | ||
yaml.safe_dump(config, f) | ||
|
||
def load_yaml_config(path: str | Path): | ||
path = Path(path) | ||
if path.is_dir(): | ||
path = path.joinpath('config.yaml') | ||
|
||
if not path.exists(): | ||
io_log.error(f'Configuration file {path} not found.') | ||
raise FileNotFoundError(f'Configuration file {path} not found.') | ||
|
||
with open(path, 'r') as f: | ||
return yaml.safe_load(f) |
Oops, something went wrong.