diff --git a/environment.yml b/environment.yml index 4aa999fe..67f5f2fe 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,8 @@ name: tdc-conda-env channels: - conda-forge - defaults + - pyg + - pytorch dependencies: - dataclasses=0.8 - fuzzywuzzy=0.18.0 @@ -10,11 +12,16 @@ dependencies: - python=3.9.13 - pip=23.3.1 - pandas=2.1.4 + - pyg=2.5.0 + - pytorch=2.2.1 - requests=2.31.0 - scikit-learn=1.3.0 - seaborn=0.12.2 - tqdm=4.65.0 + - torchaudio=2.2.1 + - torchvision=0.17.1 - pip: - cellxgene-census==1.10.2 - - PyTDC==0.4.1 + - pydantic==2.6.3 - rdkit==2023.9.5 + - yapf==0.40.2 diff --git a/run_tests.py b/run_tests.py index 981f21b7..6e89f236 100644 --- a/run_tests.py +++ b/run_tests.py @@ -6,4 +6,8 @@ suite = loader.discover(start_dir) runner = unittest.TextTestRunner() - runner.run(suite) \ No newline at end of file + res = runner.run(suite) + if res.wasSuccessful(): + print("All base tests passed") + else: + raise RuntimeError("Some base tests failed") diff --git a/setup.py b/setup.py index 063fca8b..22635ad6 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ def readme(): # read the contents of requirements.txt -with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f: +with open(path.join(this_directory, "requirements.txt"), + encoding="utf-8") as f: requirements = f.read().splitlines() setup( diff --git a/tdc/base_dataset.py b/tdc/base_dataset.py index 8f133d91..4e49bea3 100644 --- a/tdc/base_dataset.py +++ b/tdc/base_dataset.py @@ -16,7 +16,6 @@ class DataLoader: - """base data loader class that contains functions shared by almost all data loader classes.""" def __init__(self): @@ -35,13 +34,11 @@ def get_data(self, format="df"): AttributeError: format not supported """ if format == "df": - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + "Y": self.y, + }) elif format == "dict": return { self.entity1_name + "_ID": self.entity1_idx, @@ -56,11 +53,8 @@ def get_data(self, format="df"): def print_stats(self): """print statistics""" print( - "There are " - + str(len(np.unique(self.entity1))) - + " unique " - + self.entity1_name.lower() - + "s", + "There are " + str(len(np.unique(self.entity1))) + " unique " + + self.entity1_name.lower() + "s", flush=True, file=sys.stderr, ) @@ -86,7 +80,8 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): if method == "random": return utils.create_fold(df, seed, frac) elif method == "cold_" + self.entity1_name.lower(): - return utils.create_fold_setting_cold(df, seed, frac, self.entity1_name) + return utils.create_fold_setting_cold(df, seed, frac, + self.entity1_name) else: raise AttributeError("Please specify the correct splitting method") @@ -110,30 +105,22 @@ def binarize(self, threshold=None, order="descending"): if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " - "'binarize(threshold = N)'!" - ) + "'binarize(threshold = N)'!") if len(np.unique(self.y)) == 2: print("The data is already binarized!", flush=True, file=sys.stderr) else: print( - "Binariztion using threshold " - + str(threshold) - + ", default, we assume the smaller values are 1 " + "Binariztion using threshold " + str(threshold) + + ", default, we assume the smaller values are 1 " "and larger ones is 0, you can change the order " "by 'binarize(order = 'ascending')'", flush=True, file=sys.stderr, ) - if ( - np.unique(self.y) - .reshape( - -1, - ) - .shape[0] - < 2 - ): - raise AttributeError("Adjust your threshold, there is only one class.") + if (np.unique(self.y).reshape(-1,).shape[0] < 2): + raise AttributeError( + "Adjust your threshold, there is only one class.") self.y = utils.binarize(self.y, threshold, order) return self @@ -223,36 +210,26 @@ def balanced(self, oversample=False, seed=42): flush=True, file=sys.stderr, ) - val = ( - pd.concat( - [ - val[val.Y == major_class].sample( - n=len(val[val.Y == minor_class]), - replace=False, - random_state=seed, - ), - val[val.Y == minor_class], - ] - ) - .sample(frac=1, replace=False, random_state=seed) - .reset_index(drop=True) - ) + val = (pd.concat([ + val[val.Y == major_class].sample( + n=len(val[val.Y == minor_class]), + replace=False, + random_state=seed, + ), + val[val.Y == minor_class], + ]).sample(frac=1, replace=False, + random_state=seed).reset_index(drop=True)) else: - print( - " Oversample of minority class is used. ", flush=True, file=sys.stderr - ) - val = ( - pd.concat( - [ - val[val.Y == minor_class].sample( - n=len(val[val.Y == major_class]), - replace=True, - random_state=seed, - ), - val[val.Y == major_class], - ] - ) - .sample(frac=1, replace=False, random_state=seed) - .reset_index(drop=True) - ) + print(" Oversample of minority class is used. ", + flush=True, + file=sys.stderr) + val = (pd.concat([ + val[val.Y == minor_class].sample( + n=len(val[val.Y == major_class]), + replace=True, + random_state=seed, + ), + val[val.Y == major_class], + ]).sample(frac=1, replace=False, + random_state=seed).reset_index(drop=True)) return val diff --git a/tdc/benchmark_deprecated.py b/tdc/benchmark_deprecated.py index 0150b5af..16b10fac 100644 --- a/tdc/benchmark_deprecated.py +++ b/tdc/benchmark_deprecated.py @@ -25,6 +25,7 @@ class BenchmarkGroup: + def __init__( self, name, @@ -162,7 +163,8 @@ def __next__(self): ncpu=self.num_cpus, num_max_call=self.num_max_call, ) - data = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t") + data = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") return {"oracle": oracle, "data": data, "name": dataset} else: return {"train_val": train, "test": test, "name": dataset} @@ -200,15 +202,19 @@ def get_train_valid_split(self, seed, benchmark, split_type="default"): frac = [frac[0], frac[1], 0.0] """ if split_method == "scaffold": - out = create_scaffold_split(train_val, seed, frac=frac, entity="Drug") + out = create_scaffold_split(train_val, + seed, + frac=frac, + entity="Drug") elif split_method == "random": out = create_fold(train_val, seed, frac=frac) elif split_method == "combination": out = create_combination_split(train_val, seed, frac=frac) elif split_method == "group": - out = create_group_split( - train_val, seed, holdout_frac=0.2, group_column="Year" - ) + out = create_group_split(train_val, + seed, + holdout_frac=0.2, + group_column="Year") else: raise NotImplementedError return out["train"], out["valid"] @@ -246,7 +252,12 @@ def get(self, benchmark, num_max_call=5000): else: return {"train_val": train, "test": test, "name": dataset} - def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True): + def evaluate(self, + pred, + true=None, + benchmark=None, + m1_api=None, + save_dict=True): if self.name == "docking_group": results_all = {} @@ -284,7 +295,8 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) dataset = fuzzy_search(data_name, self.dataset_names) # docking scores for the top K smiles (K <= 100) - target_pdb_file = os.path.join(self.path, dataset + ".pdb") + target_pdb_file = os.path.join(self.path, + dataset + ".pdb") from .oracles import Oracle data_path = os.path.join(self.path, dataset) @@ -304,12 +316,14 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) docking_scores = oracle(pred_) print_sys("---- Calculating average docking scores ----") - if ( - len(np.where(np.array(list(docking_scores.values())) > 0)[0]) - > 0.7 - ): + if (len( + np.where( + np.array(list(docking_scores.values())) > 0)[0]) + > 0.7): ## check if the scores are all positive.. if so, make them all negative - docking_scores = {j: -k for j, k in docking_scores.items()} + docking_scores = { + j: -k for j, k in docking_scores.items() + } if save_dict: results["docking_scores_dict"] = docking_scores values = np.array(list(docking_scores.values())) @@ -327,23 +341,23 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) ) from .oracles import Oracle - m1 = Oracle(name="Molecule One Synthesis", api_token=m1_api) + m1 = Oracle(name="Molecule One Synthesis", + api_token=m1_api) import heapq from operator import itemgetter top10_docking_smiles = list( dict( - heapq.nsmallest( - 10, docking_scores.items(), key=itemgetter(1) - ) - ).keys() - ) + heapq.nsmallest(10, + docking_scores.items(), + key=itemgetter(1))).keys()) m1_scores = m1(top10_docking_smiles) scores_array = list(m1_scores.values()) - scores_array = np.array([float(i) for i in scores_array]) - scores_array[ - np.where(scores_array == -1.0)[0] - ] = 10 # m1 score errors are usually large complex molecules + scores_array = np.array( + [float(i) for i in scores_array]) + scores_array[np.where( + scores_array == -1.0 + )[0]] = 10 # m1 score errors are usually large complex molecules if save_dict: results["m1_dict"] = m1_scores results["m1"] = np.mean(scores_array) @@ -361,8 +375,7 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["pass_list"] = pred_filter results["%pass"] = float(len(pred_filter)) / 100 results["top1_%pass"] = max( - [docking_scores[i] for i in pred_filter] - ) + [docking_scores[i] for i in pred_filter]) print_sys("---- Calculating diversity ----") from .evaluator import Evaluator @@ -371,13 +384,13 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["diversity"] = score print_sys("---- Calculating novelty ----") evaluator = Evaluator(name="Novelty") - training = pd.read_csv( - os.path.join(self.path, "zinc.tab"), sep="\t" - ) + training = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") score = evaluator(pred_, training.smiles.values) results["novelty"] = score results["top smiles"] = [ - i[0] for i in sorted(docking_scores.items(), key=lambda x: x[1]) + i[0] for i in sorted(docking_scores.items(), + key=lambda x: x[1]) ] results_max_call[num_max_call] = results results_all[data_name] = results_max_call @@ -395,8 +408,11 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) elif self.file_format == "pkl": test = pd.read_pickle(os.path.join(data_path, "test.pkl")) y = test.Y.values - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") - out[data_name] = {metric_dict[data_name]: round(evaluator(y, pred_), 3)} + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") + out[data_name] = { + metric_dict[data_name]: round(evaluator(y, pred_), 3) + } # If reporting accuracy across target classes if "target_class" in test.columns: @@ -407,13 +423,11 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) y_subset = test_subset.Y.values pred_subset = test_subset.pred.values - evaluator = eval( - "Evaluator(name = '" + metric_dict[data_name_subset] + "')" - ) + evaluator = eval("Evaluator(name = '" + + metric_dict[data_name_subset] + "')") out[data_name_subset] = { - metric_dict[data_name_subset]: round( - evaluator(y_subset, pred_subset), 3 - ) + metric_dict[data_name_subset]: + round(evaluator(y_subset, pred_subset), 3) } return out else: @@ -424,12 +438,15 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) ) data_name = fuzzy_search(benchmark, self.dataset_names) metric_dict = bm_metric_names[self.name] - evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + "')") + evaluator = eval("Evaluator(name = '" + metric_dict[data_name] + + "')") return {metric_dict[data_name]: round(evaluator(true, pred), 3)} - def evaluate_many( - self, preds, save_file_name=None, m1_api=None, results_individual=None - ): + def evaluate_many(self, + preds, + save_file_name=None, + m1_api=None, + results_individual=None): """ :param preds: list of dict :return: dict 0)[0]) > 0.7: + if len( + np.where(np.array(list(docking_scores.values())) > 0) + [0]) > 0.7: ## check if the scores are all positive.. if so, make them all negative docking_scores = {j: -k for j, k in docking_scores.items()} if save_dict: @@ -275,7 +284,8 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) if save_dict: results["pass_list"] = pred_filter results["%pass"] = float(len(pred_filter)) / 100 - results["top1_%pass"] = min([docking_scores[i] for i in pred_filter]) + results["top1_%pass"] = min( + [docking_scores[i] for i in pred_filter]) print_sys("---- Calculating diversity ----") from ..evaluator import Evaluator @@ -284,19 +294,23 @@ def evaluate(self, pred, true=None, benchmark=None, m1_api=None, save_dict=True) results["diversity"] = score print_sys("---- Calculating novelty ----") evaluator = Evaluator(name="Novelty") - training = pd.read_csv(os.path.join(self.path, "zinc.tab"), sep="\t") + training = pd.read_csv(os.path.join(self.path, "zinc.tab"), + sep="\t") score = evaluator(pred_, training.smiles.values) results["novelty"] = score results["top smiles"] = [ - i[0] for i in sorted(docking_scores.items(), key=lambda x: x[1]) + i[0] + for i in sorted(docking_scores.items(), key=lambda x: x[1]) ] results_max_call[num_max_call] = results results_all[data_name] = results_max_call return results_all - def evaluate_many( - self, preds, save_file_name=None, m1_api=None, results_individual=None - ): + def evaluate_many(self, + preds, + save_file_name=None, + m1_api=None, + results_individual=None): """evaluate many runs together and output submission ready pkl file. Args: @@ -310,11 +324,9 @@ def evaluate_many( """ min_requirement = 3 if len(preds) < min_requirement: - return ValueError( - "Must have predictions from at least " - + str(min_requirement) - + " runs for leaderboard submission" - ) + return ValueError("Must have predictions from at least " + + str(min_requirement) + + " runs for leaderboard submission") if results_individual is None: individual_results = [] for pred in preds: @@ -345,13 +357,10 @@ def evaluate_many( for metric in metrics: if metric == "top smiles": results_agg_target_call[metric] = np.unique( - np.array( - [ - individual_results[fold][target][num_calls][metric] - for fold in range(num_folds) - ] - ).reshape(-1) - ).tolist() + np.array([ + individual_results[fold][target][num_calls] + [metric] for fold in range(num_folds) + ]).reshape(-1)).tolist() else: res = [ individual_results[fold][target][num_calls][metric] diff --git a/tdc/benchmark_group/drugcombo_group.py b/tdc/benchmark_group/drugcombo_group.py index 801b411c..a1b5421a 100644 --- a/tdc/benchmark_group/drugcombo_group.py +++ b/tdc/benchmark_group/drugcombo_group.py @@ -15,11 +15,11 @@ class drugcombo_group(BenchmarkGroup): def __init__(self, path="./data"): """create a drug combination benchmark group""" super().__init__(name="DrugCombo_Group", path=path, file_format="pkl") - - + def get_cell_line_meta_data(self): import os from ..utils.load import download_wrapper from ..utils import load_dict - name = download_wrapper('drug_comb_meta_data', self.path, ['drug_comb_meta_data']) - return load_dict(os.path.join(self.path, name + '.pkl')) \ No newline at end of file + name = download_wrapper('drug_comb_meta_data', self.path, + ['drug_comb_meta_data']) + return load_dict(os.path.join(self.path, name + '.pkl')) diff --git a/tdc/chem_utils/evaluator.py b/tdc/chem_utils/evaluator.py index 8a00f290..e2e60fb3 100644 --- a/tdc/chem_utils/evaluator.py +++ b/tdc/chem_utils/evaluator.py @@ -12,7 +12,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") def single_molecule_validity(smiles): @@ -57,7 +58,8 @@ def canonicalize(smiles): def unique_lst_of_smiles(list_of_smiles): canonical_smiles_lst = list(map(canonicalize, list_of_smiles)) - canonical_smiles_lst = list(filter(lambda x: x is not None, canonical_smiles_lst)) + canonical_smiles_lst = list( + filter(lambda x: x is not None, canonical_smiles_lst)) canonical_smiles_lst = list(set(canonical_smiles_lst)) return canonical_smiles_lst @@ -88,11 +90,9 @@ def novelty(generated_smiles_lst, training_smiles_lst): """ generated_smiles_lst = unique_lst_of_smiles(generated_smiles_lst) training_smiles_lst = unique_lst_of_smiles(training_smiles_lst) - novel_ratio = ( - sum([1 if i in training_smiles_lst else 0 for i in generated_smiles_lst]) - * 1.0 - / len(generated_smiles_lst) - ) + novel_ratio = (sum( + [1 if i in training_smiles_lst else 0 for i in generated_smiles_lst]) * + 1.0 / len(generated_smiles_lst)) return 1 - novel_ratio @@ -107,14 +107,19 @@ def diversity(list_of_smiles): div: float """ list_of_unique_smiles = unique_lst_of_smiles(list_of_smiles) - list_of_mol = [Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles] + list_of_mol = [ + Chem.MolFromSmiles(smiles) for smiles in list_of_unique_smiles + ] list_of_fp = [ - AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048, useChirality=False) + AllChem.GetMorganFingerprintAsBitVect(mol, + 2, + nBits=2048, + useChirality=False) for mol in list_of_mol ] avg_lst = [] for idx, fp in enumerate(list_of_fp): - for fp2 in list_of_fp[idx + 1 :]: + for fp2 in list_of_fp[idx + 1:]: sim = DataStructs.TanimotoSimilarity(fp, fp2) ### option I distance = 1 - sim @@ -235,7 +240,9 @@ def get_fingerprints(mols, radius=2, length=4096): Returns: a list of fingerprints """ - return [AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols] + return [ + AllChem.GetMorganFingerprintAsBitVect(m, radius, length) for m in mols + ] def get_mols(smiles_list): @@ -267,10 +274,8 @@ def calculate_internal_pairwise_similarities(smiles_list): Symmetric matrix of pairwise similarities. Diagonal is set to zero. """ if len(smiles_list) > 10000: - logger.warning( - f"Calculating internal similarity on large set of " - f"SMILES strings ({len(smiles_list)})" - ) + logger.warning(f"Calculating internal similarity on large set of " + f"SMILES strings ({len(smiles_list)})") mols = get_mols(smiles_list) fps = get_fingerprints(mols) @@ -313,7 +318,8 @@ def kl_divergence(generated_smiles_lst, training_smiles_lst): def canonical(smiles): mol = Chem.MolFromSmiles(smiles) if mol is not None: - return Chem.MolToSmiles(mol, isomericSmiles=True) ### todo double check + return Chem.MolToSmiles(mol, + isomericSmiles=True) ### todo double check else: return None @@ -323,17 +329,20 @@ def canonical(smiles): generated_lst_mol = list(filter(filter_out_func, generated_lst_mol)) training_lst_mol = list(filter(filter_out_func, training_lst_mol)) - d_sampled = calculate_pc_descriptors(generated_lst_mol, pc_descriptor_subset) + d_sampled = calculate_pc_descriptors(generated_lst_mol, + pc_descriptor_subset) d_chembl = calculate_pc_descriptors(training_lst_mol, pc_descriptor_subset) kldivs = {} for i in range(4): - kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) + kldiv = continuous_kldiv(X_baseline=d_chembl[:, i], + X_sampled=d_sampled[:, i]) kldivs[pc_descriptor_subset[i]] = kldiv # ... and for the int valued ones. for i in range(4, 9): - kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], X_sampled=d_sampled[:, i]) + kldiv = discrete_kldiv(X_baseline=d_chembl[:, i], + X_sampled=d_sampled[:, i]) kldivs[pc_descriptor_subset[i]] = kldiv # pairwise similarity @@ -344,7 +353,8 @@ def canonical(smiles): sampled_sim = calculate_internal_pairwise_similarities(generated_lst_mol) sampled_sim = sampled_sim.max(axis=1) - kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, X_sampled=sampled_sim) + kldiv_int_int = continuous_kldiv(X_baseline=chembl_sim, + X_sampled=sampled_sim) kldivs["internal_similarity"] = kldiv_int_int """ # for some reason, this runs into problems when both sets are identical. @@ -395,10 +405,14 @@ def _calculate_distribution_statistics(chemnet, molecules): cov = np.cov(gen_mol_act.T) return mu, cov - mu_ref, cov_ref = _calculate_distribution_statistics(chemnet, training_smiles_lst) + mu_ref, cov_ref = _calculate_distribution_statistics( + chemnet, training_smiles_lst) mu, cov = _calculate_distribution_statistics(chemnet, generated_smiles_lst) - FCD = fcd.calculate_frechet_distance(mu1=mu_ref, mu2=mu, sigma1=cov_ref, sigma2=cov) + FCD = fcd.calculate_frechet_distance(mu1=mu_ref, + mu2=mu, + sigma1=cov_ref, + sigma2=cov) fcd_distance = np.exp(-0.2 * FCD) return fcd_distance diff --git a/tdc/chem_utils/featurize/_smiles2pubchem.py b/tdc/chem_utils/featurize/_smiles2pubchem.py index f41fc82a..0cd21cbb 100644 --- a/tdc/chem_utils/featurize/_smiles2pubchem.py +++ b/tdc/chem_utils/featurize/_smiles2pubchem.py @@ -7,8 +7,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: import networkx as nx @@ -409,9 +409,11 @@ def func_4(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -752,9 +754,11 @@ def func_7(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -862,9 +866,11 @@ def func_8(mol, bits): for bondIdx in ring: BeginAtom = mol.GetBondWithIdx(bondIdx).GetBeginAtom() EndAtom = mol.GetBondWithIdx(bondIdx).GetEndAtom() - if BeginAtom.GetAtomicNum() not in [1, 6] or EndAtom.GetAtomicNum() not in [ - 1, - 6, + if BeginAtom.GetAtomicNum() not in [ + 1, 6 + ] or EndAtom.GetAtomicNum() not in [ + 1, + 6, ]: heteroatom = True break @@ -936,6 +942,7 @@ def calcPubChemFingerAll(s): AllBits[index3 + 115] = 1 return np.array(AllBits) + def canonicalize(smiles): mol = Chem.MolFromSmiles(smiles) if mol is not None: @@ -943,13 +950,13 @@ def canonicalize(smiles): else: return None + def smiles2pubchem(s): s = canonicalize(s) try: features = calcPubChemFingerAll(s) except: - print( - "pubchem fingerprint not working for smiles: " + s + " convert to 0 vectors" - ) + print("pubchem fingerprint not working for smiles: " + s + + " convert to 0 vectors") features = np.zeros((881,)) return np.array(features) diff --git a/tdc/chem_utils/featurize/_xyz2mol.py b/tdc/chem_utils/featurize/_xyz2mol.py index 92af581f..9c91d2c6 100644 --- a/tdc/chem_utils/featurize/_xyz2mol.py +++ b/tdc/chem_utils/featurize/_xyz2mol.py @@ -2,15 +2,14 @@ from collections import defaultdict from typing import List - try: from rdkit import Chem from rdkit import rdBase rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: import networkx as nx @@ -19,7 +18,6 @@ from ...utils import print_sys - ############## begin xyz2mol ################ # from https://github.com/jensengroup/xyz2mol/blob/master/xyz2mol.py @@ -121,7 +119,6 @@ "pu", ] - global atomic_valence global atomic_valence_electrons @@ -179,7 +176,8 @@ def get_UA(maxValence_list, valence_list): """ """ UA = [] DU = [] - for i, (maxValence, valence) in enumerate(zip(maxValence_list, valence_list)): + for i, (maxValence, valence) in enumerate(zip(maxValence_list, + valence_list)): if not maxValence - valence > 0: continue UA.append(i) @@ -235,7 +233,8 @@ def charge_is_OK( BO_valences = list(BO.sum(axis=1)) for i, atom in enumerate(atoms): - q = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + q = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) Q += q if atom == 6: number_of_single_bonds_to_C = list(BO[i, :]).count(1) @@ -381,8 +380,7 @@ def BO2mol( if l != l2: raise RuntimeError( - "sizes of adjMat ({0:d}) and Atoms {1:d} differ".format(l, l2) - ) + "sizes of adjMat ({0:d}) and Atoms {1:d} differ".format(l, l2)) rwMol = Chem.RWMol(mol) @@ -403,23 +401,23 @@ def BO2mol( mol = rwMol.GetMol() if allow_charged_fragments: - mol = set_atomic_charges( - mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge - ) + mol = set_atomic_charges(mol, atoms, atomic_valence_electrons, + BO_valences, BO_matrix, mol_charge) else: - mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences) + mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, + BO_valences) return mol -def set_atomic_charges( - mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge -): +def set_atomic_charges(mol, atoms, atomic_valence_electrons, BO_valences, + BO_matrix, mol_charge): """ """ q = 0 for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) q += charge if atom == 6: number_of_single_bonds_to_C = list(BO_matrix[i, :]).count(1) @@ -444,7 +442,8 @@ def set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences): """ for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], + BO_valences[i]) if abs(charge) > 0: a.SetNumRadicalElectrons(abs(int(charge))) @@ -457,7 +456,7 @@ def get_bonds(UA, AC): bonds = [] for k, i in enumerate(UA): - for j in UA[k + 1 :]: + for j in UA[k + 1:]: if AC[i, j] == 1: bonds.append(tuple(sorted([i, j]))) @@ -510,7 +509,9 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): for i, (atomicNum, valence) in enumerate(zip(atoms, AC_valence)): # valence can't be smaller than number of neighbourgs - possible_valence = [x for x in atomic_valence[atomicNum] if x >= valence] + possible_valence = [ + x for x in atomic_valence[atomicNum] if x >= valence + ] if not possible_valence: print_sys( "Valence of atom", @@ -553,7 +554,12 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph) for UA_pairs in UA_pairs_list: - BO = get_BO(AC, UA, DU_from_AC, valences, UA_pairs, use_graph=use_graph) + BO = get_BO(AC, + UA, + DU_from_AC, + valences, + UA_pairs, + use_graph=use_graph) status = BO_is_OK( BO, AC, @@ -577,17 +583,19 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): if status: return BO, atomic_valence_electrons - elif ( - BO.sum() >= best_BO.sum() - and valences_not_too_large(BO, valences) - and charge_OK - ): + elif (BO.sum() >= best_BO.sum() and + valences_not_too_large(BO, valences) and charge_OK): best_BO = BO.copy() return best_BO, atomic_valence_electrons -def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True): +def AC2mol(mol, + AC, + atoms, + charge, + allow_charged_fragments=True, + use_graph=True): """ """ # convert AC matrix to bond order (BO) matrix @@ -614,9 +622,8 @@ def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True) # return [] # BO2mol returns an arbitrary resonance form. Let's make the rest - mols = rdchem.ResonanceMolSupplier( - mol, Chem.UNCONSTRAINED_CATIONS, Chem.UNCONSTRAINED_ANIONS - ) + mols = rdchem.ResonanceMolSupplier(mol, Chem.UNCONSTRAINED_CATIONS, + Chem.UNCONSTRAINED_ANIONS) mols = [mol for mol in mols] return mols, BO @@ -754,15 +761,14 @@ def xyz2AC_huckel(atomicNumList, xyz, charge): mol_huckel = Chem.Mol(mol) mol_huckel.GetAtomWithIdx(0).SetFormalCharge( - charge - ) # mol charge arbitrarily added to 1st atom + charge) # mol charge arbitrarily added to 1st atom passed, result = rdEHTTools.RunMol(mol_huckel) opop = result.GetReducedOverlapPopulationMatrix() tri = np.zeros((num_atoms, num_atoms)) - tri[ - np.tril(np.ones((num_atoms, num_atoms), dtype=bool)) - ] = opop # lower triangular to square matrix + tri[np.tril(np.ones( + (num_atoms, num_atoms), + dtype=bool))] = opop # lower triangular to square matrix for i in range(num_atoms): for j in range(i + 1, num_atoms): pair_pop = abs(tri[j, i]) diff --git a/tdc/chem_utils/featurize/molconvert.py b/tdc/chem_utils/featurize/molconvert.py index 4741d5d4..6a217f58 100644 --- a/tdc/chem_utils/featurize/molconvert.py +++ b/tdc/chem_utils/featurize/molconvert.py @@ -5,7 +5,6 @@ import numpy as np from typing import List - try: from rdkit import Chem, DataStructs from rdkit.Chem import AllChem @@ -15,8 +14,8 @@ from rdkit.Chem.Fingerprints import FingerprintMols from rdkit.Chem import MACCSkeys except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") - + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from ...utils import print_sys from ..oracle.oracle import ( @@ -52,15 +51,14 @@ def smiles2morgan(s, radius=2, nBits=1024): try: s = canonicalize(s) mol = Chem.MolFromSmiles(s) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, + radius, + nBits=nBits) features = np.zeros((1,)) DataStructs.ConvertToNumpyArray(features_vec, features) except: - print_sys( - "rdkit not found this smiles for morgan: " - + s - + " convert to all 0 features" - ) + print_sys("rdkit not found this smiles for morgan: " + s + + " convert to all 0 features") features = np.zeros((nBits,)) return features @@ -89,9 +87,8 @@ def smiles2rdkit2d(s): NaNs = np.isnan(features) features[NaNs] = 0 except: - print_sys( - "descriptastorus not found this smiles: " + s + " convert to all 0 features" - ) + print_sys("descriptastorus not found this smiles: " + s + + " convert to all 0 features") features = np.zeros((200,)) return np.array(features) @@ -115,7 +112,8 @@ def smiles2daylight(s): features = np.zeros((NumFinger,)) features[np.array(temp)] = 1 except: - print_sys("rdkit not found this smiles: " + s + " convert to all 0 features") + print_sys("rdkit not found this smiles: " + s + + " convert to all 0 features") features = np.zeros((2048,)) return np.array(features) @@ -210,7 +208,6 @@ def smiles2ECFP6(smiles): class MoleculeFingerprint: - """ Example: MolFP = MoleculeFingerprint(fp = 'ECFP6') @@ -239,10 +236,9 @@ def __init__(self, fp="ECFP4"): try: assert fp in fp2func except: - raise Exception( - "The fingerprint you specify are not supported. \ + raise Exception("The fingerprint you specify are not supported. \ It can only among 'ECFP2', 'ECFP4', 'ECFP6', 'MACCS', 'Daylight', 'RDKit2D', 'Morgan', 'PubChem'" - ) + ) self.fp = fp self.func = fp2func[fp] @@ -388,12 +384,11 @@ def onek_encoding_unk(x, allowable_set): def get_atom_features(atom): return torch.Tensor( - onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) - + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) - + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) - + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3]) - + [atom.GetIsAromatic()] - ) + onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) + + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) + + onek_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]) + + onek_encoding_unk(int(atom.GetChiralTag()), [0, 1, 2, 3]) + + [atom.GetIsAromatic()]) def smiles2PyG(smiles): @@ -413,8 +408,9 @@ def smiles2PyG(smiles): atom_features = torch.stack(atom_features) y = [atom.GetSymbol() for atom in mol.GetAtoms()] y = list( - map(lambda x: ELEM_LIST.index(x) if x in ELEM_LIST else len(ELEM_LIST) - 1, y) - ) + map( + lambda x: ELEM_LIST.index(x) + if x in ELEM_LIST else len(ELEM_LIST) - 1, y)) y = torch.LongTensor(y) bond_features = [] for bond in mol.GetBonds(): @@ -438,6 +434,7 @@ def molfile2PyG(molfile): ############### PyG end ############### + ############### DGL begin ############### def smiles2DGL(smiles): """convert SMILES string into dgl.DGLGraph @@ -468,7 +465,6 @@ def smiles2DGL(smiles): ############### DGL end ############### - from ._xyz2mol import xyzfile2mol @@ -511,7 +507,8 @@ def xyzfile2selfies(xyzfile): def distance3d(coordinate_1, coordinate_2): - return np.sqrt(sum([(c1 - c2) ** 2 for c1, c2 in zip(coordinate_1, coordinate_2)])) + return np.sqrt( + sum([(c1 - c2)**2 for c1, c2 in zip(coordinate_1, coordinate_2)])) def upper_atom(atomsymbol): @@ -526,7 +523,9 @@ def xyzfile2graph3d(xyzfile): for j in range(i + 1, num_atoms): distance = distance3d(xyz_coordinates[i], xyz_coordinates[j]) distance_adj_matrix[i, j] = distance_adj_matrix[j, i] = distance - idx2atom = {idx: upper_atom(str_atom(atom)) for idx, atom in enumerate(atoms)} + idx2atom = { + idx: upper_atom(str_atom(atom)) for idx, atom in enumerate(atoms) + } mol, BO = xyzfile2mol(xyzfile) return idx2atom, distance_adj_matrix, BO @@ -599,9 +598,9 @@ def mol_conformer2graph3d(mol_conformer_lst): positions = np.concatenate(positions, 0) for i in range(atom_num): for j in range(i + 1, atom_num): - distance_adj_matrix[i, j] = distance_adj_matrix[j, i] = distance3d( - positions[i], positions[j] - ) + distance_adj_matrix[i, + j] = distance_adj_matrix[j, i] = distance3d( + positions[i], positions[j]) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx() @@ -687,6 +686,7 @@ def xyzfile2coulomb(xyzfile): # 2D_format = ['SMILES', 'SELFIES', 'Graph2D', 'PyG', 'DGL', 'ECFP2', 'ECFP4', 'ECFP6', 'MACCS', 'Daylight', 'RDKit2D', 'Morgan', 'PubChem'] # 3D_format = ['Graph3D', 'Coulumb'] + ## XXX2smiles def molfile2smiles(molfile): """convert molfile into SMILES string @@ -722,7 +722,6 @@ def mol2file2smiles(molfile): ## smiles2xxx - atom_types = ["C", "N", "O", "H", "F", "unknown"] ### Cl, S? @@ -868,7 +867,6 @@ def raw3D2pyg(raw3d_feature): class MolConvert: - """MolConvert: convert the molecule from src formet to dst format. @@ -902,7 +900,8 @@ def __init__(self, src="SMILES", dst="Graph2D", radius=2, nBits=1024): global sf except: - raise Exception("Please install selfies via 'pip install selfies'") + raise Exception( + "Please install selfies via 'pip install selfies'") if "Coulumb" == dst: try: @@ -1023,7 +1022,8 @@ def __call__(self, x): else: lst = [] for x0 in x: - lst.append(self.func(x0, radius=self._radius, nBits=self._nbits)) + lst.append( + self.func(x0, radius=self._radius, nBits=self._nbits)) out = lst if self._dst in fingerprints_list: out = np.array(out) diff --git a/tdc/chem_utils/oracle/docking.py b/tdc/chem_utils/oracle/docking.py index a92314a5..9a920c78 100644 --- a/tdc/chem_utils/oracle/docking.py +++ b/tdc/chem_utils/oracle/docking.py @@ -9,8 +9,11 @@ center = [float(i) for i in center] box_size = [sys.argv[7], sys.argv[8], sys.argv[9]] box_size = [float(i) for i in box_size] + + # print(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) -def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size): +def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, + box_size): t1 = time() v = Vina(sf_name="vina") v.set_receptor(rigid_pdbqt_filename=receptor_pdbqt_file) @@ -25,7 +28,6 @@ def docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_siz docking(ligand_pdbqt_file, receptor_pdbqt_file, output_file, center, box_size) - """ Example: python XXXX.py data/1iep_ligand.pdbqt ./data/1iep_receptor.pdbqt ./data/out 15.190 53.903 16.917 20 20 20 diff --git a/tdc/chem_utils/oracle/filter.py b/tdc/chem_utils/oracle/filter.py index 06b4e0c3..47e6bde0 100644 --- a/tdc/chem_utils/oracle/filter.py +++ b/tdc/chem_utils/oracle/filter.py @@ -9,7 +9,8 @@ rdBase.DisableLog("rdApp.error") except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from ...utils import print_sys, install @@ -73,16 +74,14 @@ def __init__( for i in filters: if i not in all_filters: raise ValueError( - i - + " not found; Please choose from a list of available filters from 'BMS', 'Dundee', 'Glaxo', 'Inpharmatica', 'LINT', 'MLSMR', 'PAINS', 'SureChEMBL'" + i + + " not found; Please choose from a list of available filters from 'BMS', 'Dundee', 'Glaxo', 'Inpharmatica', 'LINT', 'MLSMR', 'PAINS', 'SureChEMBL'" ) alert_file_name = pkg_resources.resource_filename( - "rd_filters", "data/alert_collection.csv" - ) + "rd_filters", "data/alert_collection.csv") rules_file_path = pkg_resources.resource_filename( - "rd_filters", "data/rules.json" - ) + "rd_filters", "data/rules.json") self.rf = RDFilters(alert_file_name) self.rule_dict = read_rules(rules_file_path) self.rule_dict["Rule_Inpharmatica"] = False @@ -163,15 +162,13 @@ def __call__(self, input_data): "Rot", ], ) - df_ok = df[ - (df.FILTER == "OK") - & df.MW.between(*self.rule_dict["MW"]) - & df.LogP.between(*self.rule_dict["LogP"]) - & df.HBD.between(*self.rule_dict["HBD"]) - & df.HBA.between(*self.rule_dict["HBA"]) - & df.TPSA.between(*self.rule_dict["TPSA"]) - & df.Rot.between(*self.rule_dict["Rot"]) - ] + df_ok = df[(df.FILTER == "OK") & + df.MW.between(*self.rule_dict["MW"]) & + df.LogP.between(*self.rule_dict["LogP"]) & + df.HBD.between(*self.rule_dict["HBD"]) & + df.HBA.between(*self.rule_dict["HBA"]) & + df.TPSA.between(*self.rule_dict["TPSA"]) & + df.Rot.between(*self.rule_dict["Rot"])] else: df = pd.DataFrame( diff --git a/tdc/chem_utils/oracle/oracle.py b/tdc/chem_utils/oracle/oracle.py index 64d93092..f29f1848 100644 --- a/tdc/chem_utils/oracle/oracle.py +++ b/tdc/chem_utils/oracle/oracle.py @@ -12,7 +12,6 @@ from packaging import version import pkg_resources - try: import rdkit from rdkit import Chem, DataStructs @@ -24,7 +23,8 @@ from rdkit.Chem import rdMolDescriptors from rdkit.six import iteritems except: - raise ImportError("Please install rdkit by 'conda install -c conda-forge rdkit'! ") + raise ImportError( + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") try: from scipy.stats.mstats import gmean @@ -43,7 +43,8 @@ "geometric": gmean, "arithmetic": np.mean, } -SKLEARN_VERSION = version.parse(pkg_resources.get_distribution("scikit-learn").version) +SKLEARN_VERSION = version.parse( + pkg_resources.get_distribution("scikit-learn").version) def smiles_to_rdkit_mol(smiles): @@ -259,9 +260,11 @@ class ClippedScoreModifier(ScoreModifier): Then the generated values are clipped between low and high scores. """ - def __init__( - self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0 - ) -> None: + def __init__(self, + upper_x: float, + lower_x=0.0, + high_score=1.0, + low_score=0.0) -> None: """ Args: upper_x: x-value from which (or until which if smaller than lower_x) the score is maximal @@ -292,9 +295,11 @@ class SmoothClippedScoreModifier(ScoreModifier): center of the logistic function. """ - def __init__( - self, upper_x: float, lower_x=0.0, high_score=1.0, low_score=0.0 - ) -> None: + def __init__(self, + upper_x: float, + lower_x=0.0, + high_score=1.0, + low_score=0.0) -> None: """ Args: upper_x: x-value from which (or until which if smaller than lower_x) the score approaches high_score @@ -315,7 +320,8 @@ def __init__( self.L = high_score - low_score def __call__(self, x): - return self.low_score + self.L / (1 + np.exp(-self.k * (x - self.middle_x))) + return self.low_score + self.L / (1 + np.exp(-self.k * + (x - self.middle_x))) class ThresholdedLinearModifier(ScoreModifier): @@ -371,8 +377,7 @@ def calculateScore(m): # fragment score fp = rdMolDescriptors.GetMorganFingerprint( - m, 2 - ) # <- 2 is the *radius* of the circular fingerprint + m, 2) # <- 2 is the *radius* of the circular fingerprint fps = fp.GetNonzeroElements() score1 = 0.0 nf = 0 @@ -404,14 +409,8 @@ def calculateScore(m): if nMacrocycles > 0: macrocyclePenalty = math.log10(2) - score2 = ( - 0.0 - - sizePenalty - - stereoPenalty - - spiroPenalty - - bridgePenalty - - macrocyclePenalty - ) + score2 = (0.0 - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - + macrocyclePenalty) # correction for the fingerprint density # not in the original publication, added in version 1.1 @@ -519,7 +518,8 @@ def cyp3a4_veith(smiles): try: from DeepPurpose import utils except: - raise ImportError("Please install DeepPurpose by 'pip install DeepPurpose'") + raise ImportError( + "Please install DeepPurpose by 'pip install DeepPurpose'") import os @@ -535,9 +535,10 @@ def cyp3a4_veith(smiles): X_drug = [smiles] drug_encoding = "CNN" y = [1] - X_pred = utils.data_process( - X_drug=X_drug, y=y, drug_encoding=drug_encoding, split_method="no_split" - ) + X_pred = utils.data_process(X_drug=X_drug, + y=y, + drug_encoding=drug_encoding, + split_method="no_split") # cyp3a4_veith_model = cyp3a4_veith_model.to("cuda:0") y_pred = cyp3a4_veith_model.predict(X_pred) return y_pred[0] @@ -564,8 +565,14 @@ def similarity(smiles_a, smiles_b): bmol = Chem.MolFromSmiles(smiles_b) if amol is None or bmol is None: return 0.0 - fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, 2, nBits=2048, useChirality=False) - fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, 2, nBits=2048, useChirality=False) + fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, + 2, + nBits=2048, + useChirality=False) + fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, + 2, + nBits=2048, + useChirality=False) return DataStructs.TanimotoSimilarity(fp1, fp2) @@ -709,6 +716,7 @@ def __call__(self, smiles): class AtomCounter: + def __init__(self, element): """ Args: @@ -789,6 +797,7 @@ def canonicalize(smiles: str, include_stereocenters=True): class Isomer_scoring_prev: + def __init__(self, target_smiles, means="geometric"): assert means in ["geometric", "arithmetic"] if means == "geometric": @@ -797,11 +806,11 @@ def __init__(self, target_smiles, means="geometric"): self.mean_func = np.mean atom2cnt_lst = parse_molecular_formula(target_smiles) total_atom_num = sum([cnt for atom, cnt in atom2cnt_lst]) - self.total_atom_modifier = GaussianModifier(mu=total_atom_num, sigma=2.0) - self.AtomCounter_Modifier_lst = [ - ((AtomCounter(atom)), GaussianModifier(mu=cnt, sigma=1.0)) - for atom, cnt in atom2cnt_lst - ] + self.total_atom_modifier = GaussianModifier(mu=total_atom_num, + sigma=2.0) + self.AtomCounter_Modifier_lst = [((AtomCounter(atom)), + GaussianModifier(mu=cnt, sigma=1.0)) + for atom, cnt in atom2cnt_lst] def __call__(self, test_smiles): molecule = smiles_to_rdkit_mol(test_smiles) @@ -818,6 +827,7 @@ def __call__(self, test_smiles): class Isomer_scoring: + def __init__(self, target_smiles, means="geometric"): assert means in ["geometric", "arithmetic"] if means == "geometric": @@ -826,11 +836,11 @@ def __init__(self, target_smiles, means="geometric"): self.mean_func = np.mean atom2cnt_lst = parse_molecular_formula(target_smiles) total_atom_num = sum([cnt for atom, cnt in atom2cnt_lst]) - self.total_atom_modifier = GaussianModifier(mu=total_atom_num, sigma=2.0) - self.AtomCounter_Modifier_lst = [ - ((AtomCounter(atom)), GaussianModifier(mu=cnt, sigma=1.0)) - for atom, cnt in atom2cnt_lst - ] + self.total_atom_modifier = GaussianModifier(mu=total_atom_num, + sigma=2.0) + self.AtomCounter_Modifier_lst = [((AtomCounter(atom)), + GaussianModifier(mu=cnt, sigma=1.0)) + for atom, cnt in atom2cnt_lst] def __call__(self, test_smiles): #### difference 1 @@ -864,29 +874,34 @@ def isomer_meta(target_smiles, means="geometric"): return Isomer_scoring(target_smiles, means=means) -isomers_c7h8n2o2_prev = isomer_meta_prev(target_smiles="C7H8N2O2", means="geometric") -isomers_c9h10n2o2pf2cl_prev = isomer_meta_prev( - target_smiles="C9H10N2O2PF2Cl", means="geometric" -) -isomers_c11h24_prev = isomer_meta_prev(target_smiles="C11H24", means="geometric") +isomers_c7h8n2o2_prev = isomer_meta_prev(target_smiles="C7H8N2O2", + means="geometric") +isomers_c9h10n2o2pf2cl_prev = isomer_meta_prev(target_smiles="C9H10N2O2PF2Cl", + means="geometric") +isomers_c11h24_prev = isomer_meta_prev(target_smiles="C11H24", + means="geometric") isomers_c7h8n2o2 = isomer_meta(target_smiles="C7H8N2O2", means="geometric") -isomers_c9h10n2o2pf2cl = isomer_meta(target_smiles="C9H10N2O2PF2Cl", means="geometric") +isomers_c9h10n2o2pf2cl = isomer_meta(target_smiles="C9H10N2O2PF2Cl", + means="geometric") isomers_c11h24 = isomer_meta(target_smiles="C11H24", means="geometric") class rediscovery_meta: + def __init__(self, target_smiles, fp="ECFP4"): self.similarity_func = fp2fpfunc[fp] self.target_fp = self.similarity_func(target_smiles) def __call__(self, test_smiles): test_fp = self.similarity_func(test_smiles) - similarity_value = DataStructs.TanimotoSimilarity(self.target_fp, test_fp) + similarity_value = DataStructs.TanimotoSimilarity( + self.target_fp, test_fp) return similarity_value class similarity_meta: + def __init__(self, target_smiles, fp="FCFP4", modifier_func=None): self.similarity_func = fp2fpfunc[fp] self.target_fp = self.similarity_func(target_smiles) @@ -894,7 +909,8 @@ def __init__(self, target_smiles, fp="FCFP4", modifier_func=None): def __call__(self, test_smiles): test_fp = self.similarity_func(test_smiles) - similarity_value = DataStructs.TanimotoSimilarity(self.target_fp, test_fp) + similarity_value = DataStructs.TanimotoSimilarity( + self.target_fp, test_fp) if self.modifier_func is None: modifier_score = similarity_value else: @@ -903,14 +919,14 @@ def __call__(self, test_smiles): celecoxib_rediscovery = rediscovery_meta( - target_smiles="CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F", fp="ECFP4" -) + target_smiles="CC1=CC=C(C=C1)C1=CC(=NN1C1=CC=C(C=C1)S(N)(=O)=O)C(F)(F)F", + fp="ECFP4") troglitazone_rediscovery = rediscovery_meta( - target_smiles="Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O", fp="ECFP4" -) + target_smiles="Cc1c(C)c2OC(C)(COc3ccc(CC4SC(=O)NC4=O)cc3)CCc2c(C)c1O", + fp="ECFP4") thiothixene_rediscovery = rediscovery_meta( - target_smiles="CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1", fp="ECFP4" -) + target_smiles="CN(C)S(=O)(=O)c1ccc2Sc3ccccc3C(=CCCN4CCN(C)CC4)c2c1", + fp="ECFP4") similarity_modifier = ClippedScoreModifier(upper_x=0.75) aripiprazole_similarity = similarity_meta( @@ -933,6 +949,7 @@ def __call__(self, test_smiles): class median_meta: + def __init__( self, target_smiles_1, @@ -954,13 +971,12 @@ def __init__( def __call__(self, test_smiles): test_fp1 = self.similarity_func1(test_smiles) - test_fp2 = ( - test_fp1 - if self.similarity_func2 == self.similarity_func1 - else self.similarity_func2(test_smiles) - ) - similarity_value1 = DataStructs.TanimotoSimilarity(self.target_fp1, test_fp1) - similarity_value2 = DataStructs.TanimotoSimilarity(self.target_fp2, test_fp2) + test_fp2 = (test_fp1 if self.similarity_func2 == self.similarity_func1 + else self.similarity_func2(test_smiles)) + similarity_value1 = DataStructs.TanimotoSimilarity( + self.target_fp1, test_fp1) + similarity_value2 = DataStructs.TanimotoSimilarity( + self.target_fp2, test_fp2) if self.modifier_func1 is None: modifier_score1 = similarity_value1 else: @@ -1000,6 +1016,7 @@ def __call__(self, test_smiles): class MPO_meta: + def __init__(self, means): """ target_smiles, fp in ['ECFP4', 'AP', ..., ] @@ -1023,8 +1040,7 @@ def osimertinib_mpo(test_smiles): if "osimertinib_fp_fcfc4" not in globals().keys(): global osimertinib_fp_fcfc4, osimertinib_fp_ecfc6 osimertinib_smiles = ( - "COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34" - ) + "COc1cc(N(C)CCN(C)C)c(NC(=O)C=C)cc1Nc2nccc(n2)c3cn(C)c4ccccc34") osimertinib_fp_fcfc4 = smiles_2_fingerprint_FCFP4(osimertinib_smiles) osimertinib_fp_ecfc6 = smiles_2_fingerprint_ECFP6(osimertinib_smiles) @@ -1039,13 +1055,12 @@ def osimertinib_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_v1 = sim_v1_modifier( - DataStructs.TanimotoSimilarity(osimertinib_fp_fcfc4, fp_fcfc4) - ) + DataStructs.TanimotoSimilarity(osimertinib_fp_fcfc4, fp_fcfc4)) similarity_v2 = sim_v2_modifier( - DataStructs.TanimotoSimilarity(osimertinib_fp_ecfc6, fp_ecfc6) - ) + DataStructs.TanimotoSimilarity(osimertinib_fp_ecfc6, fp_ecfc6)) - osimertinib_gmean = gmean([tpsa_score, logp_score, similarity_v1, similarity_v2]) + osimertinib_gmean = gmean( + [tpsa_score, logp_score, similarity_v1, similarity_v2]) return osimertinib_gmean @@ -1053,8 +1068,7 @@ def fexofenadine_mpo(test_smiles): if "fexofenadine_fp" not in globals().keys(): global fexofenadine_fp fexofenadine_smiles = ( - "CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4" - ) + "CC(C)(C(=O)O)c1ccc(cc1)C(O)CCCN2CCC(CC2)C(O)(c3ccccc3)c4ccccc4") fexofenadine_fp = smiles_2_fingerprint_AP(fexofenadine_smiles) similar_modifier = ClippedScoreModifier(upper_x=0.8) @@ -1066,8 +1080,7 @@ def fexofenadine_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_value = similar_modifier( - DataStructs.TanimotoSimilarity(fp_ap, fexofenadine_fp) - ) + DataStructs.TanimotoSimilarity(fp_ap, fexofenadine_fp)) fexofenadine_gmean = gmean([tpsa_score, logp_score, similarity_value]) return fexofenadine_gmean @@ -1089,11 +1102,11 @@ def ranolazine_mpo(test_smiles): tpsa_score = tpsa_modifier(Descriptors.TPSA(molecule)) logp_score = logp_modifier(Descriptors.MolLogP(molecule)) similarity_value = similar_modifier( - DataStructs.TanimotoSimilarity(fp_ap, ranolazine_fp) - ) + DataStructs.TanimotoSimilarity(fp_ap, ranolazine_fp)) fluorine_value = fluorine_modifier(fluorine_counter(molecule)) - ranolazine_gmean = gmean([tpsa_score, logp_score, similarity_value, fluorine_value]) + ranolazine_gmean = gmean( + [tpsa_score, logp_score, similarity_value, fluorine_value]) return ranolazine_gmean @@ -1147,7 +1160,8 @@ def zaleplon_mpo_prev(test_smiles): global zaleplon_fp, isomer_scoring_C19H17N3O2 zaleplon_smiles = "O=C(C)N(CC)C1=CC=CC(C2=CC=NC3=C(C=NN23)C#N)=C1" zaleplon_fp = smiles_2_fingerprint_ECFP4(zaleplon_smiles) - isomer_scoring_C19H17N3O2 = Isomer_scoring_prev(target_smiles="C19H17N3O2") + isomer_scoring_C19H17N3O2 = Isomer_scoring_prev( + target_smiles="C19H17N3O2") fp = smiles_2_fingerprint_ECFP4(test_smiles) similarity_value = DataStructs.TanimotoSimilarity(fp, zaleplon_fp) @@ -1176,8 +1190,10 @@ def sitagliptin_mpo_prev(test_smiles): sitagliptin_mol = Chem.MolFromSmiles(sitagliptin_smiles) sitagliptin_logp = Descriptors.MolLogP(sitagliptin_mol) sitagliptin_tpsa = Descriptors.TPSA(sitagliptin_mol) - sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, sigma=0.2) - sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, sigma=5) + sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, + sigma=0.2) + sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, + sigma=5) isomers_scoring_C16H15F6N5O = Isomer_scoring_prev("C16H15F6N5O") sitagliptin_similar_modifier = GaussianModifier(mu=0, sigma=0.1) @@ -1189,8 +1205,7 @@ def sitagliptin_mpo_prev(test_smiles): tpsa_score = sitagliptin_tpsa_modifier(tpsa_score) isomer_score = isomers_scoring_C16H15F6N5O(test_smiles) similarity_value = sitagliptin_similar_modifier( - DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4) - ) + DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4)) return gmean([similarity_value, logp_score, tpsa_score, isomer_score]) @@ -1202,8 +1217,10 @@ def sitagliptin_mpo(test_smiles): sitagliptin_mol = Chem.MolFromSmiles(sitagliptin_smiles) sitagliptin_logp = Descriptors.MolLogP(sitagliptin_mol) sitagliptin_tpsa = Descriptors.TPSA(sitagliptin_mol) - sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, sigma=0.2) - sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, sigma=5) + sitagliptin_logp_modifier = GaussianModifier(mu=sitagliptin_logp, + sigma=0.2) + sitagliptin_tpsa_modifier = GaussianModifier(mu=sitagliptin_tpsa, + sigma=5) isomers_scoring_C16H15F6N5O = Isomer_scoring("C16H15F6N5O") sitagliptin_similar_modifier = GaussianModifier(mu=0, sigma=0.1) @@ -1215,8 +1232,7 @@ def sitagliptin_mpo(test_smiles): tpsa_score = sitagliptin_tpsa_modifier(tpsa_score) isomer_score = isomers_scoring_C16H15F6N5O(test_smiles) similarity_value = sitagliptin_similar_modifier( - DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4) - ) + DataStructs.TanimotoSimilarity(fp_ecfp4, sitagliptin_fp_ecfp4)) return gmean([similarity_value, logp_score, tpsa_score, isomer_score]) @@ -1228,6 +1244,7 @@ def get_PHCO_fingerprint(mol): class SMARTS_scoring: + def __init__(self, target_smarts, inverse): self.target_mol = Chem.MolFromSmarts(target_smarts) self.inverse = inverse @@ -1253,12 +1270,10 @@ def deco_hop(test_smiles): pharmacophor_mol = smiles_to_rdkit_mol(pharmacophor_smiles) pharmacophor_fp = get_PHCO_fingerprint(pharmacophor_mol) - deco1_smarts_scoring = SMARTS_scoring( - target_smarts="CS([#6])(=O)=O", inverse=True - ) + deco1_smarts_scoring = SMARTS_scoring(target_smarts="CS([#6])(=O)=O", + inverse=True) deco2_smarts_scoring = SMARTS_scoring( - target_smarts="[#7]-c1ccc2ncsc2c1", inverse=True - ) + target_smarts="[#7]-c1ccc2ncsc2c1", inverse=True) scaffold_smarts_scoring = SMARTS_scoring( target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", inverse=False, @@ -1269,43 +1284,41 @@ def deco_hop(test_smiles): similarity_modifier = ClippedScoreModifier(upper_x=0.85) similarity_value = similarity_modifier( - DataStructs.TanimotoSimilarity(fp, pharmacophor_fp) - ) + DataStructs.TanimotoSimilarity(fp, pharmacophor_fp)) deco1_score = deco1_smarts_scoring(molecule) deco2_score = deco2_smarts_scoring(molecule) scaffold_score = scaffold_smarts_scoring(molecule) - all_scores = np.mean([similarity_value, deco1_score, deco2_score, scaffold_score]) + all_scores = np.mean( + [similarity_value, deco1_score, deco2_score, scaffold_score]) return all_scores def scaffold_hop(test_smiles): - if ( - "pharmacophor_fp" not in globals().keys() - or "scaffold_smarts_scoring" not in globals().keys() - or "deco_smarts_scoring" not in globals().keys() - ): + if ("pharmacophor_fp" not in globals().keys() or + "scaffold_smarts_scoring" not in globals().keys() or + "deco_smarts_scoring" not in globals().keys()): global pharmacophor_fp, deco_smarts_scoring, scaffold_smarts_scoring pharmacophor_smiles = "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C" pharmacophor_mol = smiles_to_rdkit_mol(pharmacophor_smiles) pharmacophor_fp = get_PHCO_fingerprint(pharmacophor_mol) deco_smarts_scoring = SMARTS_scoring( - target_smarts="[#6]-[#6]-[#6]-[#8]-[#6]~[#6]~[#6]~[#6]~[#6]-[#7]-c1ccc2ncsc2c1", + target_smarts= + "[#6]-[#6]-[#6]-[#8]-[#6]~[#6]~[#6]~[#6]~[#6]-[#7]-c1ccc2ncsc2c1", inverse=False, ) scaffold_smarts_scoring = SMARTS_scoring( - target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", inverse=True - ) + target_smarts="[#7]-c1n[c;h1]nc2[c;h1]c(-[#8])[c;h0][c;h1]c12", + inverse=True) molecule = smiles_to_rdkit_mol(test_smiles) fp = get_PHCO_fingerprint(molecule) similarity_modifier = ClippedScoreModifier(upper_x=0.75) similarity_value = similarity_modifier( - DataStructs.TanimotoSimilarity(fp, pharmacophor_fp) - ) + DataStructs.TanimotoSimilarity(fp, pharmacophor_fp)) deco_score = deco_smarts_scoring(molecule) scaffold_score = scaffold_smarts_scoring(molecule) @@ -1349,8 +1362,6 @@ def valsartan_smarts(test_smiles): ########################################################################### ### END of Guacamol ########################################################################### - - """ Synthesizability from a full retrosynthetic analysis Including: @@ -1502,9 +1513,9 @@ def askcos( # For each entry, repeat to test up to num_trials times if got error message for _ in range(num_trials): print("Trying to send the request, for the %i times now" % (_ + 1)) - resp = requests.get( - host_ip + "/api/treebuilder/", params=params, verify=False - ) + resp = requests.get(host_ip + "/api/treebuilder/", + params=params, + verify=False) if "error" not in resp.json().keys(): break @@ -1513,8 +1524,7 @@ def askcos( json.dump(resp.json(), f_data) num_path, status, depth, p_score, synthesizability, price = tree_analysis( - resp.json() - ) + resp.json()) if output == "plausibility": return p_score @@ -1539,13 +1549,13 @@ def ibm_rxn(smiles, api_key, output="confidence", sleep_time=30): rxn4chemistry_wrapper = RXN4ChemistryWrapper(api_key=api_key) response = rxn4chemistry_wrapper.create_project("test") time.sleep(sleep_time) - response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis(product=smiles) + response = rxn4chemistry_wrapper.predict_automatic_retrosynthesis( + product=smiles) status = "" while status != "SUCCESS": time.sleep(sleep_time) results = rxn4chemistry_wrapper.get_predict_automatic_retrosynthesis_results( - response["prediction_id"] - ) + response["prediction_id"]) status = results["status"] if output == "confidence": @@ -1557,19 +1567,22 @@ def ibm_rxn(smiles, api_key, output="confidence", sleep_time=30): class molecule_one_retro: + def __init__(self, api_token): try: from m1wrapper import MoleculeOneWrapper except: try: - install("git+https://github.com/molecule-one/m1wrapper-python@v1") + install( + "git+https://github.com/molecule-one/m1wrapper-python@v1") from m1wrapper import MoleculeOneWrapper except: raise ImportError( "Install Molecule.One Wrapper via pip install git+https://github.com/molecule-one/m1wrapper-python@v1" ) - self.m1wrapper = MoleculeOneWrapper(api_token, "https://tdc.molecule.one") + self.m1wrapper = MoleculeOneWrapper(api_token, + "https://tdc.molecule.one") def __call__(self, smiles): if isinstance(smiles, str): @@ -1577,7 +1590,10 @@ def __call__(self, smiles): search = self.m1wrapper.run_batch_search( targets=smiles, - parameters={"exploratory_search": False, "detail_level": "score"}, + parameters={ + "exploratory_search": False, + "detail_level": "score" + }, ) status_cur = search.get_status() @@ -1594,7 +1610,8 @@ def __call__(self, smiles): if status_cur != status: print_sys(status) status_cur = status - result = search.get_results(precision=5, only=["targetSmiles", "result"]) + result = search.get_results(precision=5, + only=["targetSmiles", "result"]) return {i["targetSmiles"]: i["result"] for i in result} @@ -1666,7 +1683,11 @@ def __call__(self, test_smiles, error_value=None): class Score_3d: """Evaluate Vina score (force field) for a conformer binding to a receptor""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1702,7 +1723,11 @@ def __call__(self, ligand_pdbqt_file, minimize=True): class Vina_3d: """Perform docking search from a conformer.""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1722,9 +1747,11 @@ def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): "Cannot compute the affinity map, please check center and box_size" ) - def __call__( - self, ligand_pdbqt_file, output_file="out.pdbqt", exhaustiveness=8, n_poses=10 - ): + def __call__(self, + ligand_pdbqt_file, + output_file="out.pdbqt", + exhaustiveness=8, + n_poses=10): try: self.v.set_ligand_from_file(ligand_pdbqt_file) self.v.dock(exhaustiveness=exhaustiveness, n_poses=n_poses) @@ -1739,7 +1766,11 @@ def __call__( class Vina_smiles: """Perform docking search from a conformer.""" - def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): + def __init__(self, + receptor_pdbqt_file, + center, + box_size, + scorefunction="vina"): try: from vina import Vina except: @@ -1760,9 +1791,11 @@ def __init__(self, receptor_pdbqt_file, center, box_size, scorefunction="vina"): "Cannot compute the affinity map, please check center and box_size" ) - def __call__( - self, ligand_smiles, output_file="out.pdbqt", exhaustiveness=8, n_poses=10 - ): + def __call__(self, + ligand_smiles, + output_file="out.pdbqt", + exhaustiveness=8, + n_poses=10): try: m = Chem.MolFromSmiles(ligand_smiles) m = Chem.AddHs(m) @@ -1818,15 +1851,12 @@ def smina(ligand, protein, score_only=False, raw_input=False): f.write("%d\n\n" % n_atoms) for atom_i in range(n_atoms): atom = mol_atom[atom_i] - f.write( - "%s %.9f %.9f %.9f\n" - % ( - atom, - mol_coord[atom_i, 0], - mol_coord[atom_i, 1], - mol_coord[atom_i, 2], - ) - ) + f.write("%s %.9f %.9f %.9f\n" % ( + atom, + mol_coord[atom_i, 0], + mol_coord[atom_i, 1], + mol_coord[atom_i, 2], + )) f.close() # 2. convert to sdf file try: @@ -1838,8 +1868,8 @@ def smina(ligand, protein, score_only=False, raw_input=False): ligand = "temp_ligand.sdf" if score_only: msg = os.popen( - f"./{smina_model_path} -l {ligand} -r {protein} --score_only" - ).read() + f"./{smina_model_path} -l {ligand} -r {protein} --score_only").read( + ) return float(msg.split("\n")[-7].split(" ")[-2]) else: os.system(f"./{smina_model_path} -l {ligand} -r {protein} --score_only") diff --git a/tdc/evaluator.py b/tdc/evaluator.py index e0807d27..51c3ceb0 100644 --- a/tdc/evaluator.py +++ b/tdc/evaluator.py @@ -126,8 +126,8 @@ def range_logAUC(true_y, predicted_score, FPR_range=(0.001, 0.1)): upper_bound_idx = np.where(x == upper_bound)[-1][-1] # Create a new array trimmed at the lower and upper bound - trim_x = x[lower_bound_idx : upper_bound_idx + 1] - trim_y = y[lower_bound_idx : upper_bound_idx + 1] + trim_x = x[lower_bound_idx:upper_bound_idx + 1] + trim_y = y[lower_bound_idx:upper_bound_idx + 1] area = auc(trim_x, trim_y) / (upper_bound - lower_bound) return area @@ -371,7 +371,6 @@ def kabsch_weighted(P, Q, W=None): class Evaluator: - """evaluator to evaluate predictions Args: diff --git a/tdc/generation/bi_generation_dataset.py b/tdc/generation/bi_generation_dataset.py index 229a9b20..6c95bce3 100644 --- a/tdc/generation/bi_generation_dataset.py +++ b/tdc/generation/bi_generation_dataset.py @@ -15,7 +15,6 @@ class DataLoader(base_dataset.DataLoader): - """A base dataset loader class. Attributes: @@ -34,7 +33,9 @@ def __init__( threshold=15, remove_Hs=True, keep_het=False, - allowed_atom_list=["C", "N", "O", "S", "H", "B", "Br", "Cl", "P", "I", "F"], + allowed_atom_list=[ + "C", "N", "O", "S", "H", "B", "Br", "Cl", "P", "I", "F" + ], ): """To create a base dataloader object that each generation task can inherit from. @@ -115,6 +116,7 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): protein, ligand = data["protein"], data["ligand"] if method == "random": - return create_combination_generation_split(protein, ligand, seed, frac) + return create_combination_generation_split(protein, ligand, seed, + frac) else: raise AttributeError("Please use the correct split method") diff --git a/tdc/generation/generation_dataset.py b/tdc/generation/generation_dataset.py index cf008761..704eef19 100644 --- a/tdc/generation/generation_dataset.py +++ b/tdc/generation/generation_dataset.py @@ -20,7 +20,6 @@ class DataLoader(base_dataset.DataLoader): - """A base dataset loader class. Attributes: @@ -42,8 +41,7 @@ def __init__(self, name, path, print_stats, column_name): from ..metadata import single_molecule_dataset_names self.smiles_lst = distribution_dataset_load( - name, path, single_molecule_dataset_names, column_name=column_name - ) + name, path, single_molecule_dataset_names, column_name=column_name) ### including fuzzy-search self.name = name self.path = path @@ -102,7 +100,6 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): class PairedDataLoader(base_dataset.DataLoader): - """A basic class for generation of biomedical entities conditioned on other entities, such as reaction prediction. Attributes: @@ -125,8 +122,8 @@ def __init__(self, name, path, print_stats, input_name, output_name): from ..metadata import paired_dataset_names self.input_smiles_lst, self.output_smiles_lst = generation_paired_dataset_load( - name, path, paired_dataset_names, input_name, output_name - ) ### including fuzzy-search + name, path, paired_dataset_names, input_name, + output_name) ### including fuzzy-search self.name = name self.path = path self.dataset_names = paired_dataset_names @@ -155,11 +152,15 @@ def get_data(self, format="df"): AttributeError: Use the correct format as input (df, dict) """ if format == "df": - return pd.DataFrame( - {"input": self.input_smiles_lst, "output": self.output_smiles_lst} - ) + return pd.DataFrame({ + "input": self.input_smiles_lst, + "output": self.output_smiles_lst + }) elif format == "dict": - return {"input": self.input_smiles_lst, "output": self.output_smiles_lst} + return { + "input": self.input_smiles_lst, + "output": self.output_smiles_lst + } else: raise AttributeError("Please use the correct format input") @@ -187,7 +188,6 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): class DataLoader3D(base_dataset.DataLoader): - """A basic class for generation of 3D biomedical entities. (under construction) Attributes: @@ -209,8 +209,7 @@ def __init__(self, name, path, print_stats, dataset_names, column_name): column_name (str): The name of the column containing smiles strings. """ self.df, self.path, self.name = three_dim_dataset_load( - name, path, dataset_names - ) + name, path, dataset_names) if print_stats: self.print_stats() print_sys("Done!") @@ -240,7 +239,8 @@ def get_data(self, format="df", more_features="None"): """ if more_features in ["None", "SMILES"]: pass - elif more_features in ["Graph3D", "Coulumb", "SELFIES"]: # why SELFIES here? + elif more_features in ["Graph3D", "Coulumb", + "SELFIES"]: # why SELFIES here? try: from rdkit.Chem.PandasTools import LoadSDF from rdkit import rdBase @@ -256,7 +256,8 @@ def get_data(self, format="df", more_features="None"): convert = MolConvert(src="SDF", dst=more_features) for i in sdf_file_names[self.name]: - self.df[i + "_" + more_features] = convert(self.path + i + ".sdf") + self.df[i + "_" + more_features] = convert(self.path + i + + ".sdf") if format == "df": return self.df diff --git a/tdc/generation/ligandmolgen.py b/tdc/generation/ligandmolgen.py index aeecd538..18ec63af 100644 --- a/tdc/generation/ligandmolgen.py +++ b/tdc/generation/ligandmolgen.py @@ -11,7 +11,6 @@ class LigandMolGen(bi_generation_dataset.DataLoader): - """Data loader class accessing to pocket-based ligand generation task.""" def __init__(self, name, path="./data", print_stats=False): diff --git a/tdc/generation/molgen.py b/tdc/generation/molgen.py index 0c75c486..8e76cb5c 100644 --- a/tdc/generation/molgen.py +++ b/tdc/generation/molgen.py @@ -11,10 +11,13 @@ class MolGen(generation_dataset.DataLoader): - """Data loader class accessing to molecular generation task (distribution learning)""" - def __init__(self, name, path="./data", print_stats=False, column_name="smiles"): + def __init__(self, + name, + path="./data", + print_stats=False, + column_name="smiles"): """To create an data loader object for molecular generation task. The goal is to generate diverse, novel molecules that has desirable chemical properties. One can combined with oracle functions. diff --git a/tdc/generation/reaction.py b/tdc/generation/reaction.py index 79f11080..4a3d83ca 100644 --- a/tdc/generation/reaction.py +++ b/tdc/generation/reaction.py @@ -11,7 +11,6 @@ class Reaction(generation_dataset.PairedDataLoader): - """Data loader class accessing to forward reaction prediction task.""" def __init__( diff --git a/tdc/generation/retrosyn.py b/tdc/generation/retrosyn.py index df1232c8..2420ce26 100644 --- a/tdc/generation/retrosyn.py +++ b/tdc/generation/retrosyn.py @@ -12,7 +12,6 @@ class RetroSyn(generation_dataset.PairedDataLoader): - """Data loader class accessing to retro-synthetic prediction task.""" def __init__( @@ -66,10 +65,8 @@ def get_split( df["reaction_type"] = rt except: raise ValueError( - "Reaction Type Unavailable for " - + str(self.name) - + "! Please turn include_reaction_type to be false!" - ) + "Reaction Type Unavailable for " + str(self.name) + + "! Please turn include_reaction_type to be false!") if method == "random": return create_fold(df, seed, frac) diff --git a/tdc/generation/sbdd.py b/tdc/generation/sbdd.py index 80368335..a697049d 100644 --- a/tdc/generation/sbdd.py +++ b/tdc/generation/sbdd.py @@ -16,7 +16,6 @@ class SBDD(base_dataset.DataLoader): - """Data loader class accessing to structure-based drug design task.""" def __init__( @@ -51,7 +50,8 @@ def __init__( try: import biopandas except: - raise ImportError("Please install biopandas by 'pip install biopandas'! ") + raise ImportError( + "Please install biopandas by 'pip install biopandas'! ") protein, ligand = bi_distribution_dataset_load( name, path, @@ -126,7 +126,8 @@ def get_split(self, method="random", seed=42, frac=[0.7, 0.1, 0.2]): data = self.get_data(format="dict") protein, ligand = data["protein"], data["ligand"] - splitted_data = create_combination_generation_split(protein, ligand, seed, frac) + splitted_data = create_combination_generation_split( + protein, ligand, seed, frac) if self.save: np.savez( diff --git a/tdc/metadata.py b/tdc/metadata.py index 9ad5f686..26050b67 100644 --- a/tdc/metadata.py +++ b/tdc/metadata.py @@ -3,7 +3,6 @@ # License: MIT from packaging import version import pkg_resources - """This file contains all metadata of datasets in TDC. Attributes: @@ -121,18 +120,16 @@ "clearance_microsome_az", ] -hts_dataset_names = ["hiv", - "sarscov2_3clpro_diamond", - "sarscov2_vitro_touret", - "orexin1_receptor_butkiewicz", - "m1_muscarinic_receptor_agonists_butkiewicz", - "m1_muscarinic_receptor_antagonists_butkiewicz", - "potassium_ion_channel_kir2.1_butkiewicz", - "kcnq2_potassium_channel_butkiewicz", - "cav3_t-type_calcium_channels_butkiewicz", - "choline_transporter_butkiewicz", - "serine_threonine_kinase_33_butkiewicz", - "tyrosyl-dna_phosphodiesterase_butkiewicz"] +hts_dataset_names = [ + "hiv", "sarscov2_3clpro_diamond", "sarscov2_vitro_touret", + "orexin1_receptor_butkiewicz", "m1_muscarinic_receptor_agonists_butkiewicz", + "m1_muscarinic_receptor_antagonists_butkiewicz", + "potassium_ion_channel_kir2.1_butkiewicz", + "kcnq2_potassium_channel_butkiewicz", + "cav3_t-type_calcium_channels_butkiewicz", "choline_transporter_butkiewicz", + "serine_threonine_kinase_33_butkiewicz", + "tyrosyl-dna_phosphodiesterase_butkiewicz" +] qm_dataset_names = ["qm7", "qm7b", "qm8", "qm9"] @@ -142,7 +139,6 @@ develop_dataset_names = ["tap", "sabdab_chen"] - # multi_pred prediction dti_dataset_names = [ @@ -183,7 +179,6 @@ #################################### # generation - retrosyn_dataset_names = ["uspto50k", "uspto"] forwardsyn_dataset_names = ["uspto"] @@ -194,7 +189,6 @@ paired_dataset_names = ["uspto50k", "uspto"] - #################################### # resource @@ -287,7 +281,6 @@ "scaffold_hop", ] - #################################### # Benchmark Datasets @@ -342,7 +335,10 @@ } docking_target_info = { - "3pbl": {"center": (9, 22.5, 26), "size": (15, 15, 15)}, + "3pbl": { + "center": (9, 22.5, 26), + "size": (15, 15, 15) + }, "1iep": { "center": (15.61389189189189, 53.38013513513513, 15.454837837837842), "size": (15, 15, 15), @@ -355,7 +351,10 @@ "center": (-9.063639999999998, -7.1446, 55.86259999999999), "size": (15, 15, 15), }, - "3ny8": {"center": (2.2488, 4.68495, 51.39820000000001), "size": (15, 15, 15)}, + "3ny8": { + "center": (2.2488, 4.68495, 51.39820000000001), + "size": (15, 15, 15) + }, "4rlu": { "center": (-0.7359999999999999, 22.75547368421052, -31.2368947368421), "size": (15, 15, 15), @@ -460,13 +459,16 @@ #################################### # evaluator for single molecule, the input of __call__ is a single smiles OR list of smiles -download_oracle_names = ["drd2", "gsk3b", "jnk3", "fpscores", "cyp3a4_veith", "smina"] -# download_oracle_names = ['drd2', 'gsk3b', 'jnk3', 'fpscores', 'cyp3a4_veith'] -download_oracle_names = ["drd2", "gsk3b", "jnk3", "fpscores", "cyp3a4_veith"] + [ - "drd2_current", - "gsk3b_current", - "jnk3_current", +download_oracle_names = [ + "drd2", "gsk3b", "jnk3", "fpscores", "cyp3a4_veith", "smina" ] +# download_oracle_names = ['drd2', 'gsk3b', 'jnk3', 'fpscores', 'cyp3a4_veith'] +download_oracle_names = ["drd2", "gsk3b", "jnk3", "fpscores", "cyp3a4_veith" + ] + [ + "drd2_current", + "gsk3b_current", + "jnk3_current", + ] trivial_oracle_names = ["qed", "logp", "sa"] + guacamol_oracle synthetic_oracle_name = ["askcos", "ibm_rxn"] @@ -503,7 +505,6 @@ "3pbl_docking_vina", ] - meta_oracle_name = [ "isomer_meta", "rediscovery_meta", @@ -514,24 +515,16 @@ "pyscreener", ] -oracle_names = ( - download_oracle_names - + trivial_oracle_names - + distribution_oracles - + synthetic_oracle_name - + meta_oracle_name - + docking_oracles - + download_receptor_oracle_name -) +oracle_names = (download_oracle_names + trivial_oracle_names + + distribution_oracles + synthetic_oracle_name + + meta_oracle_name + docking_oracles + + download_receptor_oracle_name) molgenpaired_dataset_names = ["qed", "drd2", "logp"] -generation_datasets = ( - retrosyn_dataset_names - + forwardsyn_dataset_names - + molgenpaired_dataset_names - + multiple_molecule_dataset_names -) +generation_datasets = (retrosyn_dataset_names + forwardsyn_dataset_names + + molgenpaired_dataset_names + + multiple_molecule_dataset_names) # generation #################################### @@ -559,7 +552,7 @@ "GDA", "Catalyst", "TCR_Epitope_Binding", - "TrialOutcome", + "TrialOutcome", ], "generation": ["RetroSyn", "Reaction", "MolGen"], } @@ -599,8 +592,8 @@ def get_task2category(): "CRISPROutcome": crisproutcome_dataset_names, "test_single_pred": test_single_pred_dataset_names, "test_multi_pred": test_multi_pred_dataset_names, - "TCREpitopeBinding": tcr_epi_dataset_names, - "TrialOutcome": trial_outcome_dataset_names, + "TCREpitopeBinding": tcr_epi_dataset_names, + "TrialOutcome": trial_outcome_dataset_names, } benchmark_names = { @@ -662,14 +655,14 @@ def get_task2category(): "hiv": "tab", "sarscov2_3clpro_diamond": "tab", "sarscov2_vitro_touret": "tab", - "orexin1_receptor_butkiewicz": "tab", - "m1_muscarinic_receptor_agonists_butkiewicz": "tab", + "orexin1_receptor_butkiewicz": "tab", + "m1_muscarinic_receptor_agonists_butkiewicz": "tab", "m1_muscarinic_receptor_antagonists_butkiewicz": "tab", - "potassium_ion_channel_kir2.1_butkiewicz": "tab", - "kcnq2_potassium_channel_butkiewicz": "tab", - "cav3_t-type_calcium_channels_butkiewicz": "tab", - "choline_transporter_butkiewicz": "tab", - "serine_threonine_kinase_33_butkiewicz": "tab", + "potassium_ion_channel_kir2.1_butkiewicz": "tab", + "kcnq2_potassium_channel_butkiewicz": "tab", + "cav3_t-type_calcium_channels_butkiewicz": "tab", + "choline_transporter_butkiewicz": "tab", + "serine_threonine_kinase_33_butkiewicz": "tab", "tyrosyl-dna_phosphodiesterase_butkiewicz": "tab", "davis": "tab", "kiba": "tab", @@ -735,10 +728,10 @@ def get_task2category(): "primekg": "tab", "primekg_drug_feature": "tab", "primekg_disease_feature": "tab", - "drug_comb_meta_data": "pkl", + "drug_comb_meta_data": "pkl", "phase1": "tab", - "phase2": "tab", - "phase3": "tab", + "phase2": "tab", + "phase3": "tab", } name2id = { @@ -782,14 +775,14 @@ def get_task2category(): "ppbr_ma": 4259603, "sarscov2_3clpro_diamond": 4259606, "sarscov2_vitro_touret": 4259607, - "orexin1_receptor_butkiewicz": 6894447, - "m1_muscarinic_receptor_agonists_butkiewicz": 6894443, + "orexin1_receptor_butkiewicz": 6894447, + "m1_muscarinic_receptor_agonists_butkiewicz": 6894443, "m1_muscarinic_receptor_antagonists_butkiewicz": 6894446, - "potassium_ion_channel_kir2.1_butkiewicz": 6894442, - "kcnq2_potassium_channel_butkiewicz": 6894444, - "cav3_t-type_calcium_channels_butkiewicz": 6894445, - "choline_transporter_butkiewicz": 6894441, - "serine_threonine_kinase_33_butkiewicz": 6894448, + "potassium_ion_channel_kir2.1_butkiewicz": 6894442, + "kcnq2_potassium_channel_butkiewicz": 6894444, + "cav3_t-type_calcium_channels_butkiewicz": 6894445, + "choline_transporter_butkiewicz": 6894441, + "serine_threonine_kinase_33_butkiewicz": 6894448, "tyrosyl-dna_phosphodiesterase_butkiewicz": 6894440, "solubility_aqsoldb": 4259610, "tox21": 4259612, @@ -849,10 +842,10 @@ def get_task2category(): "primekg": 6180626, "primekg_drug_feature": 6180619, "primekg_disease_feature": 6180618, - "drug_comb_meta_data": 7104245, - "phase1": 7331305, - "phase2": 7331306, - "phase3": 7331307, + "drug_comb_meta_data": 7104245, + "phase1": 7331305, + "phase2": 7331306, + "phase3": 7331307, } oracle2type = { @@ -880,7 +873,6 @@ def get_task2category(): "gsk3b_current": 6413412, } - benchmark2type = { "admet_group": "zip", "drugcombo_group": "zip", @@ -907,7 +899,6 @@ def get_task2category(): "3pbl": [5257195, 5617666], } ## 'drd3': 5137901, - sdf_file_names = {"grambow": ["Product", "Reactant", "TS"]} name2stats = { diff --git a/tdc/multi_pred/antibodyaff.py b/tdc/multi_pred/antibodyaff.py index 85e68774..cfca5143 100644 --- a/tdc/multi_pred/antibodyaff.py +++ b/tdc/multi_pred/antibodyaff.py @@ -13,7 +13,6 @@ class AntibodyAff(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Antibody-antigen Affinity Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/antibodyaff/ diff --git a/tdc/multi_pred/bi_pred_dataset.py b/tdc/multi_pred/bi_pred_dataset.py index 8c00be44..5fc788cc 100644 --- a/tdc/multi_pred/bi_pred_dataset.py +++ b/tdc/multi_pred/bi_pred_dataset.py @@ -26,7 +26,6 @@ class DataLoader(base_dataset.DataLoader): - """A base data loader class that each bi-instance prediction task dataloader class can inherit from. Attributes: TODO @@ -52,10 +51,8 @@ def __init__(self, name, path, label_name, print_stats, dataset_names): if label_name is None: raise ValueError( "Please select a label name. " - "You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + "You can use tdc.utils.retrieve_label_name_list('" + + name.lower() + "') to retrieve all available label names.") name = fuzzy_search(name, dataset_names) if name == "bindingdb_patent": @@ -70,9 +67,11 @@ def __init__(self, name, path, label_name, print_stats, dataset_names): entity1_idx, entity2_idx, aux_column_val, - ) = interaction_dataset_load( - name, path, label_name, dataset_names, aux_column=aux_column - ) + ) = interaction_dataset_load(name, + path, + label_name, + dataset_names, + aux_column=aux_column) self.name = name self.entity1 = entity1 @@ -107,26 +106,22 @@ def get_data(self, format="df"): """ if format == "df": if self.aux_column is None: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity2_name + "_ID": self.entity2_idx, - self.entity2_name: self.entity2, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + self.entity2_name + "_ID": self.entity2_idx, + self.entity2_name: self.entity2, + "Y": self.y, + }) else: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity2_name + "_ID": self.entity2_idx, - self.entity2_name: self.entity2, - "Y": self.y, - self.aux_column: self.aux_column_val, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + self.entity2_name + "_ID": self.entity2_idx, + self.entity2_name: self.entity2, + "Y": self.y, + self.aux_column: self.aux_column_val, + }) elif format == "DeepPurpose": return self.entity1.values, self.entity2.values, self.y.values @@ -165,12 +160,8 @@ def print_stats(self): file=sys.stderr, ) print( - str(len(self.y)) - + " " - + self.entity1_name.lower() - + "-" - + self.entity2_name.lower() - + " pairs.", + str(len(self.y)) + " " + self.entity1_name.lower() + "-" + + self.entity2_name.lower() + " pairs.", flush=True, file=sys.stderr, ) @@ -218,12 +209,10 @@ def get_split( return create_fold_setting_cold(df, seed, frac, self.entity2_name) elif method == "cold_split": if column_name is None or not all( - list(map(lambda x: x in df.columns.values, column_name)) - ): + list(map(lambda x: x in df.columns.values, column_name))): raise AttributeError( "For cold_split, please provide one or multiple column names " - "that are contained in the dataframe." - ) + "that are contained in the dataframe.") return create_fold_setting_cold(df, seed, frac, column_name) elif method == "combination": return create_combination_split(df, seed, frac) @@ -298,9 +287,8 @@ def to_graph( if len(np.unique(self.raw_y)) > 2: print( "The dataset label consists of affinity scores. " - "Binarization using threshold " - + str(threshold) - + " is conducted to construct the positive edges in the network. " + "Binarization using threshold " + str(threshold) + + " is conducted to construct the positive edges in the network. " "Adjust the threshold by to_graph(threshold = X)", flush=True, file=sys.stderr, @@ -308,29 +296,34 @@ def to_graph( if threshold is None: raise AttributeError( "Please specify the threshold to binarize the data by " - "'to_graph(threshold = N)'!" - ) - df["label_binary"] = label_transform( - self.raw_y, True, threshold, False, verbose=False, order=order - ) + "'to_graph(threshold = N)'!") + df["label_binary"] = label_transform(self.raw_y, + True, + threshold, + False, + verbose=False, + order=order) else: # already binary df["label_binary"] = df["Y"] - df[self.entity1_name + "_ID"] = df[self.entity1_name + "_ID"].astype(str) - df[self.entity2_name + "_ID"] = df[self.entity2_name + "_ID"].astype(str) + df[self.entity1_name + "_ID"] = df[self.entity1_name + + "_ID"].astype(str) + df[self.entity2_name + "_ID"] = df[self.entity2_name + + "_ID"].astype(str) df_pos = df[df.label_binary == 1] df_neg = df[df.label_binary == 0] return_dict = {} - pos_edges = df_pos[ - [self.entity1_name + "_ID", self.entity2_name + "_ID"] - ].values - neg_edges = df_neg[ - [self.entity1_name + "_ID", self.entity2_name + "_ID"] - ].values - edges = df[[self.entity1_name + "_ID", self.entity2_name + "_ID"]].values + pos_edges = df_pos[[ + self.entity1_name + "_ID", self.entity2_name + "_ID" + ]].values + neg_edges = df_neg[[ + self.entity1_name + "_ID", self.entity2_name + "_ID" + ]].values + edges = df[[self.entity1_name + "_ID", + self.entity2_name + "_ID"]].values if format == "edge_list": return_dict["edge_list"] = pos_edges @@ -364,7 +357,8 @@ def to_graph( edge_list1 = np.array([dict_[i] for i in pos_edges.T[0]]) edge_list2 = np.array([dict_[i] for i in pos_edges.T[1]]) - edge_index = torch.tensor([edge_list1, edge_list2], dtype=torch.long) + edge_index = torch.tensor([edge_list1, edge_list2], + dtype=torch.long) x = torch.tensor(np.array(index), dtype=torch.float) data = Data(x=x, edge_index=edge_index) return_dict["pyg_graph"] = data diff --git a/tdc/multi_pred/catalyst.py b/tdc/multi_pred/catalyst.py index 285a06f7..3c797a55 100644 --- a/tdc/multi_pred/catalyst.py +++ b/tdc/multi_pred/catalyst.py @@ -13,7 +13,6 @@ class Catalyst(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Catalyst Prediction task More info: https://tdcommons.ai/multi_pred_tasks/catalyst/ @@ -33,9 +32,11 @@ class Catalyst(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Catalyst Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["Catalyst"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["Catalyst"]) self.entity1_name = "Reactant" self.entity2_name = "Product" self.two_types = True diff --git a/tdc/multi_pred/ddi.py b/tdc/multi_pred/ddi.py index ce5b41d9..b56559c6 100644 --- a/tdc/multi_pred/ddi.py +++ b/tdc/multi_pred/ddi.py @@ -13,7 +13,6 @@ class DDI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug-Drug Interaction Prediction task More info: https://tdcommons.ai/multi_pred_tasks/ddi/ @@ -32,9 +31,11 @@ class DDI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug-Drug Interaction (DDI) Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DDI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DDI"]) self.entity1_name = "Drug1" self.entity2_name = "Drug2" self.two_types = False @@ -50,9 +51,9 @@ def print_stats(self): print_sys("--- Dataset Statistics ---") print( - "There are " - + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) - + " unique drugs.", + "There are " + + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + + " unique drugs.", flush=True, file=sys.stderr, ) diff --git a/tdc/multi_pred/drugres.py b/tdc/multi_pred/drugres.py index b526fec3..85214cf5 100644 --- a/tdc/multi_pred/drugres.py +++ b/tdc/multi_pred/drugres.py @@ -14,7 +14,6 @@ class DrugRes(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug Response Prediction Task. More info: https://tdcommons.ai/multi_pred_tasks/drugres/ @@ -33,9 +32,11 @@ class DrugRes(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug Response Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DrugRes"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DrugRes"]) self.entity1_name = "Drug" self.entity2_name = "Cell Line" self.two_types = True @@ -51,12 +52,11 @@ def get_gene_symbols(self): Retrieve the gene symbols for the cell line gene expression """ path = self.path - name = download_wrapper("gdsc_gene_symbols", path, ["gdsc_gene_symbols"]) + name = download_wrapper("gdsc_gene_symbols", path, + ["gdsc_gene_symbols"]) print_sys("Loading...") import pandas as pd import os df = pd.read_csv(os.path.join(path, name + ".tab"), sep="\t") - return df.values.reshape( - -1, - ) + return df.values.reshape(-1,) diff --git a/tdc/multi_pred/drugsyn.py b/tdc/multi_pred/drugsyn.py index 5c27673a..9a2ee2df 100644 --- a/tdc/multi_pred/drugsyn.py +++ b/tdc/multi_pred/drugsyn.py @@ -13,7 +13,6 @@ class DrugSyn(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug Synergy Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/drugsyn/ @@ -32,9 +31,10 @@ class DrugSyn(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", print_stats=False): """Create Drug Synergy Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["DrugSyn"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["DrugSyn"]) if print_stats: self.print_stats() diff --git a/tdc/multi_pred/dti.py b/tdc/multi_pred/dti.py index 33655a04..aa12b736 100644 --- a/tdc/multi_pred/dti.py +++ b/tdc/multi_pred/dti.py @@ -13,7 +13,6 @@ class DTI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Drug-Target Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/dti/ @@ -34,9 +33,11 @@ class DTI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Drug-Target Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["DTI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["DTI"]) self.entity1_name = "Drug" self.entity2_name = "Target" self.two_types = True @@ -60,30 +61,21 @@ def harmonize_affinities(self, mode=None): print_sys( "The scale is converted to log scale, so we will take the maximum!" ) - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(max) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(max).reset_index()) else: print_sys( "The scale is in original affinity scale, so we will take the minimum!" ) - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(min) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(min).reset_index()) elif mode == "mean": import numpy as np df_ = self.get_data() - df = ( - df_.groupby(["Drug_ID", "Drug", "Target_ID", "Target"]) - .Y.agg(np.mean) - .reset_index() - ) + df = (df_.groupby(["Drug_ID", "Drug", "Target_ID", + "Target"]).Y.agg(np.mean).reset_index()) self.entity1_idx = df.Drug_ID.values self.entity2_idx = df.Target_ID.values @@ -92,4 +84,4 @@ def harmonize_affinities(self, mode=None): self.entity2 = df.Target.values self.y = df.Y.values print_sys("The original data has been updated!") - return df + return df \ No newline at end of file diff --git a/tdc/multi_pred/gda.py b/tdc/multi_pred/gda.py index e3609a6f..bc69477d 100644 --- a/tdc/multi_pred/gda.py +++ b/tdc/multi_pred/gda.py @@ -13,7 +13,6 @@ class GDA(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Gene-Disease Association Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/gdi/ @@ -35,9 +34,11 @@ class GDA(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Gene-Disease Association Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["GDA"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["GDA"]) self.entity1_name = "Gene" self.entity2_name = "Disease" self.two_types = True diff --git a/tdc/multi_pred/mti.py b/tdc/multi_pred/mti.py index 2e2dafe4..f0ec169a 100644 --- a/tdc/multi_pred/mti.py +++ b/tdc/multi_pred/mti.py @@ -13,7 +13,6 @@ class MTI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in MicroRNA-Target Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/mti/ @@ -35,9 +34,11 @@ class MTI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create MicroRNA-Target Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["MTI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["MTI"]) self.entity1_name = "miRNA" self.entity2_name = "Target" self.two_types = True diff --git a/tdc/multi_pred/multi_pred_dataset.py b/tdc/multi_pred/multi_pred_dataset.py index 4c5ba3ae..eabe3fe2 100644 --- a/tdc/multi_pred/multi_pred_dataset.py +++ b/tdc/multi_pred/multi_pred_dataset.py @@ -41,9 +41,8 @@ def __init__(self, name, path, print_stats, dataset_names): if label_name is None: raise ValueError( "Please select a label name. You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + + name.lower() + + "') to retrieve all available label names.") df = multi_dataset_load(name, path, dataset_names) @@ -77,9 +76,11 @@ def print_stats(self): print(str(len(self.df)) + " data points.", flush=True, file=sys.stderr) print_sys("--------------------------") - def get_split( - self, method="random", seed=42, frac=[0.7, 0.1, 0.2], column_name=None - ): + def get_split(self, + method="random", + seed=42, + frac=[0.7, 0.1, 0.2], + column_name=None): """split dataset into train/validation/test. Args: @@ -106,9 +107,8 @@ def get_split( elif method == "cold_split": if isinstance(column_name, str): column_name = [column_name] - if (column_name is None) or ( - not all([x in df.columns.values for x in column_name]) - ): + if (column_name is None) or (not all( + [x in df.columns.values for x in column_name])): raise AttributeError( "For cold_split, please provide one or multiple column names that are contained in the dataframe." ) diff --git a/tdc/multi_pred/peptidemhc.py b/tdc/multi_pred/peptidemhc.py index 21219b05..cb9c7779 100644 --- a/tdc/multi_pred/peptidemhc.py +++ b/tdc/multi_pred/peptidemhc.py @@ -13,7 +13,6 @@ class PeptideMHC(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Peptide-MHC Binding Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/peptidemhc/ diff --git a/tdc/multi_pred/ppi.py b/tdc/multi_pred/ppi.py index d135d4e0..31d4a7b1 100644 --- a/tdc/multi_pred/ppi.py +++ b/tdc/multi_pred/ppi.py @@ -13,7 +13,6 @@ class PPI(bi_pred_dataset.DataLoader): - """Data loader class to load datasets in Protein-Protein Interaction Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/ppi/ @@ -33,9 +32,11 @@ class PPI(bi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Protein-Protein Interaction Prediction dataloader object""" - super().__init__( - name, path, label_name, print_stats, dataset_names=dataset_names["PPI"] - ) + super().__init__(name, + path, + label_name, + print_stats, + dataset_names=dataset_names["PPI"]) self.entity1_name = "Protein1" self.entity2_name = "Protein2" self.two_types = False @@ -51,9 +52,9 @@ def print_stats(self): print_sys("--- Dataset Statistics ---") print( - "There are " - + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) - + " unique proteins.", + "There are " + + str(len(np.unique(self.entity1.tolist() + self.entity2.tolist()))) + + " unique proteins.", flush=True, file=sys.stderr, ) diff --git a/tdc/multi_pred/tcr_epi.py b/tdc/multi_pred/tcr_epi.py index efa37f97..0d1aeb5a 100644 --- a/tdc/multi_pred/tcr_epi.py +++ b/tdc/multi_pred/tcr_epi.py @@ -14,7 +14,6 @@ class TCREpitopeBinding(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in T cell receptor (TCR) Specificity Prediction Task. More info: @@ -31,9 +30,10 @@ class TCREpitopeBinding(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", print_stats=False): """Create TCR Specificity Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["TCREpitopeBinding"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["TCREpitopeBinding"]) self.entity1_name = "TCR" self.entity2_name = "Epitope" diff --git a/tdc/multi_pred/test_multi_pred.py b/tdc/multi_pred/test_multi_pred.py index 4aa01c3d..981e182a 100644 --- a/tdc/multi_pred/test_multi_pred.py +++ b/tdc/multi_pred/test_multi_pred.py @@ -13,7 +13,6 @@ class TestMultiPred(bi_pred_dataset.DataLoader): - """Summary Attributes: diff --git a/tdc/multi_pred/trialoutcome.py b/tdc/multi_pred/trialoutcome.py index 03f46192..25490e14 100644 --- a/tdc/multi_pred/trialoutcome.py +++ b/tdc/multi_pred/trialoutcome.py @@ -13,7 +13,6 @@ class TrialOutcome(multi_pred_dataset.DataLoader): - """Data loader class to load datasets in clinical trial outcome Prediction task. More info: https://tdcommons.ai/multi_pred_tasks/trialoutcome/ @@ -35,9 +34,10 @@ class TrialOutcome(multi_pred_dataset.DataLoader): def __init__(self, name, path="./data", label_name=None, print_stats=False): """Create Clinical Trial Outcome Prediction dataloader object""" - super().__init__( - name, path, print_stats, dataset_names=dataset_names["TrialOutcome"] - ) + super().__init__(name, + path, + print_stats, + dataset_names=dataset_names["TrialOutcome"]) self.entity1_name = "drug_molecule" self.entity2_name = "disease_code" # self.entity3_name = "eligibility_criteria" diff --git a/tdc/oracles.py b/tdc/oracles.py index 1e88e244..4d27f695 100644 --- a/tdc/oracles.py +++ b/tdc/oracles.py @@ -17,7 +17,8 @@ docking_target_info, ) -SKLEARN_VERSION = version.parse(pkg_resources.get_distribution("scikit-learn").version) +SKLEARN_VERSION = version.parse( + pkg_resources.get_distribution("scikit-learn").version) def _normalize_docking_score(raw_score): @@ -25,7 +26,6 @@ def _normalize_docking_score(raw_score): class Oracle: - """the oracle class to retrieve any oracle given by query name Args: @@ -112,28 +112,24 @@ def assign_evaluator(self): from .chem_utils import similarity_meta self.evaluator_func = similarity_meta( - target_smiles=self.target_smiles, **self.kwargs - ) + target_smiles=self.target_smiles, **self.kwargs) elif self.name == "rediscovery_meta": from .chem_utils import rediscovery_meta self.evaluator_func = rediscovery_meta( - target_smiles=self.target_smiles, **self.kwargs - ) + target_smiles=self.target_smiles, **self.kwargs) elif self.name == "isomer_meta": from .chem_utils import isomer_meta - self.evaluator_func = isomer_meta( - target_smiles=self.target_smiles, **self.kwargs - ) + self.evaluator_func = isomer_meta(target_smiles=self.target_smiles, + **self.kwargs) elif self.name == "median_meta": from .chem_utils import median_meta self.evaluator_func = median_meta( target_smiles_1=self.target_smiles[0], target_smiles_2=self.target_smiles[1], - **self.kwargs - ) + **self.kwargs) elif self.name == "rediscovery": from .chem_utils import ( celecoxib_rediscovery, @@ -257,7 +253,10 @@ def assign_evaluator(self): elif self.name == "hop": from .chem_utils import deco_hop, scaffold_hop - self.evaluator_func = {"Deco Hop": deco_hop, "Scaffold Hop": scaffold_hop} + self.evaluator_func = { + "Deco Hop": deco_hop, + "Scaffold Hop": scaffold_hop + } elif self.name == "deco_hop": from .chem_utils import deco_hop @@ -318,12 +317,9 @@ def assign_evaluator(self): box_size=boxsize, ) - elif ( - self.name == "drd3_docking" - or self.name == "3pbl_docking" - or self.name == "drd3_docking_normalize" - or self.name == "3pbl_docking_normalize" - ): + elif (self.name == "drd3_docking" or self.name == "3pbl_docking" or + self.name == "drd3_docking_normalize" or + self.name == "3pbl_docking_normalize"): from .chem_utils import PyScreener_meta @@ -596,10 +592,9 @@ def __call__(self, *args, **kwargs): self.num_called -= len(smiles_lst) raise ValueError( "The maximum number of evaluator call is reached! The maximum is: " - + str(self.num_max_call) - + ". The current requested call (plus accumulated calls) is: " - + str(self.num_called + len(smiles_lst)) - ) + + str(self.num_max_call) + + ". The current requested call (plus accumulated calls) is: " + + str(self.num_called + len(smiles_lst))) #### evaluator for single molecule, #### the input of __call__ is a single smiles OR list of smiles @@ -618,16 +613,14 @@ def __call__(self, *args, **kwargs): for smiles in smiles_lst: results_lst.append( self.normalize( - self.evaluator_func(smiles, *(args[1:]), **kwargs) - ) - ) + self.evaluator_func(smiles, *(args[1:]), + **kwargs))) else: results_lst = [] for smiles in smiles_lst: try: - results = self.evaluator_func( - [smiles], *(args[1:]), **kwargs - ) + results = self.evaluator_func([smiles], *(args[1:]), + **kwargs) results = results[0] except: results = self.default_property @@ -649,10 +642,9 @@ def __call__(self, *args, **kwargs): self.num_called -= 1 raise ValueError( "The maximum number of evaluator call is reached! The maximum is: " - + str(self.num_max_call) - + ". The current requested call (plus accumulated calls) is: " - + str(self.num_called + 1) - ) + + str(self.num_max_call) + + ". The current requested call (plus accumulated calls) is: " + + str(self.num_called + 1)) ## a single smiles if type(self.evaluator_func) == dict: diff --git a/tdc/resource/primekg.py b/tdc/resource/primekg.py index 5cbedba9..00a3a9d4 100644 --- a/tdc/resource/primekg.py +++ b/tdc/resource/primekg.py @@ -16,7 +16,6 @@ class PrimeKG: - """PrimeKG data loader class to load the knowledge graph with additional support functions.""" def __init__(self, path="./data"): @@ -32,19 +31,18 @@ def to_nx(self): G = nx.Graph() for i in self.df.relation.unique(): - G.add_edges_from( - self.df[self.df.relation == i][["x_id", "y_id"]].values, relation=i - ) + G.add_edges_from(self.df[self.df.relation == i][["x_id", + "y_id"]].values, + relation=i) return G def get_features(self, feature_type): if feature_type not in ["drug", "disease"]: raise ValueError("feature_type only supports drug/disease!") - return general_load("primekg_" + feature_type + "_feature", self.path, "\t") + return general_load("primekg_" + feature_type + "_feature", self.path, + "\t") def get_node_list(self, node_type): df = self.df - return np.unique( - df[(df.x_type == node_type)].x_id.unique().tolist() - + df[(df.y_type == node_type)].y_id.unique().tolist() - ) + return np.unique(df[(df.x_type == node_type)].x_id.unique().tolist() + + df[(df.y_type == node_type)].y_id.unique().tolist()) diff --git a/tdc/single_pred/adme.py b/tdc/single_pred/adme.py index b9cf7b0f..02b43645 100644 --- a/tdc/single_pred/adme.py +++ b/tdc/single_pred/adme.py @@ -50,9 +50,9 @@ def __init__( import pandas as pd import os - self.ppbr_df = pd.read_csv( - os.path.join(self.path, self.name + ".tab"), sep="\t" - ) + self.ppbr_df = pd.read_csv(os.path.join(self.path, + self.name + ".tab"), + sep="\t") df = self.ppbr_df[self.ppbr_df.Species == "Homo sapiens"] self.entity1 = df.Drug.values self.y = df.Y.values @@ -66,10 +66,11 @@ def get_approved_set(self): import pandas as pd if self.name not in ["pampa_ncats"]: - raise ValueError("This function is only available for PAMPA_NCATS dataset") - entity1, y, entity1_idx = property_dataset_load( - "approved_pampa_ncats", self.path, None, dataset_names["ADME"] - ) + raise ValueError( + "This function is only available for PAMPA_NCATS dataset") + entity1, y, entity1_idx = property_dataset_load("approved_pampa_ncats", + self.path, None, + dataset_names["ADME"]) return pd.DataFrame({"Drug_ID": entity1_idx, "Drug": entity1, "Y": y}) def get_other_species(self, species=None): @@ -83,7 +84,8 @@ def get_other_species(self, species=None): return self.ppbr_df if species in self.ppbr_df.Species.unique(): - return self.ppbr_df[self.ppbr_df.Species == species].reset_index(drop=True) + return self.ppbr_df[self.ppbr_df.Species == species].reset_index( + drop=True) else: raise ValueError( "You can only specify the following set of species name: 'Canis lupus familiaris', 'Cavia porcellus', 'Homo sapiens', 'Mus musculus', 'Rattus norvegicus', 'all'" @@ -99,19 +101,15 @@ def harmonize(self, mode=None): if mode == "max": df_ = self.get_data() - df = ( - df_.sort_values("Y", ascending=True) - .drop_duplicates("Drug") - .reset_index(drop=True) - ) + df = (df_.sort_values( + "Y", + ascending=True).drop_duplicates("Drug").reset_index(drop=True)) elif mode == "min": df_ = self.get_data() - df = ( - df_.sort_values("Y", ascending=False) - .drop_duplicates("Drug") - .reset_index(drop=True) - ) + df = (df_.sort_values( + "Y", + ascending=False).drop_duplicates("Drug").reset_index(drop=True)) elif mode == "remove_all": df_ = self.get_data() diff --git a/tdc/single_pred/crispr_outcome.py b/tdc/single_pred/crispr_outcome.py index 97f5966e..7f0d686f 100644 --- a/tdc/single_pred/crispr_outcome.py +++ b/tdc/single_pred/crispr_outcome.py @@ -13,7 +13,6 @@ class CRISPROutcome(single_pred_dataset.DataLoader): - """Data loader class to load datasets in CRISPROutcome task. More info: https://tdcommons.ai/single_pred_tasks/CRISPROutcome/ Args: diff --git a/tdc/single_pred/develop.py b/tdc/single_pred/develop.py index d88dc8c7..cc953298 100644 --- a/tdc/single_pred/develop.py +++ b/tdc/single_pred/develop.py @@ -13,7 +13,6 @@ class Develop(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Develop task. More info: https://tdcommons.ai/single_pred_tasks/develop/ Args: @@ -85,17 +84,20 @@ def graphein( from graphein.protein.utils import get_obsolete_mapping obs = get_obsolete_mapping() - train_obs = [t for t in split["train"]["Antibody_ID"] if t in obs.keys()] - valid_obs = [t for t in split["valid"]["Antibody_ID"] if t in obs.keys()] - test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()] - - split["train"] = split["train"].loc[ - ~split["train"]["Antibody_ID"].isin(train_obs) + train_obs = [ + t for t in split["train"]["Antibody_ID"] if t in obs.keys() ] - split["test"] = split["test"].loc[~split["test"]["Antibody_ID"].isin(test_obs)] - split["valid"] = split["valid"].loc[ - ~split["valid"]["Antibody_ID"].isin(valid_obs) + valid_obs = [ + t for t in split["valid"]["Antibody_ID"] if t in obs.keys() ] + test_obs = [t for t in split["test"]["Antibody_ID"] if t in obs.keys()] + + split["train"] = split["train"].loc[~split["train"]["Antibody_ID"]. + isin(train_obs)] + split["test"] = split["test"].loc[~split["test"]["Antibody_ID"]. + isin(test_obs)] + split["valid"] = split["valid"].loc[~split["valid"]["Antibody_ID"]. + isin(valid_obs)] self.split = split @@ -104,8 +106,7 @@ def get_label_map(split_name: str) -> Dict[str, torch.Tensor]: zip( split[split_name].Antibody_ID, split[split_name].Y.apply(torch.tensor), - ) - ) + )) train_labels = get_label_map("train") valid_labels = get_label_map("valid") diff --git a/tdc/single_pred/epitope.py b/tdc/single_pred/epitope.py index ec9448fd..cb47a0ee 100644 --- a/tdc/single_pred/epitope.py +++ b/tdc/single_pred/epitope.py @@ -13,7 +13,6 @@ class Epitope(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Epitope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/epitope/ Args: diff --git a/tdc/single_pred/hts.py b/tdc/single_pred/hts.py index b3589cee..c203d88a 100644 --- a/tdc/single_pred/hts.py +++ b/tdc/single_pred/hts.py @@ -13,7 +13,6 @@ class HTS(single_pred_dataset.DataLoader): - """Data loader class to load datasets in HTS task. More info: https://tdcommons.ai/single_pred_tasks/hts/ Args: diff --git a/tdc/single_pred/paratope.py b/tdc/single_pred/paratope.py index 8d19060c..1866f000 100644 --- a/tdc/single_pred/paratope.py +++ b/tdc/single_pred/paratope.py @@ -13,7 +13,6 @@ class Paratope(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Paratope Prediction task. More info: https://tdcommons.ai/single_pred_tasks/paratope/ Args: diff --git a/tdc/single_pred/qm.py b/tdc/single_pred/qm.py index 9e851340..75bbe0e7 100644 --- a/tdc/single_pred/qm.py +++ b/tdc/single_pred/qm.py @@ -13,7 +13,6 @@ class QM(single_pred_dataset.DataLoader): - """Data loader class to load datasets in QM (Quantum Mechanics Modeling) task. More info: https://tdcommons.ai/single_pred_tasks/qm/ Args: diff --git a/tdc/single_pred/single_pred_dataset.py b/tdc/single_pred/single_pred_dataset.py index 48523af9..8b2771ab 100644 --- a/tdc/single_pred/single_pred_dataset.py +++ b/tdc/single_pred/single_pred_dataset.py @@ -21,7 +21,6 @@ class DataLoader(base_dataset.DataLoader): - """A base data loader class. Args: @@ -65,13 +64,11 @@ def __init__( if label_name is None: raise ValueError( "Please select a label name. You can use tdc.utils.retrieve_label_name_list('" - + name.lower() - + "') to retrieve all available label names." - ) + + name.lower() + + "') to retrieve all available label names.") - entity1, y, entity1_idx = property_dataset_load( - name, path, label_name, dataset_names - ) + entity1, y, entity1_idx = property_dataset_load(name, path, label_name, + dataset_names) self.entity1 = entity1 self.y = y @@ -106,32 +103,34 @@ def get_data(self, format="df"): if format == "df": if self.convert_format is not None: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - self.entity1_name - + "_" - + self.convert_format: self.convert_result, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": + self.entity1_idx, + self.entity1_name: + self.entity1, + self.entity1_name + "_" + self.convert_format: + self.convert_result, + "Y": + self.y, + }) else: - return pd.DataFrame( - { - self.entity1_name + "_ID": self.entity1_idx, - self.entity1_name: self.entity1, - "Y": self.y, - } - ) + return pd.DataFrame({ + self.entity1_name + "_ID": self.entity1_idx, + self.entity1_name: self.entity1, + "Y": self.y, + }) elif format == "dict": if self.convert_format is not None: return { - self.entity1_name + "_ID": self.entity1_idx.values, - self.entity1_name: self.entity1.values, - self.entity1_name + "_" + self.convert_format: self.convert_result, - "Y": self.y.values, + self.entity1_name + "_ID": + self.entity1_idx.values, + self.entity1_name: + self.entity1.values, + self.entity1_name + "_" + self.convert_format: + self.convert_result, + "Y": + self.y.values, } else: return { diff --git a/tdc/single_pred/test_single_pred.py b/tdc/single_pred/test_single_pred.py index 20edc3ee..315937d9 100644 --- a/tdc/single_pred/test_single_pred.py +++ b/tdc/single_pred/test_single_pred.py @@ -13,7 +13,6 @@ class TestSinglePred(single_pred_dataset.DataLoader): - """Data loader class to test the single instance prediction data loader. Args: diff --git a/tdc/single_pred/tox.py b/tdc/single_pred/tox.py index 8ad02ca1..3d3892b5 100644 --- a/tdc/single_pred/tox.py +++ b/tdc/single_pred/tox.py @@ -13,7 +13,6 @@ class Tox(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Tox (Toxicity Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/tox/ Args: diff --git a/tdc/single_pred/yields.py b/tdc/single_pred/yields.py index 1091c83c..a4894573 100644 --- a/tdc/single_pred/yields.py +++ b/tdc/single_pred/yields.py @@ -13,7 +13,6 @@ class Yields(single_pred_dataset.DataLoader): - """Data loader class to load datasets in Yields (Reaction Yields Prediction) task. More info: https://tdcommons.ai/single_pred_tasks/yields/ Args: diff --git a/tdc/tdc_hf.py b/tdc/tdc_hf.py index 51857d22..34a00e16 100644 --- a/tdc/tdc_hf.py +++ b/tdc/tdc_hf.py @@ -4,16 +4,17 @@ deeppurpose_repo = [ 'hERG_Karim-Morgan', - 'hERG_Karim-CNN', - 'hERG_Karim-AttentiveFP', + 'hERG_Karim-CNN', + 'hERG_Karim-AttentiveFP', 'BBB_Martins-AttentiveFP', - 'BBB_Martins-Morgan', - 'BBB_Martins-CNN', - 'CYP3A4_Veith-Morgan', - 'CYP3A4_Veith-CNN', - 'CYP3A4_Veith-AttentiveFP', + 'BBB_Martins-Morgan', + 'BBB_Martins-CNN', + 'CYP3A4_Veith-Morgan', + 'CYP3A4_Veith-CNN', + 'CYP3A4_Veith-AttentiveFP', ] + class tdc_hf_interface: ''' Example use cases: @@ -25,31 +26,27 @@ class tdc_hf_interface: dp_model = tdc_hf_herg.load_deeppurpose('./data') dp_model.predict(XXX) ''' - + def __init__(self, repo_name): self.repo_id = "tdc/" + repo_name self.model_name = repo_name.split('-')[1] - + def upload(self, folder_path): create_repo(repo_id=self.repo_id) api = HfApi() - api.upload_folder( - folder_path=folder_path, - path_in_repo="model", - repo_id=self.repo_id, - repo_type="model" - ) - + api.upload_folder(folder_path=folder_path, + path_in_repo="model", + repo_id=self.repo_id, + repo_type="model") + def file_download(self, save_path, filename): - model_ckpt = hf_hub_download( - repo_id = self.repo_id, - filename = filename, - cache_dir = save_path - ) - + model_ckpt = hf_hub_download(repo_id=self.repo_id, + filename=filename, + cache_dir=save_path) + def repo_download(self, save_path): - snapshot_download(repo_id=self.repo_id, cache_dir= save_path) - + snapshot_download(repo_id=self.repo_id, cache_dir=save_path) + def load_deeppurpose(self, save_path): if self.repo_id[4:] in deeppurpose_repo: save_path = save_path + '/' + self.repo_id[4:] @@ -57,36 +54,43 @@ def load_deeppurpose(self, save_path): os.mkdir(save_path) self.file_download(save_path, "model/model.pt") self.file_download(save_path, "model/config.pkl") - - save_path = save_path + '/models--tdc--' + self.repo_id[4:] + '/blobs/' + + save_path = save_path + '/models--tdc--' + self.repo_id[ + 4:] + '/blobs/' file_name1 = save_path + os.listdir(save_path)[0] file_name2 = save_path + os.listdir(save_path)[1] - + if os.path.getsize(file_name1) > os.path.getsize(file_name2): model_file, config_file = file_name1, file_name2 else: config_file, model_file = file_name1, file_name2 os.rename(model_file, save_path + 'model.pt') - os.rename(config_file, save_path + 'config.pkl') + os.rename(config_file, save_path + 'config.pkl') try: from DeepPurpose import CompoundPred except: - raise ValueError("Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation") - - net = CompoundPred.model_pretrained(path_dir = save_path) + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) + + net = CompoundPred.model_pretrained(path_dir=save_path) return net else: raise ValueError("This repo does not host a DeepPurpose model!") + def predict_deeppurpose(self, model, drugs): try: from DeepPurpose import utils except: - raise ValueError("Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation") + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) if self.model_name == 'AttentiveFP': self.model_name = 'DGL_' + self.model_name - X_pred = utils.data_process(X_drug = drugs, y = [0]*len(drugs), - drug_encoding = self.model_name, - split_method='no_split') + X_pred = utils.data_process(X_drug=drugs, + y=[0] * len(drugs), + drug_encoding=self.model_name, + split_method='no_split') y_pred = model.predict(X_pred)[0] return y_pred diff --git a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py index cd034247..c08b847a 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molconverter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molconverter.py @@ -11,31 +11,31 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestMolConvert(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_MolConvert(self): from tdc.chem_utils import MolConvert converter = MolConvert(src="SMILES", dst="Graph2D") - converter( - [ - "Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC", - "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C", - ] - ) + converter([ + "Clc1ccccc1C2C(=C(/N/C(=C2/C(=O)OCC)COCCN)C)\C(=O)OC", + "CCCOc1cc2ncnc(Nc3ccc4ncsc4c3)c2cc1S(=O)(=O)C(C)(C)C", + ]) from tdc.chem_utils import MolConvert MolConvert.eligible_format() - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py index 9b0476c6..95bf402c 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_molfilter.py +++ b/tdc/test/dev_tests/chem_utils_test/test_molfilter.py @@ -11,22 +11,24 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestMolFilter(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_MolConvert(self): from tdc.chem_utils import MolFilter filters = MolFilter(filters=["PAINS"], HBD=[0, 6]) filters(["CCSc1ccccc1C(=O)Nc1onc2c1CCC2"]) - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/chem_utils_test/test_oracles.py b/tdc/test/dev_tests/chem_utils_test/test_oracles.py index 10dededf..884e8a55 100644 --- a/tdc/test/dev_tests/chem_utils_test/test_oracles.py +++ b/tdc/test/dev_tests/chem_utils_test/test_oracles.py @@ -11,43 +11,41 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) class TestOracle(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_Oracle(self): from tdc import Oracle oracle = Oracle(name="SA") - x = oracle( - [ - "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", - "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", - "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", - ] - ) + x = oracle([ + "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", + "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", + "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", + ]) oracle = Oracle(name="Hop") x = oracle(["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C=O"]) - @unittest.skip("dev test") + def test_distribution(self): from tdc import Evaluator evaluator = Evaluator(name="Diversity") - x = evaluator( - [ - "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", - "C[C@@H]1CCc2c(sc(NC(=O)c3ccco3)c2C(N)=O)C1", - "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", - "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", - ] - ) + x = evaluator([ + "CC(C)(C)[C@H]1CCc2c(sc(NC(=O)COc3ccc(Cl)cc3)c2C(N)=O)C1", + "C[C@@H]1CCc2c(sc(NC(=O)c3ccco3)c2C(N)=O)C1", + "CCNC(=O)c1ccc(NC(=O)N2CC[C@H](C)[C@H](O)C2)c(C)c1", + "C[C@@H]1CCN(C(=O)CCCc2ccccc2)C[C@@H]1O", + ]) def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_misc_utils.py b/tdc/test/dev_tests/utils_tests/test_misc_utils.py index 8e4ab331..c28e1bdb 100644 --- a/tdc/test/dev_tests/utils_tests/test_misc_utils.py +++ b/tdc/test/dev_tests/utils_tests/test_misc_utils.py @@ -11,15 +11,17 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + @unittest.skip("long running test") def test_neg_sample(self): from tdc.multi_pred import PPI @@ -31,7 +33,7 @@ def test_neg_sample(self): # data = ADME(name='Caco2_Wang') # x = data.label_distribution() - @unittest.skip("dev test") + def test_get_label_map(self): from tdc.multi_pred import DDI from tdc.utils import get_label_map @@ -40,26 +42,26 @@ def test_get_label_map(self): split = data.get_split() get_label_map(name="DrugBank", task="DDI") - @unittest.skip("dev test") + def test_balanced(self): from tdc.single_pred import HTS data = HTS(name="SARSCoV2_3CLPro_Diamond") data.balanced(oversample=True, seed=42) - @unittest.skip("dev test") + def test_cid2smiles(self): from tdc.utils import cid2smiles smiles = cid2smiles(2248631) - @unittest.skip("dev test") + def test_uniprot2seq(self): from tdc.utils import uniprot2seq seq = uniprot2seq("P49122") - @unittest.skip("dev test") + def test_to_graph(self): from tdc.multi_pred import DTI @@ -93,7 +95,7 @@ def test_to_graph(self): ) # output: {'pyg_graph': the PyG graph object, 'index_to_entities': a dict map from ID in the data to node ID in the PyG object, 'split': {'train': df, 'valid': df, 'test': df}} - # @unittest.skip("dev test") + # def tearDown(self): print(os.getcwd()) diff --git a/tdc/test/dev_tests/utils_tests/test_splits.py b/tdc/test/dev_tests/utils_tests/test_splits.py index 7377a723..fdf22ac1 100644 --- a/tdc/test/dev_tests/utils_tests/test_splits.py +++ b/tdc/test/dev_tests/utils_tests/test_splits.py @@ -11,22 +11,24 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass - @unittest.skip("dev test") + def test_random_split(self): from tdc.single_pred import ADME data = ADME(name="Caco2_Wang") split = data.get_split(method="random") - @unittest.skip("dev test") + def test_scaffold_split(self): ## requires RDKit from tdc.single_pred import ADME @@ -34,7 +36,7 @@ def test_scaffold_split(self): data = ADME(name="Caco2_Wang") split = data.get_split(method="scaffold") - @unittest.skip("dev test") + def test_cold_start_split(self): from tdc.multi_pred import DTI @@ -42,19 +44,24 @@ def test_cold_start_split(self): split = data.get_split(method="cold_split", column_name="Drug") self.assertEqual( - 0, len(set(split["train"]["Drug"]).intersection(set(split["test"]["Drug"]))) - ) + 0, + len( + set(split["train"]["Drug"]).intersection( + set(split["test"]["Drug"])))) self.assertEqual( - 0, len(set(split["valid"]["Drug"]).intersection(set(split["test"]["Drug"]))) - ) + 0, + len( + set(split["valid"]["Drug"]).intersection( + set(split["test"]["Drug"])))) self.assertEqual( 0, - len(set(split["train"]["Drug"]).intersection(set(split["valid"]["Drug"]))), + len( + set(split["train"]["Drug"]).intersection( + set(split["valid"]["Drug"]))), ) - multi_split = data.get_split( - method="cold_split", column_name=["Drug_ID", "Target_ID"] - ) + multi_split = data.get_split(method="cold_split", + column_name=["Drug_ID", "Target_ID"]) for entity in ["Drug_ID", "Target_ID"]: train_entity = set(multi_split["train"][entity]) valid_entity = set(multi_split["valid"][entity]) @@ -63,21 +70,21 @@ def test_cold_start_split(self): self.assertEqual(0, len(train_entity.intersection(test_entity))) self.assertEqual(0, len(valid_entity.intersection(test_entity))) - @unittest.skip("dev test") + def test_combination_split(self): from tdc.multi_pred import DrugSyn data = DrugSyn(name="DrugComb") split = data.get_split(method="combination") - @unittest.skip("dev test") + def test_time_split(self): from tdc.multi_pred import DTI data = DTI(name="BindingDB_Patent") split = data.get_split(method="time", time_column="Year") - @unittest.skip("dev test") + def test_tearDown(self): print(os.getcwd()) diff --git a/tdc/test/test_benchmark.py b/tdc/test/test_benchmark.py index b4b78083..90e16272 100644 --- a/tdc/test/test_benchmark.py +++ b/tdc/test/test_benchmark.py @@ -5,7 +5,8 @@ import sys import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from tdc.benchmark_group import admet_group @@ -17,6 +18,7 @@ def is_classification(values): class TestBenchmarkGroup(unittest.TestCase): + def setUp(self): self.group = admet_group(path="data/") @@ -52,11 +54,14 @@ def test_ADME_evaluate_many(self): for ds_name, metrics in results.items(): self.assertEqual(len(metrics), 2) u, std = metrics - self.assertTrue(u in (1, 0)) # A perfect score for all metrics is 1 or 0 + self.assertTrue(u + in (1, + 0)) # A perfect score for all metrics is 1 or 0 self.assertEqual(0, std) for my_group in self.group: self.assertTrue(my_group["name"] in results) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tdc/test/test_dataloaders.py b/tdc/test/test_dataloaders.py index c5de2fa9..6e04d92a 100644 --- a/tdc/test/test_dataloaders.py +++ b/tdc/test/test_dataloaders.py @@ -11,11 +11,13 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) # TODO: add verification for the generation other than simple integration class TestDataloader(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -42,5 +44,6 @@ def tearDown(self): print(os.getcwd()) shutil.rmtree(os.path.join(os.getcwd(), "data")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tdc/test/test_functions.py b/tdc/test/test_functions.py index e713512e..40f0b976 100644 --- a/tdc/test/test_functions.py +++ b/tdc/test/test_functions.py @@ -11,10 +11,12 @@ # temporary solution for relative imports in case TDC is not installed # if TDC is installed, no need to use the following line -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) class TestFunctions(unittest.TestCase): + def setUp(self): print(os.getcwd()) pass @@ -51,5 +53,6 @@ def tearDown(self): if os.path.exists(os.path.join(os.getcwd(), "oracle")): shutil.rmtree(os.path.join(os.getcwd(), "oracle")) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tdc/utils/label.py b/tdc/utils/label.py index a73829ad..b875d882 100644 --- a/tdc/utils/label.py +++ b/tdc/utils/label.py @@ -21,7 +21,7 @@ def convert_y_unit(y, from_, to_): if from_ == "nM": y = y elif from_ == "p": - y = (10 ** (-y) - 1e-10) / 1e-9 + y = (10**(-y) - 1e-10) / 1e-9 if to_ == "p": y = -np.log10(y * 1e-9 + 1e-10) @@ -31,9 +31,12 @@ def convert_y_unit(y, from_, to_): return y -def label_transform( - y, binary, threshold, convert_to_log, verbose=True, order="descending" -): +def label_transform(y, + binary, + threshold, + convert_to_log, + verbose=True, + order="descending"): """label transformation helper function Args: @@ -62,7 +65,8 @@ def label_transform( elif order == "ascending": y = np.array([1 if i else 0 for i in np.array(y) > threshold]) else: - raise ValueError("Please select order from 'descending or ascending!") + raise ValueError( + "Please select order from 'descending or ascending!") else: if (len(np.unique(y)) > 2) and convert_to_log: if verbose: @@ -144,16 +148,16 @@ def label_dist(y, name=None): median = np.median(y) mean = np.mean(y) - f, (ax_box, ax_hist) = plt.subplots( - 2, sharex=True, gridspec_kw={"height_ratios": (0.15, 1)} - ) + f, (ax_box, + ax_hist) = plt.subplots(2, + sharex=True, + gridspec_kw={"height_ratios": (0.15, 1)}) if name is None: sns.boxplot(y, ax=ax_box).set_title("Label Distribution") else: - sns.boxplot(y, ax=ax_box).set_title( - "Label Distribution of " + str(name) + " Dataset" - ) + sns.boxplot(y, ax=ax_box).set_title("Label Distribution of " + + str(name) + " Dataset") ax_box.axvline(median, color="b", linestyle="--") ax_box.axvline(mean, color="g", linestyle="--") @@ -191,7 +195,8 @@ def NegSample(df, column_names, frac, two_types): pos_set = set([tuple([i[0], i[1]]) for i in pos]) np.random.seed(1234) samples = np.random.choice(df_unique, size=(x, 2), replace=True) - neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1]]) - pos_set + neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1] + ]) - pos_set while len(neg_set) < x: sample = np.random.choice(df_unique, 2, replace=False) @@ -207,11 +212,17 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq[i[0]], i[1], id2seq[i[1]], 0]) - df = df.append( - pd.DataFrame(neg_list_val).rename( - columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"} - ) - ).reset_index(drop=True) + df = pd.concat([ + df, + pd.DataFrame(neg_list_val).rename(columns={ + 0: id1, + 1: x1, + 2: id2, + 3: x2, + 4: "Y" + }) + ], + ignore_index=True, sort=False) return df else: df_unique_id1 = np.unique(df[id1].values.reshape(-1)) @@ -224,16 +235,11 @@ def NegSample(df, column_names, frac, two_types): sample_id1 = np.random.choice(df_unique_id1, size=len(df), replace=True) sample_id2 = np.random.choice(df_unique_id2, size=len(df), replace=True) - neg_set = ( - set( - [ - tuple([sample_id1[i], sample_id2[i]]) - for i in range(len(df)) - if sample_id1[i] != sample_id2[i] - ] - ) - - pos_set - ) + neg_set = (set([ + tuple([sample_id1[i], sample_id2[i]]) + for i in range(len(df)) + if sample_id1[i] != sample_id2[i] + ]) - pos_set) while len(neg_set) < len(df): sample_id1 = np.random.choice(df_unique_id1, size=1, replace=True) @@ -251,9 +257,15 @@ def NegSample(df, column_names, frac, two_types): for i in neg_list: neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0]) - df = df.append( - pd.DataFrame(neg_list_val).rename( - columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"} - ) - ).reset_index(drop=True) + df = pd.concat([ + df, + pd.DataFrame(neg_list_val).rename(columns={ + 0: id1, + 1: x1, + 2: id2, + 3: x2, + 4: "Y" + }) + ], + ignore_index=True, sort=False) return df diff --git a/tdc/utils/label_name_list.py b/tdc/utils/label_name_list.py index e164d91c..24a433dc 100644 --- a/tdc/utils/label_name_list.py +++ b/tdc/utils/label_name_list.py @@ -636,11 +636,9 @@ "Tanguay_ZF_120hpf_YSE_up", ] - QM7_targets = ["Y"] # QM7_targets = ["E_PBE0", "E_max_EINDO", "I_max_ZINDO", "HOMO_ZINDO", "LUMO_ZINDO", "E_1st_ZINDO", "IP_ZINDO", "EA_ZINDO", "HOMO_PBE0", "LUMO_PBE0", "HOMO_GW", "LUMO_GW", "alpha_PBE0", "alpha_SCS"] - #### qm7b: 14 labels QM7b_targets = [ "AE_PBE0", @@ -683,7 +681,6 @@ "f1-CAM", ] - # QM9_targets = [ # "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "cv", "u0", "u298", # "h298", "g298" diff --git a/tdc/utils/load.py b/tdc/utils/load.py index 82b348e4..ea88283a 100644 --- a/tdc/utils/load.py +++ b/tdc/utils/load.py @@ -68,12 +68,16 @@ def download_wrapper(name, path, dataset_names): os.mkdir(path) if os.path.exists( - os.path.join(path, name + "-" + str(i + 1) + "." + name2type[name]) - ): + os.path.join( + path, name + "-" + str(i + 1) + "." + name2type[name])): print_sys("Found local copy...") else: print_sys("Downloading...") - dataverse_download(dataset_path, path, name, name2type, id=i + 1) + dataverse_download(dataset_path, + path, + name, + name2type, + id=i + 1) return name else: @@ -118,11 +122,16 @@ def zip_data_download_wrapper(name, path, dataset_names): ) else: print_sys(f"Downloading {i+1}/{len(name2idlist[name])} file...") - dataverse_download(dataset_path, path, name, name2type, id=i + 1) - print_sys(f"Extracting zip {i+1}/{len(name2idlist[name])} file...") + dataverse_download(dataset_path, + path, + name, + name2type, + id=i + 1) + print_sys( + f"Extracting zip {i+1}/{len(name2idlist[name])} file...") with ZipFile( - os.path.join(path, name + "-" + str(i + 1) + ".zip"), "r" - ) as zip: + os.path.join(path, name + "-" + str(i + 1) + ".zip"), + "r") as zip: zip.extractall(path=os.path.join(path)) if not os.path.exists(os.path.join(path, name)): os.mkdir(os.path.join(path, name)) @@ -196,7 +205,8 @@ def oracle_download_wrapper(name, path, oracle_names): print_sys("Found local copy...") else: print_sys("Downloading Oracle...") - dataverse_download(dataset_path, path, name, oracle2type) ## to-do to-check + dataverse_download(dataset_path, path, name, + oracle2type) ## to-do to-check print_sys("Done!") return name @@ -222,19 +232,16 @@ def receptor_download_wrapper(name, path): os.mkdir(path) if os.path.exists(os.path.join(path, name + ".pdbqt")) and os.path.exists( - os.path.join(path, name + ".pdb") - ): + os.path.join(path, name + ".pdb")): print_sys("Found local copy...") else: print_sys("Downloading receptor...") receptor2type = defaultdict(lambda: "pdbqt") - dataverse_download( - dataset_paths[0], path, name, receptor2type - ) ## to-do to-check + dataverse_download(dataset_paths[0], path, name, + receptor2type) ## to-do to-check receptor2type = defaultdict(lambda: "pdb") - dataverse_download( - dataset_paths[1], path, name, receptor2type - ) ## to-do to-check + dataverse_download(dataset_paths[1], path, name, + receptor2type) ## to-do to-check print_sys("Done!") return name @@ -284,11 +291,13 @@ def pd_load(name, path): """ try: if name2type[name] == "tab": - df = pd.read_csv(os.path.join(path, name + "." + name2type[name]), sep="\t") + df = pd.read_csv(os.path.join(path, name + "." + name2type[name]), + sep="\t") elif name2type[name] == "csv": df = pd.read_csv(os.path.join(path, name + "." + name2type[name])) elif name2type[name] == "pkl": - df = pd.read_pickle(os.path.join(path, name + "." + name2type[name])) + df = pd.read_pickle(os.path.join(path, + name + "." + name2type[name])) elif name2type[name] == "zip": df = pd.read_pickle(os.path.join(path, name + "/" + name + ".pkl")) else: @@ -328,7 +337,8 @@ def property_dataset_load(name, path, target, dataset_names): target = fuzzy_search(target, df.columns.values) # df = df.T.drop_duplicates().T ### does not work # df2 = df.loc[:,~df.T.duplicated(keep='first')] ### does not work - df2 = df.loc[:, ~df.columns.duplicated()] ### remove the duplicate columns + df2 = df.loc[:, + ~df.columns.duplicated()] ### remove the duplicate columns df = df2 df = df[df[target].notnull()].reset_index(drop=True) except: @@ -337,8 +347,8 @@ def property_dataset_load(name, path, target, dataset_names): import pickle file_content = pickle.load( - open(os.path.join(path, name + "." + name2type[name]), "rb") - ) + open(os.path.join(path, name + "." + name2type[name]), + "rb")) else: file_content = " ".join(f.readlines()) flag = "Service Unavailable" in " ".join(file_content) @@ -352,7 +362,8 @@ def property_dataset_load(name, path, target, dataset_names): else: import sys - sys.exit("Please report this error to contact@tdcommons.ai, thanks!") + sys.exit( + "Please report this error to contact@tdcommons.ai, thanks!") try: return df["X"], df[target], df["ID"] except: @@ -386,7 +397,8 @@ def interaction_dataset_load(name, path, target, dataset_names, aux_column): if aux_column is None: return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], "_" else: - return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], df[aux_column] + return df["X1"], df["X2"], df[target], df["ID1"], df["ID2"], df[ + aux_column] except: with open(os.path.join(path, name + "." + name2type[name]), "r") as f: @@ -400,7 +412,8 @@ def interaction_dataset_load(name, path, target, dataset_names, aux_column): else: import sys - sys.exit("Please report this error to cosamhkx@gmail.com, thanks!") + sys.exit( + "Please report this error to cosamhkx@gmail.com, thanks!") def multi_dataset_load(name, path, dataset_names): @@ -421,7 +434,8 @@ def multi_dataset_load(name, path, dataset_names): return df -def generation_paired_dataset_load(name, path, dataset_names, input_name, output_name): +def generation_paired_dataset_load(name, path, dataset_names, input_name, + output_name): """a wrapper to download, process and load generation-paired task datasets Args: @@ -502,26 +516,26 @@ def bi_distribution_dataset_load( if name == "pdbbind": print_sys("Processing (this may take long)...") - protein, ligand = process_pdbbind( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_pdbbind(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "dude": print_sys("Processing (this may take long)...") if return_pocket: raise ImportError("DUD-E does not support pocket extraction yet") - protein, ligand = process_dude( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_dude(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "scpdb": print_sys("Processing (this may take long)...") - protein, ligand = process_scpdb( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_scpdb(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) elif name == "crossdock": print_sys("Processing (this may take long)...") - protein, ligand = process_crossdock( - path, name, return_pocket, remove_protein_Hs, remove_ligand_Hs, keep_het - ) + protein, ligand = process_crossdock(path, name, return_pocket, + remove_protein_Hs, remove_ligand_Hs, + keep_het) return protein, ligand @@ -631,15 +645,13 @@ def process_pdbbind( try: if return_pocket: protein = PandasPdb().read_pdb( - os.path.join(path, f"{file}/{file}_pocket.pdb") - ) + os.path.join(path, f"{file}/{file}_pocket.pdb")) else: protein = PandasPdb().read_pdb( - os.path.join(path, f"{file}/{file}_protein.pdb") - ) - ligand = Chem.SDMolSupplier( - os.path.join(path, f"{file}/{file}_ligand.sdf"), sanitize=False - )[0] + os.path.join(path, f"{file}/{file}_protein.pdb")) + ligand = Chem.SDMolSupplier(os.path.join( + path, f"{file}/{file}_ligand.sdf"), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_ligand_Hs) # if ligand contains unallowed atoms if ligand is None: @@ -716,17 +728,16 @@ def process_crossdock( else: # full protein not stored in the preprocessed crossdock by Luo et al 2021 protein = PandasPdb().read_pdb(os.path.join(path, pocket_fn)) - ligand = Chem.SDMolSupplier(os.path.join(path, ligand_fn), sanitize=False)[ - 0 - ] + ligand = Chem.SDMolSupplier(os.path.join(path, ligand_fn), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_ligand_Hs) if ligand is None: continue else: ligand_coord, ligand_atom_type = ligand protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, keep_het - ) + protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, + keep_het) protein_coords.append(protein_coord) ligand_coords.append(ligand_coord) protein_atom_types.append(protein_atom_type) @@ -778,23 +789,24 @@ def process_dude( failure = 0 total_ct = 0 for idx, file in enumerate(tqdm(files)): - protein = PandasPdb().read_pdb(os.path.join(path, f"{file}/receptor.pdb")) + protein = PandasPdb().read_pdb( + os.path.join(path, f"{file}/receptor.pdb")) if not os.path.exists(os.path.join(path, f"{file}/actives_final.sdf")): os.system(f"gzip -d {path}/{file}/actives_final.sdf.gz") - crystal_ligand = Chem.MolFromMol2File( - os.path.join(path, f"{file}/crystal_ligand.mol2"), sanitize=False - ) + crystal_ligand = Chem.MolFromMol2File(os.path.join( + path, f"{file}/crystal_ligand.mol2"), + sanitize=False) crystal_ligand = extract_atom_from_mol(crystal_ligand, remove_ligand_Hs) if crystal_ligand is None: continue else: crystal_ligand_coord, crystal_ligand_atom_type = crystal_ligand - ligands = Chem.SDMolSupplier( - os.path.join(path, f"{file}/actives_final.sdf"), sanitize=False - ) + ligands = Chem.SDMolSupplier(os.path.join(path, + f"{file}/actives_final.sdf"), + sanitize=False) protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, keep_het - ) + protein.df["ATOM"], protein.df["HETATM"], remove_protein_Hs, + keep_het) protein_coords.append(protein_coord) ligand_coords.append(crystal_ligand_coord) protein_atom_types.append(protein_atom_type) @@ -863,15 +875,13 @@ def process_scpdb( try: if return_pocket: protein = PandasMol2().read_mol2( - os.path.join(path, f"{file}/site.mol2") - ) + os.path.join(path, f"{file}/site.mol2")) else: protein = PandasMol2().read_mol2( - os.path.join(path, f"{file}/protein.mol2") - ) - ligand = Chem.SDMolSupplier( - os.path.join(path, f"{file}/ligand.sdf"), sanitize=False - )[0] + os.path.join(path, f"{file}/protein.mol2")) + ligand = Chem.SDMolSupplier(os.path.join(path, + f"{file}/ligand.sdf"), + sanitize=False)[0] ligand = extract_atom_from_mol(ligand, remove_Hs=remove_ligand_Hs) # if ligand contains unallowed atoms if ligand is None: @@ -879,8 +889,7 @@ def process_scpdb( else: ligand_coord, ligand_atom_type = ligand protein_coord, protein_atom_type = extract_atom_from_protein( - protein.df, None, remove_Hs=remove_protein_Hs, keep_het=False - ) + protein.df, None, remove_Hs=remove_protein_Hs, keep_het=False) protein_coords.append(protein_coord) ligand_coords.append(ligand_coord) protein_atom_types.append(protein_atom_type) @@ -957,23 +966,15 @@ def extract_atom_from_protein(data_frame, data_frame_het, remove_Hs, keep_het): if keep_het and data_frame_het is not None: data_frame = pd.concat([data_frame, data_frame_het]) if remove_Hs: - data_frame = data_frame[data_frame["atom_name"].str.startswith("H") == False] + data_frame = data_frame[data_frame["atom_name"].str.startswith("H") == + False] data_frame.reset_index(inplace=True, drop=True) - x = ( - data_frame["x_coord"].to_numpy() - if "x_coord" in data_frame - else data_frame["x"].to_numpy() - ) - y = ( - data_frame["y_coord"].to_numpy() - if "y_coord" in data_frame - else data_frame["y"].to_numpy() - ) - z = ( - data_frame["z_coord"].to_numpy() - if "z_coord" in data_frame - else data_frame["z"].to_numpy() - ) + x = (data_frame["x_coord"].to_numpy() + if "x_coord" in data_frame else data_frame["x"].to_numpy()) + y = (data_frame["y_coord"].to_numpy() + if "y_coord" in data_frame else data_frame["y"].to_numpy()) + z = (data_frame["z_coord"].to_numpy() + if "z_coord" in data_frame else data_frame["z"].to_numpy()) x = np.expand_dims(x, axis=1) y = np.expand_dims(y, axis=1) z = np.expand_dims(z, axis=1) diff --git a/tdc/utils/misc.py b/tdc/utils/misc.py index c7f65abb..57dc9dc1 100644 --- a/tdc/utils/misc.py +++ b/tdc/utils/misc.py @@ -33,7 +33,8 @@ def fuzzy_search(name, dataset_names): return s else: raise ValueError( - s + " does not belong to this task, please refer to the correct task name!" + s + + " does not belong to this task, please refer to the correct task name!" ) @@ -56,7 +57,9 @@ def get_closet_match(predefined_tokens, test_token, threshold=0.8): for token in predefined_tokens: # print(token) - prob_list.append(fuzz.ratio(str(token).lower(), str(test_token).lower())) + prob_list.append(fuzz.ratio( + str(token).lower(), + str(test_token).lower())) assert len(prob_list) == len(predefined_tokens) @@ -67,8 +70,8 @@ def get_closet_match(predefined_tokens, test_token, threshold=0.8): if prob_max / 100 < threshold: print_sys(predefined_tokens) raise ValueError( - test_token, "does not match to available values. " "Please double check." - ) + test_token, "does not match to available values. " + "Please double check.") return token_max, prob_max / 100 diff --git a/tdc/utils/query.py b/tdc/utils/query.py index 1d450844..a4dede37 100644 --- a/tdc/utils/query.py +++ b/tdc/utils/query.py @@ -15,7 +15,8 @@ def _parse_prop(search, proplist): """Extract property value from record using the given urn search filter.""" props = [ - i for i in proplist if all(item in i["urn"].items() for item in search.items()) + i for i in proplist + if all(item in i["urn"].items() for item in search.items()) ] if len(props) > 0: return props[0]["value"][list(props[0]["value"].keys())[0]] @@ -48,18 +49,15 @@ def request( urlid, postdata = None, None if namespace == "sourceid": identifier = identifier.replace("/", ".") - if ( - namespace in ["listkey", "formula", "sourceid"] - or searchtype == "xref" - or (searchtype and namespace == "cid") - or domain == "sources" - ): + if (namespace in ["listkey", "formula", "sourceid"] or + searchtype == "xref" or (searchtype and namespace == "cid") or + domain == "sources"): urlid = quote(identifier.encode("utf8")) else: postdata = urlencode([(namespace, identifier)]).encode("utf8") comps = filter( - None, [API_BASE, domain, searchtype, namespace, urlid, operation, output] - ) + None, + [API_BASE, domain, searchtype, namespace, urlid, operation, output]) apiurl = "/".join(comps) # Make request response = urlopen(apiurl, postdata) @@ -99,8 +97,12 @@ def cid2smiles(cid): """ try: smiles = _parse_prop( - {"label": "SMILES", "name": "Canonical"}, - json.loads(request(cid).read().decode())["PC_Compounds"][0]["props"], + { + "label": "SMILES", + "name": "Canonical" + }, + json.loads( + request(cid).read().decode())["PC_Compounds"][0]["props"], ) except: print("cid " + str(cid) + " failed, use NULL string") diff --git a/tdc/utils/retrieve.py b/tdc/utils/retrieve.py index 429cfbab..a50c8bca 100644 --- a/tdc/utils/retrieve.py +++ b/tdc/utils/retrieve.py @@ -73,7 +73,8 @@ def get_reaction_type(name, path="./data", output_format="array"): elif output_format == "array": return df["category"].values else: - raise ValueError("Please use the correct output format, select from df, array.") + raise ValueError( + "Please use the correct output format, select from df, array.") def retrieve_label_name_list(name): diff --git a/tdc/utils/split.py b/tdc/utils/split.py index c8f518cc..c1eef453 100644 --- a/tdc/utils/split.py +++ b/tdc/utils/split.py @@ -21,9 +21,9 @@ def create_fold(df, fold_seed, frac): train_frac, val_frac, test_frac = frac test = df.sample(frac=test_frac, replace=False, random_state=fold_seed) train_val = df[~df.index.isin(test.index)] - val = train_val.sample( - frac=val_frac / (1 - test_frac), replace=False, random_state=1 - ) + val = train_val.sample(frac=val_frac / (1 - test_frac), + replace=False, + random_state=1) train = train_val[~train_val.index.isin(val.index)] return { @@ -54,10 +54,9 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): # For each entity, sample the instances belonging to the test datasets test_entity_instances = [ - df[e] - .drop_duplicates() - .sample(frac=test_frac, replace=False, random_state=fold_seed) - .values + df[e].drop_duplicates().sample(frac=test_frac, + replace=False, + random_state=fold_seed).values for e in entities ] @@ -69,8 +68,7 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): if len(test) == 0: raise ValueError( "No test samples found. Try another seed, increasing the test frac or a " - "less stringent splitting strategy." - ) + "less stringent splitting strategy.") # Proceed with validation data train_val = df.copy() @@ -78,10 +76,9 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): train_val = train_val[~train_val[e].isin(test_entity_instances[i])] val_entity_instances = [ - train_val[e] - .drop_duplicates() - .sample(frac=val_frac / (1 - test_frac), replace=False, random_state=fold_seed) - .values + train_val[e].drop_duplicates().sample(frac=val_frac / (1 - test_frac), + replace=False, + random_state=fold_seed).values for e in entities ] val = train_val.copy() @@ -91,8 +88,7 @@ def create_fold_setting_cold(df, fold_seed, frac, entities): if len(val) == 0: raise ValueError( "No validation samples found. Try another seed, increasing the test frac " - "or a less stringent splitting strategy." - ) + "or a less stringent splitting strategy.") train = train_val.copy() for i, e in enumerate(entities): @@ -127,8 +123,7 @@ def create_scaffold_split(df, seed, frac, entity): RDLogger.DisableLog("rdApp.*") except: raise ImportError( - "Please install rdkit by 'conda install -c conda-forge rdkit'! " - ) + "Please install rdkit by 'conda install -c conda-forge rdkit'! ") from tqdm import tqdm from random import Random @@ -144,8 +139,7 @@ def create_scaffold_split(df, seed, frac, entity): for i, smiles in tqdm(enumerate(s), total=len(s)): try: scaffold = MurckoScaffold.MurckoScaffoldSmiles( - mol=Chem.MolFromSmiles(smiles), includeChirality=False - ) + mol=Chem.MolFromSmiles(smiles), includeChirality=False) scaffolds[scaffold].add(i) except: print_sys(smiles + " returns RDKit error and is thus omitted...") @@ -213,9 +207,9 @@ def create_combination_generation_split(dict1, dict2, seed, frac): length = len(dict1["coord"]) indices = np.random.permutation(length) train_idx, val_idx, test_idx = ( - indices[: int(length * train_frac)], - indices[int(length * train_frac) : int(length * (train_frac + val_frac))], - indices[int(length * (train_frac + val_frac)) :], + indices[:int(length * train_frac)], + indices[int(length * train_frac):int(length * (train_frac + val_frac))], + indices[int(length * (train_frac + val_frac)):], ) return { @@ -272,9 +266,10 @@ def create_combination_split(df, seed, frac): intxn = intxn.intersection(c) # Split combinations into train, val and test - test_choices = np.random.choice( - list(intxn), int(test_size / len(df["Cell_Line_ID"].unique())), replace=False - ) + test_choices = np.random.choice(list(intxn), + int(test_size / + len(df["Cell_Line_ID"].unique())), + replace=False) trainval_intxn = intxn.difference(test_choices) val_choices = np.random.choice( list(trainval_intxn), @@ -312,15 +307,18 @@ def create_fold_time(df, frac, date_column): df = df.sort_values(by=date_column).reset_index(drop=True) train_frac, val_frac, test_frac = frac[0], frac[1], frac[2] - split_date = df[: int(len(df) * (train_frac + val_frac))].iloc[-1][date_column] + split_date = df[:int(len(df) * + (train_frac + val_frac))].iloc[-1][date_column] test = df[df[date_column] >= split_date].reset_index(drop=True) train_val = df[df[date_column] < split_date] - split_date_valid = train_val[ - : int(len(train_val) * train_frac / (train_frac + val_frac)) - ].iloc[-1][date_column] - train = train_val[train_val[date_column] <= split_date_valid].reset_index(drop=True) - valid = train_val[train_val[date_column] > split_date_valid].reset_index(drop=True) + split_date_valid = train_val[:int( + len(train_val) * train_frac / + (train_frac + val_frac))].iloc[-1][date_column] + train = train_val[train_val[date_column] <= split_date_valid].reset_index( + drop=True) + valid = train_val[train_val[date_column] > split_date_valid].reset_index( + drop=True) return { "train": train,