From 034a336b92ef0a66b3529a46def3cd369280fd07 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Fri, 17 Sep 2021 09:48:30 -0500 Subject: [PATCH] Remove unnecessary dask computation in 'nearest' resampler (#377) * Remove unnecessary dask computation in 'nearest' resampler * Simplify if statement in kd_tree.py Co-authored-by: Martin Raspaud * Add 'assert_maximum_dask_computes' context manager for testing Co-authored-by: Martin Raspaud --- pyresample/kd_tree.py | 64 +++++++++++++-------------------- pyresample/test/test_kd_tree.py | 32 +++++++++++++---- pyresample/test/utils.py | 25 ++++++++++--- setup.cfg | 2 ++ 4 files changed, 72 insertions(+), 51 deletions(-) diff --git a/pyresample/kd_tree.py b/pyresample/kd_tree.py index 70cd26bcd..33a5d3399 100644 --- a/pyresample/kd_tree.py +++ b/pyresample/kd_tree.py @@ -1,7 +1,6 @@ # pyresample, Resampling of remote sensing image data in python # -# Copyright (C) 2010, 2014, 2015 Esben S. Nielsen -# Adam.Dybbroe +# Copyright (C) 2010-2021 Pyresample developers # # This program is free software: you can redistribute it and/or modify it under # the terms of the GNU Lesser General Public License as published by the Free @@ -53,7 +52,7 @@ class EmptyResult(ValueError): - pass + """No valid data is produced.""" def resample_nearest(source_geo_def, @@ -99,7 +98,6 @@ def resample_nearest(source_geo_def, data : numpy array Source data resampled to target geometry """ - return _resample(source_geo_def, data, target_geo_def, 'nn', radius_of_influence, neighbours=1, epsilon=epsilon, fill_value=fill_value, @@ -234,7 +232,6 @@ def resample_custom(source_geo_def, data, target_geo_def, Weighted standard devaition for all pixels having more than one source value Counts of number of source values used in weighting per pixel """ - if not isinstance(weight_funcs, (list, tuple)): if not isinstance(weight_funcs, types.FunctionType): raise TypeError('weight_func must be function object') @@ -254,7 +251,6 @@ def _resample(source_geo_def, data, target_geo_def, resample_type, radius_of_influence, neighbours=8, epsilon=0, weight_funcs=None, fill_value=0, reduce_data=True, nprocs=1, segments=None, with_uncert=False): """Resamples swath using kd-tree approach.""" - valid_input_index, valid_output_index, index_array, distance_array = \ get_neighbour_info(source_geo_def, target_geo_def, @@ -279,7 +275,7 @@ def _resample(source_geo_def, data, target_geo_def, resample_type, def get_neighbour_info(source_geo_def, target_geo_def, radius_of_influence, neighbours=8, epsilon=0, reduce_data=True, nprocs=1, segments=None): - """Returns neighbour info. + """Return neighbour info. Parameters ---------- @@ -309,7 +305,6 @@ def get_neighbour_info(source_geo_def, target_geo_def, radius_of_influence, index_array, distance_array) : tuple of numpy arrays Neighbour resampling info """ - if source_geo_def.size < neighbours: warnings.warn('Searching for %s neighbours in %s data points' % (neighbours, source_geo_def.size)) @@ -396,7 +391,6 @@ def _get_valid_input_index(source_geo_def, radius_of_influence, nprocs=1): """Find indices of reduced inputput data.""" - source_lons, source_lats = source_geo_def.get_lonlats(nprocs=nprocs) source_lons = np.asanyarray(source_lons).ravel() source_lats = np.asanyarray(source_lats).ravel() @@ -408,18 +402,15 @@ def _get_valid_input_index(source_geo_def, raise ValueError('Mismatch between lons and lats') # Remove illegal values - valid_input_index = ((source_lons >= -180) & (source_lons <= 180) & - (source_lats <= 90) & (source_lats >= -90)) + valid_input_index = ((source_lons >= -180) & (source_lons <= 180) & (source_lats <= 90) & (source_lats >= -90)) if reduce_data: # Reduce dataset - if (isinstance(source_geo_def, geometry.CoordinateDefinition) and - isinstance(target_geo_def, (geometry.GridDefinition, - geometry.AreaDefinition))) or \ - (isinstance(source_geo_def, (geometry.GridDefinition, - geometry.AreaDefinition)) and - isinstance(target_geo_def, (geometry.GridDefinition, - geometry.AreaDefinition))): + griddish_types = (geometry.GridDefinition, geometry.AreaDefinition) + source_is_griddish = isinstance(source_geo_def, griddish_types) + target_is_griddish = isinstance(target_geo_def, griddish_types) + source_is_coord = isinstance(source_geo_def, geometry.CoordinateDefinition) + if (source_is_coord or source_is_griddish) and target_is_griddish: # Resampling from swath to grid or from grid to grid lonlat_boundary = target_geo_def.get_boundary_lonlats() @@ -441,7 +432,6 @@ def _get_valid_input_index(source_geo_def, def _get_valid_output_index(source_geo_def, target_geo_def, target_lons, target_lats, reduce_data, radius_of_influence): """Find indices of reduced output data.""" - valid_output_index = np.ones(target_lons.size, dtype=bool) if reduce_data: @@ -460,8 +450,7 @@ def _get_valid_output_index(source_geo_def, target_geo_def, target_lons, valid_output_index = valid_output_index.astype(bool) # Remove illegal values - valid_out = ((target_lons >= -180) & (target_lons <= 180) & - (target_lats <= 90) & (target_lats >= -90)) + valid_out = ((target_lons >= -180) & (target_lons <= 180) & (target_lats <= 90) & (target_lats >= -90)) # Combine reduced and legal values valid_output_index = (valid_output_index & valid_out) @@ -561,16 +550,13 @@ def _query_resample_kdtree(resample_kdtree, def _create_empty_info(source_geo_def, target_geo_def, neighbours): - """Creates dummy info for empty result set.""" - + """Create dummy info for empty result set.""" valid_output_index = np.ones(target_geo_def.size, dtype=bool) if neighbours > 1: - index_array = (np.ones((target_geo_def.size, neighbours), - dtype=np.int32) * source_geo_def.size) + index_array = (np.ones((target_geo_def.size, neighbours), dtype=np.int32) * source_geo_def.size) distance_array = np.ones((target_geo_def.size, neighbours)) else: - index_array = (np.ones(target_geo_def.size, dtype=np.int32) * - source_geo_def.size) + index_array = (np.ones(target_geo_def.size, dtype=np.int32) * source_geo_def.size) distance_array = np.ones(target_geo_def.size) return valid_output_index, index_array, distance_array @@ -617,7 +603,6 @@ def get_sample_from_neighbour_info(resample_type, output_shape, data, result : numpy array Source data resampled to target geometry """ - if data.ndim > 2 and data.shape[0] * data.shape[1] == valid_input_index.size: data = data.reshape(data.shape[0] * data.shape[1], data.shape[2]) elif data.shape[0] != valid_input_index.size: @@ -879,6 +864,7 @@ def get_sample_from_neighbour_info(resample_type, output_shape, data, def lonlat2xyz(lons, lats): + """Convert lon/lat degrees to geocentric x/y/z coordinates.""" R = 6370997.0 x_coords = R * np.cos(np.deg2rad(lats)) * np.cos(np.deg2rad(lons)) y_coords = R * np.cos(np.deg2rad(lats)) * np.sin(np.deg2rad(lons)) @@ -934,7 +920,7 @@ def query_no_distance(target_lons, target_lats, valid_output_index, def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None, fill_value=np.nan): - """Helper function for 'get_sample_from_neighbour_info'.""" + """Wrap index logic for 'get_sample_from_neighbour_info' to be used inside dask map_blocks.""" vii_slices = tuple( x if x is not None else vii.ravel() for x in vii_slices) mask_slices = tuple( @@ -947,13 +933,15 @@ def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None, class XArrayResamplerNN(object): + """Resampler for Xarray DataArray objects with the nearest neighbor algorithm.""" + def __init__(self, source_geo_def, target_geo_def, radius_of_influence=None, neighbours=1, epsilon=0): - """ + """Resampler for xarray DataArrays using a nearest neighbor algorithm. Parameters ---------- @@ -1016,8 +1004,7 @@ def _create_resample_kdtree(self, chunks=CHUNK_SIZE): """Set up kd tree on input.""" source_lons, source_lats = self.source_geo_def.get_lonlats( chunks=chunks) - valid_input_idx = ((source_lons >= -180) & (source_lons <= 180) & - (source_lats <= 90) & (source_lats >= -90)) + valid_input_idx = ((source_lons >= -180) & (source_lons <= 180) & (source_lats <= 90) & (source_lats >= -90)) input_coords = lonlat2xyz(source_lons, source_lats) input_coords = input_coords[valid_input_idx.ravel(), :] @@ -1070,8 +1057,7 @@ def get_neighbour_info(self, mask=None): # TODO: Add 'chunks' keyword argument to this method and use it target_lons, target_lats = self.target_geo_def.get_lonlats(chunks=CHUNK_SIZE) - valid_output_idx = ((target_lons >= -180) & (target_lons <= 180) & - (target_lats <= 90) & (target_lats >= -90)) + valid_output_idx = ((target_lons >= -180) & (target_lons <= 180) & (target_lats <= 90) & (target_lats >= -90)) if mask is not None: assert (mask.shape == self.source_geo_def.shape), \ @@ -1146,9 +1132,8 @@ def get_sample_from_neighbour_info(self, data, fill_value=np.nan): # verify that the dims are next to each other first_dim_idx = data.dims.index(src_geo_dims[0]) num_dims = len(src_geo_dims) - assert (data.dims[first_dim_idx:first_dim_idx + num_dims] == - data_geo_dims), "Data's geolocation dimensions are not " \ - "consecutive." + assert (data.dims[first_dim_idx:first_dim_idx + num_dims] == data_geo_dims),\ + "Data's geolocation dimensions are not consecutive." # FIXME: Can't include coordinates whose dimensions depend on the geo # dims either @@ -1158,8 +1143,8 @@ def contain_coords(var, coord_list): coords = {c: c_var for c, c_var in data.coords.items() if not contain_coords(c_var, src_geo_dims + dst_geo_dims)} try: - # TODO: Add 'chunks' kwarg - coord_x, coord_y = self.target_geo_def.get_proj_vectors(chunks=CHUNK_SIZE) + # get these as numpy arrays because xarray is going to compute them anyway + coord_x, coord_y = self.target_geo_def.get_proj_vectors() coords['y'] = coord_y coords['x'] = coord_x except AttributeError: @@ -1241,7 +1226,6 @@ def _get_fill_mask_value(data_dtype): def _remask_data(data, is_to_be_masked=True): """Interprets half the array as mask for the other half.""" - channels = data.shape[-1] if is_to_be_masked: mask = data[..., (channels // 2):] diff --git a/pyresample/test/test_kd_tree.py b/pyresample/test/test_kd_tree.py index d73c54ded..9846f5da7 100644 --- a/pyresample/test/test_kd_tree.py +++ b/pyresample/test/test_kd_tree.py @@ -1,3 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Pyresample developers +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +"""Test kd_tree operations.""" import os import numpy as np @@ -9,6 +27,7 @@ class Test(unittest.TestCase): + """Test nearest neighbor resampling on numpy arrays.""" @classmethod def setUpClass(cls): @@ -329,8 +348,7 @@ def test_gauss_multi_uncert(self): expected_stddev = [0.44621800779801657, 0.44363137712896705, 0.43861019464274459] expected_counts = 4934802.0 - self.assertTrue(res.shape == stddev.shape and stddev.shape == - counts.shape and counts.shape == (800, 800, 3)) + self.assertTrue(res.shape == stddev.shape and stddev.shape == counts.shape and counts.shape == (800, 800, 3)) self.assertAlmostEqual(cross_sum, expected) for i, e_stddev in enumerate(expected_stddev): @@ -881,13 +899,15 @@ def test_nearest_swath_2d_mask_to_area_1n(self): def test_nearest_area_2d_to_area_1n(self): """Test 2D area definition to 2D area definition; 1 neighbor.""" from pyresample.kd_tree import XArrayResamplerNN + from pyresample.test.utils import assert_maximum_dask_computes import xarray as xr import dask.array as da data = self.data_2d resampler = XArrayResamplerNN(self.src_area_2d, self.area_def, radius_of_influence=50000, neighbours=1) - ninfo = resampler.get_neighbour_info() + with assert_maximum_dask_computes(0): + ninfo = resampler.get_neighbour_info() for val in ninfo[:3]: # vii, ia, voi self.assertIsInstance(val, da.Array) @@ -896,7 +916,8 @@ def test_nearest_area_2d_to_area_1n(self): # rename data dimensions to match the expected area dimensions data = data.rename({'my_dim_y': 'y', 'my_dim_x': 'x'}) - res = resampler.get_sample_from_neighbour_info(data) + with assert_maximum_dask_computes(0): + res = resampler.get_sample_from_neighbour_info(data) self.assertIsInstance(res, xr.DataArray) self.assertIsInstance(res.data, da.Array) res = res.values @@ -946,8 +967,7 @@ def test_nearest_area_2d_to_area_1n_no_roi(self): self.assertEqual(cross_sum, expected) def test_nearest_area_2d_to_area_1n_3d_data(self): - """Test 2D area definition to 2D area definition; 1 neighbor, 3d - data.""" + """Test 2D area definition to 2D area definition; 1 neighbor, 3d data.""" from pyresample.kd_tree import XArrayResamplerNN import xarray as xr import dask.array as da diff --git a/pyresample/test/utils.py b/pyresample/test/utils.py index 199780789..43591224c 100644 --- a/pyresample/test/utils.py +++ b/pyresample/test/utils.py @@ -24,6 +24,7 @@ import sys import types import warnings +from contextlib import contextmanager import numpy as np @@ -39,10 +40,9 @@ def treat_deprecations_as_exceptions(): - """Turn all DeprecationWarnings (which indicate deprecated uses of Python - itself or Numpy, but not within Astropy, where we use our own deprecation - warning class) into exceptions so that we find out about them early. + """Turn all DeprecationWarnings into exceptions. + Deprecation warnings indicate deprecated uses of Python itself or Numpy. This completely resets the warning filters and any "already seen" warning state. """ @@ -124,8 +124,9 @@ def treat_deprecations_as_exceptions(): class catch_warnings(warnings.catch_warnings): - """A high-powered version of warnings.catch_warnings to use for testing and - to make sure that there is no dependence on the order in which the tests + """A high-powered version of warnings.catch_warnings to use for testing. + + Makes sure that there is no dependence on the order in which the tests are run. This completely blitzes any memory of any warnings that have @@ -140,11 +141,14 @@ class catch_warnings(warnings.catch_warnings): do.something.bad() assert len(w) > 0 """ + def __init__(self, *classes): + """Initialize the classes of warnings to catch.""" super(catch_warnings, self).__init__(record=True) self.classes = classes def __enter__(self): + """Catch any warnings during this context.""" warning_list = super(catch_warnings, self).__enter__() treat_deprecations_as_exceptions() if len(self.classes) == 0: @@ -156,10 +160,12 @@ def __enter__(self): return warning_list def __exit__(self, type, value, traceback): + """Raise any warnings as errors.""" treat_deprecations_as_exceptions() def create_test_longitude(start, stop, shape, twist_factor=0.0, dtype=np.float32): + """Get basic sample of longitude data.""" if start > stop: stop += 360.0 @@ -175,6 +181,7 @@ def create_test_longitude(start, stop, shape, twist_factor=0.0, dtype=np.float32 def create_test_latitude(start, stop, shape, twist_factor=0.0, dtype=np.float32): + """Get basic sample of latitude data.""" num_cols = 1 if len(shape) < 2 else shape[1] lat_col = np.linspace(start, stop, num=shape[0]).astype(dtype).reshape((shape[0], 1)) twist_array = np.arange(num_cols) * twist_factor @@ -201,6 +208,14 @@ def __call__(self, dsk, keys, **kwargs): return dask.get(dsk, keys, **kwargs) +@contextmanager +def assert_maximum_dask_computes(max_computes=1): + """Context manager to make sure dask computations are not executed more than ``max_computes`` times.""" + import dask + with dask.config.set(scheduler=CustomScheduler(max_computes=max_computes)) as new_config: + yield new_config + + def friendly_crs_equal(expected, actual, keys=None, use_obj=True, use_wkt=True): """Test if two projection definitions are equal. diff --git a/setup.cfg b/setup.cfg index 4e57f8837..589a60d39 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,8 @@ doc_files = docs/Makefile docs/source/*.rst [flake8] max-line-length = 120 +per-file-ignores: + pyresample/test/*.py:D102 [versioneer] VCS = git