Skip to content

Commit

Permalink
Merge pull request #2201 from chrishavlin/scipy_1pt11_kdtree_prefilter
Browse files Browse the repository at this point in the history
precompute finite value mask for kdtree query
  • Loading branch information
greglucas authored Jul 14, 2023
2 parents de7b307 + 58c5fe5 commit 41880a8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
22 changes: 15 additions & 7 deletions lib/cartopy/img_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@


try:
import pykdtree.kdtree
from pykdtree.kdtree import KDTree as _kdtreeClass
_is_pykdtree = True
except ImportError:
try:
import scipy.spatial
from scipy.spatial import cKDTree as _kdtreeClass
except ImportError as e:
raise ImportError("Using image transforms requires either "
"pykdtree or scipy.") from e
Expand Down Expand Up @@ -268,17 +268,25 @@ def regrid(array, source_x_coords, source_y_coords, source_proj, target_proj,
target_x_points.flatten(),
target_y_points.flatten())

# Find mask of valid points before querying kdtree: scipy >= 1.11 errors
# when querying nan points, might as well use for pykdtree too.
indices = np.zeros(target_xyz.shape[0], dtype=int)
finite_xyz = np.all(np.isfinite(target_xyz), axis=-1)

if _is_pykdtree:
kdtree = pykdtree.kdtree.KDTree(xyz)
kdtree = _kdtreeClass(xyz)
# Use sqr_dists=True because we don't care about distances,
# and it saves a sqrt.
_, indices = kdtree.query(target_xyz, k=1, sqr_dists=True)
_, indices[finite_xyz] = kdtree.query(target_xyz[finite_xyz, :],
k=1,
sqr_dists=True)
else:
# Versions of scipy >= v0.16 added the balanced_tree argument,
# which caused the KDTree to hang with this input.
kdtree = scipy.spatial.cKDTree(xyz, balanced_tree=False)
_, indices = kdtree.query(target_xyz, k=1)
mask = indices >= len(xyz)
kdtree = _kdtreeClass(xyz, balanced_tree=False)
_, indices[finite_xyz] = kdtree.query(target_xyz[finite_xyz, :], k=1)

mask = ~finite_xyz | (indices >= len(xyz))
indices[mask] = 0

desired_ny, desired_nx = target_x_points.shape
Expand Down
23 changes: 23 additions & 0 deletions lib/cartopy/tests/test_img_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from numpy.testing import assert_array_equal
import pytest
import scipy.spatial

import cartopy.crs as ccrs
import cartopy.img_transform as img_trans
Expand Down Expand Up @@ -108,3 +109,25 @@ def test_gridding_data_outside_projection():
assert_array_equal([-180, 180, -90, 90], extent)
assert_array_equal(expected, image)
assert_array_equal(expected_mask, image.mask)


@pytest.mark.parametrize("target_prj",
(ccrs.Mollweide(), ccrs.Orthographic()))
@pytest.mark.parametrize("use_scipy", (True, False))
def test_regridding_with_invalid_extent(target_prj, use_scipy, monkeypatch):
# tests that when a valid extent results in invalid points in the
# transformed coordinates, the regridding does not error.

# create 3 data points
lats = np.array([65, 10, -45])
lons = np.array([-170, 10, 170])
data = np.array([1, 2, 3])
data_trans = ccrs.Geodetic()

target_x, target_y, extent = img_trans.mesh_projection(target_prj, 8, 4)

if use_scipy:
monkeypatch.setattr(img_trans, "_is_pykdtree", False)
monkeypatch.setattr(img_trans, "_kdtreeClass", scipy.spatial.cKDTree)
_ = img_trans.regrid(data, lons, lats, data_trans, target_prj,
target_x, target_y)

0 comments on commit 41880a8

Please sign in to comment.