Skip to content

Commit

Permalink
Various fixes to main script
Browse files Browse the repository at this point in the history
  • Loading branch information
HealthyPear committed Mar 30, 2021
1 parent 04ba81a commit 4a16c42
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions protopipe/scripts/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main():
parser.add_argument(
"--max_events",
type=int,
default=-1,
default=None,
help="maximum number of events for training",
)
mode_group = parser.add_mutually_exclusive_group()
Expand Down Expand Up @@ -157,7 +157,7 @@ def main():
if args.infile_background is None:
data_bkg_file = cfg["General"]["data_bkg_file"].format(args.mode)
else:
data_sig_file = args.infile_background
data_bkg_file = args.infile_background

filename_sig = path.join(data_dir, data_sig_file)
filename_bkg = path.join(data_dir, data_bkg_file)
Expand All @@ -167,7 +167,9 @@ def main():
cam_ids = cfg["General"]["cam_id_list"]
elif args.cameras_from_file:
print("TAKING CAMERAS FROM TRAINING FILE")
cam_ids = get_camera_names(filename)
# in the same analysis all particle types are analyzed in the
# same way so we can just use gammas
cam_ids = get_camera_names(filename_sig)
else:
print("TAKING CAMERAS FROM CLI")
cam_ids = args.cam_id_lists.split()
Expand Down Expand Up @@ -212,7 +214,9 @@ def main():
if model_type in "regressor":
# Load data
data = pd.read_hdf(filename, table_name[idx], mode="r")
data = prepare_data(ds=data, cuts=cuts)[0 : args.max_events]
data = prepare_data(ds=data, cuts=cuts)[0:args.max_events]

print(f"Going to split {len(data)} SIGNAL images...")

# Init model factory
factory = TrainModel(
Expand All @@ -232,8 +236,11 @@ def main():
data_sig = prepare_data(ds=data_sig, label=1, cuts=sig_cuts)
data_bkg = prepare_data(ds=data_bkg, label=0, cuts=bkg_cuts)

data_sig = data_sig[0 : args.max_events]
data_bkg = data_bkg[0 : args.max_events]
if args.max_events:
data_sig = data_sig[0:(args.max_events - 1)]
data_bkg = data_bkg[0:(args.max_events - 1)]

print(f"Going to split {len(data_sig)} SIGNAL images and {len(data_bkg)} BACKGROUND images")

# Init model factory
factory = TrainModel(
Expand Down

0 comments on commit 4a16c42

Please sign in to comment.