diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif new file mode 100644 index 00000000000..042c7cb084e Binary files /dev/null and b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif differ diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip new file mode 100644 index 00000000000..9a36194335b Binary files /dev/null and b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip differ diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py new file mode 100644 index 00000000000..c89c4440807 --- /dev/null +++ b/tests/datasets/test_chesapeake.py @@ -0,0 +1,78 @@ +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset +from torchgeo.transforms import Identity + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestChesapeake13: + @pytest.fixture + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> Chesapeake13: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + md5 = "8363639b51651cc1de2bdbeb2be4f9b1" + monkeypatch.setattr(Chesapeake13, "md5", md5) # type: ignore[attr-defined] + url = os.path.join( + "tests", "data", "chesapeake", "BAYWIDE", "Baywide_13Class_20132014.zip" + ) + monkeypatch.setattr(Chesapeake13, "url", url) # type: ignore[attr-defined] + monkeypatch.setattr( # type: ignore[attr-defined] + plt, "show", lambda *args: None + ) + (tmp_path / "chesapeake" / "BAYWIDE").mkdir(parents=True) + root = str(tmp_path) + transforms = Identity() + return Chesapeake13(root, transforms=transforms, download=True, checksum=True) + + def test_getitem(self, dataset: Chesapeake13) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["masks"], torch.Tensor) + + def test_add(self, dataset: Chesapeake13) -> None: + ds = dataset + dataset + assert isinstance(ds, ZipDataset) + + def test_already_downloaded(self, dataset: Chesapeake13) -> None: + Chesapeake13(root=os.path.dirname(os.path.dirname(dataset.root)), download=True) + + def test_plot(self, dataset: Chesapeake13) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x["masks"]) + + def test_url(self) -> None: + ds = Chesapeake13(os.path.join("tests", "data")) + assert "cicwebresources.blob.core.windows.net" in ds.url + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + Chesapeake13(str(tmp_path)) + + def test_invalid_query(self, dataset: Chesapeake13) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* is not within bounds of the index:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 69ad6e86ab3..fbc350480eb 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -2,6 +2,18 @@ from .benin_cashews import BeninSmallHolderCashews from .cdl import CDL +from .chesapeake import ( + Chesapeake, + Chesapeake7, + Chesapeake13, + ChesapeakeDC, + ChesapeakeDE, + ChesapeakeMD, + ChesapeakeNY, + ChesapeakePA, + ChesapeakeVA, + ChesapeakeWV, +) from .cowc import COWC, COWCCounting, COWCDetection from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCycloneWindEstimation @@ -31,6 +43,16 @@ "BoundingBox", "CDL", "collate_dict", + "Chesapeake", + "Chesapeake7", + "Chesapeake13", + "ChesapeakeDC", + "ChesapeakeDE", + "ChesapeakeMD", + "ChesapeakeNY", + "ChesapeakePA", + "ChesapeakeVA", + "ChesapeakeWV", "COWC", "COWCCounting", "COWCDetection", diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py new file mode 100644 index 00000000000..12e0ef5b11e --- /dev/null +++ b/torchgeo/datasets/chesapeake.py @@ -0,0 +1,338 @@ +"""Chesapeake Bay High-Resolution Land Cover Project dataset.""" + +import abc +import os +import sys +from typing import Any, Callable, Dict, Optional + +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import torch +from rasterio.crs import CRS +from rasterio.vrt import WarpedVRT +from rtree.index import Index, Property +from torch import Tensor + +from .geo import GeoDataset +from .utils import BoundingBox, check_integrity, download_and_extract_archive + +_crs = CRS.from_wkt( + """ +PROJCS["USA_Contiguous_Albers_Equal_Area_Conic_USGS_version", + GEOGCS["NAD83", + DATUM["North_American_Datum_1983", + SPHEROID["GRS 1980",6378137,298.257222101004, + AUTHORITY["EPSG","7019"]], + AUTHORITY["EPSG","6269"]], + PRIMEM["Greenwich",0], + UNIT["degree",0.0174532925199433, + AUTHORITY["EPSG","9122"]], + AUTHORITY["EPSG","4269"]], + PROJECTION["Albers_Conic_Equal_Area"], + PARAMETER["latitude_of_center",23], + PARAMETER["longitude_of_center",-96], + PARAMETER["standard_parallel_1",29.5], + PARAMETER["standard_parallel_2",45.5], + PARAMETER["false_easting",0], + PARAMETER["false_northing",0], + UNIT["metre",1, + AUTHORITY["EPSG","9001"]], + AXIS["Easting",EAST], + AXIS["Northing",NORTH]] +""" +) + + +class Chesapeake(GeoDataset, abc.ABC): + """Abstract base class for all Chesapeake datasets. + + `Chesapeake Bay High-Resolution Land Cover Project + `_ + dataset. + + This dataset was collected by the Chesapeake Conservancy's Conservation Innovation + Center (CIC) in partnership with the University of Vermont and WorldView Solutions, + Inc. It consists of one-meter resolution land cover information for the Chesapeake + Bay watershed (~100,000 square miles of land). + + For more information, see: + + * `User Guide + `_ + * `Class Descriptions + `_ + * `Accuracy Assessment + `_ + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1109/cvpr.2019.01301 + """ + + @property + @abc.abstractmethod + def base_folder(self) -> str: + """Subdirectory to find/store dataset in.""" + + @property + @abc.abstractmethod + def filename(self) -> str: + """Filename to find/store dataset in.""" + + @property + @abc.abstractmethod + def zipfile(self) -> str: + """Name of zipfile in download URL.""" + + @property + @abc.abstractmethod + def md5(self) -> str: + """MD5 checksum to verify integrity of dataset.""" + + @property + def url(self) -> str: + """URL to download dataset from.""" + url = "https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover" + url += f"/{self.base_folder}/{self.zipfile}" + return url + + def __init__( + self, + root: str, + crs: CRS = _crs, + transforms: Optional[Callable[[Any], Any]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Chesapeake dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to project to + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + """ + self.root = os.path.join(root, "chesapeake", self.base_folder) + self.crs = crs + self.transforms = transforms + self.checksum = checksum + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + + "You can use download=True to download it" + ) + + # Create an R-tree to index the dataset + self.index = Index(interleaved=False, properties=Property(dimension=3)) + filename = os.path.join(self.root, self.filename) + with rasterio.open(filename) as src: + with rasterio.open(filename) as src: + cmap = src.colormap(1) + with WarpedVRT(src, crs=self.crs) as vrt: + minx, miny, maxx, maxy = vrt.bounds + mint = 0 + maxt = sys.maxsize + coords = (minx, maxx, miny, maxy, mint, maxt) + self.index.insert(0, coords, filename) + self.cmap = np.array([cmap[i] for i in range(len(cmap))]) + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + """Retrieve labels and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of labels and metadata at that index + + Raises: + IndexError: if query is not within bounds of the index + """ + if not query.intersects(self.bounds): + raise IndexError( + f"query: {query} is not within bounds of the index: {self.bounds}" + ) + + hits = self.index.intersection(query, objects=True) + filename = next(hits).object + with rasterio.open(filename) as src: + with WarpedVRT(src, crs=self.crs) as vrt: + window = rasterio.windows.from_bounds( + query.minx, + query.miny, + query.maxx, + query.maxy, + transform=vrt.transform, + ) + masks = vrt.read(window=window) + sample = { + "masks": torch.tensor(masks), # type: ignore[attr-defined] + "crs": self.crs, + "bbox": query, + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _check_integrity(self) -> bool: + """Check integrity of dataset. + + Returns: + True if dataset MD5s match, else False + """ + integrity: bool = check_integrity( + os.path.join(self.root, self.zipfile), + self.md5 if self.checksum else None, + ) + return integrity + + def _download(self) -> None: + """Download the dataset and extract it.""" + if self._check_integrity(): + print("Files already downloaded and verified") + return + + download_and_extract_archive( + self.url, + self.root, + filename=self.zipfile, + md5=self.md5, + ) + + def plot(self, image: Tensor) -> None: + """Plot an image on a map. + + Args: + image: the image to plot + """ + # Convert from class labels to RGBA values + array = image.squeeze().numpy() + array = self.cmap[array] + + # Plot the image + ax = plt.axes() + ax.imshow(array, origin="lower") + ax.axis("off") + plt.show() + + +class Chesapeake7(Chesapeake): + """Complete 7-class dataset. + + This version of the dataset is composed of 7 classes: + + 0. No Data: Background values + 1. Barren: Areas devoid of vegetation consisting of natural earthen material + 2. Impervious Roads: Impervious surfaces that are used for transportation + 3. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height + 4. Low Vegetation: Plant material less than 2 meters in height including lawns + 5. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height + 6. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces + 7. Water: All areas of open water including ponds, rivers, and lakes + """ + + # TODO: make sure these class numbers are correct + + base_folder = "BAYWIDE" + filename = "Baywide_7Class_20132014.tif" + zipfile = "Baywide_7Class_20132014.zip" + md5 = "" + + +class Chesapeake13(Chesapeake): + """Complete 13-class dataset. + + This version of the dataset is composed of 13 classes: + + 0. No Data: Background values + 1. Water: All areas of open water including ponds, rivers, and lakes + 2. Wetlands: Low vegetation areas located along marine or estuarine regions + 3. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height + 4. Shrubland: Heterogeneous woody vegetation including shrubs and young trees + 5. Low Vegetation: Plant material less than 2 meters in height including lawns + 6. Barren: Areas devoid of vegetation consisting of natural earthen material + 7. Structures: Human-constructed objects made of impervious materials + 8. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height + 9. Impervious Roads: Impervious surfaces that are used for transportation + 10. Tree Canopy over Structures: Tree cover overlapping impervious structures + 11. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces + 12. Tree Canopy over Impervious Roads: Tree cover overlapping impervious roads + 13. Aberdeen Proving Ground: U.S. Army facility with no labels + """ + + base_folder = "BAYWIDE" + filename = "Baywide_13Class_20132014.tif" + zipfile = "Baywide_13Class_20132014.zip" + md5 = "7e51118923c91e80e6e268156d25a4b9" + + +class ChesapeakeDC(Chesapeake): + """This subset of the dataset contains data only for Washington, D.C.""" + + base_folder = "DC" + filename = os.path.join("DC_11001", "DC_11001.img") + zipfile = "DC_11001.zip" + md5 = "ed06ba7570d2955e8857d7d846c53b06" + + +class ChesapeakeDE(Chesapeake): + """This subset of the dataset contains data only for Delaware.""" + + base_folder = "DE" + filename = "DE_STATEWIDE.tif" + zipfile = "_DE_STATEWIDE.zip" + md5 = "5e12eff3b6950c01092c7e480b38e544" + + +class ChesapeakeMD(Chesapeake): + """This subset of the dataset contains data only for Maryland.""" + + base_folder = "MD" + filename = "MD_STATEWIDE.tif" + zipfile = "_MD_STATEWIDE.zip" + md5 = "40c7cd697a887f2ffdb601b5c114e567" + + +class ChesapeakeNY(Chesapeake): + """This subset of the dataset contains data only for New York.""" + + base_folder = "NY" + filename = "NY_STATEWIDE.tif" + zipfile = "_NY_STATEWIDE.zip" + md5 = "1100078c526616454ef2e508affda915" + + +class ChesapeakePA(Chesapeake): + """This subset of the dataset contains data only for Pennsylvania.""" + + base_folder = "PA" + filename = "PA_STATEWIDE.tif" + zipfile = "_PA_STATEWIDE.zip" + md5 = "" + + +class ChesapeakeVA(Chesapeake): + """This subset of the dataset contains data only for Virginia.""" + + base_folder = "VA" + filename = "CIC2014_VA_STATEWIDE.tif" + zipfile = "_VA_STATEWIDE.zip" + md5 = "6f2c97deaf73bb3e1ea9b21bd7a3fc8e" + + +class ChesapeakeWV(Chesapeake): + """This subset of the dataset contains data only for West Virginia.""" + + base_folder = "WV" + filename = "WV_STATEWIDE.tif" + zipfile = "_WV_STATEWIDE.zip" + md5 = "350621ea293651fbc557a1c3e3c64cc3"