Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmark for climatology #1552

Merged
merged 8 commits into from
Oct 4, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions tests/geospatial/test_climatology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""This benchmark is a port of the climatology computation implemented in
https://github.com/google-research/weatherbench2/blob/47d72575cf5e99383a09bed19ba989b718d5fe30/scripts/compute_climatology.py
with the parameters

FREQUENCY = "hourly"
HOUR_INTERVAL = 6
WINDOW_SIZE = 61
STATISTICS = ["mean"]
METHOD = "explicit"
"""

import numpy as np
import pytest
import xarray as xr
from coiled.credentials.google import CoiledShippedCredentials


def compute_hourly_climatology(
ds: xr.Dataset,
) -> xr.Dataset:
hours = xr.DataArray(range(0, 24, 6), dims=["hour"])
window_weights = create_window_weights(61)
return xr.concat(
[compute_rolling_mean(select_hour(ds, hour), window_weights) for hour in hours],
dim=hours,
)


def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dataset:
window_size = len(window_weights)
half_window_size = window_size // 2 # For padding
stacked = xr.concat(
[
replace_time_with_doy(ds.sel(time=str(y)))
for y in np.unique(ds.time.dt.year)
],
dim="year",
)
stacked = stacked.fillna(stacked.sel(dayofyear=365))
stacked = stacked.pad(pad_width={"dayofyear": half_window_size}, mode="wrap")
stacked = stacked.rolling(dayofyear=window_size, center=True).construct("window")
rolling = stacked.weighted(window_weights).mean(dim=("window", "year"))
return rolling.isel(dayofyear=slice(half_window_size, -half_window_size))


def create_window_weights(window_size: int) -> xr.DataArray:
"""Create linearly decaying window weights."""
assert window_size % 2 == 1, "Window size must be odd."
half_window_size = window_size // 2
window_weights = np.concatenate(
[
np.linspace(0, 1, half_window_size + 1),
np.linspace(1, 0, half_window_size + 1)[1:],
]
)
window_weights = window_weights / window_weights.mean()
window_weights = xr.DataArray(window_weights, dims=["window"])
return window_weights


def replace_time_with_doy(ds: xr.Dataset) -> xr.Dataset:
"""Replace time coordinate with days of year."""
return ds.assign_coords({"time": ds.time.dt.dayofyear}).rename(
{"time": "dayofyear"}
)


def select_hour(ds: xr.Dataset, hour: int) -> xr.Dataset:
"""Select given hour of day from Datset."""
# Select hour
ds = ds.isel(time=ds.time.dt.hour == hour)
# Adjust time dimension
ds = ds.assign_coords({"time": ds.time.astype("datetime64[D]")})
return ds


@pytest.mark.client("compute_climatology")
jrbourbeau marked this conversation as resolved.
Show resolved Hide resolved
def test_compute_climatology(client, gcs_url, scale):
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
)

if scale == "small":
# 101.83 GiB (small)
time_range = slice("2020-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
elif scale == "medium":
# 2.12 TiB (medium)
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature"]
else:
# 4.24 TiB (large)
# This currently doesn't complete successfully.
time_range = slice("1959-01-01", "2022-12-31")
variables = ["sea_surface_temperature", "snow_depth"]
ds = ds[variables].sel(time=time_range)

ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims])
pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"}

working = ds.chunk(pencil_chunks)
hours = xr.DataArray(range(0, 24, 6), dims=["hour"])
daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"])
template = (
working.isel(time=0)
.drop_vars("time")
.expand_dims(hour=hours, dayofyear=daysofyear)
.assign_coords(hour=hours, dayofyear=daysofyear)
)
working = working.map_blocks(compute_hourly_climatology, template=template)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if people typically compute climatology with complex-enough calculations that warrant the use of map_blocks or typically use calculations that leverage Xarray's higher-level API.

Copy link
Contributor

@dcherian dcherian Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People use rechunk+map_blocks because it usually would not work otherwise. It's basically "manual fusion of tasks" to not overwhelm the scheduler. And convenient because you can use Xarray API in there.

I strongly recommend running a version of this without those steps. Ideally, these tricks should not be necessary for this calculation.

Notice that the workload rechunks back to pancakes... The chunking to pencils is purely to get this to work with dask+distributed

Copy link
Member Author

@hendrikmakait hendrikmakait Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, what I'm hearing you say is: The computations people generally execute for calculating climatology, be it mean, std, quantile (or SEEP?) can be expressed in Xarray's high-level API and do not fundamentally require a custom function mapped over all blocks.

In that case, I fully agree that we should avoid map_blocks but express these calculations in Xarray. We should probably benchmark both a computation based on a decomposable aggregation like mean and one based on a holistic aggregation like quantile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computations people generally execute for calculating climatology,

Yes. generally. The trick is to look at the function being mapped: in this case it's all Xarray ops and no for loops. That indicates the map_blocks is a workaround for some Xarray and/or dask/distributed inefficiency.

We should probably benchmark both a computation based on a decomposable aggregation like mean and one based on a holistic aggregation like quantile.

Yes totally agree with adding quantile. Dask will force a rechunk :)

PS: What is SEEP?

PPS: I notice this is based on a Xarray-beam workload, it may be that rechunk+map_blocks is the only way to express this in that framework.

Copy link
Member Author

@hendrikmakait hendrikmakait Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. generally. The trick is to look at the function being mapped: in this case it's all Xarray ops and no for loops. That indicates the map_blocks is a workaround for some Xarray and/or dask/distributed inefficiency.

Thanks for the clarification!

PS: What is SEEP?

All I can offer is https://github.com/google-research/weatherbench2/blob/47d72575cf5e99383a09bed19ba989b718d5fe30/scripts/compute_climatology.py#L147-L175. Maybe @shoyer can elaborate on SEEP and how common similar calculations are? Just looking at the code, this can probably be expressed in Xarray's high-level API as well.

PPS: I notice this is based on a Xarray-beam workload, it may be that rechunk+map_blocks is the only way to express this in that framework.

Yes, that's what it's based on, but not what we should be limited by.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience, less experienced users usually try without rechunk+map_blocks first, and then grudgingly resort to that. It'd be nice to model that interaction and improve things if possible.

Yes, rechunk+map_blocks is definitely not the obvious solution. It would be quite nice if we could automatically optimize high-level Xarray code to use a rechunk + map_blocks style implementation when it will be more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we won't beat hand-optimized code. However, a major goal of these benchmarks is to understand how we can improve the performance of code that relies on high-level APIs to improve the end-user experience. We all seem to agree that it would be great if users wouldn't have to worry about rewriting their high-level code in a rechunk+map_blocks-style code if it can be avoided. So, wherever possible, I think we should try and benchmark code without optimization and analyze where and how this falls apart.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll also add that we want rechunk+map_blocks to always succeed, even if slow, so that it can be a dependable fallback, particularly when the algorithms get more complicated. This would be a major improvement over the dask.array experience today.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian: Could you elaborate on that last comment? In what situations does that combination not succeed today?

I can imagine this failing if the block sizes or the temporary memory footprint of the mapped function end up being too large. However, this seems like a situation that is beyond the control of higher-level APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fail whenever rechunking fails... agree that after that, the blockwise is pretty stable, which is why people use this solution in combination with some kind of batching to avoid the rechunking failure. I don't have an example at hand. This is anecdotal experience from looking at lot of non-expert code at NCAR.


pancake_chunks = {
"hour": 1,
"dayofyear": 1,
"latitude": ds.chunks["latitude"],
"longitude": ds.chunks["longitude"],
}
result = working.chunk(pancake_chunks)
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
Loading