Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an initial version of WMSDataset #1965

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ repos:
rev: v1.8.0
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --show-error-codes]
args: [--python-version=3.10, --strict, --ignore-missing-imports, --show-error-codes]
additional_dependencies: [einops>=0.6.0, kornia>=0.6.9, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, scikit-image>=0.18.0, torch>=2.2, torchmetrics>=0.10]
exclude: (build|data|dist|logo|logs|output)/
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ datasets = [
"laspy>=2",
# opencv-python 4.4.0.46+ required for Python 3.9 wheels
"opencv-python>=4.4.0.46",
# OWSLib is required for WMS Dataset
"OWSLib>=0.30.0",
# pycocotools 2.0.5+ required for cython 3+ support
"pycocotools>=2.0.5",
# pyvista 0.34.2+ required to avoid ImportError in CI
Expand Down
1 change: 1 addition & 0 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
h5py==3.10.0
laspy==2.5.3
opencv-python==4.9.0.80
OWSLib==0.30.0
pycocotools==2.0.7
pyvista==0.43.4
radiant-mlhub==0.4.1
Expand Down
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ torchvision==0.14.0
h5py==3.0.0
laspy==2.0.0
opencv-python==4.4.0.46
OWSLib==0.21.0
pycocotools==2.0.5
pyvista==0.34.2
radiant-mlhub==0.3.0
Expand Down
68 changes: 68 additions & 0 deletions tests/datasets/test_wms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import pytest
import urllib3

from torchgeo.datasets import WMSDataset

SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?"


def service_ok(url: str, timeout: int = 5) -> bool:
try:
http = urllib3.PoolManager()
resp = http.request("HEAD", url, timeout=timeout)
ok = 200 == resp.status
except urllib3.exceptions.NewConnectionError:
ok = False
except Exception:
ok = False
return ok


class TestWMSDataset:

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_no_layer(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0)
assert "nexrad_base_reflect" in wms.layers()
assert 4326 == wms.crs.to_epsg()
wms.layer("nexrad_base_reflect", crs=4326)
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_layer(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326)
assert 4326 == wms.crs.to_epsg()
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_layer_nocrs(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect")
assert 4326 == wms.crs.to_epsg()
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
from .vaihingen import Vaihingen2D
from .vhr10 import VHR10
from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture
from .wms import WMSDataset
from .xview import XView2
from .zuericrop import ZueriCrop

Expand Down Expand Up @@ -263,6 +264,7 @@
"RasterDataset",
"UnionDataset",
"VectorDataset",
"WMSDataset",
# Utilities
"BoundingBox",
"concat_samples",
Expand Down
155 changes: 155 additions & 0 deletions torchgeo/datasets/wms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""A simple class to fetch WMS images."""
from collections.abc import Callable
from io import BytesIO
from typing import Any

import torchvision.transforms as transforms
from PIL import Image
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
from rasterio.errors import CRSError
from rtree.index import Index, Property

from torchgeo.datasets import GeoDataset


class WMSDataset(GeoDataset):
"""Allow models to fetch images from a WMS (at a good resolution)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what "at a good resolution" means? Also would be good to explain what WMS is and why it is useful. For example, see the documentation for some of our other existing dataset base classes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you add:

.. versionadded:: 0.6

to the docstring? This will document what version of TorchGeo is required to use this feature.


try:
from owslib.map.wms111 import ContentMetadata
from owslib.wms import WebMapService
except ImportError:
raise ImportError("OWSLib is not installed and is required to use this dataset")
_url: str = ""
_wms: WebMapService = None

_layers: list[ContentMetadata] = []
_layer: ContentMetadata = None
_layer_name: str = ""
is_image = True

def __init__(
self,
url: str,
res: float,
layer: str | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
crs: str | int | None = None,
) -> None:
"""Initialize the data set instance.

Args:
url: String pointing to the WMS Server (no need for
request=getcapabilities).
res: The best resolution to make requests at
(should probably match your other datasets).
layer: An optional string with the name of the
layer you want to fetch.
transforms: a function/transform that takes an input sample
and returns a transformed version
crs: An optional string or integer code that is a
valid EPSG code (without the 'EPSG:')

"""
try:
from owslib.wms import WebMapService
except ImportError:
raise ImportError(
"OWSLib is not installed and is required to use this dataset"
)
super().__init__(transforms)
self._url = url
self._res = res

if crs is not None:
self._crs = CRS.from_epsg(crs)
self._wms = WebMapService(url)
self._format = self._wms.getOperationByName("GetMap").formatOptions[0]
self._layers = list(self._wms.contents)

if layer is not None and layer in self._layers:
self.layer(layer, crs)

def layer(self, layer: str, crs: str | int | None = None) -> None:
"""Set the layer to be fetched.

Args:
layer: A string with the name of the layer you want to fetch.
crs: An optional string or integer code that is a valid EPSG
code (without the 'EPSG:')
"""
self._layer = self._wms[layer]
self._layer_name = layer
coords = self._wms[layer].boundingBox
self.index = Index(interleaved=False, properties=Property(dimension=3))
self.index.insert(
0,
(
float(coords[0]),
float(coords[2]),
float(coords[1]),
float(coords[3]),
0,
9.223372036854776e18,
),
)
if crs is None:
i = 0
while self._crs is None:
crs_str = sorted(self._layer.crsOptions)[i].upper()
if "EPSG:" in crs_str:
crs_str = crs_str[5:]
elif "CRS:84":
crs_str = "4326"
try:
self._crs = CRS.from_epsg(crs_str)
except CRSError:
pass

def getlayer(self) -> ContentMetadata:
"""Return the selected layer object."""
return self._layer

def layers(self) -> list[str]:
"""Return a list of availiable layers."""
return self._layers
Comment on lines +113 to +119
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getters/setters are generally considered to be bad practice in Python, especially if you don't actually need to do anything special. In this case, if we want the ability to access the attribute, let's just make it a public attribute.


def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image/mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
img = self._wms.getmap(
layers=[self._layer_name],
srs="epsg:" + str(self.crs.to_epsg()),
bbox=(query.minx, query.miny, query.maxx, query.maxy),
# TODO fix size
size=(500, 500),
format=self._format,
transparent=True,
)
sample = {"crs": self.crs, "bbox": query}

transform = transforms.Compose([transforms.ToTensor()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't normally use torchvision transforms. What is the output from wms, a numpy array? Let's just convert that directly to PyTorch without relying on PIL

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a JPEG or PNG image can I convert that directly to a tensor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T.ToTensor() also automatically divides by 255 which we should leave up to the user to decide. You can convert the image to a tensor using torch.from_numpy(np.array(Image.open(path).astype(np.float32))). You can break this up over multiple lines so it's not too verbose

# Convert the PIL image to Torch tensor
img_tensor = transform(Image.open(BytesIO(img.read())))
if self.is_image:
sample["image"] = img_tensor
else:
sample["mask"] = img_tensor

if self.transforms is not None:
sample = self.transforms(sample)

return sample