Skip to content

Commit

Permalink
Merge pull request #223 from mims-harvard/avelez-dev
Browse files Browse the repository at this point in the history
Modernizing / productionize : google lint applied to tdc/* and conda env modified to pass all tests
  • Loading branch information
amva13 authored Mar 5, 2024
2 parents 048e975 + d49f385 commit 9a5256a
Show file tree
Hide file tree
Showing 68 changed files with 1,091 additions and 1,022 deletions.
9 changes: 8 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ name: tdc-conda-env
channels:
- conda-forge
- defaults
- pyg
- pytorch
dependencies:
- dataclasses=0.8
- fuzzywuzzy=0.18.0
Expand All @@ -10,11 +12,16 @@ dependencies:
- python=3.9.13
- pip=23.3.1
- pandas=2.1.4
- pyg=2.5.0
- pytorch=2.2.1
- requests=2.31.0
- scikit-learn=1.3.0
- seaborn=0.12.2
- tqdm=4.65.0
- torchaudio=2.2.1
- torchvision=0.17.1
- pip:
- cellxgene-census==1.10.2
- PyTDC==0.4.1
- pydantic==2.6.3
- rdkit==2023.9.5
- yapf==0.40.2
6 changes: 5 additions & 1 deletion run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
suite = loader.discover(start_dir)

runner = unittest.TextTestRunner()
runner.run(suite)
res = runner.run(suite)
if res.wasSuccessful():
print("All base tests passed")
else:
raise RuntimeError("Some base tests failed")
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def readme():


# read the contents of requirements.txt
with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f:
with open(path.join(this_directory, "requirements.txt"),
encoding="utf-8") as f:
requirements = f.read().splitlines()

setup(
Expand Down
95 changes: 36 additions & 59 deletions tdc/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class DataLoader:

"""base data loader class that contains functions shared by almost all data loader classes."""

def __init__(self):
Expand All @@ -35,13 +34,11 @@ def get_data(self, format="df"):
AttributeError: format not supported
"""
if format == "df":
return pd.DataFrame(
{
self.entity1_name + "_ID": self.entity1_idx,
self.entity1_name: self.entity1,
"Y": self.y,
}
)
return pd.DataFrame({
self.entity1_name + "_ID": self.entity1_idx,
self.entity1_name: self.entity1,
"Y": self.y,
})
elif format == "dict":
return {
self.entity1_name + "_ID": self.entity1_idx,
Expand All @@ -56,11 +53,8 @@ def get_data(self, format="df"):
def print_stats(self):
"""print statistics"""
print(
"There are "
+ str(len(np.unique(self.entity1)))
+ " unique "
+ self.entity1_name.lower()
+ "s",
"There are " + str(len(np.unique(self.entity1))) + " unique " +
self.entity1_name.lower() + "s",
flush=True,
file=sys.stderr,
)
Expand All @@ -86,7 +80,8 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]):
if method == "random":
return utils.create_fold(df, seed, frac)
elif method == "cold_" + self.entity1_name.lower():
return utils.create_fold_setting_cold(df, seed, frac, self.entity1_name)
return utils.create_fold_setting_cold(df, seed, frac,
self.entity1_name)
else:
raise AttributeError("Please specify the correct splitting method")

Expand All @@ -110,30 +105,22 @@ def binarize(self, threshold=None, order="descending"):
if threshold is None:
raise AttributeError(
"Please specify the threshold to binarize the data by "
"'binarize(threshold = N)'!"
)
"'binarize(threshold = N)'!")

if len(np.unique(self.y)) == 2:
print("The data is already binarized!", flush=True, file=sys.stderr)
else:
print(
"Binariztion using threshold "
+ str(threshold)
+ ", default, we assume the smaller values are 1 "
"Binariztion using threshold " + str(threshold) +
", default, we assume the smaller values are 1 "
"and larger ones is 0, you can change the order "
"by 'binarize(order = 'ascending')'",
flush=True,
file=sys.stderr,
)
if (
np.unique(self.y)
.reshape(
-1,
)
.shape[0]
< 2
):
raise AttributeError("Adjust your threshold, there is only one class.")
if (np.unique(self.y).reshape(-1,).shape[0] < 2):
raise AttributeError(
"Adjust your threshold, there is only one class.")
self.y = utils.binarize(self.y, threshold, order)
return self

Expand Down Expand Up @@ -223,36 +210,26 @@ def balanced(self, oversample=False, seed=42):
flush=True,
file=sys.stderr,
)
val = (
pd.concat(
[
val[val.Y == major_class].sample(
n=len(val[val.Y == minor_class]),
replace=False,
random_state=seed,
),
val[val.Y == minor_class],
]
)
.sample(frac=1, replace=False, random_state=seed)
.reset_index(drop=True)
)
val = (pd.concat([
val[val.Y == major_class].sample(
n=len(val[val.Y == minor_class]),
replace=False,
random_state=seed,
),
val[val.Y == minor_class],
]).sample(frac=1, replace=False,
random_state=seed).reset_index(drop=True))
else:
print(
" Oversample of minority class is used. ", flush=True, file=sys.stderr
)
val = (
pd.concat(
[
val[val.Y == minor_class].sample(
n=len(val[val.Y == major_class]),
replace=True,
random_state=seed,
),
val[val.Y == major_class],
]
)
.sample(frac=1, replace=False, random_state=seed)
.reset_index(drop=True)
)
print(" Oversample of minority class is used. ",
flush=True,
file=sys.stderr)
val = (pd.concat([
val[val.Y == minor_class].sample(
n=len(val[val.Y == major_class]),
replace=True,
random_state=seed,
),
val[val.Y == major_class],
]).sample(frac=1, replace=False,
random_state=seed).reset_index(drop=True))
return val
Loading

0 comments on commit 9a5256a

Please sign in to comment.