diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif
new file mode 100644
index 00000000000..042c7cb084e
Binary files /dev/null and b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif differ
diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip
new file mode 100644
index 00000000000..9a36194335b
Binary files /dev/null and b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip differ
diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py
new file mode 100644
index 00000000000..c89c4440807
--- /dev/null
+++ b/tests/datasets/test_chesapeake.py
@@ -0,0 +1,78 @@
+import os
+import shutil
+from pathlib import Path
+from typing import Generator
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+from _pytest.fixtures import SubRequest
+from pytest import MonkeyPatch
+from rasterio.crs import CRS
+
+import torchgeo.datasets.utils
+from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset
+from torchgeo.transforms import Identity
+
+
+def download_url(url: str, root: str, *args: str) -> None:
+ shutil.copy(url, root)
+
+
+class TestChesapeake13:
+ @pytest.fixture
+ def dataset(
+ self,
+ monkeypatch: Generator[MonkeyPatch, None, None],
+ tmp_path: Path,
+ request: SubRequest,
+ ) -> Chesapeake13:
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ torchgeo.datasets.utils, "download_url", download_url
+ )
+ md5 = "8363639b51651cc1de2bdbeb2be4f9b1"
+ monkeypatch.setattr(Chesapeake13, "md5", md5) # type: ignore[attr-defined]
+ url = os.path.join(
+ "tests", "data", "chesapeake", "BAYWIDE", "Baywide_13Class_20132014.zip"
+ )
+ monkeypatch.setattr(Chesapeake13, "url", url) # type: ignore[attr-defined]
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ plt, "show", lambda *args: None
+ )
+ (tmp_path / "chesapeake" / "BAYWIDE").mkdir(parents=True)
+ root = str(tmp_path)
+ transforms = Identity()
+ return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
+
+ def test_getitem(self, dataset: Chesapeake13) -> None:
+ x = dataset[dataset.bounds]
+ assert isinstance(x, dict)
+ assert isinstance(x["crs"], CRS)
+ assert isinstance(x["masks"], torch.Tensor)
+
+ def test_add(self, dataset: Chesapeake13) -> None:
+ ds = dataset + dataset
+ assert isinstance(ds, ZipDataset)
+
+ def test_already_downloaded(self, dataset: Chesapeake13) -> None:
+ Chesapeake13(root=os.path.dirname(os.path.dirname(dataset.root)), download=True)
+
+ def test_plot(self, dataset: Chesapeake13) -> None:
+ query = dataset.bounds
+ x = dataset[query]
+ dataset.plot(x["masks"])
+
+ def test_url(self) -> None:
+ ds = Chesapeake13(os.path.join("tests", "data"))
+ assert "cicwebresources.blob.core.windows.net" in ds.url
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ Chesapeake13(str(tmp_path))
+
+ def test_invalid_query(self, dataset: Chesapeake13) -> None:
+ query = BoundingBox(0, 0, 0, 0, 0, 0)
+ with pytest.raises(
+ IndexError, match="query: .* is not within bounds of the index:"
+ ):
+ dataset[query]
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 69ad6e86ab3..fbc350480eb 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -2,6 +2,18 @@
from .benin_cashews import BeninSmallHolderCashews
from .cdl import CDL
+from .chesapeake import (
+ Chesapeake,
+ Chesapeake7,
+ Chesapeake13,
+ ChesapeakeDC,
+ ChesapeakeDE,
+ ChesapeakeMD,
+ ChesapeakeNY,
+ ChesapeakePA,
+ ChesapeakeVA,
+ ChesapeakeWV,
+)
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
@@ -31,6 +43,16 @@
"BoundingBox",
"CDL",
"collate_dict",
+ "Chesapeake",
+ "Chesapeake7",
+ "Chesapeake13",
+ "ChesapeakeDC",
+ "ChesapeakeDE",
+ "ChesapeakeMD",
+ "ChesapeakeNY",
+ "ChesapeakePA",
+ "ChesapeakeVA",
+ "ChesapeakeWV",
"COWC",
"COWCCounting",
"COWCDetection",
diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py
new file mode 100644
index 00000000000..12e0ef5b11e
--- /dev/null
+++ b/torchgeo/datasets/chesapeake.py
@@ -0,0 +1,338 @@
+"""Chesapeake Bay High-Resolution Land Cover Project dataset."""
+
+import abc
+import os
+import sys
+from typing import Any, Callable, Dict, Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import rasterio
+import torch
+from rasterio.crs import CRS
+from rasterio.vrt import WarpedVRT
+from rtree.index import Index, Property
+from torch import Tensor
+
+from .geo import GeoDataset
+from .utils import BoundingBox, check_integrity, download_and_extract_archive
+
+_crs = CRS.from_wkt(
+ """
+PROJCS["USA_Contiguous_Albers_Equal_Area_Conic_USGS_version",
+ GEOGCS["NAD83",
+ DATUM["North_American_Datum_1983",
+ SPHEROID["GRS 1980",6378137,298.257222101004,
+ AUTHORITY["EPSG","7019"]],
+ AUTHORITY["EPSG","6269"]],
+ PRIMEM["Greenwich",0],
+ UNIT["degree",0.0174532925199433,
+ AUTHORITY["EPSG","9122"]],
+ AUTHORITY["EPSG","4269"]],
+ PROJECTION["Albers_Conic_Equal_Area"],
+ PARAMETER["latitude_of_center",23],
+ PARAMETER["longitude_of_center",-96],
+ PARAMETER["standard_parallel_1",29.5],
+ PARAMETER["standard_parallel_2",45.5],
+ PARAMETER["false_easting",0],
+ PARAMETER["false_northing",0],
+ UNIT["metre",1,
+ AUTHORITY["EPSG","9001"]],
+ AXIS["Easting",EAST],
+ AXIS["Northing",NORTH]]
+"""
+)
+
+
+class Chesapeake(GeoDataset, abc.ABC):
+ """Abstract base class for all Chesapeake datasets.
+
+ `Chesapeake Bay High-Resolution Land Cover Project
+ `_
+ dataset.
+
+ This dataset was collected by the Chesapeake Conservancy's Conservation Innovation
+ Center (CIC) in partnership with the University of Vermont and WorldView Solutions,
+ Inc. It consists of one-meter resolution land cover information for the Chesapeake
+ Bay watershed (~100,000 square miles of land).
+
+ For more information, see:
+
+ * `User Guide
+ `_
+ * `Class Descriptions
+ `_
+ * `Accuracy Assessment
+ `_
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://doi.org/10.1109/cvpr.2019.01301
+ """
+
+ @property
+ @abc.abstractmethod
+ def base_folder(self) -> str:
+ """Subdirectory to find/store dataset in."""
+
+ @property
+ @abc.abstractmethod
+ def filename(self) -> str:
+ """Filename to find/store dataset in."""
+
+ @property
+ @abc.abstractmethod
+ def zipfile(self) -> str:
+ """Name of zipfile in download URL."""
+
+ @property
+ @abc.abstractmethod
+ def md5(self) -> str:
+ """MD5 checksum to verify integrity of dataset."""
+
+ @property
+ def url(self) -> str:
+ """URL to download dataset from."""
+ url = "https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover"
+ url += f"/{self.base_folder}/{self.zipfile}"
+ return url
+
+ def __init__(
+ self,
+ root: str,
+ crs: CRS = _crs,
+ transforms: Optional[Callable[[Any], Any]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new Chesapeake dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ crs: :term:`coordinate reference system (CRS)` to project to
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+ """
+ self.root = os.path.join(root, "chesapeake", self.base_folder)
+ self.crs = crs
+ self.transforms = transforms
+ self.checksum = checksum
+
+ if download:
+ self._download()
+
+ if not self._check_integrity():
+ raise RuntimeError(
+ "Dataset not found or corrupted. "
+ + "You can use download=True to download it"
+ )
+
+ # Create an R-tree to index the dataset
+ self.index = Index(interleaved=False, properties=Property(dimension=3))
+ filename = os.path.join(self.root, self.filename)
+ with rasterio.open(filename) as src:
+ with rasterio.open(filename) as src:
+ cmap = src.colormap(1)
+ with WarpedVRT(src, crs=self.crs) as vrt:
+ minx, miny, maxx, maxy = vrt.bounds
+ mint = 0
+ maxt = sys.maxsize
+ coords = (minx, maxx, miny, maxy, mint, maxt)
+ self.index.insert(0, coords, filename)
+ self.cmap = np.array([cmap[i] for i in range(len(cmap))])
+
+ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
+ """Retrieve labels and metadata indexed by query.
+
+ Args:
+ query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
+
+ Returns:
+ sample of labels and metadata at that index
+
+ Raises:
+ IndexError: if query is not within bounds of the index
+ """
+ if not query.intersects(self.bounds):
+ raise IndexError(
+ f"query: {query} is not within bounds of the index: {self.bounds}"
+ )
+
+ hits = self.index.intersection(query, objects=True)
+ filename = next(hits).object
+ with rasterio.open(filename) as src:
+ with WarpedVRT(src, crs=self.crs) as vrt:
+ window = rasterio.windows.from_bounds(
+ query.minx,
+ query.miny,
+ query.maxx,
+ query.maxy,
+ transform=vrt.transform,
+ )
+ masks = vrt.read(window=window)
+ sample = {
+ "masks": torch.tensor(masks), # type: ignore[attr-defined]
+ "crs": self.crs,
+ "bbox": query,
+ }
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def _check_integrity(self) -> bool:
+ """Check integrity of dataset.
+
+ Returns:
+ True if dataset MD5s match, else False
+ """
+ integrity: bool = check_integrity(
+ os.path.join(self.root, self.zipfile),
+ self.md5 if self.checksum else None,
+ )
+ return integrity
+
+ def _download(self) -> None:
+ """Download the dataset and extract it."""
+ if self._check_integrity():
+ print("Files already downloaded and verified")
+ return
+
+ download_and_extract_archive(
+ self.url,
+ self.root,
+ filename=self.zipfile,
+ md5=self.md5,
+ )
+
+ def plot(self, image: Tensor) -> None:
+ """Plot an image on a map.
+
+ Args:
+ image: the image to plot
+ """
+ # Convert from class labels to RGBA values
+ array = image.squeeze().numpy()
+ array = self.cmap[array]
+
+ # Plot the image
+ ax = plt.axes()
+ ax.imshow(array, origin="lower")
+ ax.axis("off")
+ plt.show()
+
+
+class Chesapeake7(Chesapeake):
+ """Complete 7-class dataset.
+
+ This version of the dataset is composed of 7 classes:
+
+ 0. No Data: Background values
+ 1. Barren: Areas devoid of vegetation consisting of natural earthen material
+ 2. Impervious Roads: Impervious surfaces that are used for transportation
+ 3. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height
+ 4. Low Vegetation: Plant material less than 2 meters in height including lawns
+ 5. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height
+ 6. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces
+ 7. Water: All areas of open water including ponds, rivers, and lakes
+ """
+
+ # TODO: make sure these class numbers are correct
+
+ base_folder = "BAYWIDE"
+ filename = "Baywide_7Class_20132014.tif"
+ zipfile = "Baywide_7Class_20132014.zip"
+ md5 = ""
+
+
+class Chesapeake13(Chesapeake):
+ """Complete 13-class dataset.
+
+ This version of the dataset is composed of 13 classes:
+
+ 0. No Data: Background values
+ 1. Water: All areas of open water including ponds, rivers, and lakes
+ 2. Wetlands: Low vegetation areas located along marine or estuarine regions
+ 3. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height
+ 4. Shrubland: Heterogeneous woody vegetation including shrubs and young trees
+ 5. Low Vegetation: Plant material less than 2 meters in height including lawns
+ 6. Barren: Areas devoid of vegetation consisting of natural earthen material
+ 7. Structures: Human-constructed objects made of impervious materials
+ 8. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height
+ 9. Impervious Roads: Impervious surfaces that are used for transportation
+ 10. Tree Canopy over Structures: Tree cover overlapping impervious structures
+ 11. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces
+ 12. Tree Canopy over Impervious Roads: Tree cover overlapping impervious roads
+ 13. Aberdeen Proving Ground: U.S. Army facility with no labels
+ """
+
+ base_folder = "BAYWIDE"
+ filename = "Baywide_13Class_20132014.tif"
+ zipfile = "Baywide_13Class_20132014.zip"
+ md5 = "7e51118923c91e80e6e268156d25a4b9"
+
+
+class ChesapeakeDC(Chesapeake):
+ """This subset of the dataset contains data only for Washington, D.C."""
+
+ base_folder = "DC"
+ filename = os.path.join("DC_11001", "DC_11001.img")
+ zipfile = "DC_11001.zip"
+ md5 = "ed06ba7570d2955e8857d7d846c53b06"
+
+
+class ChesapeakeDE(Chesapeake):
+ """This subset of the dataset contains data only for Delaware."""
+
+ base_folder = "DE"
+ filename = "DE_STATEWIDE.tif"
+ zipfile = "_DE_STATEWIDE.zip"
+ md5 = "5e12eff3b6950c01092c7e480b38e544"
+
+
+class ChesapeakeMD(Chesapeake):
+ """This subset of the dataset contains data only for Maryland."""
+
+ base_folder = "MD"
+ filename = "MD_STATEWIDE.tif"
+ zipfile = "_MD_STATEWIDE.zip"
+ md5 = "40c7cd697a887f2ffdb601b5c114e567"
+
+
+class ChesapeakeNY(Chesapeake):
+ """This subset of the dataset contains data only for New York."""
+
+ base_folder = "NY"
+ filename = "NY_STATEWIDE.tif"
+ zipfile = "_NY_STATEWIDE.zip"
+ md5 = "1100078c526616454ef2e508affda915"
+
+
+class ChesapeakePA(Chesapeake):
+ """This subset of the dataset contains data only for Pennsylvania."""
+
+ base_folder = "PA"
+ filename = "PA_STATEWIDE.tif"
+ zipfile = "_PA_STATEWIDE.zip"
+ md5 = ""
+
+
+class ChesapeakeVA(Chesapeake):
+ """This subset of the dataset contains data only for Virginia."""
+
+ base_folder = "VA"
+ filename = "CIC2014_VA_STATEWIDE.tif"
+ zipfile = "_VA_STATEWIDE.zip"
+ md5 = "6f2c97deaf73bb3e1ea9b21bd7a3fc8e"
+
+
+class ChesapeakeWV(Chesapeake):
+ """This subset of the dataset contains data only for West Virginia."""
+
+ base_folder = "WV"
+ filename = "WV_STATEWIDE.tif"
+ zipfile = "_WV_STATEWIDE.zip"
+ md5 = "350621ea293651fbc557a1c3e3c64cc3"