Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial work on supporting units in the scatter viewer #2509

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion glue/viewers/scatter/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from glue.core import BaseData, Subset
from glue.core import BaseData, Subset, Data

from glue.config import colormaps
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
Expand All @@ -14,6 +14,7 @@
from glue.core.data_combo_helper import ComponentIDComboHelper, ComboHelper
from glue.core.exceptions import IncompatibleAttribute
from glue.viewers.common.stretch_state_mixin import StretchStateMixin
from glue.core.units import find_unit_choices

from matplotlib.projections import get_projection_names

Expand All @@ -34,6 +35,9 @@ class ScatterViewerState(MatplotlibDataViewerState):
x_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining x limits")
y_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining y limits")

x_display_unit = DDSCProperty(docstring='The units to use to display the x-axis.')
y_display_unit = DDSCProperty(docstring='The units to use to display the y-axis')

def __init__(self, **kwargs):

super(ScatterViewerState, self).__init__()
Expand All @@ -43,11 +47,13 @@ def __init__(self, **kwargs):
self.x_lim_helper = StateAttributeLimitsHelper(self, attribute='x_att',
lower='x_min', upper='x_max',
log='x_log', margin=0.04,
display_units='x_display_unit',
limits_cache=self.limits_cache)

self.y_lim_helper = StateAttributeLimitsHelper(self, attribute='y_att',
lower='y_min', upper='y_max',
log='y_log', margin=0.04,
display_units='y_display_unit',
limits_cache=self.limits_cache)

self.add_callback('layers', self._layers_changed)
Expand All @@ -68,6 +74,9 @@ def __init__(self, **kwargs):
self.add_callback('x_log', self._reset_x_limits)
self.add_callback('y_log', self._reset_y_limits)

self.add_callback('x_att', self._update_x_display_unit_choices)
self.add_callback('y_att', self._update_y_display_unit_choices)

if self.using_polar:
self.full_circle()

Expand Down Expand Up @@ -197,6 +206,36 @@ def _layers_changed(self, *args):

self._layers_data_cache = layers_data

def _update_x_display_unit_choices(self, *args):

# NOTE: only Data and its subclasses support specifying units
if self.x_att is None or not isinstance(self.x_att.parent, Data):
ScatterViewerState.x_display_unit.set_choices(self, [])
return

component = self.x_att.parent.get_component(self.x_att)
if component.units:
x_choices = find_unit_choices([(self.x_att.parent, self.x_att, component.units)])
else:
x_choices = ['']
ScatterViewerState.x_display_unit.set_choices(self, x_choices)
self.x_display_unit = component.units

def _update_y_display_unit_choices(self, *args):

# NOTE: only Data and its subclasses support specifying units
if self.y_att is None or not isinstance(self.y_att.parent, Data):
ScatterViewerState.y_display_unit.set_choices(self, [])
return

component = self.y_att.parent.get_component(self.y_att)
if component.units:
y_choices = find_unit_choices([(self.y_att.parent, self.y_att, component.units)])
else:
y_choices = ['']
ScatterViewerState.y_display_unit.set_choices(self, y_choices)
self.y_display_unit = component.units


def display_func_slow(x):
if x == 'Linear':
Expand Down
115 changes: 113 additions & 2 deletions glue/viewers/scatter/tests/test_viewer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

import matplotlib.pyplot as plt

Expand All @@ -8,8 +8,9 @@
from glue.viewers.scatter.viewer import SimpleScatterViewer
from glue.core.application_base import Application
from glue.core.data import Data
from glue.core.link_helpers import LinkSame
from glue.core.link_helpers import LinkSame, LinkSameWithUnits
from glue.core.data_derived import IndexedData
from glue.core.roi import RectangularROI


@visual_test
Expand Down Expand Up @@ -131,3 +132,113 @@ def test_indexed_data():

assert viewer.state.x_att is data_2d.main_components[0]
assert viewer.state.y_att is data_2d.main_components[1]


def test_unit_conversion():

d1 = Data(a=[1, 2, 3], b=[2, 3, 4])
d1.get_component('a').units = 'm'
d1.get_component('b').units = 's'

d2 = Data(c=[2000, 1000, 3000], d=[0.001, 0.002, 0.004])
d2.get_component('c').units = 'mm'
d2.get_component('d').units = 'ks'

# d3 is the same as d2 but we will link it differently
d3 = Data(e=[2000, 1000, 3000], f=[0.001, 0.002, 0.004])
d3.get_component('e').units = 'mm'
d3.get_component('f').units = 'ks'

d4 = Data(g=[2, 2, 3], h=[1, 2, 1])
d4.get_component('g').units = 'kg'
d4.get_component('h').units = 'm/s'

app = Application()
session = app.session

data_collection = session.data_collection
data_collection.append(d1)
data_collection.append(d2)
data_collection.append(d3)
data_collection.append(d4)

data_collection.add_link(LinkSameWithUnits(d1.id['a'], d2.id['c']))
data_collection.add_link(LinkSameWithUnits(d1.id['b'], d2.id['d']))
data_collection.add_link(LinkSame(d1.id['a'], d3.id['e']))
data_collection.add_link(LinkSame(d1.id['b'], d3.id['f']))
data_collection.add_link(LinkSame(d1.id['a'], d4.id['g']))
data_collection.add_link(LinkSame(d1.id['b'], d4.id['h']))

viewer = app.new_data_viewer(SimpleScatterViewer)
viewer.add_data(d1)
viewer.add_data(d2)
viewer.add_data(d3)
viewer.add_data(d4)

assert viewer.layers[0].enabled
assert viewer.layers[1].enabled
assert viewer.layers[2].enabled
assert viewer.layers[3].enabled

assert viewer.state.x_min == 0.92
assert viewer.state.x_max == 3.08
assert viewer.state.y_min == 1.92
assert viewer.state.y_max == 4.08

roi = RectangularROI(0.5, 2.5, 1.5, 4.5)
viewer.apply_roi(roi)

assert len(d1.subsets) == 1
assert_equal(d1.subsets[0].to_mask(), [1, 1, 0])

# Because of the LinkSameWithUnits, the points actually appear in the right
# place even before we set the display units.
assert len(d2.subsets) == 1
assert_equal(d2.subsets[0].to_mask(), [0, 1, 0])

# d3 is only linked with LinkSame not LinkSameWithUnits so currently the
# points are outside the visible axes
assert len(d3.subsets) == 1
assert_equal(d3.subsets[0].to_mask(), [0, 0, 0])

# As we haven't set display units yet, the values for this dataset are shown
# on the same scale as for d1 as if the units had never been set.
assert len(d4.subsets) == 1
assert_equal(d4.subsets[0].to_mask(), [0, 1, 0])

# Now try setting the units explicitly

viewer.state.x_display_unit = 'km'
viewer.state.y_display_unit = 'ms'

assert_allclose(viewer.state.x_min, 0.92e-3)
assert_allclose(viewer.state.x_max, 3.08e-3)
assert_allclose(viewer.state.y_min, 1.92e3)
assert_allclose(viewer.state.y_max, 4.08e3)

roi = RectangularROI(0.5e-3, 2.5e-3, 1.5e3, 4.5e3)
viewer.apply_roi(roi)

# Results are as above - the display units do not result in any changes to
# the actual content of the axes and does not deal with automatic conversion
# of different units between different datasets - LinkSameWithUnits should
# deal with that already.

assert_equal(d1.subsets[0].to_mask(), [1, 1, 0])
assert_equal(d2.subsets[0].to_mask(), [0, 1, 0])
assert_equal(d3.subsets[0].to_mask(), [0, 0, 0])
assert_equal(d4.subsets[0].to_mask(), [0, 1, 0])

# Change the limits to make sure they are always converted
viewer.state.x_min = 0.0001
viewer.state.x_max = 0.005
viewer.state.y_min = 200
viewer.state.y_max = 7000

viewer.state.x_display_unit = 'm'
viewer.state.y_display_unit = 's'

assert viewer.state.x_min == 0.1
assert viewer.state.x_max == 5
assert viewer.state.y_min == 0.2
assert viewer.state.y_max == 7
27 changes: 24 additions & 3 deletions glue/viewers/scatter/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from glue.viewers.matplotlib.viewer import SimpleMatplotlibViewer
from glue.viewers.scatter.state import ScatterViewerState
from glue.viewers.scatter.layer_artist import ScatterLayerArtist
from glue.core.units import UnitConverter


__all__ = ['MatplotlibScatterMixin', 'SimpleScatterViewer']

Expand Down Expand Up @@ -152,9 +154,28 @@
x_date = 'datetime' in self.state.x_kinds
y_date = 'datetime' in self.state.y_kinds

if x_date or y_date:
roi = roi.transformed(xfunc=mpl_to_datetime64 if x_date else None,
yfunc=mpl_to_datetime64 if y_date else None)
converter = UnitConverter()

xfunc = None
if x_date:
xfunc = mpl_to_datetime64

Check warning on line 161 in glue/viewers/scatter/viewer.py

View check run for this annotation

Codecov / codecov/patch

glue/viewers/scatter/viewer.py#L161

Added line #L161 was not covered by tests
else:
if self.state.x_display_unit:
xfunc = lambda x: converter.to_native(self.state.x_att.parent,
self.state.x_att, x,
self.state.x_display_unit)

yfunc = None
if y_date:
yfunc = mpl_to_datetime64

Check warning on line 170 in glue/viewers/scatter/viewer.py

View check run for this annotation

Codecov / codecov/patch

glue/viewers/scatter/viewer.py#L170

Added line #L170 was not covered by tests
else:
if self.state.y_display_unit:
yfunc = lambda y: converter.to_native(self.state.y_att.parent,
self.state.y_att, y,
self.state.y_display_unit)

if xfunc or yfunc:
roi = roi.transformed(xfunc=xfunc, yfunc=yfunc)

use_transform = not self.using_rectilinear()
subset_state = roi_to_subset_state(roi,
Expand Down