diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 362f3b2a375..6bf222dcdc6 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -124,6 +124,11 @@ NAIP .. autoclass:: NAIP +NLCD +^^^^ + +.. autoclass:: NLCD + Open Buildings ^^^^^^^^^^^^^^ diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index 3c5a8e75e41..4c59b922f38 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -16,5 +16,6 @@ Dataset,Type,Source,Size (px),Resolution (m) `Landsat`_,Imagery,Landsat,"8,900x8,900",30 `L8 Biome`_,"Imagery, Masks",Landsat,"8,900x8,900","15, 30" `NAIP`_,Imagery,Aerial,"6,100x7,600",1 +`NLCD`_,Masks,Landsat,-,30 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- `Sentinel`_,Imagery,Sentinel,"10,000x10,000",10 diff --git a/tests/data/nlcd/data.py b/tests/data/nlcd/data.py new file mode 100644 index 00000000000..fa1c592ea29 --- /dev/null +++ b/tests/data/nlcd/data.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +SIZE = 32 + +np.random.seed(0) + +dir = "nlcd_{}_land_cover_l48_20210604" + +years = [2011, 2019] + +wkt = """ +PROJCS["Albers Conical Equal Area", + GEOGCS["WGS 84", + DATUM["WGS_1984", + SPHEROID["WGS 84",6378137,298.257223563, + AUTHORITY["EPSG","7030"]], + AUTHORITY["EPSG","6326"]], + PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]], + UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]], + AUTHORITY["EPSG","4326"]], + PROJECTION["Albers_Conic_Equal_Area"], + PARAMETER["latitude_of_center",23], + PARAMETER["longitude_of_center",-96], + PARAMETER["standard_parallel_1",29.5], + PARAMETER["standard_parallel_2",45.5], + PARAMETER["false_easting",0], + PARAMETER["false_northing",0], + UNIT["meters",1], + AXIS["Easting",EAST], + AXIS["Northing",NORTH]] +""" + + +def create_file(path: str, dtype: str): + """Create the testing file.""" + profile = { + "driver": "GTiff", + "dtype": dtype, + "count": 1, + "crs": CRS.from_wkt(wkt), + "transform": Affine(30.0, 0.0, -2493045.0, 0.0, -30.0, 3310005.0), + "height": SIZE, + "width": SIZE, + "compress": "lzw", + "predictor": 2, + } + + allowed_values = [0, 11, 12, 21, 22, 23, 24, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95] + + Z = np.random.choice(allowed_values, size=(SIZE, SIZE)) + + with rasterio.open(path, "w", **profile) as src: + src.write(Z, 1) + + +if __name__ == "__main__": + for year in years: + year_dir = dir.format(year) + # Remove old data + if os.path.isdir(year_dir): + shutil.rmtree(year_dir) + + os.makedirs(os.path.join(os.getcwd(), year_dir)) + + zip_filename = year_dir + ".zip" + filename = year_dir + ".img" + create_file(os.path.join(year_dir, filename), dtype="int8") + + # Compress data + shutil.make_archive(year_dir, "zip", ".", year_dir) + + # Compute checksums + with open(zip_filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{zip_filename}: {md5}") diff --git a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip new file mode 100644 index 00000000000..6e5ad22538f Binary files /dev/null and b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip differ diff --git a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img new file mode 100644 index 00000000000..aa0e0f7df1a Binary files /dev/null and b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img differ diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip new file mode 100644 index 00000000000..0021cc86961 Binary files /dev/null and b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip differ diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img new file mode 100644 index 00000000000..231f010d997 Binary files /dev/null and b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img differ diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py new file mode 100644 index 00000000000..e891eeca9f6 --- /dev/null +++ b/tests/datasets/test_nlcd.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestNLCD: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: + monkeypatch.setattr(torchgeo.datasets.nlcd, "download_url", download_url) + + md5s = { + 2011: "99546a3b89a0dddbe4e28e661c79984e", + 2019: "a4008746f15720b8908ddd357a75fded", + } + monkeypatch.setattr(NLCD, "md5s", md5s) + + url = os.path.join( + "tests", "data", "nlcd", "nlcd_{}_land_cover_l48_20210604.zip" + ) + monkeypatch.setattr(NLCD, "url", url) + monkeypatch.setattr(plt, "show", lambda *args: None) + root = str(tmp_path) + transforms = nn.Identity() + return NLCD( + root, + transforms=transforms, + download=True, + checksum=True, + years=[2011, 2019], + ) + + def test_getitem(self, dataset: NLCD) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_and(self, dataset: NLCD) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: NLCD) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_already_extracted(self, dataset: NLCD) -> None: + NLCD(root=dataset.root, download=True, years=[2019]) + + def test_already_downloaded(self, tmp_path: Path) -> None: + pathname = os.path.join( + "tests", "data", "nlcd", "nlcd_2019_land_cover_l48_20210604.zip" + ) + root = str(tmp_path) + shutil.copy(pathname, root) + NLCD(root, years=[2019]) + + def test_invalid_year(self, tmp_path: Path) -> None: + with pytest.raises( + AssertionError, + match="NLCD data product only exists for the following years:", + ): + NLCD(str(tmp_path), years=[1996]) + + def test_plot(self, dataset: NLCD) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: NLCD) -> None: + query = dataset.bounds + x = dataset[query] + x["prediction"] = x["mask"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + NLCD(str(tmp_path)) + + def test_invalid_query(self, dataset: NLCD) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 2eff64d8696..c626f8790ee 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -73,6 +73,7 @@ from .millionaid import MillionAID from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris +from .nlcd import NLCD from .openbuildings import OpenBuildings from .oscd import OSCD from .patternnet import PatternNet @@ -153,6 +154,7 @@ "Landsat8", "Landsat9", "NAIP", + "NLCD", "OpenBuildings", "Sentinel", "Sentinel1", diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py new file mode 100644 index 00000000000..8926cd9f885 --- /dev/null +++ b/torchgeo/datasets/nlcd.py @@ -0,0 +1,302 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NLCD dataset.""" + +import os +from typing import Any, Callable, Optional + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import BoundingBox, download_url, extract_archive + + +class NLCD(RasterDataset): + """National Land Cover Database (NLCD) dataset. + + The `NLCD dataset + `_ + is a land cover product that covers the United States and Puerto Rico. The current + implementation supports maps for the continental United States only. The product is + a joint effort between the United States Geological Survey + (`USGS `_) and the Multi-Resolution Land Characteristics + Consortium (`MRLC `_) which released the first product + in 2001 with new updates every five years since then. + + The dataset contains the following 17 classes: + + 0. Background + #. Open Water + #. Perennial Ice/Snow + #. Developed, Open Space + #. Developed, Low Intensity + #. Developed, Medium Intensity + #. Developed, High Intensity + #. Barren Land (Rock/Sand/Clay) + #. Deciduous Forest + #. Evergreen Forest + #. Mixed Forest + #. Shrub/Scrub + #. Grassland/Herbaceous + #. Pasture/Hay + #. Cultivated Crops + #. Woody Wetlands + #. Emergent Herbaceous Wetlands + + Detailed descriptions of the classes can be found + `here `__. + + Dataset format: + + * single channel .img file with integer class labels + + If you use this dataset in your research, please use the corresponding citation: + + * 2001: https://doi.org/10.5066/P9MZGHLF + * 2006: https://doi.org/10.5066/P9HBR9V3 + * 2011: https://doi.org/10.5066/P97S2IID + * 2016: https://doi.org/10.5066/P96HHBIE + * 2019: https://doi.org/10.5066/P9KZCM54 + + .. versionadded:: 0.5 + """ # noqa: E501 + + filename_glob = "nlcd_*_land_cover_l48_20210604.img" + filename_regex = ( + r"nlcd_(?P\d{4})_land_cover_l48_(?P\d{8})\.img" + ) + zipfile_glob = "nlcd_*_land_cover_l48_20210604.zip" + date_format = "%Y" + is_image = False + + url = "https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip" + + md5s = { + 2001: "538166a4d783204764e3df3b221fc4cd", + 2006: "67454e7874a00294adb9442374d0c309", + 2011: "ea524c835d173658eeb6fa3c8e6b917b", + 2016: "452726f6e3bd3f70d8ca2476723d238a", + 2019: "82851c3f8105763b01c83b4a9e6f3961", + } + + ordinal_label_map = { + 0: 0, + 11: 1, + 12: 2, + 21: 3, + 22: 4, + 23: 5, + 24: 6, + 31: 7, + 41: 8, + 42: 9, + 43: 10, + 52: 11, + 71: 12, + 81: 13, + 82: 14, + 90: 15, + 95: 16, + } + + cmap = { + 0: (0, 0, 0, 255), + 1: (70, 107, 159, 255), + 2: (209, 222, 248, 255), + 3: (222, 197, 197, 255), + 4: (217, 146, 130, 255), + 5: (235, 0, 0, 255), + 6: (171, 0, 0, 255), + 7: (179, 172, 159, 255), + 8: (104, 171, 95, 255), + 9: (28, 95, 44, 255), + 10: (181, 197, 143, 255), + 11: (204, 184, 121, 255), + 12: (223, 223, 194, 255), + 13: (220, 217, 57, 255), + 14: (171, 108, 40, 255), + 15: (184, 217, 235, 255), + 16: (108, 159, 184, 255), + } + + def __init__( + self, + root: str = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + years: list[int] = [2019], + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + years: list of years for which to use nlcd layer + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 after downloading files (may be slow) + + Raises: + FileNotFoundError: if no files are found in ``root`` + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + AssertionError: if ``year`` is invalid + """ + assert set(years).issubset(self.md5s.keys()), ( + "NLCD data product only exists for the following years: " + f"{list(self.md5s.keys())}." + ) + self.years = years + self.root = root + self.download = download + self.checksum = checksum + + self._verify() + + super().__init__(root, crs, res, transforms=transforms, cache=cache) + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + sample = super().__getitem__(query) + + mask = sample["mask"] + for k, v in self.ordinal_label_map.items(): + mask[mask == k] = v + + sample["mask"] = mask + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted files already exist + exists = [] + for year in self.years: + filename_year = self.filename_glob.replace("*", str(year)) + dirname_year = filename_year.split(".")[0] + pathname = os.path.join(self.root, dirname_year, filename_year) + if os.path.exists(pathname): + exists.append(True) + else: + exists.append(False) + + if all(exists): + return + + # Check if the zip files have already been downloaded + exists = [] + for year in self.years: + pathname = os.path.join( + self.root, self.zipfile_glob.replace("*", str(year)) + ) + if os.path.exists(pathname): + exists.append(True) + self._extract() + else: + exists.append(False) + + if all(exists): + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + for year in self.years: + download_url( + self.url.format(year), + self.root, + md5=self.md5s[year] if self.checksum else None, + ) + + def _extract(self) -> None: + """Extract the dataset.""" + for year in self.years: + zipfile_name = self.zipfile_glob.replace("*", str(year)) + pathname = os.path.join(self.root, zipfile_name) + extract_archive(pathname, self.root) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + mask = sample["mask"].squeeze().numpy() + ncols = 1 + + plt_cmap = ListedColormap( + np.stack([np.array(val) / 255 for val in self.cmap.values()], axis=0) + ) + + showing_predictions = "prediction" in sample + if showing_predictions: + pred = sample["prediction"].squeeze().numpy() + ncols = 2 + + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False + ) + + axs[0, 0].imshow(mask, cmap=plt_cmap) + axs[0, 0].axis("off") + + if show_titles: + axs[0, 0].set_title("Mask") + + if showing_predictions: + axs[0, 1].imshow(pred, cmap=plt_cmap) + axs[0, 1].axis("off") + if show_titles: + axs[0, 1].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig