Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to optionally return distances with .sel #41

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
mamba-version: "*"
miniforge-variant: Mambaforge
channels: conda-forge,defaults
channel-priority: true
auto-activate-base: false
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ repos:
- id: double-quote-string-fixer

- repo: https://github.com/ambv/black
rev: 20.8b1
rev: 22.3.0 # https://github.com/psf/black/issues/2964
hooks:
- id: black
args: ["--line-length", "100", "--skip-string-normalization"]

- repo: https://gitlab.com/PyCQA/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.8.4
hooks:
- id: flake8
Expand Down
6 changes: 5 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Required
version: 2

build:
os: "ubuntu-20.04"
tools:
python: "mambaforge-4.10"

# Build documentation in the doc/ directory with Sphinx
sphinx:
configuration: doc/conf.py
Expand All @@ -9,7 +14,6 @@ conda:
environment: environment_doc.yml

python:
version: 3.8
install:
- method: pip
path: .
234 changes: 209 additions & 25 deletions doc/examples/introduction.ipynb

Large diffs are not rendered by default.

47 changes: 43 additions & 4 deletions src/xoak/accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Hashable, Iterable, List, Mapping, Tuple, Type, Union
from typing import Any, Hashable, Iterable, List, Mapping, Optional, Tuple, Type, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -134,12 +134,18 @@ def index(self) -> Union[None, Index, Iterable[Index]]:
return [wrp.index for wrp in index_wrappers]

def _query(self, indexers):
"""Find the distance(s) and indices of nearest point(s).

Note that the distance is converted in function from radians to kilometers
by multiplying by the radius of the earth in km.
benbovy marked this conversation as resolved.
Show resolved Hide resolved
"""
X = coords_to_point_array([indexers[c] for c in self._index_coords])

if isinstance(X, np.ndarray) and isinstance(self._index, XoakIndexWrapper):
# directly call index wrapper's query method
res = self._index.query(X)
results = res['indices'][:, 0]
distances = res['distances'][:, 0]

else:
# Two-stage lazy query with dask
Expand Down Expand Up @@ -195,7 +201,7 @@ def _query(self, indexers):
concatenate=True,
)

return results
return results, distances * 6371
benbovy marked this conversation as resolved.
Show resolved Hide resolved
willirath marked this conversation as resolved.
Show resolved Hide resolved

def _get_pos_indexers(self, indices, indexers):
"""Returns positional indexers based on the query results and the
Expand Down Expand Up @@ -225,7 +231,10 @@ def _get_pos_indexers(self, indices, indexers):
return pos_indexers

def sel(
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
self,
indexers: Mapping[Hashable, Any] = None,
distances_name: Optional[str] = None,
**indexers_kwargs: Any,
) -> Union[xr.Dataset, xr.DataArray]:
"""Selection based on a ball tree index.

Expand All @@ -243,14 +252,26 @@ def sel(
This triggers :func:`dask.compute` if the given indexers and/or the index
coordinates are chunked.

Parameters
----------
distances_name: str, optional
If a string is input, it is used to save the distances into the xarray
object. Distances are in km.
kthyng marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
xr.Dataset, xr.DataArray
Normally, the type input is the type output. However, if you input a
str for `distances_name`, the return type will be Dataset to
accommodate the additional variable.
"""
if not getattr(self, '_index', False):
raise ValueError(
'The index(es) has/have not been built yet. Call `.xoak.set_index()` first'
)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'xoak.sel')
indices = self._query(indexers)
indices, distances = self._query(indexers)

if not isinstance(indices, np.ndarray):
# TODO: remove (see todo below)
Expand All @@ -263,4 +284,22 @@ def sel(
# This would also allow lazy selection
result = self._xarray_obj.isel(indexers=pos_indexers)

# save distances as a new variable in xarray object if name is input
if distances_name is not None:
kthyng marked this conversation as resolved.
Show resolved Hide resolved
# need to have a Dataset instead of DataArray to add a new variable
# otherwise goes in as a coordinate
if not isinstance(result, xr.Dataset):
result = result.to_dataset()
# use same dimensions as indexers
attrs = {
'units': 'km',
benbovy marked this conversation as resolved.
Show resolved Hide resolved
'long_name': 'Distance from location to nearest comparison point.',
}

indexer_dim = list(indexers.values())[0].dims
indexer_shape = indexers[list(indexers.keys())[0]].shape
result[distances_name] = xr.Variable(
indexer_dim, distances.reshape(indexer_shape), attrs
)

return result
2 changes: 1 addition & 1 deletion src/xoak/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import dask
import dask.array
kthyng marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pytest
import xarray as xr
Expand Down
21 changes: 21 additions & 0 deletions src/xoak/tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import xarray as xr
from scipy.spatial import cKDTree
Expand Down Expand Up @@ -80,3 +81,23 @@ def test_index_property():
ds_chunk = ds.chunk(2)
ds_chunk.xoak.set_index(['x', 'y'], 'scipy_kdtree')
assert isinstance(ds_chunk.xoak.index, list)


def test_distances():

ds = xr.Dataset(
coords={
'x': ('a', [0, 1, 2, 3]),
'y': ('a', [0, 1, 2, 3]),
}
)

ds_to_find = xr.Dataset({'lat_to_find': ('a', [0, 0]), 'lon_to_find': ('a', [0, 0.5])})
ds.xoak.set_index(['y', 'x'], 'sklearn_geo_balltree')

output = ds.xoak.sel(
{'y': ds_to_find.lat_to_find, 'x': ds_to_find.lon_to_find}, distances_name='distances'
)

assert isinstance(output, xr.Dataset)
assert np.allclose(output['distances'], [0, 55.59746332])
kthyng marked this conversation as resolved.
Show resolved Hide resolved