Skip to content

Commit

Permalink
align cultivated band name to official
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 6, 2024
1 parent e176cd8 commit 729e294
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class StatsCultivatedClass(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["cultivated_class"]
_measurements = ["cultivated"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -304,7 +304,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "cultivated_class"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
6 changes: 3 additions & 3 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,10 @@ def test_cultivated_reduce(
)
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
assert res["cultivated_class"].attrs["nodata"] == 255
assert res["cultivated_class"].data.dtype == "uint8"
assert res["cultivated"].attrs["nodata"] == 255
assert res["cultivated"].data.dtype == "uint8"
assert (
res["cultivated_class"].data.compute()
res["cultivated"].data.compute()
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()

Expand Down

0 comments on commit 729e294

Please sign in to comment.