Skip to content

Commit

Permalink
Remove unnecessary dask computation in 'nearest' resampler (#377)
Browse files Browse the repository at this point in the history
* Remove unnecessary dask computation in 'nearest' resampler

* Simplify if statement in kd_tree.py

Co-authored-by: Martin Raspaud <[email protected]>

* Add 'assert_maximum_dask_computes' context manager for testing

Co-authored-by: Martin Raspaud <[email protected]>
  • Loading branch information
djhoese and mraspaud authored Sep 17, 2021
1 parent 6a37a16 commit 034a336
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 51 deletions.
64 changes: 24 additions & 40 deletions pyresample/kd_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pyresample, Resampling of remote sensing image data in python
#
# Copyright (C) 2010, 2014, 2015 Esben S. Nielsen
# Adam.Dybbroe
# Copyright (C) 2010-2021 Pyresample developers
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
Expand Down Expand Up @@ -53,7 +52,7 @@


class EmptyResult(ValueError):
pass
"""No valid data is produced."""


def resample_nearest(source_geo_def,
Expand Down Expand Up @@ -99,7 +98,6 @@ def resample_nearest(source_geo_def,
data : numpy array
Source data resampled to target geometry
"""

return _resample(source_geo_def, data, target_geo_def, 'nn',
radius_of_influence, neighbours=1,
epsilon=epsilon, fill_value=fill_value,
Expand Down Expand Up @@ -234,7 +232,6 @@ def resample_custom(source_geo_def, data, target_geo_def,
Weighted standard devaition for all pixels having more than one source value
Counts of number of source values used in weighting per pixel
"""

if not isinstance(weight_funcs, (list, tuple)):
if not isinstance(weight_funcs, types.FunctionType):
raise TypeError('weight_func must be function object')
Expand All @@ -254,7 +251,6 @@ def _resample(source_geo_def, data, target_geo_def, resample_type,
radius_of_influence, neighbours=8, epsilon=0, weight_funcs=None,
fill_value=0, reduce_data=True, nprocs=1, segments=None, with_uncert=False):
"""Resamples swath using kd-tree approach."""

valid_input_index, valid_output_index, index_array, distance_array = \
get_neighbour_info(source_geo_def,
target_geo_def,
Expand All @@ -279,7 +275,7 @@ def _resample(source_geo_def, data, target_geo_def, resample_type,
def get_neighbour_info(source_geo_def, target_geo_def, radius_of_influence,
neighbours=8, epsilon=0, reduce_data=True,
nprocs=1, segments=None):
"""Returns neighbour info.
"""Return neighbour info.
Parameters
----------
Expand Down Expand Up @@ -309,7 +305,6 @@ def get_neighbour_info(source_geo_def, target_geo_def, radius_of_influence,
index_array, distance_array) : tuple of numpy arrays
Neighbour resampling info
"""

if source_geo_def.size < neighbours:
warnings.warn('Searching for %s neighbours in %s data points' %
(neighbours, source_geo_def.size))
Expand Down Expand Up @@ -396,7 +391,6 @@ def _get_valid_input_index(source_geo_def,
radius_of_influence,
nprocs=1):
"""Find indices of reduced inputput data."""

source_lons, source_lats = source_geo_def.get_lonlats(nprocs=nprocs)
source_lons = np.asanyarray(source_lons).ravel()
source_lats = np.asanyarray(source_lats).ravel()
Expand All @@ -408,18 +402,15 @@ def _get_valid_input_index(source_geo_def,
raise ValueError('Mismatch between lons and lats')

# Remove illegal values
valid_input_index = ((source_lons >= -180) & (source_lons <= 180) &
(source_lats <= 90) & (source_lats >= -90))
valid_input_index = ((source_lons >= -180) & (source_lons <= 180) & (source_lats <= 90) & (source_lats >= -90))

if reduce_data:
# Reduce dataset
if (isinstance(source_geo_def, geometry.CoordinateDefinition) and
isinstance(target_geo_def, (geometry.GridDefinition,
geometry.AreaDefinition))) or \
(isinstance(source_geo_def, (geometry.GridDefinition,
geometry.AreaDefinition)) and
isinstance(target_geo_def, (geometry.GridDefinition,
geometry.AreaDefinition))):
griddish_types = (geometry.GridDefinition, geometry.AreaDefinition)
source_is_griddish = isinstance(source_geo_def, griddish_types)
target_is_griddish = isinstance(target_geo_def, griddish_types)
source_is_coord = isinstance(source_geo_def, geometry.CoordinateDefinition)
if (source_is_coord or source_is_griddish) and target_is_griddish:
# Resampling from swath to grid or from grid to grid
lonlat_boundary = target_geo_def.get_boundary_lonlats()

Expand All @@ -441,7 +432,6 @@ def _get_valid_input_index(source_geo_def,
def _get_valid_output_index(source_geo_def, target_geo_def, target_lons,
target_lats, reduce_data, radius_of_influence):
"""Find indices of reduced output data."""

valid_output_index = np.ones(target_lons.size, dtype=bool)

if reduce_data:
Expand All @@ -460,8 +450,7 @@ def _get_valid_output_index(source_geo_def, target_geo_def, target_lons,
valid_output_index = valid_output_index.astype(bool)

# Remove illegal values
valid_out = ((target_lons >= -180) & (target_lons <= 180) &
(target_lats <= 90) & (target_lats >= -90))
valid_out = ((target_lons >= -180) & (target_lons <= 180) & (target_lats <= 90) & (target_lats >= -90))

# Combine reduced and legal values
valid_output_index = (valid_output_index & valid_out)
Expand Down Expand Up @@ -561,16 +550,13 @@ def _query_resample_kdtree(resample_kdtree,


def _create_empty_info(source_geo_def, target_geo_def, neighbours):
"""Creates dummy info for empty result set."""

"""Create dummy info for empty result set."""
valid_output_index = np.ones(target_geo_def.size, dtype=bool)
if neighbours > 1:
index_array = (np.ones((target_geo_def.size, neighbours),
dtype=np.int32) * source_geo_def.size)
index_array = (np.ones((target_geo_def.size, neighbours), dtype=np.int32) * source_geo_def.size)
distance_array = np.ones((target_geo_def.size, neighbours))
else:
index_array = (np.ones(target_geo_def.size, dtype=np.int32) *
source_geo_def.size)
index_array = (np.ones(target_geo_def.size, dtype=np.int32) * source_geo_def.size)
distance_array = np.ones(target_geo_def.size)

return valid_output_index, index_array, distance_array
Expand Down Expand Up @@ -617,7 +603,6 @@ def get_sample_from_neighbour_info(resample_type, output_shape, data,
result : numpy array
Source data resampled to target geometry
"""

if data.ndim > 2 and data.shape[0] * data.shape[1] == valid_input_index.size:
data = data.reshape(data.shape[0] * data.shape[1], data.shape[2])
elif data.shape[0] != valid_input_index.size:
Expand Down Expand Up @@ -879,6 +864,7 @@ def get_sample_from_neighbour_info(resample_type, output_shape, data,


def lonlat2xyz(lons, lats):
"""Convert lon/lat degrees to geocentric x/y/z coordinates."""
R = 6370997.0
x_coords = R * np.cos(np.deg2rad(lats)) * np.cos(np.deg2rad(lons))
y_coords = R * np.cos(np.deg2rad(lats)) * np.sin(np.deg2rad(lons))
Expand Down Expand Up @@ -934,7 +920,7 @@ def query_no_distance(target_lons, target_lats, valid_output_index,

def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None,
fill_value=np.nan):
"""Helper function for 'get_sample_from_neighbour_info'."""
"""Wrap index logic for 'get_sample_from_neighbour_info' to be used inside dask map_blocks."""
vii_slices = tuple(
x if x is not None else vii.ravel() for x in vii_slices)
mask_slices = tuple(
Expand All @@ -947,13 +933,15 @@ def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None,


class XArrayResamplerNN(object):
"""Resampler for Xarray DataArray objects with the nearest neighbor algorithm."""

def __init__(self,
source_geo_def,
target_geo_def,
radius_of_influence=None,
neighbours=1,
epsilon=0):
"""
"""Resampler for xarray DataArrays using a nearest neighbor algorithm.
Parameters
----------
Expand Down Expand Up @@ -1016,8 +1004,7 @@ def _create_resample_kdtree(self, chunks=CHUNK_SIZE):
"""Set up kd tree on input."""
source_lons, source_lats = self.source_geo_def.get_lonlats(
chunks=chunks)
valid_input_idx = ((source_lons >= -180) & (source_lons <= 180) &
(source_lats <= 90) & (source_lats >= -90))
valid_input_idx = ((source_lons >= -180) & (source_lons <= 180) & (source_lats <= 90) & (source_lats >= -90))
input_coords = lonlat2xyz(source_lons, source_lats)
input_coords = input_coords[valid_input_idx.ravel(), :]

Expand Down Expand Up @@ -1070,8 +1057,7 @@ def get_neighbour_info(self, mask=None):

# TODO: Add 'chunks' keyword argument to this method and use it
target_lons, target_lats = self.target_geo_def.get_lonlats(chunks=CHUNK_SIZE)
valid_output_idx = ((target_lons >= -180) & (target_lons <= 180) &
(target_lats <= 90) & (target_lats >= -90))
valid_output_idx = ((target_lons >= -180) & (target_lons <= 180) & (target_lats <= 90) & (target_lats >= -90))

if mask is not None:
assert (mask.shape == self.source_geo_def.shape), \
Expand Down Expand Up @@ -1146,9 +1132,8 @@ def get_sample_from_neighbour_info(self, data, fill_value=np.nan):
# verify that the dims are next to each other
first_dim_idx = data.dims.index(src_geo_dims[0])
num_dims = len(src_geo_dims)
assert (data.dims[first_dim_idx:first_dim_idx + num_dims] ==
data_geo_dims), "Data's geolocation dimensions are not " \
"consecutive."
assert (data.dims[first_dim_idx:first_dim_idx + num_dims] == data_geo_dims),\
"Data's geolocation dimensions are not consecutive."

# FIXME: Can't include coordinates whose dimensions depend on the geo
# dims either
Expand All @@ -1158,8 +1143,8 @@ def contain_coords(var, coord_list):
coords = {c: c_var for c, c_var in data.coords.items()
if not contain_coords(c_var, src_geo_dims + dst_geo_dims)}
try:
# TODO: Add 'chunks' kwarg
coord_x, coord_y = self.target_geo_def.get_proj_vectors(chunks=CHUNK_SIZE)
# get these as numpy arrays because xarray is going to compute them anyway
coord_x, coord_y = self.target_geo_def.get_proj_vectors()
coords['y'] = coord_y
coords['x'] = coord_x
except AttributeError:
Expand Down Expand Up @@ -1241,7 +1226,6 @@ def _get_fill_mask_value(data_dtype):

def _remask_data(data, is_to_be_masked=True):
"""Interprets half the array as mask for the other half."""

channels = data.shape[-1]
if is_to_be_masked:
mask = data[..., (channels // 2):]
Expand Down
32 changes: 26 additions & 6 deletions pyresample/test/test_kd_tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Pyresample developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Test kd_tree operations."""
import os
import numpy as np

Expand All @@ -9,6 +27,7 @@


class Test(unittest.TestCase):
"""Test nearest neighbor resampling on numpy arrays."""

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -329,8 +348,7 @@ def test_gauss_multi_uncert(self):
expected_stddev = [0.44621800779801657, 0.44363137712896705,
0.43861019464274459]
expected_counts = 4934802.0
self.assertTrue(res.shape == stddev.shape and stddev.shape ==
counts.shape and counts.shape == (800, 800, 3))
self.assertTrue(res.shape == stddev.shape and stddev.shape == counts.shape and counts.shape == (800, 800, 3))
self.assertAlmostEqual(cross_sum, expected)

for i, e_stddev in enumerate(expected_stddev):
Expand Down Expand Up @@ -881,13 +899,15 @@ def test_nearest_swath_2d_mask_to_area_1n(self):
def test_nearest_area_2d_to_area_1n(self):
"""Test 2D area definition to 2D area definition; 1 neighbor."""
from pyresample.kd_tree import XArrayResamplerNN
from pyresample.test.utils import assert_maximum_dask_computes
import xarray as xr
import dask.array as da
data = self.data_2d
resampler = XArrayResamplerNN(self.src_area_2d, self.area_def,
radius_of_influence=50000,
neighbours=1)
ninfo = resampler.get_neighbour_info()
with assert_maximum_dask_computes(0):
ninfo = resampler.get_neighbour_info()
for val in ninfo[:3]:
# vii, ia, voi
self.assertIsInstance(val, da.Array)
Expand All @@ -896,7 +916,8 @@ def test_nearest_area_2d_to_area_1n(self):

# rename data dimensions to match the expected area dimensions
data = data.rename({'my_dim_y': 'y', 'my_dim_x': 'x'})
res = resampler.get_sample_from_neighbour_info(data)
with assert_maximum_dask_computes(0):
res = resampler.get_sample_from_neighbour_info(data)
self.assertIsInstance(res, xr.DataArray)
self.assertIsInstance(res.data, da.Array)
res = res.values
Expand Down Expand Up @@ -946,8 +967,7 @@ def test_nearest_area_2d_to_area_1n_no_roi(self):
self.assertEqual(cross_sum, expected)

def test_nearest_area_2d_to_area_1n_3d_data(self):
"""Test 2D area definition to 2D area definition; 1 neighbor, 3d
data."""
"""Test 2D area definition to 2D area definition; 1 neighbor, 3d data."""
from pyresample.kd_tree import XArrayResamplerNN
import xarray as xr
import dask.array as da
Expand Down
25 changes: 20 additions & 5 deletions pyresample/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import types
import warnings
from contextlib import contextmanager

import numpy as np

Expand All @@ -39,10 +40,9 @@


def treat_deprecations_as_exceptions():
"""Turn all DeprecationWarnings (which indicate deprecated uses of Python
itself or Numpy, but not within Astropy, where we use our own deprecation
warning class) into exceptions so that we find out about them early.
"""Turn all DeprecationWarnings into exceptions.
Deprecation warnings indicate deprecated uses of Python itself or Numpy.
This completely resets the warning filters and any "already seen"
warning state.
"""
Expand Down Expand Up @@ -124,8 +124,9 @@ def treat_deprecations_as_exceptions():


class catch_warnings(warnings.catch_warnings):
"""A high-powered version of warnings.catch_warnings to use for testing and
to make sure that there is no dependence on the order in which the tests
"""A high-powered version of warnings.catch_warnings to use for testing.
Makes sure that there is no dependence on the order in which the tests
are run.
This completely blitzes any memory of any warnings that have
Expand All @@ -140,11 +141,14 @@ class catch_warnings(warnings.catch_warnings):
do.something.bad()
assert len(w) > 0
"""

def __init__(self, *classes):
"""Initialize the classes of warnings to catch."""
super(catch_warnings, self).__init__(record=True)
self.classes = classes

def __enter__(self):
"""Catch any warnings during this context."""
warning_list = super(catch_warnings, self).__enter__()
treat_deprecations_as_exceptions()
if len(self.classes) == 0:
Expand All @@ -156,10 +160,12 @@ def __enter__(self):
return warning_list

def __exit__(self, type, value, traceback):
"""Raise any warnings as errors."""
treat_deprecations_as_exceptions()


def create_test_longitude(start, stop, shape, twist_factor=0.0, dtype=np.float32):
"""Get basic sample of longitude data."""
if start > stop:
stop += 360.0

Expand All @@ -175,6 +181,7 @@ def create_test_longitude(start, stop, shape, twist_factor=0.0, dtype=np.float32


def create_test_latitude(start, stop, shape, twist_factor=0.0, dtype=np.float32):
"""Get basic sample of latitude data."""
num_cols = 1 if len(shape) < 2 else shape[1]
lat_col = np.linspace(start, stop, num=shape[0]).astype(dtype).reshape((shape[0], 1))
twist_array = np.arange(num_cols) * twist_factor
Expand All @@ -201,6 +208,14 @@ def __call__(self, dsk, keys, **kwargs):
return dask.get(dsk, keys, **kwargs)


@contextmanager
def assert_maximum_dask_computes(max_computes=1):
"""Context manager to make sure dask computations are not executed more than ``max_computes`` times."""
import dask
with dask.config.set(scheduler=CustomScheduler(max_computes=max_computes)) as new_config:
yield new_config


def friendly_crs_equal(expected, actual, keys=None, use_obj=True, use_wkt=True):
"""Test if two projection definitions are equal.
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ doc_files = docs/Makefile docs/source/*.rst

[flake8]
max-line-length = 120
per-file-ignores:
pyresample/test/*.py:D102

[versioneer]
VCS = git
Expand Down

0 comments on commit 034a336

Please sign in to comment.