Skip to content

Commit

Permalink
Merge pull request #125 from GPlates/grid-compression
Browse files Browse the repository at this point in the history
Improve netCDF4 compression
  • Loading branch information
brmather committed Sep 17, 2024
2 parents e9bcbcd + b635a23 commit 982efe1
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 118 deletions.
228 changes: 136 additions & 92 deletions gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def realign_grid(array, lons, lats):
return array, lons, lats


def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None):
def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None, resize=None):
"""Read a `netCDF` (.nc) grid from a given `filename` and return its data as a
`MaskedArray`.
Expand Down Expand Up @@ -174,6 +174,10 @@ def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None)
If passed as `resample = (spacingX, spacingY)`, the given `netCDF` grid is resampled
with these x and y resolutions.
resize : tuple, optional, default=None
If passed as `resample = (resX, resY)`, the given `netCDF` grid is resized
to the number of columns (resX) and rows (resY).
Returns
-------
grid_z : MaskedArray
Expand Down Expand Up @@ -237,6 +241,18 @@ def find_label(keys, labels):
cdf_lon = cdf[key_lon][:]
cdf_lat = cdf[key_lat][:]

# fill missing values
if hasattr(cdf[key_z], 'missing_value') and np.issubdtype(cdf_grid.dtype, np.floating):
fill_value = cdf[key_z].missing_value
cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan

# convert to boolean array
if np.issubdtype(cdf_grid.dtype, np.integer):
unique_grid = np.unique(cdf_grid)
if len(unique_grid) == 2:
if (unique_grid == [0,1]).all():
cdf_grid = cdf_grid.astype(bool)

if realign:
# realign longitudes to -180/180 dateline
cdf_grid_z, cdf_lon, cdf_lat = realign_grid(cdf_grid, cdf_lon, cdf_lat)
Expand All @@ -246,25 +262,58 @@ def find_label(keys, labels):
# resample
if resample is not None:
spacingX, spacingY = resample
lon_grid = np.arange(cdf_lon.min(), cdf_lon.max() + spacingX, spacingX)
lat_grid = np.arange(cdf_lat.min(), cdf_lat.max() + spacingY, spacingY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# don't resample if already the same resolution
dX = np.diff(cdf_lon).mean()
dY = np.diff(cdf_lat).mean()

if spacingX != dX or spacingY != dY:
lon_grid = np.arange(cdf_lon.min(), cdf_lon.max() + spacingX, spacingX)
lat_grid = np.arange(cdf_lat.min(), cdf_lat.max() + spacingY, spacingY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# resize
if resize is not None:
resX, resY = resize

# don't resize if already the same shape
if resX != cdf_grid_z.shape[1] or resY != cdf_grid_z.shape[0]:
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
lon_grid = np.linspace(original_extent[0], original_extent[1], resX)
lat_grid = np.linspace(original_extent[2], original_extent[3], resY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)

cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# Fix grids with 9e36 as the fill value for nan.
# cdf_grid_z.fill_value = float('nan')
Expand All @@ -274,61 +323,9 @@ def find_label(keys, labels):
return cdf_grid_z, cdf_lon, cdf_lat
else:
return cdf_grid_z


def write_netcdf(filename, lons, lats, data):
"""Write geospatial data to a netCDF4 grid with a specified `filename`.
The latitude, longitude, and data variabels must be of the same size.
Parameters
----------
filename : str
The full path (including a filename and the ".nc" extension) to save the created netCDF4 file.
lons : 1D array
List of longitudinal coordinates to be written into a netCDF4 (.nc) file.
lats : 1D array
List of latitudinal coordinates to be written into a netCDF4 (.nc) file.
data : 1D array
List of data values at lon / lat coordinates to be written into a netCDF4 (.nc) file.
"""
import netCDF4

lons = np.asarray(lons)
lats = np.asarray(lats)
data = np.asarray(data)

with netCDF4.Dataset(filename, "w", driver=None) as cdf:
cdf.title = "Grid produced by gplately"
cdf.createDimension("lon", lons.size)
cdf_lon = cdf.createVariable("lon", lons.dtype, ("lon",), zlib=True)
cdf_lat = cdf.createVariable("lat", lats.dtype, ("lon",), zlib=True)
cdf_lon[:] = lons
cdf_lat[:] = lats

# Units for Geographic Grid type
cdf_lon.units = "degrees_east"
cdf_lon.standard_name = "lon"
cdf_lon.actual_range = [np.min(lons), np.max(lons)]
cdf_lat.units = "degrees_north"
cdf_lat.standard_name = "lat"
cdf_lat.actual_range = [np.min(lats), np.max(lats)]

cdf_data = cdf.createVariable("z", data.dtype, ("lon",), zlib=True)
# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
cdf_data.missing_value = np.nan
cdf_data.standard_name = "z"
# Ensure pygmt registers min and max z values properly
cdf_data.actual_range = [np.nanmin(data), np.nanmax(data)]

cdf_data[:] = data


def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
"""Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.

def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, fill_value=np.nan, **kwargs):
""" Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.
Notes
-----
Expand All @@ -352,28 +349,36 @@ def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
An ndarray grid containing data to be written into a `netCDF` (.nc) file. Note: Rows correspond to
the data's latitudes, while the columns correspond to the data's longitudes.
extent : 1D numpy array, default=[-180,180,-90,90]
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.
extent : list, default=[-180,180,-90,90]
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.
Returns
-------
A netCDF grid will be saved to the path specified in `filename`.
"""
import netCDF4
from gplately import __version__ as _version

if extent == 'global':
extent = [-180, 180, -90, 90]
else:
assert len(extent) == 4, "specify the [min lon, max lon, min lat, max lat]"

nrows, ncols = np.shape(grid)

lon_grid = np.linspace(extent[0], extent[1], ncols)
lat_grid = np.linspace(extent[2], extent[3], nrows)

with netCDF4.Dataset(filename, "w", driver=None) as cdf:
cdf.title = "Grid produced by gplately"
cdf.createDimension("lon", lon_grid.size)
cdf.createDimension("lat", lat_grid.size)
cdf_lon = cdf.createVariable("lon", lon_grid.dtype, ("lon",), zlib=True)
cdf_lat = cdf.createVariable("lat", lat_grid.dtype, ("lat",), zlib=True)
data_kwds = {'compression': 'zlib', 'complevel': 9}

with netCDF4.Dataset(filename, 'w', driver=None) as cdf:
cdf.title = "Grid produced by gplately " + str(_version)
cdf.createDimension('lon', lon_grid.size)
cdf.createDimension('lat', lat_grid.size)
cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), **data_kwds)
cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), **data_kwds)
cdf_lon[:] = lon_grid
cdf_lat[:] = lat_grid

Expand All @@ -385,15 +390,45 @@ def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
cdf_lat.standard_name = "lat"
cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]]

cdf_data = cdf.createVariable("z", grid.dtype, ("lat", "lon"), zlib=True)
# create container variable for CRS: lon/lat WGS84 datum
crso = cdf.createVariable('crs','i4')
crso.long_name = 'Lon/Lat Coords in WGS84'
crso.grid_mapping_name='latitude_longitude'
crso.longitude_of_prime_meridian = 0.0
crso.semi_major_axis = 6378137.0
crso.inverse_flattening = 298.257223563
crso.spatial_ref = """GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.01745329251994328,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]"""

# add more keyword arguments for quantizing data
if significant_digits:
# significant_digits needs to be >= 2 so that NaNs are preserved
data_kwds['significant_digits'] = max(2, int(significant_digits))
data_kwds['quantize_mode'] = 'GranularBitRound'

# boolean arrays need to be converted to integers
# no such thing as a mask on a boolean array
if grid.dtype is np.dtype(bool):
grid = grid.astype('i1')
fill_value = None

cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), **data_kwds)

# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
cdf_data.missing_value = np.nan
cdf_data.standard_name = "z"
# Ensure pygmt registers min and max z values properly
if fill_value is not None:
cdf_data.missing_value = fill_value

cdf_data.standard_name = 'z'

cdf_data.add_offset = 0.0
cdf_data.grid_mapping = 'crs'
cdf_data.set_auto_maskandscale(False)

# ensure min and max z values are properly registered
cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)]

cdf_data[:, :] = grid
# write data
cdf_data[:,:] = grid


class RegularGridInterpolator(_RGI):
Expand Down Expand Up @@ -1511,6 +1546,7 @@ def __init__(
extent="global",
realign=False,
resample=None,
resize=None,
time=0.0,
origin=None,
**kwargs,
Expand Down Expand Up @@ -1540,6 +1576,10 @@ def __init__(
Optionally resample grid, pass spacing in X and Y direction as a
2-tuple e.g. resample=(spacingX, spacingY).
resize : 2-tuple, optional
Optionally resample grid to X-columns, Y-rows as a
2-tuple e.g. resample=(resX, resY).
time : float, default: 0.0
The time step represented by the raster data. Used for raster
reconstruction.
Expand Down Expand Up @@ -1610,6 +1650,7 @@ def __init__(
return_grids=True,
realign=realign,
resample=resample,
resize=resize,
)
self._lons = lons
self._lats = lats
Expand All @@ -1631,6 +1672,9 @@ def __init__(
if (not isinstance(data, str)) and (resample is not None):
self.resample(*resample, inplace=True)

if (not isinstance(data, str)) and (resize is not None):
self.resize(*resize, inplace=True)

@property
def time(self):
"""The time step represented by the raster data."""
Expand Down Expand Up @@ -1973,10 +2017,10 @@ def fill_NaNs(self, inplace=False, return_array=False):
else:
return Raster(data, self.plate_reconstruction, self.extent, self.time)

def save_to_netcdf4(self, filename):
def save_to_netcdf4(self, filename, significant_digits=None, fill_value=np.nan):
"""Saves the grid attributed to the `Raster` object to the given `filename` (including
the ".nc" extension) in netCDF4 format."""
write_netcdf_grid(str(filename), self.data, self.extent)
write_netcdf_grid(str(filename), self.data, self.extent, significant_digits, fill_value)

def reconstruct(
self,
Expand Down
Loading

0 comments on commit 982efe1

Please sign in to comment.