From 01d3a378553db7eddd5cb8e2ec3918194893026b Mon Sep 17 00:00:00 2001 From: joegilkes Date: Thu, 11 Jan 2024 16:53:02 +0000 Subject: [PATCH] Fixed problems with separate test datasets --- KPM/cli/test_args.py | 4 ++-- KPM/train.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/KPM/cli/test_args.py b/KPM/cli/test_args.py index 19edfd6..425fdff 100644 --- a/KPM/cli/test_args.py +++ b/KPM/cli/test_args.py @@ -40,8 +40,8 @@ def run(args): tester = ModelTester(args.model, args.dataset, args.num_reacs, args.plot_dir, args.verbose) X_test, y_test = tester.process_test_data() - Eact_pred_test = tester.predict(X_test, y_test, 'test') - tester.plot_correlation(y_test, Eact_pred_test, 'test') + Eact_pred_test, Eact_uncert_test = tester.predict(X_test, y_test, 'test') + tester.plot_correlation(y_test, Eact_pred_test, Eact_uncert_test, 'test') print('KPM finished.') input('Press ENTER to close.') diff --git a/KPM/train.py b/KPM/train.py index ac334b2..c77249c 100644 --- a/KPM/train.py +++ b/KPM/train.py @@ -150,8 +150,8 @@ def process(self): ea_test, dh_test, rs_test, ps_test = load_dataset(self.separate_test_dataset) # Extract and transform data dependent on train_direction. - num_train_reacs = len(ea_train) - num_test_reacs = len(ea_test) + num_train_reacs = len(ea_train) * (2 if self.train_direction=='both' else 1) + num_test_reacs = len(ea_test) * (2 if self.train_direction=='both' else 1) Eact_train, dH_train, rmol_train, pmol_train = extract_data(ea_train, dh_train, rs_train, ps_train, num_train_reacs, self.train_direction) Eact_test, dH_test, rmol_test, pmol_test = extract_data(ea_test, dh_test, rs_test, ps_test, @@ -376,7 +376,8 @@ def process_test_data(self): print(f'Total number of MOL objects = {len(rmol)}\n') # Normalise Eact - Eact = normalise(Eact, self.norm_avg_Eact, self.norm_std_Eact, self.norm_type) + if self.norm_eacts: + Eact = normalise(Eact, self.norm_avg_Eact, self.norm_std_Eact, self.norm_type) if self.verbose: print('Data loaded. Calculating reaction difference fingerprints.')