Skip to content

Commit

Permalink
Merge pull request JiaweiZhuang#49 from Ouranosinc/use-cfxarray
Browse files Browse the repository at this point in the history
Use cf-xarray
  • Loading branch information
huard authored Nov 27, 2020
2 parents 70fa826 + d0a5379 commit b9af3da
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 42 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ Breaking changes
~~~~~~~~~~~~~~~~
* Deprecate `esmf_grid` in favor of `Grid.from_xarray`
* Deprecate `esmf_locstream` in favor of `LocStream.from_xarray`
* Installation requires numpy>=1.16
* Installation requires numpy>=1.16 and cf-xarray>=0.3.1

New features
~~~~~~~~~~~~
* Create `ESMF.Mesh` objects from `shapely.polygons` (:pull:`24`). By `Pascal Bourgault <https://github.com/aulemahal>`_
* New class `SpatialAverager` offers user-friendly mechanism to average a 2-D field over a polygon. Includes support to handle interior holes and multi-part geometries. (:pull:`24`) By `Pascal Bourgault <https://github.com/aulemahal>`_
* Automatic detection of coordinates and computation of vertices based on cf-xarray. (:pull:`49`) By `Pascal Bourgault <https://github.com/aulemahal>`_

Bug fixes
~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ dependencies:
- scipy
- matplotlib
- cartopy
- cf_xarray>=0.3.1
- pip:
- xesmf==0.2.2
1 change: 1 addition & 0 deletions ci/environment-upstream-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dependencies:
- scipy
- pip:
- git+https://github.com/pydata/xarray.git
- git+https://github.com/xarray-contrib/cf-xarray.git
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: xesmf
channels:
- conda-forge
dependencies:
- cf_xarray>=0.3.1
- codecov
- dask
- esmpy
Expand Down
6 changes: 5 additions & 1 deletion doc/notebooks/Spatial_Averaging.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,11 @@
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extend-ignore = E203,E501,E402,W605

[isort]
known_first_party=xesmf
known_third_party=ESMF,dask,numpy,pkg_resources,pytest,scipy,setuptools,shapely,xarray
known_third_party=ESMF,cf_xarray,dask,numpy,pkg_resources,pytest,scipy,setuptools,shapely,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
if on_rtd:
INSTALL_REQUIRES = []
else:
INSTALL_REQUIRES = ['esmpy>=8.0.0', 'xarray <= 0.16.0', 'numpy >=1.16', 'scipy', 'shapely']
INSTALL_REQUIRES = [
'esmpy>=8.0.0',
'xarray <= 0.16.0',
'numpy >=1.16',
'scipy',
'shapely',
'cf-xarray>=0.3.1',
]

CLASSIFIERS = [
'Development Status :: 4 - Beta',
Expand Down
108 changes: 80 additions & 28 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import warnings

import cf_xarray as cfxr
import numpy as np
import scipy.sparse as sps
import xarray as xr
Expand Down Expand Up @@ -32,6 +33,46 @@ def as_2d_mesh(lon, lat):
return lon, lat


def _get_lon_lat(ds):
"""Return lon and lat extracted from ds."""
try:
lon = ds.cf['longitude']
lat = ds.cf['latitude']
except (KeyError, AttributeError):
# KeyError if cfxr doesn't detect the coords
# AttributeError if ds is a dict
lon = ds['lon']
lat = ds['lat']

return lon, lat


def _get_lon_lat_bounds(ds):
"""Return bounds of lon and lat extracted from ds."""
if 'lat_b' in ds and 'lon_b' in ds:
# Old way.
return ds['lon_b'], ds['lat_b']
# else : cf-xarray way
try:
lon_bnds = ds.cf.get_bounds('longitude')
lat_bnds = ds.cf.get_bounds('latitude')
except KeyError: # bounds are not already present
if ds.cf['longitude'].ndim > 1:
# We cannot infer 2D bounds, raise KeyError as custom "lon_b" is missing.
raise KeyError('lon_b')
lon_name = ds.cf['longitude'].name
lat_name = ds.cf['latitude'].name
ds2 = ds.cf.add_bounds([lon_name, lat_name])
lon_bnds = ds2.cf.get_bounds('longitude')
lat_bnds = ds2.cf.get_bounds('latitude')

# Convert from CF bounds to xESMF bounds.
# order=None is because we don't want to assume the dimension order for 2D bounds.
lon_b = cfxr.bounds_to_vertices(lon_bnds, 'bounds', order=None)
lat_b = cfxr.bounds_to_vertices(lat_bnds, 'bounds', order=None)
return lon_b, lat_b


def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
"""
Convert xarray DataSet or dictionary to ESMF.Grid object.
Expand All @@ -58,9 +99,9 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
"""

# use np.asarray(dr) instead of dr.values, so it also works for dictionary
lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])
lon, lat = as_2d_mesh(lon, lat)

lon, lat = _get_lon_lat(ds)
lon, lat = as_2d_mesh(np.asarray(lon), np.asarray(lat))

if 'mask' in ds:
mask = np.asarray(ds['mask'])
Expand All @@ -74,9 +115,8 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None)

if need_bounds:
lon_b = np.asarray(ds['lon_b'])
lat_b = np.asarray(ds['lat_b'])
lon_b, lat_b = as_2d_mesh(lon_b, lat_b)
lon_b, lat_b = _get_lon_lat_bounds(ds)
lon_b, lat_b = as_2d_mesh(np.asarray(lon_b), np.asarray(lat_b))
add_corner(grid, lon_b.T, lat_b.T)

return grid, lon.shape
Expand All @@ -97,8 +137,8 @@ def ds_to_ESMFlocstream(ds):
"""

lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])
lon, lat = _get_lon_lat(ds)
lon, lat = np.asarray(lon), np.asarray(lat)

if len(lon.shape) > 1:
raise ValueError('lon can only be 1d')
Expand Down Expand Up @@ -514,15 +554,21 @@ def __init__(
Parameters
----------
ds_in, ds_out : xarray DataSet, or dictionary
Contain input and output grid coordinates. Look for variables
``lon``, ``lat``, optionally ``lon_b``, ``lat_b`` for
conservative methods, and ``mask``. Note that for `mask`,
the ESMF convention is used, where masked values are identified
by 0, and non-masked values by 1.
Contain input and output grid coordinates.
All variables that the cf-xarray accessor understand are accepted.
Otherwise, look for ``lon``, ``lat``,
optionally ``lon_b``, ``lat_b`` for conservative methods,
and ``mask``. Note that for `mask`, the ESMF convention is used,
where masked values are identified by 0, and non-masked values by 1.
For conservative methods, if bounds are not present, they will be
computed using `cf-xarray` (only 1D coordinates are currently supported).
Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids,
or 2D (n_y, n_x) for general curvilinear grids.
Shape of bounds should be (n+1,) or (n_y+1, n_x+1).
CF-bounds (shape (n, 2) or (n, m, 4)) are also accepted if they are
accessible through the cf-xarray accessor.
If either dataset includes a 2d mask variable, that will also be
used to inform the regridding.
Expand Down Expand Up @@ -616,21 +662,21 @@ def __init__(
super().__init__(grid_in, grid_out, method, **kwargs)

# record output grid and metadata
self._lon_out = np.asarray(ds_out['lon'])
self._lat_out = np.asarray(ds_out['lat'])
lon_out, lat_out = _get_lon_lat(ds_out)
self._lon_out, self._lat_out = np.asarray(lon_out), np.asarray(lat_out)

if self._lon_out.ndim == 2:
try:
self.lon_dim = self.lat_dim = ds_out['lon'].dims
self.lon_dim = self.lat_dim = lon_out.dims
except Exception:
self.lon_dim = self.lat_dim = ('y', 'x')

self.out_horiz_dims = self.lon_dim

elif self._lon_out.ndim == 1:
try:
(self.lon_dim,) = ds_out['lon'].dims
(self.lat_dim,) = ds_out['lat'].dims
(self.lon_dim,) = lon_out.dims
(self.lat_dim,) = lat_out.dims
except Exception:
self.lon_dim = 'lon'
self.lat_dim = 'lat'
Expand Down Expand Up @@ -756,16 +802,22 @@ def __init__(

# Create an output locstream so that the regridder knows the output shape and coords.
# Latitude and longitude coordinates are the polygon centroid.
lon_out, lat_out = _get_lon_lat(ds_in)
if hasattr(lon_out, 'name'):
self._lon_out_name = lon_out.name
self._lat_out_name = lat_out.name
else:
self._lon_out_name = 'lon'
self._lat_out_name = 'lat'

poly_centers = [poly.centroid.xy for poly in polys]
ds_out = xr.Dataset(
data_vars={
'lon': (('poly',), [c[0][0] for c in poly_centers]),
'lat': (('poly',), [c[1][0] for c in poly_centers]),
}
)
self._lon_out = np.asarray([c[0][0] for c in poly_centers])
self._lat_out = np.asarray([c[1][0] for c in poly_centers])

# We put names 'lon' and 'lat' so ds_to_ESMFlocstream finds them easily.
# _lon_out_name and _lat_out_name are used on the output anyway.
ds_out = {'lon': self._lon_out, 'lat': self._lat_out}
locstream_out, shape_out = ds_to_ESMFlocstream(ds_out)
self._lon_out = ds_out.lon
self._lat_out = ds_out.lat

# BaseRegridder with custom-computed weights and dummy out grid
super().__init__(
Expand Down Expand Up @@ -835,7 +887,7 @@ def _format_xroutput(self, out, new_dims=None):

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
out.coords['lon'] = xr.DataArray(self._lon_out, dims=(self.geom_dim_name,))
out.coords['lat'] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,))
out.coords[self._lon_out_name] = xr.DataArray(self._lon_out, dims=(self.geom_dim_name,))
out.coords[self._lat_out_name] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,))
out.attrs['regrid_method'] = self.method
return out
36 changes: 28 additions & 8 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings

import cf_xarray # noqa
import dask
import numpy as np
import pytest
Expand Down Expand Up @@ -244,17 +245,36 @@ def test_regrid_with_1d_grid():
assert_equal(dr_out['lat'].values, ds_out_1d['lat'].values)


def test_regrid_with_1d_grid_infer_bounds():
ds_in_1d = ds_2d_to_1d(ds_in).rename(x='lon', y='lat')
ds_out_1d = ds_2d_to_1d(ds_out).rename(x='lon', y='lat')

regridder = xe.Regridder(ds_in_1d, ds_out_1d, 'conservative', periodic=True)

dr_out = regridder(ds_in['data'])

# compare with provided-bounds solution
dr_exp = xe.Regridder(ds_in, ds_out, 'conservative', periodic=True)(ds_in['data'])

assert_allclose(dr_out, dr_exp)


# TODO: consolidate (regrid method, input data types) combination
# using pytest fixtures and parameterization


def test_regrid_dataarray():
@pytest.mark.parametrize('use_cfxr', [True, False])
def test_regrid_dataarray(use_cfxr):
# xarray.DataArray containing in-memory numpy array
if use_cfxr:
ds_in2 = ds_in.rename(lat='Latitude', lon='Longitude')
else:
ds_in2 = ds_in

regridder = xe.Regridder(ds_in, ds_out, 'conservative')
regridder = xe.Regridder(ds_in2, ds_out, 'conservative')

outdata = regridder(ds_in['data'].values) # pure numpy array
dr_out = regridder(ds_in['data']) # xarray DataArray
outdata = regridder(ds_in2['data'].values) # pure numpy array
dr_out = regridder(ds_in2['data']) # xarray DataArray

# DataArray and numpy array should lead to the same result
assert_equal(outdata, dr_out.values)
Expand All @@ -268,18 +288,18 @@ def test_regrid_dataarray():
assert_equal(dr_out['lon'].values, ds_out['lon'].values)

# test broadcasting
dr_out_4D = regridder(ds_in['data4D'])
dr_out_4D = regridder(ds_in2['data4D'])

# data over broadcasting dimensions should agree
assert_almost_equal(
ds_in['data4D'].values.mean(axis=(2, 3)),
ds_in2['data4D'].values.mean(axis=(2, 3)),
dr_out_4D.values.mean(axis=(2, 3)),
decimal=10,
)

# check metadata
xr.testing.assert_identical(dr_out_4D['time'], ds_in['time'])
xr.testing.assert_identical(dr_out_4D['lev'], ds_in['lev'])
xr.testing.assert_identical(dr_out_4D['time'], ds_in2['time'])
xr.testing.assert_identical(dr_out_4D['lev'], ds_in2['lev'])


def test_regrid_dataarray_to_locstream():
Expand Down
4 changes: 2 additions & 2 deletions xesmf/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat):

ds = xr.Dataset(
coords={
'lon': (['y', 'x'], lon),
'lat': (['y', 'x'], lat),
'lon': (['y', 'x'], lon, {'standard_name': 'longitude'}),
'lat': (['y', 'x'], lat, {'standard_name': 'latitude'}),
'lon_b': (['y_b', 'x_b'], lon_b),
'lat_b': (['y_b', 'x_b'], lat_b),
}
Expand Down

0 comments on commit b9af3da

Please sign in to comment.