diff --git a/scripts/SAUCIE.py b/scripts/SAUCIE.py index 4aa81b7..00ef45e 100644 --- a/scripts/SAUCIE.py +++ b/scripts/SAUCIE.py @@ -109,7 +109,8 @@ def train_batch_correction(rawfiles): print("Training model {}".format(counter)) nonrefx = get_data(nonref) - alldata = np.concatenate([refx.as_matrix(), nonrefx.as_matrix()], axis=0) + # alldata = np.concatenate([refx.as_matrix(), nonrefx.as_matrix()], axis=0) + alldata = np.concatenate([np.array(refx), np.array(nonrefx)], axis = 0) alllabels = np.concatenate([np.zeros(refx.shape[0]), np.ones(nonrefx.shape[0])], axis=0) load = SAUCIE.Loader(data=alldata, labels=alllabels, shuffle=True) @@ -146,7 +147,8 @@ def output_batch_correction(rawfiles): print("Outputing file {}".format(counter)) nonrefx = get_data(nonref) - alldata = np.concatenate([refx.as_matrix(), nonrefx.as_matrix()], axis=0) + # alldata = np.concatenate([refx.as_matrix(), nonrefx.as_matrix()], axis=0) + alldata = np.concatenate([np.array(refx), np.array(nonrefx)], axis = 0) alllabels = np.concatenate([np.zeros(refx.shape[0]), np.ones(nonrefx.shape[0])], axis=0) load = SAUCIE.Loader(data=alldata, labels=alllabels, shuffle=False) @@ -371,4 +373,4 @@ def parse_args(): - print("Finished training models and outputing data!") \ No newline at end of file + print("Finished training models and outputing data!")