Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying extractors per telescope, fixes #1660 #1661

Merged
merged 2 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions ctapipe/calib/camera/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from ctapipe.core import TelescopeComponent
from ctapipe.image.extractor import ImageExtractor
from ctapipe.image.reducer import DataVolumeReducer
from ctapipe.core.traits import create_class_enum_trait, BoolTelescopeParameter
from ctapipe.core.traits import (
TelescopeParameter,
create_class_enum_trait,
BoolTelescopeParameter,
)

from numba import guvectorize, float64, float32, int64

Expand All @@ -36,8 +40,12 @@ class CameraCalibrator(TelescopeComponent):
DataVolumeReducer, default_value="NullDataVolumeReducer"
).tag(config=True)

image_extractor_type = create_class_enum_trait(
ImageExtractor, default_value="NeighborPeakWindowSum"
image_extractor_type = TelescopeParameter(
trait=create_class_enum_trait(
ImageExtractor, default_value="NeighborPeakWindowSum"
),
default_value="NeighborPeakWindowSum",
help="Name of the ImageExtractor subclass to be used.",
).tag(config=True)

apply_waveform_time_shift = BoolTelescopeParameter(
Expand Down Expand Up @@ -96,12 +104,17 @@ def __init__(
self._r1_empty_warn = False
self._dl0_empty_warn = False

self.image_extractors = {}

if image_extractor is None:
self.image_extractor = ImageExtractor.from_name(
self.image_extractor_type, subarray=self.subarray, parent=self
)
for (_, _, name) in self.image_extractor_type:
self.image_extractors[name] = ImageExtractor.from_name(
name, subarray=self.subarray, parent=self
)
else:
self.image_extractor = image_extractor
name = image_extractor.__class__.__name__
self.image_extractor_type = [("type", "*", name)]
self.image_extractors[name] = image_extractor

if data_volume_reducer is None:
self.data_volume_reducer = DataVolumeReducer.from_name(
Expand Down Expand Up @@ -191,7 +204,8 @@ def _calibrate_dl1(self, event, telid):
else:
remaining_shift = time_shift

charge, peak_time = self.image_extractor(
extractor = self.image_extractors[self.image_extractor_type.tel[telid]]
charge, peak_time = extractor(
waveforms, telid=telid, selected_gain_channel=selected_gain_channel
)

Expand Down
47 changes: 38 additions & 9 deletions ctapipe/calib/camera/tests/test_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,36 @@ def test_camera_calibrator(example_event, example_subarray):


def test_manual_extractor(example_subarray):
calibrator = CameraCalibrator(
subarray=example_subarray,
image_extractor=LocalPeakWindowSum(subarray=example_subarray),
)
assert isinstance(calibrator.image_extractor, LocalPeakWindowSum)
extractor = LocalPeakWindowSum(subarray=example_subarray)
calibrator = CameraCalibrator(subarray=example_subarray, image_extractor=extractor)
assert "LocalPeakWindowSum" in calibrator.image_extractors
assert calibrator.image_extractor_type.tel[1] == "LocalPeakWindowSum"
assert calibrator.image_extractors["LocalPeakWindowSum"] is extractor


def test_config(example_subarray):
calibrator = CameraCalibrator(subarray=example_subarray)

# test defaults
assert isinstance(calibrator.image_extractor, NeighborPeakWindowSum)
assert len(calibrator.image_extractors) == 1
assert isinstance(
calibrator.image_extractors["NeighborPeakWindowSum"], NeighborPeakWindowSum
)
assert isinstance(calibrator.data_volume_reducer, NullDataVolumeReducer)

# test we can configure different extractors with different options
# per telescope.
config = Config(
{
"CameraCalibrator": {
"image_extractor_type": "LocalPeakWindowSum",
"image_extractor_type": [
("type", "*", "GlobalPeakWindowSum"),
("id", 1, "LocalPeakWindowSum"),
],
"LocalPeakWindowSum": {"window_width": 15},
"GlobalPeakWindowSum": {
"window_width": [("type", "*", 10), ("id", 2, 8)]
},
"data_volume_reducer_type": "TailCutsDataVolumeReducer",
"TailCutsDataVolumeReducer": {
"TailcutsImageCleaner": {"picture_threshold_pe": 20.0}
Expand All @@ -60,8 +71,26 @@ def test_config(example_subarray):
)

calibrator = CameraCalibrator(example_subarray, config=config)
assert isinstance(calibrator.image_extractor, LocalPeakWindowSum)
assert calibrator.image_extractor.window_width.tel[None] == 15
assert "GlobalPeakWindowSum" in calibrator.image_extractors
assert "LocalPeakWindowSum" in calibrator.image_extractors
assert isinstance(
calibrator.image_extractors["LocalPeakWindowSum"], LocalPeakWindowSum
)
assert isinstance(
calibrator.image_extractors["GlobalPeakWindowSum"], GlobalPeakWindowSum
)

extractor_1 = calibrator.image_extractors[calibrator.image_extractor_type.tel[1]]
assert isinstance(extractor_1, LocalPeakWindowSum)
assert extractor_1.window_width.tel[1] == 15

extractor_2 = calibrator.image_extractors[calibrator.image_extractor_type.tel[2]]
assert isinstance(extractor_2, GlobalPeakWindowSum)
assert extractor_2.window_width.tel[2] == 8

extractor_3 = calibrator.image_extractors[calibrator.image_extractor_type.tel[3]]
assert isinstance(extractor_3, GlobalPeakWindowSum)
assert extractor_3.window_width.tel[3] == 10

assert isinstance(calibrator.data_volume_reducer, TailCutsDataVolumeReducer)
assert calibrator.data_volume_reducer.cleaner.picture_threshold_pe.tel[None] == 20
Expand Down
29 changes: 21 additions & 8 deletions ctapipe/image/reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np
from ctapipe.image import TailcutsImageCleaner
from ctapipe.core import TelescopeComponent
from ctapipe.core.traits import IntTelescopeParameter, BoolTelescopeParameter, Unicode
from ctapipe.core.traits import (
IntTelescopeParameter,
BoolTelescopeParameter,
TelescopeParameter,
create_class_enum_trait,
)
from ctapipe.image.extractor import ImageExtractor
from ctapipe.image.cleaning import dilate

Expand Down Expand Up @@ -115,9 +120,12 @@ class TailCutsDataVolumeReducer(DataVolumeReducer):
normal TailcutCleaning is used.
"""

image_extractor_type = Unicode(
image_extractor_type = TelescopeParameter(
trait=create_class_enum_trait(
ImageExtractor, default_value="NeighborPeakWindowSum"
),
default_value="NeighborPeakWindowSum",
help="Name of the image_extractor" "to be used.",
help="Name of the ImageExtractor subclass to be used.",
).tag(config=True)

n_end_dilates = IntTelescopeParameter(
Expand Down Expand Up @@ -157,17 +165,22 @@ def __init__(
else:
self.cleaner = cleaner

self.image_extractors = {}
if image_extractor is None:
self.image_extractor = ImageExtractor.from_name(
self.image_extractor_type, subarray=self.subarray, parent=self
)
for (_, _, name) in self.image_extractor_type:
self.image_extractors[name] = ImageExtractor.from_name(
name, subarray=self.subarray, parent=self
)
else:
self.image_extractor = image_extractor
name = image_extractor.__class__.__name__
self.image_extractor_type = [("type", "*", name)]
self.image_extractors[name] = image_extractor

def select_pixels(self, waveforms, telid=None, selected_gain_channel=None):
camera_geom = self.subarray.tel[telid].camera.geometry
# Pulse-integrate waveforms
charge, _ = self.image_extractor(
extractor = self.image_extractors[self.image_extractor_type.tel[telid]]
charge, _ = extractor(
waveforms, telid=telid, selected_gain_channel=selected_gain_channel
)

Expand Down
2 changes: 1 addition & 1 deletion ctapipe/tools/display_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def start(self):
)
exit()

extractor_name = self.calibrate.image_extractor.__class__.__name__
extractor_name = self.calibrate.image_extractor_type.tel[telid]

plot(self.subarray, event, telid, self.channel, extractor_name)

Expand Down