Skip to content

Commit

Permalink
Merge branch 'main' into keep-nest-oustide-validity
Browse files Browse the repository at this point in the history
  • Loading branch information
HenningSE committed Jun 24, 2024
2 parents cf5bbaf + cd0f0fe commit 52999c9
Showing 1 changed file with 66 additions and 27 deletions.
93 changes: 66 additions & 27 deletions fuse/plugins/micro_physics/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import straxen

from ...dtypes import g4_fields, primary_positions_fields, deposit_positions_fields
from ...common import full_array_to_numpy, reshape_awkward, dynamic_chunking
from ...common import full_array_to_numpy, reshape_awkward, dynamic_chunking, awkward_to_flat_numpy
from ...plugin import FuseBasePlugin

export, __all__ = strax.exporter()
Expand Down Expand Up @@ -240,20 +240,31 @@ def output_chunk(self):
if self.event_rate > 0:
interactions["t"] = interactions["t"] - interactions["t"][:, 0]

inter_reshaped = full_array_to_numpy(interactions, self.dtype)
# Get the interaction times into flat numpy array
interaction_time = awkward_to_flat_numpy(interactions["t"])

# Need to check start and stop again....
# Remove interactions that happen way after the run ended
# we will apply the cut later on the times instead of t
delay_cut = interaction_time <= self.cut_delayed
log.info(
f"Removing {np.sum(~delay_cut)} ({np.sum(~delay_cut) / len(delay_cut):.4%}) "
f"interactions later than {self.cut_delayed:.2e} ns."
)

# Adjust event times if necessary
if self.event_rate > 0:

# We need to get the number of interactions in the event,
# as we could have empty for TPC but not for other detectors
num_interactions = len(interactions["t"])
event_times = self.rng.uniform(
low=start / self.event_rate, high=stop / self.event_rate, size=stop - start
low=start / self.event_rate, high=stop / self.event_rate, size=num_interactions
).astype(np.int64)

event_times = np.sort(event_times)

structure = np.unique(inter_reshaped["eventid"], return_counts=True)[1]
interactions["time"] = interactions["t"] + event_times

# Check again why [:len(structure)] is needed
interaction_time = np.repeat(event_times[: len(structure)], structure)
inter_reshaped["time"] = interaction_time + inter_reshaped["t"]
elif self.event_rate == 0:
log.info("Using event times from provided input file.")
if self.file_type == "root":
Expand All @@ -262,32 +273,33 @@ def output_chunk(self):
"Use a source_rate > 0 instead."
)
log.warning(msg)
inter_reshaped["time"] = inter_reshaped["t"]
interactions["time"] = interactions["t"]

else:
raise ValueError("Source rate cannot be negative!")

# Remove interactions that happen way after the run ended
delay_cut = inter_reshaped["t"] <= self.cut_delayed
log.info(
f"Removing {np.sum(~delay_cut)} ({np.sum(~delay_cut) / len(delay_cut):.4%}) "
f"interactions later than {self.cut_delayed:.2e} ns."
)
inter_reshaped = inter_reshaped[delay_cut]

sort_idx = np.argsort(inter_reshaped["time"])
inter_reshaped = inter_reshaped[sort_idx]
# Overwrite interaction_time (based on "t") with the new event times
interaction_time = awkward_to_flat_numpy(interactions["time"])
# First caclulate sort index for the interaction times
sort_idx = np.argsort(interaction_time)
# and now make it an integer for strax time field
interaction_time = interaction_time.astype(np.int64)
# Sort the interaction times
interaction_time = interaction_time[sort_idx]
# Apply the delay cut
interaction_time = interaction_time[delay_cut]

# Group into chunks
chunk_idx = dynamic_chunking(
inter_reshaped["time"], scale=self.separation_scale, n_min=self.n_interactions_per_chunk
interaction_time, scale=self.separation_scale, n_min=self.n_interactions_per_chunk
)

# Calculate chunk start and end times
unique_chunk_index_values = np.unique(chunk_idx)

chunk_start = np.array(
[inter_reshaped[chunk_idx == i][0]["time"] for i in np.unique(chunk_idx)]
[interaction_time[chunk_idx == i][0] for i in unique_chunk_index_values]
)
chunk_end = np.array(
[inter_reshaped[chunk_idx == i][-1]["time"] for i in np.unique(chunk_idx)]
[interaction_time[chunk_idx == i][-1] for i in unique_chunk_index_values]
)

if (len(chunk_start) > 1) & (len(chunk_end) > 1):
Expand All @@ -307,17 +319,44 @@ def output_chunk(self):
chunk_end[0] + self.last_chunk_length,
]

# We need to get the min and max times for each event
# to preselect events with interactions in the chunk bounds
times_min = ak.to_numpy(ak.min(interactions["time"], axis=1)).astype(np.int64)
times_max = ak.to_numpy(ak.max(interactions["time"], axis=1)).astype(np.int64)

# Process and yield each chunk
source_done = False
unique_chunk_index_values = np.unique(chunk_idx)
log.info(f"Simulating data in {len(unique_chunk_index_values)} chunks.")
for c_ix, chunk_left, chunk_right in zip(
unique_chunk_index_values, self.chunk_bounds[:-1], self.chunk_bounds[1:]
):

# We do a preselection of the events that have interactions within the chunk
# before converting the full array to numpy (which is expensive in terms of memory)
m = (times_min <= chunk_right) & (times_max >= chunk_left)
current_chunk = interactions[m]

if len(current_chunk) == 0:
current_chunk = np.empty(0, dtype=self.dtype)

else:
# Convert the chunk from awkward array to a numpy array
current_chunk = full_array_to_numpy(current_chunk, self.dtype)

# Now we have the chunk of data in strax/numpy format
# We can now filter only the interactions within the chunk
select_times = current_chunk["time"] >= chunk_left
select_times &= current_chunk["time"] <= chunk_right
current_chunk = current_chunk[select_times]

# Sorting each chunk by time within the chunk
sort_chunk = np.argsort(current_chunk["time"])
current_chunk = current_chunk[sort_chunk]

if c_ix == unique_chunk_index_values[-1]:
source_done = True
log.debug("Last chunk created!")

yield inter_reshaped[chunk_idx == c_ix], chunk_left, chunk_right, source_done
yield current_chunk, chunk_left, chunk_right, source_done

def last_chunk_bounds(self):
return self.chunk_bounds[-1]
Expand Down

0 comments on commit 52999c9

Please sign in to comment.