diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e825fcce91..8aa299712bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,6 @@ repos: rev: v1.8.0 hooks: - id: mypy - args: [--strict, --ignore-missing-imports, --show-error-codes] + args: [--python-version=3.10, --strict, --ignore-missing-imports, --show-error-codes] additional_dependencies: [einops>=0.6.0, kornia>=0.6.9, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, scikit-image>=0.18.0, torch>=2.2, torchmetrics>=0.10] exclude: (build|data|dist|logo|logs|output)/ diff --git a/pyproject.toml b/pyproject.toml index 350d9273bdc..547603eaad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ datasets = [ "laspy>=2", # opencv-python 4.4.0.46+ required for Python 3.9 wheels "opencv-python>=4.4.0.46", + # OWSLib is required for WMS Dataset + "OWSLib>=0.30.0", # pycocotools 2.0.5+ required for cython 3+ support "pycocotools>=2.0.5", # pyvista 0.34.2+ required to avoid ImportError in CI diff --git a/requirements/datasets.txt b/requirements/datasets.txt index f183132801f..211be7599ec 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -2,6 +2,7 @@ h5py==3.10.0 laspy==2.5.3 opencv-python==4.9.0.80 +OWSLib==0.30.0 pycocotools==2.0.7 pyvista==0.43.4 radiant-mlhub==0.4.1 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 34b22649ceb..646525c1c20 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -25,6 +25,7 @@ torchvision==0.14.0 h5py==3.0.0 laspy==2.0.0 opencv-python==4.4.0.46 +OWSLib==0.21.0 pycocotools==2.0.5 pyvista==0.34.2 radiant-mlhub==0.3.0 diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py new file mode 100644 index 00000000000..31c942fd78c --- /dev/null +++ b/tests/datasets/test_wms.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest +import urllib3 + +from torchgeo.datasets import WMSDataset + +SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?" + + +def service_ok(url: str, timeout: int = 5) -> bool: + try: + http = urllib3.PoolManager() + resp = http.request("HEAD", url, timeout=timeout) + ok = 200 == resp.status + except urllib3.exceptions.NewConnectionError: + ok = False + except Exception: + ok = False + return ok + + +class TestWMSDataset: + + @pytest.mark.slow + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) + def test_wms_no_layer(self) -> None: + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0) + assert "nexrad_base_reflect" in wms.layers() + assert 4326 == wms.crs.to_epsg() + wms.layer("nexrad_base_reflect", crs=4326) + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format + + @pytest.mark.slow + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) + def test_wms_layer(self) -> None: + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326) + assert 4326 == wms.crs.to_epsg() + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format + + @pytest.mark.slow + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) + def test_wms_layer_nocrs(self) -> None: + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect") + assert 4326 == wms.crs.to_epsg() + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5f3f974e2b2..fdb89bdd09f 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -135,6 +135,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture +from .wms import WMSDataset from .xview import XView2 from .zuericrop import ZueriCrop @@ -263,6 +264,7 @@ "RasterDataset", "UnionDataset", "VectorDataset", + "WMSDataset", # Utilities "BoundingBox", "concat_samples", diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py new file mode 100644 index 00000000000..1ce46752ed3 --- /dev/null +++ b/torchgeo/datasets/wms.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""A simple class to fetch WMS images.""" +from collections.abc import Callable +from io import BytesIO +from typing import Any + +import torchvision.transforms as transforms +from PIL import Image +from rasterio.coords import BoundingBox +from rasterio.crs import CRS +from rasterio.errors import CRSError +from rtree.index import Index, Property + +from torchgeo.datasets import GeoDataset + + +class WMSDataset(GeoDataset): + """Allow models to fetch images from a WMS (at a good resolution).""" + + try: + from owslib.map.wms111 import ContentMetadata + from owslib.wms import WebMapService + except ImportError: + raise ImportError("OWSLib is not installed and is required to use this dataset") + _url: str = "" + _wms: WebMapService = None + + _layers: list[ContentMetadata] = [] + _layer: ContentMetadata = None + _layer_name: str = "" + is_image = True + + def __init__( + self, + url: str, + res: float, + layer: str | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + crs: str | int | None = None, + ) -> None: + """Initialize the data set instance. + + Args: + url: String pointing to the WMS Server (no need for + request=getcapabilities). + res: The best resolution to make requests at + (should probably match your other datasets). + layer: An optional string with the name of the + layer you want to fetch. + transforms: a function/transform that takes an input sample + and returns a transformed version + crs: An optional string or integer code that is a + valid EPSG code (without the 'EPSG:') + + """ + try: + from owslib.wms import WebMapService + except ImportError: + raise ImportError( + "OWSLib is not installed and is required to use this dataset" + ) + super().__init__(transforms) + self._url = url + self._res = res + + if crs is not None: + self._crs = CRS.from_epsg(crs) + self._wms = WebMapService(url) + self._format = self._wms.getOperationByName("GetMap").formatOptions[0] + self._layers = list(self._wms.contents) + + if layer is not None and layer in self._layers: + self.layer(layer, crs) + + def layer(self, layer: str, crs: str | int | None = None) -> None: + """Set the layer to be fetched. + + Args: + layer: A string with the name of the layer you want to fetch. + crs: An optional string or integer code that is a valid EPSG + code (without the 'EPSG:') + """ + self._layer = self._wms[layer] + self._layer_name = layer + coords = self._wms[layer].boundingBox + self.index = Index(interleaved=False, properties=Property(dimension=3)) + self.index.insert( + 0, + ( + float(coords[0]), + float(coords[2]), + float(coords[1]), + float(coords[3]), + 0, + 9.223372036854776e18, + ), + ) + if crs is None: + i = 0 + while self._crs is None: + crs_str = sorted(self._layer.crsOptions)[i].upper() + if "EPSG:" in crs_str: + crs_str = crs_str[5:] + elif "CRS:84": + crs_str = "4326" + try: + self._crs = CRS.from_epsg(crs_str) + except CRSError: + pass + + def getlayer(self) -> ContentMetadata: + """Return the selected layer object.""" + return self._layer + + def layers(self) -> list[str]: + """Return a list of availiable layers.""" + return self._layers + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + img = self._wms.getmap( + layers=[self._layer_name], + srs="epsg:" + str(self.crs.to_epsg()), + bbox=(query.minx, query.miny, query.maxx, query.maxy), + # TODO fix size + size=(500, 500), + format=self._format, + transparent=True, + ) + sample = {"crs": self.crs, "bbox": query} + + transform = transforms.Compose([transforms.ToTensor()]) + # Convert the PIL image to Torch tensor + img_tensor = transform(Image.open(BytesIO(img.read()))) + if self.is_image: + sample["image"] = img_tensor + else: + sample["mask"] = img_tensor + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample