Skip to content

Commit

Permalink
Fixed problems with separate test datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
joegilkes committed Jan 11, 2024
1 parent 73cda55 commit 01d3a37
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions KPM/cli/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def run(args):
tester = ModelTester(args.model, args.dataset, args.num_reacs, args.plot_dir, args.verbose)

X_test, y_test = tester.process_test_data()
Eact_pred_test = tester.predict(X_test, y_test, 'test')
tester.plot_correlation(y_test, Eact_pred_test, 'test')
Eact_pred_test, Eact_uncert_test = tester.predict(X_test, y_test, 'test')
tester.plot_correlation(y_test, Eact_pred_test, Eact_uncert_test, 'test')

print('KPM finished.')
input('Press ENTER to close.')
Expand Down
7 changes: 4 additions & 3 deletions KPM/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def process(self):
ea_test, dh_test, rs_test, ps_test = load_dataset(self.separate_test_dataset)

# Extract and transform data dependent on train_direction.
num_train_reacs = len(ea_train)
num_test_reacs = len(ea_test)
num_train_reacs = len(ea_train) * (2 if self.train_direction=='both' else 1)
num_test_reacs = len(ea_test) * (2 if self.train_direction=='both' else 1)
Eact_train, dH_train, rmol_train, pmol_train = extract_data(ea_train, dh_train, rs_train, ps_train,
num_train_reacs, self.train_direction)
Eact_test, dH_test, rmol_test, pmol_test = extract_data(ea_test, dh_test, rs_test, ps_test,
Expand Down Expand Up @@ -376,7 +376,8 @@ def process_test_data(self):
print(f'Total number of MOL objects = {len(rmol)}\n')

# Normalise Eact
Eact = normalise(Eact, self.norm_avg_Eact, self.norm_std_Eact, self.norm_type)
if self.norm_eacts:
Eact = normalise(Eact, self.norm_avg_Eact, self.norm_std_Eact, self.norm_type)

if self.verbose: print('Data loaded. Calculating reaction difference fingerprints.')

Expand Down

0 comments on commit 01d3a37

Please sign in to comment.