Skip to content

Commit

Permalink
Ensure manual geometry updates (such as invert_rotation_axis=True) …
Browse files Browse the repository at this point in the history
…are only applied once to each model (dials#2469)

Use a set to keep track of which models have already been manually updated, to avoid doing this more than once
  • Loading branch information
dagewa authored Aug 4, 2023
1 parent 4fb2a86 commit 4b86bb6
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 18 deletions.
1 change: 1 addition & 0 deletions newsfragments/2469.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``dials.import``: ensure manual geometry updates are only applied once to each model. This ensure ``invert_rotation_axis=True`` will only invert the rotation axis once.
68 changes: 50 additions & 18 deletions src/dials/command_line/dials_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def __init__(self, params):
Save the params
"""
self.params = params
self.touched = set()

def __call__(self, imageset):
"""
Expand All @@ -361,17 +362,29 @@ def __call__(self, imageset):
for j in range(len(imageset)):
imageset.set_scan(None, j)
imageset.set_goniometer(None, j)

beam = imageset.get_beam()
detector = imageset.get_detector()
goniometer = imageset.get_goniometer()
scan = imageset.get_scan()

# Create a new model with updated geometry for each model that is not
# already in the touched set
if isinstance(imageset, ImageSequence):
beam = BeamFactory.from_phil(self.params.geometry, imageset.get_beam())
detector = DetectorFactory.from_phil(
self.params.geometry, imageset.get_detector(), beam
)
goniometer = GoniometerFactory.from_phil(
self.params.geometry, imageset.get_goniometer()
)
scan = ScanFactory.from_phil(
self.params.geometry, deepcopy(imageset.get_scan())
)
if beam and beam not in self.touched:
beam = BeamFactory.from_phil(self.params.geometry, imageset.get_beam())
if detector and detector not in self.touched:
detector = DetectorFactory.from_phil(
self.params.geometry, imageset.get_detector(), beam
)
if goniometer and goniometer not in self.touched:
goniometer = GoniometerFactory.from_phil(
self.params.geometry, imageset.get_goniometer()
)
if scan and scan not in self.touched:
scan = ScanFactory.from_phil(
self.params.geometry, deepcopy(imageset.get_scan())
)
i0, i1 = scan.get_array_range()
j0, j1 = imageset.get_scan().get_array_range()
if i0 < j0 or i1 > j1:
Expand All @@ -389,18 +402,37 @@ def __call__(self, imageset):
imageset.set_scan(scan)
else:
for i in range(len(imageset)):
beam = BeamFactory.from_phil(self.params.geometry, imageset.get_beam(i))
detector = DetectorFactory.from_phil(
self.params.geometry, imageset.get_detector(i), beam
)
goniometer = GoniometerFactory.from_phil(
self.params.geometry, imageset.get_goniometer(i)
)
scan = ScanFactory.from_phil(self.params.geometry, imageset.get_scan(i))
if beam and beam not in self.touched:
beam = BeamFactory.from_phil(
self.params.geometry, imageset.get_beam(i)
)
if detector and detector not in self.touched:
detector = DetectorFactory.from_phil(
self.params.geometry, imageset.get_detector(i), beam
)
if goniometer and goniometer not in self.touched:
goniometer = GoniometerFactory.from_phil(
self.params.geometry, imageset.get_goniometer(i)
)
if scan and scan not in self.touched:
scan = ScanFactory.from_phil(
self.params.geometry, imageset.get_scan(i)
)
imageset.set_beam(beam, i)
imageset.set_detector(detector, i)
imageset.set_goniometer(goniometer, i)
imageset.set_scan(scan, i)

# Add the models from this imageset to the touched set, so they will not
# have their geometry updated again
if beam:
self.touched.add(beam)
if detector:
self.touched.add(detector)
if goniometer:
self.touched.add(goniometer)
if scan:
self.touched.add(scan)
return imageset

def extrapolate_imageset(
Expand Down
90 changes: 90 additions & 0 deletions tests/command_line/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import pytest

from dxtbx.imageset import ImageSequence
from dxtbx.model.experiment_list import ExperimentListFactory
from dxtbx.serialize import load

from dials.command_line.dials_import import ManualGeometryUpdater
from dials.util.options import geometry_phil_scope


@pytest.mark.parametrize("use_beam", ["True", "False"])
@pytest.mark.parametrize("use_gonio", ["True", "False"])
Expand Down Expand Up @@ -134,6 +138,92 @@ def test_can_import_multiple_sequences(dials_data, tmp_path):
assert experiment.identifier != ""


def test_invert_axis_with_two_sequences_sharing_a_goniometer(dials_data, tmp_path):
# Test for regression of https://github.com/dials/dials/issues/2467
image_files = sorted(
dials_data("centroid_test_data", pathlib=True).glob("centroid*.cbf")
)
del image_files[4] # Delete filename to force two sequences

result = subprocess.run(
[
shutil.which("dials.import"),
"output.experiments=experiments_multiple_sequences.expt",
"invert_rotation_axis=True",
]
+ image_files,
cwd=tmp_path,
capture_output=True,
)
assert not result.returncode and not result.stderr
assert (tmp_path / "experiments_multiple_sequences.expt").is_file()

experiments = load.experiment_list(tmp_path / "experiments_multiple_sequences.expt")
assert len(experiments.goniometers()) == 1
assert experiments.goniometers()[0].get_rotation_axis() == (-1.0, 0.0, 0.0)


def test_ManualGeometryUpdater_inverts_axis(dials_data):
# Test behaviour of inverting axes with multiple imagesets as suggested in
# https://github.com/dials/dials/pull/2469#discussion_r1278264665

# Create four imagesets, first two share a goniometer model, second two
# have independent inverted goniometer models
filenames = sorted(
dials_data("centroid_test_data", pathlib=True).glob("centroid*.cbf")
)
experiments = ExperimentListFactory.from_filenames(filenames[0:3])
experiments.extend(ExperimentListFactory.from_filenames(filenames[2:5]))
experiments.extend(ExperimentListFactory.from_filenames(filenames[4:7]))
experiments.extend(ExperimentListFactory.from_filenames(filenames[7:9]))
imagesets = experiments.imagesets()
imagesets[1].set_goniometer(imagesets[0].get_goniometer())
imagesets[2].get_goniometer().set_rotation_axis((-1.0, 0, 0))
imagesets[3].get_goniometer().set_rotation_axis((-1.0, 0, 0))

assert imagesets[0].get_goniometer().get_rotation_axis() == (1.0, 0.0, 0.0)
assert imagesets[1].get_goniometer().get_rotation_axis() == (1.0, 0.0, 0.0)
assert imagesets[2].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)
assert imagesets[3].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)

# Set the manual geometry parameters. The hierarchy.group should be unset
# here, to model how the scope_extract appears to the ManualGeometryUpdater
# during a dials.import run.
params = geometry_phil_scope.extract()
params.geometry.goniometer.invert_rotation_axis = True
params.geometry.detector.hierarchy.group = []

mgu = ManualGeometryUpdater(params=params)

# Update the first imageset (affects first two, which share a gonio)
mgu(imagesets[0])
assert imagesets[0].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)
assert imagesets[1].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)
assert imagesets[2].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)
assert imagesets[3].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)

# Run the updater on the second (should not invert again)
mgu(imagesets[1])
assert imagesets[0].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[1].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[2].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)
assert imagesets[3].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)

# Run on the third (should invert only that one)
mgu(imagesets[2])
assert imagesets[0].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[1].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[2].get_goniometer().get_rotation_axis() == (1.0, 0.0, 0.0)
assert imagesets[3].get_goniometer().get_rotation_axis() == (-1.0, 0.0, 0.0)

# Run on the fourth (should invert only that one)
mgu(imagesets[3])
assert imagesets[0].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[1].get_goniometer().get_rotation_axis() == (-1, 0, 0)
assert imagesets[2].get_goniometer().get_rotation_axis() == (1.0, 0.0, 0.0)
assert imagesets[3].get_goniometer().get_rotation_axis() == (1.0, 0.0, 0.0)


def test_with_mask(dials_data, tmp_path):
image_files = sorted(
dials_data("centroid_test_data", pathlib=True).glob("centroid*.cbf")
Expand Down

0 comments on commit 4b86bb6

Please sign in to comment.