From 1ff492fd44388c1851d0c98ec62fcda53a9a991f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 19 May 2024 21:14:49 +0200 Subject: [PATCH 1/4] Store notebooks as python files using jupytext --- docs/conf.py | 3 + docs/tutorials/custom_raster_dataset.ipynb | 567 ---------------- docs/tutorials/custom_raster_dataset.py | 261 ++++++++ docs/tutorials/getting_started.ipynb | 322 --------- docs/tutorials/getting_started.py | 108 +++ docs/tutorials/indices.ipynb | 382 ----------- docs/tutorials/indices.py | 137 ++++ docs/tutorials/pretrained_weights.ipynb | 508 --------------- docs/tutorials/pretrained_weights.py | 126 ++++ docs/tutorials/trainers.ipynb | 336 ---------- docs/tutorials/trainers.py | 136 ++++ docs/tutorials/transforms.ipynb | 725 --------------------- docs/tutorials/transforms.py | 313 +++++++++ 13 files changed, 1084 insertions(+), 2840 deletions(-) delete mode 100644 docs/tutorials/custom_raster_dataset.ipynb create mode 100644 docs/tutorials/custom_raster_dataset.py delete mode 100644 docs/tutorials/getting_started.ipynb create mode 100644 docs/tutorials/getting_started.py delete mode 100644 docs/tutorials/indices.ipynb create mode 100644 docs/tutorials/indices.py delete mode 100644 docs/tutorials/pretrained_weights.ipynb create mode 100644 docs/tutorials/pretrained_weights.py delete mode 100644 docs/tutorials/trainers.ipynb create mode 100644 docs/tutorials/trainers.py delete mode 100644 docs/tutorials/transforms.ipynb create mode 100644 docs/tutorials/transforms.py diff --git a/docs/conf.py b/docs/conf.py index 36af7d3d275..b1321831812 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -126,6 +126,9 @@ } # nbsphinx +nbsphinx_custom_formats = { + '.py': 'jupytext.reads', +} nbsphinx_execute = 'never' # TODO: branch/tag should change depending on which version of docs you look at # TODO: width option of image directive is broken, see: diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb deleted file mode 100644 index 1da51ec620d..00000000000 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ /dev/null @@ -1,567 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "iiqWbXISOEAQ" - }, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8zfSLrVHOgwv" - }, - "source": [ - "# Custom Raster Datasets\n", - "\n", - "In this tutorial, we'll describe how to write a custom dataset in TorchGeo. There are many types of datasets that you may encounter, from image data, to segmentation masks, to point labels. We'll focus on the most common type of dataset: a raster file containing an image or mask. Let's get started!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HdTsXvc8UeSS" - }, - "source": [ - "## Choosing a base class\n", - "\n", - "In TorchGeo, there are two _types_ of datasets:\n", - "\n", - "* `GeoDataset`: for uncurated raw data with geospatial metadata\n", - "* `NonGeoDataset`: for curated benchmark datasets that lack geospatial metadata\n", - "\n", - "If you're not sure which type of dataset you need, a good rule of thumb is to run `gdalinfo` on one of the files. If `gdalinfo` returns information like the bounding box, resolution, and CRS of the file, then you should probably use `GeoDataset`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S86fPV92Wdc8" - }, - "source": [ - "### GeoDataset\n", - "\n", - "In TorchGeo, each `GeoDataset` uses an [R-tree](https://en.wikipedia.org/wiki/R-tree) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`:\n", - "\n", - "* `RasterDataset`: recursively search for raster files in a directory\n", - "* `VectorDataset`: recursively search for vector files in a directory\n", - "\n", - "In this example, we'll be working with raster images, so we'll choose `RasterDataset` as the base class." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "C3fDQJdvWfsW" - }, - "source": [ - "### NonGeoDataset\n", - "\n", - "`NonGeoDataset` is almost identical to [torchvision](https://pytorch.org/vision/stable/index.html)'s `VisionDataset`, so we'll instead focus on `GeoDataset` in this tutorial. If you need to add a `NonGeoDataset`, the following tutorials may be helpful:\n", - "\n", - "* [Datasets & DataLoaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)\n", - "* [Writing Custom Datasets, DataLoaders and Transforms](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n", - "* [Developing Custom PyTorch DataLoaders](https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html)\n", - "\n", - "Of course, you can always look for similar datasets that already exist in TorchGeo and copy their design when creating your own dataset." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ABkJW-3FOi4N" - }, - "source": [ - "## Setup\n", - "\n", - "First, we install TorchGeo and a couple of other dependencies for downloading data from Microsoft's Planetary Computer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aGYEmPNONp8W", - "outputId": "5f8c773c-b1e4-471a-fa99-31c5d7125eb4" - }, - "outputs": [], - "source": [ - "%pip install torchgeo planetary_computer pystac" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MkycnrRMOBso" - }, - "source": [ - "## Imports\n", - "\n", - "Next, we import TorchGeo and any other libraries we need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9v1QN3-mOrdt" - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "from urllib.parse import urlparse\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import planetary_computer\n", - "import pystac\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples\n", - "from torchgeo.datasets.utils import download_url\n", - "from torchgeo.samplers import RandomGeoSampler\n", - "\n", - "%matplotlib inline\n", - "plt.rcParams['figure.figsize'] = (12, 12)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W6PAktuVVoSP" - }, - "source": [ - "## Downloading\n", - "\n", - "Let's download some data to play around with. In this example, we'll create a dataset for loading Sentinel-2 images. Yes, TorchGeo already has a built-in class for this, but we'll use it as an example of the steps you would need to take to add a dataset that isn't yet available in TorchGeo. We'll show how to download a few bands of Sentinel-2 imagery from the Planetary Computer. This may take a few minutes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 432, - "referenced_widgets": [ - "ecb1d85ebf264d04885a013a6e7a069c", - "4b513ad43aae431dadc2467f370a91c7", - "0af96038d774446f99181eece9a172e2", - "6c7ead8397fa418d9a67d7b1f68d794a", - "0a60e5df954c4874aaed83607e2f20ae", - "934ab0cb728541628ee8b09f0a60838f", - "66bfa56ececd4da1b1b10c4076c41ca3", - "549911c845fa4fbe82725416db96ff76", - "381bd014879f4b8cb9713fa640e6e744", - "fcedd561ee1e49f79e4e2ef140e34b8d", - "5d57d7756e4240e0a7243d817b4af6bd", - "77ed782e7b2343829235222ac932f845", - "dd23af3773714c009f8996c3c0f84ace", - "cc525ae67812410ab3a5135fce36df54", - "1651e0cd746e45699df3f9d6c8c7abef", - "cb44bd94cdc442ccb0b98f71bffb610f", - "c7fc3889e6224808bf919eeb70cde2e7", - "16b0250afcb94ce589af663dc7cd9b64", - "ed8b05db83c84c9aaca84a35c49b35e1", - "957435d60d7e40f0a945c46446943771", - "39f7a22b7a9f4844bfc6c89e9d3a94aa", - "06e5ccfc007b44b082af8cc4b418a69b", - "30a6013e4722448d94e9db91ad6d3e6f", - "5149cfb5beab412b90cb27d7662d0230", - "512d1ecfc2a74739a9945583a54e9d22", - "31cd3702e66c4d678aa9d7be5b672e8c", - "59f956f7423541f8b9f63df601422c95", - "7eb86dddca194032b143669997c8ee86", - "5b27ac2e02874fe1bf22f4af9a026488", - "11e0951fb0b440d3a5052487b75a5866", - "fac336dcb5424ad1884a6d458b19a05e", - "7e775c3f4f1d4f579900e62e235c4cc2", - "be81c4fc26f046af93ebc85ed2b9b049", - "652084f413184219ae276a6ce73b8fc2", - "fa8c6a2f39b94480b42e0739b740dd17", - "7605de5a54ef44c9a429763c9dae26d4", - "3af584f242aa4283bc4089f461b88bc4", - "fbee2fa720764ad28898c7076fa47515", - "3d873afcb0a147b5877212e70c14f428", - "82a5dfe7afeb4ce1be3c5bc0e79c03c2", - "6e25b5a3a9b74736b88cc55fdcad16e7", - "8de1e9d8038b4143add60c9c53407a42", - "7c62146cce7e41b79f31d1518f4b025f", - "7dccd3ba53154811a0e8dd63c4b36b11", - "88e4e6e8039e4b80a06b53ccaadfc7db", - "845c622bc12a4211bb3cb1bfd93830a3", - "26e81ed27e7e4b5bbc68accae6c6051c", - "8822d92bc7a6449a851e5140f97f1eb6", - "0a19c6e153e54b2c86b3f74a3f9ebc78", - "ffe4cc5f5dff4cdd90e1d1afa7f4a210", - "f12a91cdbe9c4c269f9111468b4d4473", - "5b379fc2a31b4c07a12198d98c9ec48d", - "8f0c13446f3846aeac06eba2e2a90a77", - "b0addc94caeb4f3ca6ed21f773d68725", - "e6564598832544878aa3f90a37cad7ae", - "f7e5fd99610a4fa9ab7ac076fcbe9cc1", - "ff7f7a7cb5cc4b8bb8a9ee5c8189b9b8", - "c53c376437f9408c902554d7ec58dfa5", - "5e22879c720f47f59f63b40a7f45a28d", - "9a7cac58ce4c4cdf8f7ea9bf348852f9", - "0d79a004bedc4ab0beb05410d69edcc5", - "45d6cfb69a2f43e3b411f3974ed258a0", - "94738bc89cd24297904d8aa71ffab7fa", - "e3ccaf4e4ccb40cb89ec0e1bc1bdf6aa", - "dc2e555542834dca9807927b31c5ea60", - "f1d46cc6d8cd466bae0aafe2345c34eb", - "1c943575cfb840b7b6602b7bf384baac", - "f4be36be86e7414782716bbf2ea97714", - "6e7fb39128f94190ab538ecc5cd2a529", - "a339df96ccfd4234805c35a19d4f6be1", - "e7874f8bfb4948afb2a9a94495223b5a", - "e980b11f57654baf9deb114d22d7b165", - "ccfca83232a246b183f1d8887b67bca5", - "d8748855e166467e9d8a3c6ce50b8426", - "90f125983bd74e04a8ff9ad250919884", - "d9ac7ee67f9c491d817c9b95e3ea4735", - "af59e0800a7e4788b7626f556d876017", - "8050a04b8a9c49598da21dfec7ee7992", - "ac013a8894884dcbafbd6046bcc69ace", - "bab46fc09bd94f59a3a514f7f5c86298", - "6ca58f93c56e435296b7659d454caa71", - "0ac7f49774444de48f8d1237c5842cbf", - "9291ab2d297d400c9bb89a1479506005", - "bccd758b413742b5a661a5842dd42e93", - "316708d30a5d4f34b66e423d683ee760", - "c44d19c7347e4a15993df5ee72f397fc", - "8d1f66bc4d2d4341afc437030a2baafa", - "ec3a76861bb74ee998c6b180db50104c" - ] - }, - "id": "re1vuzCQfNvr", - "outputId": "ac86ad5f-a0c4-4d80-cbdd-4869cb73b827" - }, - "outputs": [], - "source": [ - "root = os.path.join(tempfile.gettempdir(), 'sentinel')\n", - "item_urls = [\n", - " 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220902T090559_R050_T40XDH_20220902T181115',\n", - " 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220718T084609_R107_T40XEJ_20220718T175008',\n", - "]\n", - "\n", - "for item_url in item_urls:\n", - " item = pystac.Item.from_file(item_url)\n", - " signed_item = planetary_computer.sign(item)\n", - " for band in ['B02', 'B03', 'B04', 'B08']:\n", - " asset_href = signed_item.assets[band].href\n", - " filename = urlparse(asset_href).path.split('/')[-1]\n", - " download_url(asset_href, root, filename)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hz3uPKcsPLAz" - }, - "source": [ - "This downloads the following files:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zcBoq3RWPQhn", - "outputId": "ab32f780-a43a-4725-d609-4d4ea35d3ccc" - }, - "outputs": [], - "source": [ - "sorted(os.listdir(root))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pt-BP66NRkc7" - }, - "source": [ - "As you can see, each spectral band is stored in a different file. We have downloaded 2 total scenes, each with 4 spectral bands." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5BX_C8dJSCZT" - }, - "source": [ - "## Defining a dataset\n", - "\n", - "To define a new dataset class, we subclass from `RasterDataset`. `RasterDataset` has several class attributes used to customize how to find and load files.\n", - "\n", - "### `filename_glob`\n", - "\n", - "In order to search for files that belong in a dataset, we need to know what the filenames look like. In our Sentinel-2 example, all files start with a capital `T` and end with `_10m.tif`. We also want to make sure that the glob only finds a single file for each scene, so we'll include `B02` in the glob. If you've never used Unix globs before, see Python's [fnmatch](https://docs.python.org/3/library/fnmatch.html) module for documentation on allowed characters.\n", - "\n", - "### `filename_regex`\n", - "\n", - "Rasterio can read the geospatial bounding box of each file, but it can't read the timestamp. In order to determine the timestamp of the file, we'll define a `filename_regex` with a group labeled \"date\". If your files don't have a timestamp in the filename, you can skip this step. If you've never used regular expressions before, see Python's [re](https://docs.python.org/3/library/re.html) module for documentation on allowed characters.\n", - "\n", - "### `date_format`\n", - "\n", - "The timestamp can come in many formats. In our example, we have the following format:\n", - "\n", - "* 4 digit year (`%Y`)\n", - "* 2 digit month (`%m`)\n", - "* 2 digit day (`%d`)\n", - "* the letter T\n", - "* 2 digit hour (`%H`)\n", - "* 2 digit minute (`%M`)\n", - "* 2 digit second (`%S`)\n", - "\n", - "We'll define the `date_format` variable using [datetime format codes](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes).\n", - "\n", - "### `is_image`\n", - "\n", - "If your data only contains model inputs (such as images), use `is_image = True`. If your data only contains ground truth model outputs (such as segmentation masks), use `is_image = False` instead.\n", - "\n", - "### `dtype`\n", - "\n", - "Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n", - "\n", - "### `resampling`\n", - "\n", - "Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n", - "\n", - "### `separate_files`\n", - "\n", - "If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n", - "\n", - "### `all_bands`\n", - "\n", - "If your data is a multispectral image, you can define a list of all band names using the `all_bands` variable.\n", - "\n", - "### `rgb_bands`\n", - "\n", - "If your data is a multispectral image, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", - "\n", - "Putting this all together into a single class, we get:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8sFb8BTTTxZD" - }, - "outputs": [], - "source": [ - "class Sentinel2(RasterDataset):\n", - " filename_glob = 'T*_B02_10m.tif'\n", - " filename_regex = r'^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", - " date_format = '%Y%m%dT%H%M%S'\n", - " is_image = True\n", - " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a1AlbJp7XUEa" - }, - "source": [ - "We can now instantiate this class and see if it works correctly." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NXvg9EL8XZAk", - "outputId": "134235ee-b108-4861-f864-ea3d8960b0ce" - }, - "outputs": [], - "source": [ - "dataset = Sentinel2(root)\n", - "print(dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "msbeAkVOX-iJ" - }, - "source": [ - "As expected, we have a GeoDataset of size 2 because there are 2 scenes in our root data directory." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IUjv7Km7YDpH" - }, - "source": [ - "## Plotting\n", - "\n", - "A great test to make sure that the dataset works correctly is to try to plot an image. We'll add a plot function to our dataset to help visualize it. First, we need to modify the image so that it only contains the RGB bands, and ensure that they are in the correct order. We also need to ensure that the image is in the range 0.0 to 1.0 (or 0 to 255). Finally, we'll create a plot using matplotlib." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7PNFOy9mYq6K" - }, - "outputs": [], - "source": [ - "class Sentinel2(RasterDataset):\n", - " filename_glob = 'T*_B02_10m.tif'\n", - " filename_regex = r'^.{6}_(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", - " date_format = '%Y%m%dT%H%M%S'\n", - " is_image = True\n", - " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']\n", - "\n", - " def plot(self, sample):\n", - " # Find the correct band index order\n", - " rgb_indices = []\n", - " for band in self.rgb_bands:\n", - " rgb_indices.append(self.all_bands.index(band))\n", - "\n", - " # Reorder and rescale the image\n", - " image = sample['image'][rgb_indices].permute(1, 2, 0)\n", - " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", - "\n", - " # Plot the image\n", - " fig, ax = plt.subplots()\n", - " ax.imshow(image)\n", - "\n", - " return fig" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sF8HBA9gah3z" - }, - "source": [ - "Let's plot an image to see what it looks like. We'll use `RandomGeoSampler` to load small patches from each image." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "I6lv4YcVbAox", - "outputId": "e6ee643f-66bd-457e-f88c-bbedf092e19d" - }, - "outputs": [], - "source": [ - "torch.manual_seed(1)\n", - "\n", - "dataset = Sentinel2(root)\n", - "sampler = RandomGeoSampler(dataset, size=4096, length=3)\n", - "dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)\n", - "\n", - "for batch in dataloader:\n", - " sample = unbind_samples(batch)[0]\n", - " dataset.plot(sample)\n", - " plt.axis('off')\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ALLYUzhXKkfS" - }, - "source": [ - "For those who are curious, these are glaciers on Novaya Zemlya, Russia." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R_qrQkBCEvEl" - }, - "source": [ - "## Custom parameters\n", - "\n", - "If you want to add custom parameters to the class, you can override the `__init__` method. For example, let's say you have imagery that can be automatically downloaded. The `RasterDataset` base class doesn't support this, but you could add support in your subclass. Simply copy the parameters from the base class and add a new `download` parameter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TxODAvIHFKNt" - }, - "outputs": [], - "source": [ - "class Downloadable(RasterDataset):\n", - " def __init__(self, root, crs, res, transforms, cache, download=False):\n", - " super().__init__(root, crs, res, transforms, cache)\n", - "\n", - " if download:\n", - " # download the dataset\n", - " ..." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cI43f8DMF3iR" - }, - "source": [ - "## Contributing\n", - "\n", - "TorchGeo is an open source ecosystem built from the contributions of users like you. If your dataset might be useful for other users, please consider contributing it to TorchGeo! You'll need a bit of documentation and some testing before your dataset can be added, but it will be included in the next minor release for all users to enjoy. See the [Contributing](https://torchgeo.readthedocs.io/en/stable/user/contributing.html) guide to get started." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/docs/tutorials/custom_raster_dataset.py b/docs/tutorials/custom_raster_dataset.py new file mode 100644 index 00000000000..ccea38734ce --- /dev/null +++ b/docs/tutorials/custom_raster_dataset.py @@ -0,0 +1,261 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + [markdown] id="iiqWbXISOEAQ" +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="8zfSLrVHOgwv" +# # Custom Raster Datasets +# +# In this tutorial, we'll describe how to write a custom dataset in TorchGeo. There are many types of datasets that you may encounter, from image data, to segmentation masks, to point labels. We'll focus on the most common type of dataset: a raster file containing an image or mask. Let's get started! + +# + [markdown] id="HdTsXvc8UeSS" +# ## Choosing a base class +# +# In TorchGeo, there are two _types_ of datasets: +# +# * `GeoDataset`: for uncurated raw data with geospatial metadata +# * `NonGeoDataset`: for curated benchmark datasets that lack geospatial metadata +# +# If you're not sure which type of dataset you need, a good rule of thumb is to run `gdalinfo` on one of the files. If `gdalinfo` returns information like the bounding box, resolution, and CRS of the file, then you should probably use `GeoDataset`. + +# + [markdown] id="S86fPV92Wdc8" +# ### GeoDataset +# +# In TorchGeo, each `GeoDataset` uses an [R-tree](https://en.wikipedia.org/wiki/R-tree) to store the spatiotemporal bounding box of each file or data point. To simplify this process and reduce code duplication, we provide two subclasses of `GeoDataset`: +# +# * `RasterDataset`: recursively search for raster files in a directory +# * `VectorDataset`: recursively search for vector files in a directory +# +# In this example, we'll be working with raster images, so we'll choose `RasterDataset` as the base class. + +# + [markdown] id="C3fDQJdvWfsW" +# ### NonGeoDataset +# +# `NonGeoDataset` is almost identical to [torchvision](https://pytorch.org/vision/stable/index.html)'s `VisionDataset`, so we'll instead focus on `GeoDataset` in this tutorial. If you need to add a `NonGeoDataset`, the following tutorials may be helpful: +# +# * [Datasets & DataLoaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) +# * [Writing Custom Datasets, DataLoaders and Transforms](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html) +# * [Developing Custom PyTorch DataLoaders](https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html) +# +# Of course, you can always look for similar datasets that already exist in TorchGeo and copy their design when creating your own dataset. + +# + [markdown] id="ABkJW-3FOi4N" +# ## Setup +# +# First, we install TorchGeo and a couple of other dependencies for downloading data from Microsoft's Planetary Computer. + +# + colab={"base_uri": "https://localhost:8080/"} id="aGYEmPNONp8W" outputId="5f8c773c-b1e4-471a-fa99-31c5d7125eb4" +# %pip install torchgeo planetary_computer pystac + +# + [markdown] id="MkycnrRMOBso" +# ## Imports +# +# Next, we import TorchGeo and any other libraries we need. + +# + id="9v1QN3-mOrdt" +import os +import tempfile +from urllib.parse import urlparse + +import matplotlib.pyplot as plt +import planetary_computer +import pystac +import torch +from torch.utils.data import DataLoader + +from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples +from torchgeo.datasets.utils import download_url +from torchgeo.samplers import RandomGeoSampler + +# %matplotlib inline +plt.rcParams['figure.figsize'] = (12, 12) + +# + [markdown] id="W6PAktuVVoSP" +# ## Downloading +# +# Let's download some data to play around with. In this example, we'll create a dataset for loading Sentinel-2 images. Yes, TorchGeo already has a built-in class for this, but we'll use it as an example of the steps you would need to take to add a dataset that isn't yet available in TorchGeo. We'll show how to download a few bands of Sentinel-2 imagery from the Planetary Computer. This may take a few minutes. + +# + colab={"base_uri": "https://localhost:8080/", "height": 432, "referenced_widgets": ["ecb1d85ebf264d04885a013a6e7a069c", "4b513ad43aae431dadc2467f370a91c7", "0af96038d774446f99181eece9a172e2", "6c7ead8397fa418d9a67d7b1f68d794a", "0a60e5df954c4874aaed83607e2f20ae", "934ab0cb728541628ee8b09f0a60838f", "66bfa56ececd4da1b1b10c4076c41ca3", "549911c845fa4fbe82725416db96ff76", "381bd014879f4b8cb9713fa640e6e744", "fcedd561ee1e49f79e4e2ef140e34b8d", "5d57d7756e4240e0a7243d817b4af6bd", "77ed782e7b2343829235222ac932f845", "dd23af3773714c009f8996c3c0f84ace", "cc525ae67812410ab3a5135fce36df54", "1651e0cd746e45699df3f9d6c8c7abef", "cb44bd94cdc442ccb0b98f71bffb610f", "c7fc3889e6224808bf919eeb70cde2e7", "16b0250afcb94ce589af663dc7cd9b64", "ed8b05db83c84c9aaca84a35c49b35e1", "957435d60d7e40f0a945c46446943771", "39f7a22b7a9f4844bfc6c89e9d3a94aa", "06e5ccfc007b44b082af8cc4b418a69b", "30a6013e4722448d94e9db91ad6d3e6f", "5149cfb5beab412b90cb27d7662d0230", "512d1ecfc2a74739a9945583a54e9d22", "31cd3702e66c4d678aa9d7be5b672e8c", "59f956f7423541f8b9f63df601422c95", "7eb86dddca194032b143669997c8ee86", "5b27ac2e02874fe1bf22f4af9a026488", "11e0951fb0b440d3a5052487b75a5866", "fac336dcb5424ad1884a6d458b19a05e", "7e775c3f4f1d4f579900e62e235c4cc2", "be81c4fc26f046af93ebc85ed2b9b049", "652084f413184219ae276a6ce73b8fc2", "fa8c6a2f39b94480b42e0739b740dd17", "7605de5a54ef44c9a429763c9dae26d4", "3af584f242aa4283bc4089f461b88bc4", "fbee2fa720764ad28898c7076fa47515", "3d873afcb0a147b5877212e70c14f428", "82a5dfe7afeb4ce1be3c5bc0e79c03c2", "6e25b5a3a9b74736b88cc55fdcad16e7", "8de1e9d8038b4143add60c9c53407a42", "7c62146cce7e41b79f31d1518f4b025f", "7dccd3ba53154811a0e8dd63c4b36b11", "88e4e6e8039e4b80a06b53ccaadfc7db", "845c622bc12a4211bb3cb1bfd93830a3", "26e81ed27e7e4b5bbc68accae6c6051c", "8822d92bc7a6449a851e5140f97f1eb6", "0a19c6e153e54b2c86b3f74a3f9ebc78", "ffe4cc5f5dff4cdd90e1d1afa7f4a210", "f12a91cdbe9c4c269f9111468b4d4473", "5b379fc2a31b4c07a12198d98c9ec48d", "8f0c13446f3846aeac06eba2e2a90a77", "b0addc94caeb4f3ca6ed21f773d68725", "e6564598832544878aa3f90a37cad7ae", "f7e5fd99610a4fa9ab7ac076fcbe9cc1", "ff7f7a7cb5cc4b8bb8a9ee5c8189b9b8", "c53c376437f9408c902554d7ec58dfa5", "5e22879c720f47f59f63b40a7f45a28d", "9a7cac58ce4c4cdf8f7ea9bf348852f9", "0d79a004bedc4ab0beb05410d69edcc5", "45d6cfb69a2f43e3b411f3974ed258a0", "94738bc89cd24297904d8aa71ffab7fa", "e3ccaf4e4ccb40cb89ec0e1bc1bdf6aa", "dc2e555542834dca9807927b31c5ea60", "f1d46cc6d8cd466bae0aafe2345c34eb", "1c943575cfb840b7b6602b7bf384baac", "f4be36be86e7414782716bbf2ea97714", "6e7fb39128f94190ab538ecc5cd2a529", "a339df96ccfd4234805c35a19d4f6be1", "e7874f8bfb4948afb2a9a94495223b5a", "e980b11f57654baf9deb114d22d7b165", "ccfca83232a246b183f1d8887b67bca5", "d8748855e166467e9d8a3c6ce50b8426", "90f125983bd74e04a8ff9ad250919884", "d9ac7ee67f9c491d817c9b95e3ea4735", "af59e0800a7e4788b7626f556d876017", "8050a04b8a9c49598da21dfec7ee7992", "ac013a8894884dcbafbd6046bcc69ace", "bab46fc09bd94f59a3a514f7f5c86298", "6ca58f93c56e435296b7659d454caa71", "0ac7f49774444de48f8d1237c5842cbf", "9291ab2d297d400c9bb89a1479506005", "bccd758b413742b5a661a5842dd42e93", "316708d30a5d4f34b66e423d683ee760", "c44d19c7347e4a15993df5ee72f397fc", "8d1f66bc4d2d4341afc437030a2baafa", "ec3a76861bb74ee998c6b180db50104c"]} id="re1vuzCQfNvr" outputId="ac86ad5f-a0c4-4d80-cbdd-4869cb73b827" +root = os.path.join(tempfile.gettempdir(), 'sentinel') +item_urls = [ + 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220902T090559_R050_T40XDH_20220902T181115', + 'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20220718T084609_R107_T40XEJ_20220718T175008', +] + +for item_url in item_urls: + item = pystac.Item.from_file(item_url) + signed_item = planetary_computer.sign(item) + for band in ['B02', 'B03', 'B04', 'B08']: + asset_href = signed_item.assets[band].href + filename = urlparse(asset_href).path.split('/')[-1] + download_url(asset_href, root, filename) + +# + [markdown] id="Hz3uPKcsPLAz" +# This downloads the following files: + +# + colab={"base_uri": "https://localhost:8080/"} id="zcBoq3RWPQhn" outputId="ab32f780-a43a-4725-d609-4d4ea35d3ccc" +sorted(os.listdir(root)) + + +# + [markdown] id="Pt-BP66NRkc7" +# As you can see, each spectral band is stored in a different file. We have downloaded 2 total scenes, each with 4 spectral bands. + +# + [markdown] id="5BX_C8dJSCZT" +# ## Defining a dataset +# +# To define a new dataset class, we subclass from `RasterDataset`. `RasterDataset` has several class attributes used to customize how to find and load files. +# +# ### `filename_glob` +# +# In order to search for files that belong in a dataset, we need to know what the filenames look like. In our Sentinel-2 example, all files start with a capital `T` and end with `_10m.tif`. We also want to make sure that the glob only finds a single file for each scene, so we'll include `B02` in the glob. If you've never used Unix globs before, see Python's [fnmatch](https://docs.python.org/3/library/fnmatch.html) module for documentation on allowed characters. +# +# ### `filename_regex` +# +# Rasterio can read the geospatial bounding box of each file, but it can't read the timestamp. In order to determine the timestamp of the file, we'll define a `filename_regex` with a group labeled "date". If your files don't have a timestamp in the filename, you can skip this step. If you've never used regular expressions before, see Python's [re](https://docs.python.org/3/library/re.html) module for documentation on allowed characters. +# +# ### `date_format` +# +# The timestamp can come in many formats. In our example, we have the following format: +# +# * 4 digit year (`%Y`) +# * 2 digit month (`%m`) +# * 2 digit day (`%d`) +# * the letter T +# * 2 digit hour (`%H`) +# * 2 digit minute (`%M`) +# * 2 digit second (`%S`) +# +# We'll define the `date_format` variable using [datetime format codes](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes). +# +# ### `is_image` +# +# If your data only contains model inputs (such as images), use `is_image = True`. If your data only contains ground truth model outputs (such as segmentation masks), use `is_image = False` instead. +# +# ### `dtype` +# +# Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32). +# +# ### `resampling` +# +# Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms. +# +# ### `separate_files` +# +# If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead. +# +# ### `all_bands` +# +# If your data is a multispectral image, you can define a list of all band names using the `all_bands` variable. +# +# ### `rgb_bands` +# +# If your data is a multispectral image, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order. +# +# Putting this all together into a single class, we get: + +# + id="8sFb8BTTTxZD" +class Sentinel2(RasterDataset): + filename_glob = 'T*_B02_10m.tif' + filename_regex = r'^.{6}_(?P\d{8}T\d{6})_(?PB0[\d])' + date_format = '%Y%m%dT%H%M%S' + is_image = True + separate_files = True + all_bands = ['B02', 'B03', 'B04', 'B08'] + rgb_bands = ['B04', 'B03', 'B02'] + + +# + [markdown] id="a1AlbJp7XUEa" +# We can now instantiate this class and see if it works correctly. + +# + colab={"base_uri": "https://localhost:8080/"} id="NXvg9EL8XZAk" outputId="134235ee-b108-4861-f864-ea3d8960b0ce" +dataset = Sentinel2(root) +print(dataset) + + +# + [markdown] id="msbeAkVOX-iJ" +# As expected, we have a GeoDataset of size 2 because there are 2 scenes in our root data directory. + +# + [markdown] id="IUjv7Km7YDpH" +# ## Plotting +# +# A great test to make sure that the dataset works correctly is to try to plot an image. We'll add a plot function to our dataset to help visualize it. First, we need to modify the image so that it only contains the RGB bands, and ensure that they are in the correct order. We also need to ensure that the image is in the range 0.0 to 1.0 (or 0 to 255). Finally, we'll create a plot using matplotlib. + +# + id="7PNFOy9mYq6K" +class Sentinel2(RasterDataset): + filename_glob = 'T*_B02_10m.tif' + filename_regex = r'^.{6}_(?P\d{8}T\d{6})_(?PB0[\d])' + date_format = '%Y%m%dT%H%M%S' + is_image = True + separate_files = True + all_bands = ['B02', 'B03', 'B04', 'B08'] + rgb_bands = ['B04', 'B03', 'B02'] + + def plot(self, sample): + # Find the correct band index order + rgb_indices = [] + for band in self.rgb_bands: + rgb_indices.append(self.all_bands.index(band)) + + # Reorder and rescale the image + image = sample['image'][rgb_indices].permute(1, 2, 0) + image = torch.clamp(image / 10000, min=0, max=1).numpy() + + # Plot the image + fig, ax = plt.subplots() + ax.imshow(image) + + return fig + + +# + [markdown] id="sF8HBA9gah3z" +# Let's plot an image to see what it looks like. We'll use `RandomGeoSampler` to load small patches from each image. + +# + colab={"base_uri": "https://localhost:8080/", "height": 1000} id="I6lv4YcVbAox" outputId="e6ee643f-66bd-457e-f88c-bbedf092e19d" +torch.manual_seed(1) + +dataset = Sentinel2(root) +sampler = RandomGeoSampler(dataset, size=4096, length=3) +dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples) + +for batch in dataloader: + sample = unbind_samples(batch)[0] + dataset.plot(sample) + plt.axis('off') + plt.show() + + +# + [markdown] id="ALLYUzhXKkfS" +# For those who are curious, these are glaciers on Novaya Zemlya, Russia. + +# + [markdown] id="R_qrQkBCEvEl" +# ## Custom parameters +# +# If you want to add custom parameters to the class, you can override the `__init__` method. For example, let's say you have imagery that can be automatically downloaded. The `RasterDataset` base class doesn't support this, but you could add support in your subclass. Simply copy the parameters from the base class and add a new `download` parameter. + +# + id="TxODAvIHFKNt" +class Downloadable(RasterDataset): + def __init__(self, root, crs, res, transforms, cache, download=False): + super().__init__(root, crs, res, transforms, cache) + + if download: + # download the dataset + ... + +# + [markdown] id="cI43f8DMF3iR" +# ## Contributing +# +# TorchGeo is an open source ecosystem built from the contributions of users like you. If your dataset might be useful for other users, please consider contributing it to TorchGeo! You'll need a bit of documentation and some testing before your dataset can be added, but it will be included in the next minor release for all users to enjoy. See the [Contributing](https://torchgeo.readthedocs.io/en/stable/user/contributing.html) guide to get started. diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb deleted file mode 100644 index 1b1982711b3..00000000000 --- a/docs/tutorials/getting_started.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "35303546", - "metadata": {}, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "id": "9478ed9a", - "metadata": { - "id": "NdrXRgjU7Zih" - }, - "source": [ - "# Getting Started\n", - "\n", - "In this tutorial, we demonstrate some of the basic features of TorchGeo and show how easy it is to use if you're already familiar with other PyTorch domain libraries like torchvision.\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "id": "34f10e9f", - "metadata": { - "id": "lCqHTGRYBZcz" - }, - "source": [ - "## Setup\n", - "\n", - "First, we install TorchGeo." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "019092f0", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install torchgeo" - ] - }, - { - "cell_type": "markdown", - "id": "4db9f791", - "metadata": { - "id": "dV0NLHfGBMWl" - }, - "source": [ - "## Imports\n", - "\n", - "Next, we import TorchGeo and any other libraries we need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d92b0f1", - "metadata": { - "id": "entire-albania" - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples\n", - "from torchgeo.datasets.utils import download_url\n", - "from torchgeo.samplers import RandomGeoSampler" - ] - }, - { - "cell_type": "markdown", - "id": "7f26e4b8", - "metadata": { - "id": "5rLknZxrBEMz" - }, - "source": [ - "## Datasets\n", - "\n", - "For this tutorial, we'll be using imagery from the [National Agriculture Imagery Program (NAIP)](https://catalog.data.gov/dataset/national-agriculture-imagery-program-naip) and labels from the [Chesapeake Bay High-Resolution Land Cover Project](https://www.chesapeakeconservancy.org/conservation-innovation-center/high-resolution-data/land-cover-data-project/). First, we manually download a few NAIP tiles and create a PyTorch Dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4a39af46", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 232, - "referenced_widgets": [ - "d00a2177bf4b4b8191bfc8796f0e749f", - "17d6b81aec50455989276b595457cc7f", - "06ccd130058b432dbfa025c102eaeb27", - "6bc5b9872b574cb5aa6ebd1d44e7a71f", - "f7746f028f874a85b6101185fc9a8efc", - "f7ef78d6f87a4a2685788e395525fa7c", - "5b2450e316e64b4ba432c78b63275124", - "d3bbd6112f144c77bc68e5f2a7a355ff", - "a0300b1252cd4da5a798b55c15f8f5fd", - "793c2851b6464b398f7b4d2f2f509722", - "8dd61c8479d74c95a55de147e04446b3", - "b57d15e6c32b4fff8994ae67320972f6", - "9a34f8907a264232adf6b0d0543461dd", - "e680eda3c84c440083e2959f04431bea", - "a073e33fd9ae4125822fc17971233770", - "87faaa32454a42939d3bd405e726228c", - "b3d4c9c99bec4e69a199e45920d52ce4", - "a215f3310ea543d1a8991f57ec824872", - "569f60397fd6440d825e8afb83b4e1ae", - "b7f604d2ba4e4328a451725973fa755f", - "737fa148dfae49a18cc0eabbe05f2d0f", - "0b6613adbcc74165a9d9f74988af366e", - "b25f274c737d4212b3ffeedb2372ba22", - "ef0fc75ff5044171be942a6b3ba0c2da", - "612d84013a6e4890a48eb229f6431233", - "9a689285370646ab800155432ea042a5", - "014ed48a23234e8b81dd7ac4dbf95817", - "93c536a27b024728a00486b1f68b4dde", - "8a8538a91a74439b81e3f7c6516763e3", - "caf540562b484594bab8d6210dd7c2c1", - "99cd2e65fb104380953745f2e0a93fac", - "c5b818707bb64c5a865236a46399cea2", - "54f5db9555c44efa9370cbb7ab58e142", - "1d83b20dbb9c4c6a9d5c100fe4770ba4", - "c51b2400ca9442a9a9e0712d5778cd9a", - "bd2e44a8eb1a4c19a32da5a1edd647d1", - "0f9feea4b8344a7f8054c9417150825e", - "31acb7a1ca8940078e1aacd72e547f47", - "0d0ca8d64d3e4c2f88d87342808dd677", - "54402c5f8df34b83b95c94104b26e2c6", - "910b98584fa74bb5ad308fe770f5b40e", - "b2dce834ee044d69858389178b493a2b", - "237f2e31bcfe476baafae8d922877e07", - "43ac7d95481b4ea3866feef6ace2f043" - ] - }, - "id": "e3138ac3", - "outputId": "11589c46-eee6-455d-839b-390f2934d834" - }, - "outputs": [], - "source": [ - "naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n", - "naip_url = (\n", - " 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n", - ")\n", - "tiles = [\n", - " 'm_3807511_ne_18_060_20181104.tif',\n", - " 'm_3807511_se_18_060_20181104.tif',\n", - " 'm_3807512_nw_18_060_20180815.tif',\n", - " 'm_3807512_sw_18_060_20180815.tif',\n", - "]\n", - "for tile in tiles:\n", - " download_url(naip_url + tile, naip_root)\n", - "\n", - "naip = NAIP(naip_root)" - ] - }, - { - "cell_type": "markdown", - "id": "e25bad40", - "metadata": { - "id": "HQVji2B22Qfu" - }, - "source": [ - "Next, we tell TorchGeo to automatically download the corresponding Chesapeake labels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "689bb2b0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2Ah34KAw2biY", - "outputId": "03b7bdf0-78c1-4a13-ac56-59de740d7f59" - }, - "outputs": [], - "source": [ - "chesapeake_root = os.path.join(tempfile.gettempdir(), 'chesapeake')\n", - "os.makedirs(chesapeake_root, exist_ok=True)\n", - "chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True)" - ] - }, - { - "cell_type": "markdown", - "id": "56f2d78b", - "metadata": { - "id": "OWUhlfpD22IX" - }, - "source": [ - "Finally, we create an IntersectionDataset so that we can automatically sample from both GeoDatasets simultaneously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "daefbc4d", - "metadata": { - "id": "WXxy8F8l2-aC" - }, - "outputs": [], - "source": [ - "dataset = naip & chesapeake" - ] - }, - { - "cell_type": "markdown", - "id": "ded44652", - "metadata": { - "id": "yF_R54Yf3EUd" - }, - "source": [ - "## Sampler\n", - "\n", - "Unlike typical PyTorch Datasets, TorchGeo GeoDatasets are indexed using lat/long/time bounding boxes. This requires us to use a custom GeoSampler instead of the default sampler/batch_sampler that comes with PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b8a0d99c", - "metadata": { - "id": "RLczuU293itT" - }, - "outputs": [], - "source": [ - "sampler = RandomGeoSampler(dataset, size=1000, length=10)" - ] - }, - { - "cell_type": "markdown", - "id": "5b8c1c52", - "metadata": { - "id": "OWa-mmYd8S6K" - }, - "source": [ - "## DataLoader\n", - "\n", - "Now that we have a Dataset and Sampler, we can combine these into a single DataLoader." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96faa142", - "metadata": { - "id": "jfx-9ZmU8ZTc" - }, - "outputs": [], - "source": [ - "dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)" - ] - }, - { - "cell_type": "markdown", - "id": "64ae63f7", - "metadata": { - "id": "HZIfqqW58oZe" - }, - "source": [ - "## Training\n", - "\n", - "Other than that, the rest of the training pipeline is the same as it is for torchvision." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a2b44f8", - "metadata": { - "id": "7sGmNvBy8uIg" - }, - "outputs": [], - "source": [ - "for sample in dataloader:\n", - " image = sample['image']\n", - " target = sample['mask']" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "getting_started.ipynb", - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/getting_started.py b/docs/tutorials/getting_started.py new file mode 100644 index 00000000000..947e857ffc3 --- /dev/null +++ b/docs/tutorials/getting_started.py @@ -0,0 +1,108 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="NdrXRgjU7Zih" +# # Getting Started +# +# In this tutorial, we demonstrate some of the basic features of TorchGeo and show how easy it is to use if you're already familiar with other PyTorch domain libraries like torchvision. +# +# It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started. + +# + [markdown] id="lCqHTGRYBZcz" +# ## Setup +# +# First, we install TorchGeo. +# - + +# %pip install torchgeo + +# + [markdown] id="dV0NLHfGBMWl" +# ## Imports +# +# Next, we import TorchGeo and any other libraries we need. + +# + id="entire-albania" +import os +import tempfile + +from torch.utils.data import DataLoader + +from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples +from torchgeo.datasets.utils import download_url +from torchgeo.samplers import RandomGeoSampler + +# + [markdown] id="5rLknZxrBEMz" +# ## Datasets +# +# For this tutorial, we'll be using imagery from the [National Agriculture Imagery Program (NAIP)](https://catalog.data.gov/dataset/national-agriculture-imagery-program-naip) and labels from the [Chesapeake Bay High-Resolution Land Cover Project](https://www.chesapeakeconservancy.org/conservation-innovation-center/high-resolution-data/land-cover-data-project/). First, we manually download a few NAIP tiles and create a PyTorch Dataset. + +# + colab={"base_uri": "https://localhost:8080/", "height": 232, "referenced_widgets": ["d00a2177bf4b4b8191bfc8796f0e749f", "17d6b81aec50455989276b595457cc7f", "06ccd130058b432dbfa025c102eaeb27", "6bc5b9872b574cb5aa6ebd1d44e7a71f", "f7746f028f874a85b6101185fc9a8efc", "f7ef78d6f87a4a2685788e395525fa7c", "5b2450e316e64b4ba432c78b63275124", "d3bbd6112f144c77bc68e5f2a7a355ff", "a0300b1252cd4da5a798b55c15f8f5fd", "793c2851b6464b398f7b4d2f2f509722", "8dd61c8479d74c95a55de147e04446b3", "b57d15e6c32b4fff8994ae67320972f6", "9a34f8907a264232adf6b0d0543461dd", "e680eda3c84c440083e2959f04431bea", "a073e33fd9ae4125822fc17971233770", "87faaa32454a42939d3bd405e726228c", "b3d4c9c99bec4e69a199e45920d52ce4", "a215f3310ea543d1a8991f57ec824872", "569f60397fd6440d825e8afb83b4e1ae", "b7f604d2ba4e4328a451725973fa755f", "737fa148dfae49a18cc0eabbe05f2d0f", "0b6613adbcc74165a9d9f74988af366e", "b25f274c737d4212b3ffeedb2372ba22", "ef0fc75ff5044171be942a6b3ba0c2da", "612d84013a6e4890a48eb229f6431233", "9a689285370646ab800155432ea042a5", "014ed48a23234e8b81dd7ac4dbf95817", "93c536a27b024728a00486b1f68b4dde", "8a8538a91a74439b81e3f7c6516763e3", "caf540562b484594bab8d6210dd7c2c1", "99cd2e65fb104380953745f2e0a93fac", "c5b818707bb64c5a865236a46399cea2", "54f5db9555c44efa9370cbb7ab58e142", "1d83b20dbb9c4c6a9d5c100fe4770ba4", "c51b2400ca9442a9a9e0712d5778cd9a", "bd2e44a8eb1a4c19a32da5a1edd647d1", "0f9feea4b8344a7f8054c9417150825e", "31acb7a1ca8940078e1aacd72e547f47", "0d0ca8d64d3e4c2f88d87342808dd677", "54402c5f8df34b83b95c94104b26e2c6", "910b98584fa74bb5ad308fe770f5b40e", "b2dce834ee044d69858389178b493a2b", "237f2e31bcfe476baafae8d922877e07", "43ac7d95481b4ea3866feef6ace2f043"]} id="e3138ac3" outputId="11589c46-eee6-455d-839b-390f2934d834" +naip_root = os.path.join(tempfile.gettempdir(), 'naip') +naip_url = ( + 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/' +) +tiles = [ + 'm_3807511_ne_18_060_20181104.tif', + 'm_3807511_se_18_060_20181104.tif', + 'm_3807512_nw_18_060_20180815.tif', + 'm_3807512_sw_18_060_20180815.tif', +] +for tile in tiles: + download_url(naip_url + tile, naip_root) + +naip = NAIP(naip_root) + +# + [markdown] id="HQVji2B22Qfu" +# Next, we tell TorchGeo to automatically download the corresponding Chesapeake labels. + +# + colab={"base_uri": "https://localhost:8080/"} id="2Ah34KAw2biY" outputId="03b7bdf0-78c1-4a13-ac56-59de740d7f59" +chesapeake_root = os.path.join(tempfile.gettempdir(), 'chesapeake') +os.makedirs(chesapeake_root, exist_ok=True) +chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True) + +# + [markdown] id="OWUhlfpD22IX" +# Finally, we create an IntersectionDataset so that we can automatically sample from both GeoDatasets simultaneously. + +# + id="WXxy8F8l2-aC" +dataset = naip & chesapeake + +# + [markdown] id="yF_R54Yf3EUd" +# ## Sampler +# +# Unlike typical PyTorch Datasets, TorchGeo GeoDatasets are indexed using lat/long/time bounding boxes. This requires us to use a custom GeoSampler instead of the default sampler/batch_sampler that comes with PyTorch. + +# + id="RLczuU293itT" +sampler = RandomGeoSampler(dataset, size=1000, length=10) + +# + [markdown] id="OWa-mmYd8S6K" +# ## DataLoader +# +# Now that we have a Dataset and Sampler, we can combine these into a single DataLoader. + +# + id="jfx-9ZmU8ZTc" +dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples) + +# + [markdown] id="HZIfqqW58oZe" +# ## Training +# +# Other than that, the rest of the training pipeline is the same as it is for torchvision. + +# + id="7sGmNvBy8uIg" +for sample in dataloader: + image = sample['image'] + target = sample['mask'] diff --git a/docs/tutorials/indices.ipynb b/docs/tutorials/indices.ipynb deleted file mode 100644 index 30576609ac8..00000000000 --- a/docs/tutorials/indices.ipynb +++ /dev/null @@ -1,382 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "DYndcZst_kdr" - }, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZKIkyiLScf9P" - }, - "source": [ - "# Indices" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PevsPoE4cY0j" - }, - "source": [ - "In this tutorial, we demonstrate how to use TorchGeo's functions and transforms to compute popular indices used in remote sensing and provide examples of how to utilize them for analyzing raw imagery or simply for visualization purposes. Some common indices and their formulas can be found at the following links:\n", - "\n", - "- [Index Database](https://www.indexdatabase.de/db/i.php)\n", - "- [Awesome Spectral Indices](https://github.com/awesome-spectral-indices/awesome-spectral-indices)\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fsOYw-p2ccka" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VqdMMzvacOF8" - }, - "source": [ - "Install TorchGeo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wOwsb8KT_uXR", - "outputId": "7a7ca2ff-a9a5-444f-99c8-61954b82c797" - }, - "outputs": [], - "source": [ - "%pip install torchgeo" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u2f5_f4X_-vV" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cvPMr76K_9uk" - }, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "import os\n", - "import tempfile\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from torchgeo.datasets import EuroSAT100\n", - "from torchgeo.transforms import AppendNDBI, AppendNDVI, AppendNDWI" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uh2IpthodK1R" - }, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vLW4KhtEwx-e" - }, - "source": [ - "We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. EuroSAT contains 13-channel multispectral imagery captured by the Sentinel-2 satellite." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 234, - "referenced_widgets": [ - "3565053e649f436fbe534d9a961b0b43", - "d87a5695377b495db20c9bab60a6965b", - "aebba2cac4084f8dbbaf171999eb6b1b", - "3dfed74698344f2ba749793161df3bf1", - "dbbcfb0780d94988a735bc2b5ba9d7b4", - "22f2f27e72ac48508afb133ecbabb860", - "bead2be5da244a328e82e1275dc714a2", - "42ca3d48adf4444f92ef62f3e724ee5f", - "541a3d7fea7f415e9150b02b4e64b759", - "c44913e195964ebc8e548a795443c8eb", - "38a6b50cd86f4f999afb343e8107155e", - "a94ca05f055c48d9a1ee2f2afee27825", - "afec7ff1326e45cc83037c2665f09d10", - "540f46c4200041cc82ef0341738ea736", - "4696d711c24c45a085340267e8519912", - "3d6fa143130e4c9eba5a0cdb1340e0c9", - "7d3f018af0c54609bdbe8017466c56ea", - "b33e1064e29047d3bd104dc9a4c9829e", - "e76359b5fb734b9ab21e1b549ac7af6d", - "7e6907e9aa8b46f99f58ccddc0287688", - "0f7bfe118f0c4f16a79bf255c9de6175", - "ea69ba4b5a42444c9958f00fae2c8d97", - "c98229f8b2e04823bf0746851c671f8a", - "fd9a29bed92648dda83878870347e1ec", - "4b7c7982ff1440919f1e59ab79ed4fb8", - "39b8bdc4e9714627bde2d9e4d8cad0eb", - "1cf14ec4410e458aa4eb2401b0b3db1d", - "3116ea407181459e95d95f49a4d558bb", - "87759c29fefb41939fae1d6d2f32b86b", - "b789b5d94c1b4495b3d2d8ee19e386ce", - "8186ff9b662c49529ff5cb630adb0304", - "571bd10adb7c4089ac67ebaffb39ed87", - "02fb96e5de6945c0bdbde3f05834a4ce", - "cf98bf5ee4c14e47911883e9c6f3aef5", - "98fbbf1bb5ed421d810b2bbabd4bcc1c", - "b9a81dcbbe1c4a57a3a4972d4e04be39", - "e22d78e388ad4a149e3d626e21359c36", - "e5142289278548f5aeb7e82dfc89dfd8", - "c0e3373d10e849b69d71be2f684178c9", - "3dbdfa7db0e9464689beab7ec9236ecc", - "51af8f9dffcf4eceaf70af2268b95ba3", - "9d47f8cc53cc4eed9529dabb52bd5d8d", - "4d492d6b56d04e7daff2b35fa1bbfa32", - "4237c5f8175d44d5a05940a8f05b4e90" - ] - }, - "id": "_seqhOz-Cw9c", - "outputId": "99187c2c-9f4f-4ca5-ec10-6635feaaf064" - }, - "outputs": [], - "source": [ - "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", - "ds = EuroSAT100(root, download=True)\n", - "sample = ds[21]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DLdI6L4A08vu" - }, - "source": [ - "## True Color (RGB) Image" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nkb51TZpPylY" - }, - "source": [ - "We can plot a true color image consisting of the first 3 channels (RGB) to visualize the sample." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 264 - }, - "id": "tymUtqGAQCQl", - "outputId": "a5df2986-cc56-410a-ff73-4fd7caec2b80" - }, - "outputs": [], - "source": [ - "ds.plot(sample)\n", - "plt.show()\n", - "plt.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWmSuZfI1gwN" - }, - "source": [ - "## Normalized Difference Vegetation Index (NDVI)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yGTZzk6_QH1a" - }, - "source": [ - "Below we use TorchGeo's `indices.AppendNDVI` to compute the [Normalized Difference Vegetation Index (NDVI)](https://gisgeography.com/ndvi-normalized-difference-vegetation-index/). NDVI is useful for measuring the presence of vegetation and vegetation health. It can be calculated using the Near Infrared (NIR) and Red bands using the formula below, resulting in a value between [-1, 1] where low NDVI values represents no or unhealthy vegetation and high NDVI values represents healthy vegetation. Here we use a diverging red, yellow, green colormap representing -1, 0, and 1, respectively.\n", - "\n", - "$$\\text{NDVI} = \\frac{\\text{NIR} - \\text{R}}{\\text{NIR} + \\text{R}}$$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 248 - }, - "id": "e9Aob95YQCQn", - "outputId": "0f12ca7e-d8ed-4327-8a41-0674d85e1df1" - }, - "outputs": [], - "source": [ - "# NDVI is appended to channel dimension (dim=0)\n", - "index = AppendNDVI(index_nir=7, index_red=3)\n", - "image = sample['image']\n", - "image = index(image)[0]\n", - "\n", - "# Normalize from [-1, 1] -> [0, 1] for visualization\n", - "image[-1] = (image[-1] + 1) / 2\n", - "\n", - "plt.imshow(image[-1], cmap='RdYlGn')\n", - "plt.axis('off')\n", - "plt.show()\n", - "plt.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "igNq-9m91nDt" - }, - "source": [ - "## Normalized Difference Water Index (NDWI)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fDFQm4vOZEUH" - }, - "source": [ - "Below we use TorchGeo's `indices.AppendNDWI` to compute the [Normalized Difference Water Index (NDWI)](https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/ndwi/). NDWI is useful for measuring the presence of water content in water bodies. It can be calculated using the Green and Near Infrared (NIR) bands using the formula below, resulting in a value between [-1, 1] where low NDWI values represents no water and high NDWI values represents water bodies. Here we use a diverging brown, white, blue-green colormap representing -1, 0, and 1, respectively.\n", - "\n", - "$$\\text{NDWI} = \\frac{\\text{G} - \\text{NIR}}{\\text{G} + \\text{NIR}}$$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 248 - }, - "id": "H8CPnPD9QCQp", - "outputId": "81b48a45-ab72-496b-cc4c-283b8a239396" - }, - "outputs": [], - "source": [ - "# NDWI is appended to channel dimension (dim=0)\n", - "index = AppendNDWI(index_green=2, index_nir=7)\n", - "image = index(image)[0]\n", - "\n", - "# Normalize from [-1, 1] -> [0, 1] for visualization\n", - "image[-1] = (image[-1] + 1) / 2\n", - "\n", - "plt.imshow(image[-1], cmap='BrBG')\n", - "plt.axis('off')\n", - "plt.show()\n", - "plt.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oKEyz9TP2OK_" - }, - "source": [ - "## Normalized Difference Built-up Index (NDBI)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ap3adJ06bXul" - }, - "source": [ - "Below we use TorchGeo's `indices.AppendNDBI` to compute the [Normalized Difference Built-up Index (NDBI)](https://www.linkedin.com/pulse/ndvi-ndbi-ndwi-calculation-using-landsat-7-8-tek-bahadur-kshetri/). NDBI is useful for measuring the presence of urban buildings. It can be calculated using the Short-wave Infrared (SWIR) and Near Infrared (NIR) bands using the formula below, resulting in a value between [-1, 1] where low NDBI values represents no urban land and high NDBI values represents urban land. Here we use a terrain colormap with blue, green-yellow, and brown representing -1, 0, and 1, respectively.\n", - "\n", - "$$\\text{NDBI} = \\frac{\\text{SWIR} - \\text{NIR}}{\\text{SWIR} + \\text{NIR}}$$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 248 - }, - "id": "F607CaDoQCQq", - "outputId": "f18b2337-05e1-48d7-bc50-d13c536e0177" - }, - "outputs": [], - "source": [ - "# NDBI is appended to channel dimension (dim=0)\n", - "index = AppendNDBI(index_swir=11, index_nir=7)\n", - "image = index(image)[0]\n", - "\n", - "# Normalize from [-1, 1] -> [0, 1] for visualization\n", - "image[-1] = (image[-1] + 1) / 2\n", - "\n", - "plt.imshow(image[-1], cmap='terrain')\n", - "plt.axis('off')\n", - "plt.show()\n", - "plt.close()" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/tutorials/indices.py b/docs/tutorials/indices.py new file mode 100644 index 00000000000..cd2a7668dd9 --- /dev/null +++ b/docs/tutorials/indices.py @@ -0,0 +1,137 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + [markdown] id="DYndcZst_kdr" +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="ZKIkyiLScf9P" +# # Indices + +# + [markdown] id="PevsPoE4cY0j" +# In this tutorial, we demonstrate how to use TorchGeo's functions and transforms to compute popular indices used in remote sensing and provide examples of how to utilize them for analyzing raw imagery or simply for visualization purposes. Some common indices and their formulas can be found at the following links: +# +# - [Index Database](https://www.indexdatabase.de/db/i.php) +# - [Awesome Spectral Indices](https://github.com/awesome-spectral-indices/awesome-spectral-indices) +# +# It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started. + +# + [markdown] id="fsOYw-p2ccka" +# ## Setup + +# + [markdown] id="VqdMMzvacOF8" +# Install TorchGeo + +# + colab={"base_uri": "https://localhost:8080/"} id="wOwsb8KT_uXR" outputId="7a7ca2ff-a9a5-444f-99c8-61954b82c797" +# %pip install torchgeo + +# + [markdown] id="u2f5_f4X_-vV" +# ## Imports + +# + id="cvPMr76K_9uk" +# %matplotlib inline +import os +import tempfile + +import matplotlib.pyplot as plt + +from torchgeo.datasets import EuroSAT100 +from torchgeo.transforms import AppendNDBI, AppendNDVI, AppendNDWI + +# + [markdown] id="Uh2IpthodK1R" +# ## Dataset + +# + [markdown] id="vLW4KhtEwx-e" +# We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. EuroSAT contains 13-channel multispectral imagery captured by the Sentinel-2 satellite. + +# + colab={"base_uri": "https://localhost:8080/", "height": 234, "referenced_widgets": ["3565053e649f436fbe534d9a961b0b43", "d87a5695377b495db20c9bab60a6965b", "aebba2cac4084f8dbbaf171999eb6b1b", "3dfed74698344f2ba749793161df3bf1", "dbbcfb0780d94988a735bc2b5ba9d7b4", "22f2f27e72ac48508afb133ecbabb860", "bead2be5da244a328e82e1275dc714a2", "42ca3d48adf4444f92ef62f3e724ee5f", "541a3d7fea7f415e9150b02b4e64b759", "c44913e195964ebc8e548a795443c8eb", "38a6b50cd86f4f999afb343e8107155e", "a94ca05f055c48d9a1ee2f2afee27825", "afec7ff1326e45cc83037c2665f09d10", "540f46c4200041cc82ef0341738ea736", "4696d711c24c45a085340267e8519912", "3d6fa143130e4c9eba5a0cdb1340e0c9", "7d3f018af0c54609bdbe8017466c56ea", "b33e1064e29047d3bd104dc9a4c9829e", "e76359b5fb734b9ab21e1b549ac7af6d", "7e6907e9aa8b46f99f58ccddc0287688", "0f7bfe118f0c4f16a79bf255c9de6175", "ea69ba4b5a42444c9958f00fae2c8d97", "c98229f8b2e04823bf0746851c671f8a", "fd9a29bed92648dda83878870347e1ec", "4b7c7982ff1440919f1e59ab79ed4fb8", "39b8bdc4e9714627bde2d9e4d8cad0eb", "1cf14ec4410e458aa4eb2401b0b3db1d", "3116ea407181459e95d95f49a4d558bb", "87759c29fefb41939fae1d6d2f32b86b", "b789b5d94c1b4495b3d2d8ee19e386ce", "8186ff9b662c49529ff5cb630adb0304", "571bd10adb7c4089ac67ebaffb39ed87", "02fb96e5de6945c0bdbde3f05834a4ce", "cf98bf5ee4c14e47911883e9c6f3aef5", "98fbbf1bb5ed421d810b2bbabd4bcc1c", "b9a81dcbbe1c4a57a3a4972d4e04be39", "e22d78e388ad4a149e3d626e21359c36", "e5142289278548f5aeb7e82dfc89dfd8", "c0e3373d10e849b69d71be2f684178c9", "3dbdfa7db0e9464689beab7ec9236ecc", "51af8f9dffcf4eceaf70af2268b95ba3", "9d47f8cc53cc4eed9529dabb52bd5d8d", "4d492d6b56d04e7daff2b35fa1bbfa32", "4237c5f8175d44d5a05940a8f05b4e90"]} id="_seqhOz-Cw9c" outputId="99187c2c-9f4f-4ca5-ec10-6635feaaf064" +root = os.path.join(tempfile.gettempdir(), 'eurosat100') +ds = EuroSAT100(root, download=True) +sample = ds[21] + +# + [markdown] id="DLdI6L4A08vu" +# ## True Color (RGB) Image + +# + [markdown] id="nkb51TZpPylY" +# We can plot a true color image consisting of the first 3 channels (RGB) to visualize the sample. + +# + colab={"base_uri": "https://localhost:8080/", "height": 264} id="tymUtqGAQCQl" outputId="a5df2986-cc56-410a-ff73-4fd7caec2b80" +ds.plot(sample) +plt.show() +plt.close() + +# + [markdown] id="jWmSuZfI1gwN" +# ## Normalized Difference Vegetation Index (NDVI) + +# + [markdown] id="yGTZzk6_QH1a" +# Below we use TorchGeo's `indices.AppendNDVI` to compute the [Normalized Difference Vegetation Index (NDVI)](https://gisgeography.com/ndvi-normalized-difference-vegetation-index/). NDVI is useful for measuring the presence of vegetation and vegetation health. It can be calculated using the Near Infrared (NIR) and Red bands using the formula below, resulting in a value between [-1, 1] where low NDVI values represents no or unhealthy vegetation and high NDVI values represents healthy vegetation. Here we use a diverging red, yellow, green colormap representing -1, 0, and 1, respectively. +# +# $$\text{NDVI} = \frac{\text{NIR} - \text{R}}{\text{NIR} + \text{R}}$$ + +# + colab={"base_uri": "https://localhost:8080/", "height": 248} id="e9Aob95YQCQn" outputId="0f12ca7e-d8ed-4327-8a41-0674d85e1df1" +# NDVI is appended to channel dimension (dim=0) +index = AppendNDVI(index_nir=7, index_red=3) +image = sample['image'] +image = index(image)[0] + +# Normalize from [-1, 1] -> [0, 1] for visualization +image[-1] = (image[-1] + 1) / 2 + +plt.imshow(image[-1], cmap='RdYlGn') +plt.axis('off') +plt.show() +plt.close() + +# + [markdown] id="igNq-9m91nDt" +# ## Normalized Difference Water Index (NDWI) + +# + [markdown] id="fDFQm4vOZEUH" +# Below we use TorchGeo's `indices.AppendNDWI` to compute the [Normalized Difference Water Index (NDWI)](https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/ndwi/). NDWI is useful for measuring the presence of water content in water bodies. It can be calculated using the Green and Near Infrared (NIR) bands using the formula below, resulting in a value between [-1, 1] where low NDWI values represents no water and high NDWI values represents water bodies. Here we use a diverging brown, white, blue-green colormap representing -1, 0, and 1, respectively. +# +# $$\text{NDWI} = \frac{\text{G} - \text{NIR}}{\text{G} + \text{NIR}}$$ + +# + colab={"base_uri": "https://localhost:8080/", "height": 248} id="H8CPnPD9QCQp" outputId="81b48a45-ab72-496b-cc4c-283b8a239396" +# NDWI is appended to channel dimension (dim=0) +index = AppendNDWI(index_green=2, index_nir=7) +image = index(image)[0] + +# Normalize from [-1, 1] -> [0, 1] for visualization +image[-1] = (image[-1] + 1) / 2 + +plt.imshow(image[-1], cmap='BrBG') +plt.axis('off') +plt.show() +plt.close() + +# + [markdown] id="oKEyz9TP2OK_" +# ## Normalized Difference Built-up Index (NDBI) + +# + [markdown] id="ap3adJ06bXul" +# Below we use TorchGeo's `indices.AppendNDBI` to compute the [Normalized Difference Built-up Index (NDBI)](https://www.linkedin.com/pulse/ndvi-ndbi-ndwi-calculation-using-landsat-7-8-tek-bahadur-kshetri/). NDBI is useful for measuring the presence of urban buildings. It can be calculated using the Short-wave Infrared (SWIR) and Near Infrared (NIR) bands using the formula below, resulting in a value between [-1, 1] where low NDBI values represents no urban land and high NDBI values represents urban land. Here we use a terrain colormap with blue, green-yellow, and brown representing -1, 0, and 1, respectively. +# +# $$\text{NDBI} = \frac{\text{SWIR} - \text{NIR}}{\text{SWIR} + \text{NIR}}$$ + +# + colab={"base_uri": "https://localhost:8080/", "height": 248} id="F607CaDoQCQq" outputId="f18b2337-05e1-48d7-bc50-d13c536e0177" +# NDBI is appended to channel dimension (dim=0) +index = AppendNDBI(index_swir=11, index_nir=7) +image = index(image)[0] + +# Normalize from [-1, 1] -> [0, 1] for visualization +image[-1] = (image[-1] + 1) / 2 + +plt.imshow(image[-1], cmap='terrain') +plt.axis('off') +plt.show() +plt.close() diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb deleted file mode 100644 index 28b354efee6..00000000000 --- a/docs/tutorials/pretrained_weights.ipynb +++ /dev/null @@ -1,508 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "p63J-QmUrMN-" - }, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XRSkMFqyrMOE" - }, - "source": [ - "# Pretrained Weights\n", - "\n", - "In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images.\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NBa5RPAirMOF" - }, - "source": [ - "## Setup\n", - "\n", - "First, we install TorchGeo." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5AIQ1B9DrMOG", - "outputId": "6bf360ea-8f60-45cf-c96e-0eac54818079" - }, - "outputs": [], - "source": [ - "%pip install torchgeo" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IcCOnzVLrMOI" - }, - "source": [ - "## Imports\n", - "\n", - "Next, we import TorchGeo and any other libraries we need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rjEGiiurrMOI" - }, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "\n", - "import os\n", - "import tempfile\n", - "\n", - "import timm\n", - "import torch\n", - "from lightning.pytorch import Trainer\n", - "\n", - "from torchgeo.datamodules import EuroSAT100DataModule\n", - "from torchgeo.models import ResNet18_Weights\n", - "from torchgeo.trainers import ClassificationTask" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "njAH71F3rMOJ" - }, - "source": [ - "The following variables can be used to control training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TVG_Z9MKrMOJ", - "nbmake": { - "mock": { - "batch_size": 1, - "fast_dev_run": true, - "max_epochs": 1, - "num_workers": 0 - } - } - }, - "outputs": [], - "source": [ - "batch_size = 10\n", - "num_workers = 2\n", - "max_epochs = 10\n", - "fast_dev_run = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QNnDoIf2rMOK" - }, - "source": [ - "## Datamodule\n", - "\n", - "We will utilize TorchGeo's [Lightning](https://lightning.ai/docs/pytorch/stable/) datamodules to organize the dataloader setup." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ia5ktOVerMOL" - }, - "outputs": [], - "source": [ - "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", - "datamodule = EuroSAT100DataModule(\n", - " root=root, batch_size=batch_size, num_workers=num_workers, download=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ksoszRjZrMOL" - }, - "source": [ - "## Weights\n", - "\n", - "Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n", - "\n", - "To access these weights you can do the following:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wJOrRqBGrMOM" - }, - "outputs": [], - "source": [ - "weights = ResNet18_Weights.SENTINEL2_ALL_MOCO" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EIpnXuXgrMOM" - }, - "source": [ - "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 86, - "referenced_widgets": [ - "8fb7022c6e4947d8955ed6da0d89ef05", - "c5adf2ec163a43f08b3de98d0eabf8df", - "4c6928b34aef4e778c837993c4197bcd", - "ec96c2767d30412c8af5306a5c2f5ee3", - "cbfb522eafbd4d4991ca5f45ad32bd2c", - "5d02318e1c9e4034bd686527cdbb18ef", - "739421651fb84d31a7baeacef2f8226c", - "8d83a7dad192492facd4b3e03cbb2392", - "fc423924e14a49b18b99064ab45e3f1e", - "17dd6c0bbbf947b1a47a9cb8267d43ef", - "8704c42feea442abb67cfd55ae3c4fa9" - ] - }, - "id": "RZ8MPYH1rMON", - "outputId": "fa683b8f-da21-4f26-ca3a-46163c9f12bf" - }, - "outputs": [], - "source": [ - "task = ClassificationTask(\n", - " model='resnet18',\n", - " loss='ce',\n", - " weights=weights,\n", - " in_channels=13,\n", - " num_classes=10,\n", - " lr=0.001,\n", - " patience=5,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dWidC6vDrMON" - }, - "source": [ - "If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZaZQ07jorMOO" - }, - "outputs": [], - "source": [ - "in_chans = weights.meta['in_chans']\n", - "model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)\n", - "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vgNswWKOrMOO" - }, - "source": [ - "## Training\n", - "\n", - "To train our pretrained model on the EuroSAT dataset we will make use of Lightning's [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). For a more elaborate explanation of how TorchGeo uses Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0Sf-CBorrMOO" - }, - "outputs": [], - "source": [ - "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", - "default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "veVvF-5LrMOP", - "outputId": "698e3e9e-8a53-4897-d40e-13b43470e29e" - }, - "outputs": [], - "source": [ - "trainer = Trainer(\n", - " accelerator=accelerator,\n", - " default_root_dir=default_root_dir,\n", - " fast_dev_run=fast_dev_run,\n", - " log_every_n_steps=1,\n", - " min_epochs=1,\n", - " max_epochs=max_epochs,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 544, - "referenced_widgets": [ - "3435cd7003324d5dbee2222c1d919595", - "0a2c8c97b5df4d6c9388ed9637c19e8f", - "4f5209cdf8e346368da573e818ac8b98", - "8d9a2d2377b74bb397cadf9af65849ae", - "08871807406540fba8a935793cd28f93", - "cc743f08d18f4302b7c3fc8506c4c9d4", - "287c893f66294687910eccac3bb6353a", - "088567dd04944a7682b2add8353a4076", - "3af55148c50045e5b043485200ba7547", - "b68eec0296fc45458b5d4cd2111afe6e", - "b3423fc7e10a4cf3a16d6bc08f99dd17", - "9af62a957121474d8d8b2be00a083287", - "c728392f8d1f4da79492fb65fa10fbbe", - "01ac3f0b67844950a6d1e5fe44978d1b", - "d0140e681f9a437084b496927244f8ac", - "cbfee603fbad42daacef920f3acc6a22", - "cf52dda9f70249b99c893ff886a96c5b", - "9800bfe6afbb4185b25a81017cabe720", - "10538d0692ee42648400a204abb788dc", - "d6169eaaff4b4918acc91f88ce82b14b", - "71ca3d663ae8402497c9f6bb507e5e1b", - "2dfc6c6897d0449d9078ed2fd7e35e4b", - "8020d3f767d849cdbe2126c50dc546c8", - "198bcafb3e4146a984e26260dc0ca2fa", - "e774bef4588241e8bfd2a46a3ad0489d", - "bf6c49957c3b4b3ebc09819d6259c3bf", - "10b348c37c0b46018960247d083ed439", - "2174d7e0a0804ff687c4f6985d28b9c5", - "e500c7a093d14eed8ea9ac66c759f206", - "31eddc17aa7843beb41d740698b7276d", - "9a935b8fdbd4425988fb52e01dcb5d2d", - "487eb01386324ecfb7d4fae26b84c93d", - "b08c4d2cf51049c79ec61aa72b6e787e", - "a0b2c0f1d01d4b4ba8ab4ef2d1db6c1a", - "8ace9fe0a980414e90ae09c431f5f126", - "2813e52affed45e6901c468c573a69cb", - "a085a49f93a54e289d4007e0595d6de1", - "df0bf05e12404f199d5a412c8a077da0", - "ca5cdfec77ff49f7a088a96b910b796c", - "82ad1d0d948b44829471f1818ccc5b30", - "dedb7f740221433ea8d4668145b7818d", - "4b1d82016e4c47e48d89e86492aacdcb", - "e774743224c44ea5bf997a3842395d1b", - "24f5ac9cf9f44be892a138892a7bd219", - "7debe41bd81346eb951605e0048fadf9", - "9b4a4015605142bdb6d5147dc5cbc9e2", - "91f26f9100af4d3da3b2628f3ec6c3d3", - "43437ff07611497eb2bf4822fb434e53", - "ca455b6a71274921b8dbb789e9005449", - "164ccc2485c44f4ab37b72c47e5f0898", - "13a7139766e54d1e9073648c018b41bc", - "ea23ac32759c4c6d899911f3bad213ee", - "109c5832df954967b40487d6a51929c1", - "ec16f29395b34aa6b127d0d63502a88b", - "38071ab74d6e464a9b320fc2fc40bda3", - "fd6a665f5a1943c2ad22b97b0c1d6999", - "05b08c2cc14f4e0cae59fef46e283431", - "f42a11c6eb4c449eb214287e3e61884a", - "2167acd001b34219875104f2555a54ee", - "f71c8ebc88cb483c8adac440b787b7b8", - "ce776b04772f4795b5cf3c2279377634", - "c5c491755a644c5ebd1e7c34d5a2437c", - "e2c89dc3dea942aa9261a215b2a81d15", - "e2e4c0a732384398bec0c89af39a1c95", - "2aaa5c2a04d74d408434a43632e38d7a", - "84128d9cef9f4926b3c321ccd83d761d", - "6bcc169344064f15b4410261de4c16a5", - "f449d843801f4f829ea75a2c6e4b67db", - "9cc12d2e6b0846828d10cf8e48f1874d", - "5afb9ef1a858408485139346bf4960f0", - "1453657473fc45fd8cc43a3fd6d911ad", - "ac0e5a1eb7394e6a9a20b50b9cb3b2d6", - "d7882f9f6ff947828a9d965b07ae6308", - "8fca6a1b68ee4bfa96dafdc5c438ef91", - "dceb2d342e04422c9b965d6759b740d1", - "92c572a2a1d74e789576318f3edb08bf", - "96fee82a24e84a3ab39e8543e97fe539", - "2128512dbe3e4a02a68f86d13b0f467c", - "1b9d867c345843c78fdc27dd0efac5ff", - "44487915e4204d6abf86766d6b17e3df", - "e9a2537dd76144b395e308bdb0494321", - "2ca0650dfd9e4f30a206a81a16258c2e", - "13a0b8dcb53f4f50a1e253c33c197dd6", - "6430bd284b4c4453ab64b88bfdba4f9b", - "e91501471db04da991e73bbf28abb71d", - "a6e3c391120a4171b59e69f7111d34cd", - "df8651ab692b4a648a2e412b2e045bae", - "2133201a6113468490d75b3a4ef942b1", - "e6698265603c46aca5ce96aee64b485b", - "d6680844b3e345e3af2bb27d9afc60f0", - "e0dd466941a14188a40bfbe9d9f818ee", - "e6a846177d7d49e6a05eecc1174dc574", - "e39fed0066c14e16bfd3f9c366f0f288", - "f59c4c11e1b34a4288656f28c228ec76", - "f1aea0224e7246a39011c36f7ca72702", - "5ae6f538d2aa4be785b7be93af885562", - "4929a4a1103e47c3bd16e1db31f478e8", - "3b8ccaf505604661a65efd0267c0df0a", - "6bcd9a3684d0499395c21fbb617ed890", - "a0bbee764501493887d45b15f58d1c84", - "899c206bc1b948b7aede52ba26b0ab0e", - "6c0d5a1b905743ff86613e88e151b120", - "98c9c5edc84e4d70a63406fdd1e360e5", - "3200b22ab9bd4ffab358630eb09d21bf", - "eb0c0a2a3d814c08b8e78bd99f99048f", - "366fbb8b74554e7d87329676bbe53488", - "736cf1d1b23e4d6aba833d9977b77626", - "001bf9ab98ec495782911cab6a91fd67", - "8875242b17124440bac9eada00cbf7cc", - "a2e6022491e04e5b882907db35ac5c66", - "b040763ebdf646fd8d3d1d24914bcb47", - "a380e1c490a84b75b60f4e50e979f707", - "70382059066c41328a3348a7056032d1", - "81b1f031892043bfb7bc385edf52ce48", - "4f3da49fe0504640ad7c08fdf0b80113", - "69ff1319e4c64d47b9925bb17b9c7c93", - "3820af62fa1d498f80598e204a69e60d", - "d769c7a4405f474284f9f5ca94116858", - "d71c86b25cc048508a351238b4cd171f", - "51c78f4a481c4982a7576b4b9c1c81fb", - "811893b6b9d24f7f9976328d67eeb0c4", - "78206d211645442bbdb8675c175cc233", - "b1146db74d2f45859997e2173ce3350c", - "62fad5fc0fc24099a0a95305206a5950", - "d5ff5612fc3b448c8ad18b06fcf4f9cf", - "17b95a1469594b2791a0d057cf4b3367", - "8f09c3ccd53c44faac8ad55234c1487b", - "4c3e55fdbc024ad2be6ac2facf1a44db", - "b88c7c0bc00b4d88a1c49f11389126b5", - "fd2d47e147794aa4bedb9933fc4900ce", - "7a429ced76e0431ea49a795ec30d5ed2", - "08d02351fc8543228e70629577344c7a", - "4e38ef1829cb4c25ad84a85c3a7ab221", - "25193dcc82f54db8a2993bc4f828fa05", - "cfa2f244cd8740798360a2c4722d5748", - "ed353a8c04d6417a80ed5ddc1bcebd80", - "2252979760814e9da9e455fd34cce955", - "8414aeb74f4b4a03b66544867bbe5e8e", - "acb7fa2daf5646dca3eb9c9a0d29dc06", - "efaf4f813c8d4385a4308a8c7db4958a", - "f860da95c8b24bc1969b887d36ad9902", - "e33fb707901547d5a4045ea80e5a7c76", - "d2144b82918b4dbca37e9a0c1a90c2e6", - "eabe4f7af9fe47a6af5509bba81ec0cd", - "fead1396f54f4d86a52a04acb2deeea0", - "c0c405dcda684c88a6995805bbe1a714", - "ae14e90e6a084d0ca79dd0376dcbaa3a", - "621eb1d3097147b1991faa1474e7d966", - "aa7b75f8100749519494896ba4d12b1c", - "84cdf31dac71421b948474bd9e4cde53", - "4f5ad2bae09e4112b6b3a95397d7fb55", - "8e3980c9433447f89740624bd85a5768", - "7e22fc0da31d490aa43be2153cfd7403", - "a692d570c31649c28fea3bfac1aaab01", - "85755c2e08a34a2daefd4555df0b4fff", - "23618d84213a4ea98b9c17bf969272cd", - "44690735319544eda0fbd509b92b52c0", - "dbc11af4883c4f85b6805f8b52e8170c", - "78aed98077a74a75bc408d6fffa05065", - "941c2e43942448fba0ff0b13e4f6bf66", - "e32858052ad54f4f859a599dce4d13d6", - "19c4c142c6224c09ad29b13d0f9af2fe", - "6899eab30b4044a8a6493cd5acd48e7a", - "a23b211aefd54270821aea17dd96f88a", - "17af3f9a527d419ca4b369db183ae659", - "756215384ee8444ab59695bf05d29b55", - "d4e4b2592d54418f9744e8d3bfced6e5", - "bcfe68ab1f46499baf4e29bc9628ef64", - "f5145a7dbec343c39ea80958b5331caa", - "aed96231191b47a3adeac49c4121f814", - "4aad22e983004b1e9e4cebae0fe2897a", - "71782ed9d2f94f55be80f7c6dc1b0a59", - "945afa9b582e4431a71c8a39f4d7cacc", - "d723ac8b9b314ae8ae9861570d192bdf", - "11bed5e93c8b46c7a21de82c10e5cfd4", - "b2d11d577ed54cb3bcd49f766cc5013a" - ] - }, - "id": "9WQD4cuwrMOP", - "outputId": "590f75dd-064b-4bcc-b504-167bf2ad6cfb" - }, - "outputs": [], - "source": [ - "trainer.fit(model=task, datamodule=datamodule)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "vscode": { - "interpreter": { - "hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60" - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/tutorials/pretrained_weights.py b/docs/tutorials/pretrained_weights.py new file mode 100644 index 00000000000..f8dde65a6ae --- /dev/null +++ b/docs/tutorials/pretrained_weights.py @@ -0,0 +1,126 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + [markdown] id="p63J-QmUrMN-" +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="XRSkMFqyrMOE" +# # Pretrained Weights +# +# In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. +# +# It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started. + +# + [markdown] id="NBa5RPAirMOF" +# ## Setup +# +# First, we install TorchGeo. + +# + colab={"base_uri": "https://localhost:8080/"} id="5AIQ1B9DrMOG" outputId="6bf360ea-8f60-45cf-c96e-0eac54818079" +# %pip install torchgeo + +# + [markdown] id="IcCOnzVLrMOI" +# ## Imports +# +# Next, we import TorchGeo and any other libraries we need. + +# + id="rjEGiiurrMOI" +# %matplotlib inline + +import os +import tempfile + +import timm +import torch +from lightning.pytorch import Trainer + +from torchgeo.datamodules import EuroSAT100DataModule +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import ClassificationTask + +# + [markdown] id="njAH71F3rMOJ" +# The following variables can be used to control training. + +# + id="TVG_Z9MKrMOJ" nbmake={"mock": {"batch_size": 1, "fast_dev_run": true, "max_epochs": 1, "num_workers": 0}} +batch_size = 10 +num_workers = 2 +max_epochs = 10 +fast_dev_run = False + +# + [markdown] id="QNnDoIf2rMOK" +# ## Datamodule +# +# We will utilize TorchGeo's [Lightning](https://lightning.ai/docs/pytorch/stable/) datamodules to organize the dataloader setup. + +# + id="ia5ktOVerMOL" +root = os.path.join(tempfile.gettempdir(), 'eurosat100') +datamodule = EuroSAT100DataModule( + root=root, batch_size=batch_size, num_workers=num_workers, download=True +) + +# + [markdown] id="ksoszRjZrMOL" +# ## Weights +# +# Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data. +# +# To access these weights you can do the following: + +# + id="wJOrRqBGrMOM" +weights = ResNet18_Weights.SENTINEL2_ALL_MOCO + +# + [markdown] id="EIpnXuXgrMOM" +# This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic. + +# + colab={"base_uri": "https://localhost:8080/", "height": 86, "referenced_widgets": ["8fb7022c6e4947d8955ed6da0d89ef05", "c5adf2ec163a43f08b3de98d0eabf8df", "4c6928b34aef4e778c837993c4197bcd", "ec96c2767d30412c8af5306a5c2f5ee3", "cbfb522eafbd4d4991ca5f45ad32bd2c", "5d02318e1c9e4034bd686527cdbb18ef", "739421651fb84d31a7baeacef2f8226c", "8d83a7dad192492facd4b3e03cbb2392", "fc423924e14a49b18b99064ab45e3f1e", "17dd6c0bbbf947b1a47a9cb8267d43ef", "8704c42feea442abb67cfd55ae3c4fa9"]} id="RZ8MPYH1rMON" outputId="fa683b8f-da21-4f26-ca3a-46163c9f12bf" +task = ClassificationTask( + model='resnet18', + loss='ce', + weights=weights, + in_channels=13, + num_classes=10, + lr=0.001, + patience=5, +) + +# + [markdown] id="dWidC6vDrMON" +# If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows: + +# + id="ZaZQ07jorMOO" +in_chans = weights.meta['in_chans'] +model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10) +model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + +# + [markdown] id="vgNswWKOrMOO" +# ## Training +# +# To train our pretrained model on the EuroSAT dataset we will make use of Lightning's [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). For a more elaborate explanation of how TorchGeo uses Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html). + +# + id="0Sf-CBorrMOO" +accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' +default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments') + +# + colab={"base_uri": "https://localhost:8080/"} id="veVvF-5LrMOP" outputId="698e3e9e-8a53-4897-d40e-13b43470e29e" +trainer = Trainer( + accelerator=accelerator, + default_root_dir=default_root_dir, + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + min_epochs=1, + max_epochs=max_epochs, +) + +# + colab={"base_uri": "https://localhost:8080/", "height": 544, "referenced_widgets": ["3435cd7003324d5dbee2222c1d919595", "0a2c8c97b5df4d6c9388ed9637c19e8f", "4f5209cdf8e346368da573e818ac8b98", "8d9a2d2377b74bb397cadf9af65849ae", "08871807406540fba8a935793cd28f93", "cc743f08d18f4302b7c3fc8506c4c9d4", "287c893f66294687910eccac3bb6353a", "088567dd04944a7682b2add8353a4076", "3af55148c50045e5b043485200ba7547", "b68eec0296fc45458b5d4cd2111afe6e", "b3423fc7e10a4cf3a16d6bc08f99dd17", "9af62a957121474d8d8b2be00a083287", "c728392f8d1f4da79492fb65fa10fbbe", "01ac3f0b67844950a6d1e5fe44978d1b", "d0140e681f9a437084b496927244f8ac", "cbfee603fbad42daacef920f3acc6a22", "cf52dda9f70249b99c893ff886a96c5b", "9800bfe6afbb4185b25a81017cabe720", "10538d0692ee42648400a204abb788dc", "d6169eaaff4b4918acc91f88ce82b14b", "71ca3d663ae8402497c9f6bb507e5e1b", "2dfc6c6897d0449d9078ed2fd7e35e4b", "8020d3f767d849cdbe2126c50dc546c8", "198bcafb3e4146a984e26260dc0ca2fa", "e774bef4588241e8bfd2a46a3ad0489d", "bf6c49957c3b4b3ebc09819d6259c3bf", "10b348c37c0b46018960247d083ed439", "2174d7e0a0804ff687c4f6985d28b9c5", "e500c7a093d14eed8ea9ac66c759f206", "31eddc17aa7843beb41d740698b7276d", "9a935b8fdbd4425988fb52e01dcb5d2d", "487eb01386324ecfb7d4fae26b84c93d", "b08c4d2cf51049c79ec61aa72b6e787e", "a0b2c0f1d01d4b4ba8ab4ef2d1db6c1a", "8ace9fe0a980414e90ae09c431f5f126", "2813e52affed45e6901c468c573a69cb", "a085a49f93a54e289d4007e0595d6de1", "df0bf05e12404f199d5a412c8a077da0", "ca5cdfec77ff49f7a088a96b910b796c", "82ad1d0d948b44829471f1818ccc5b30", "dedb7f740221433ea8d4668145b7818d", "4b1d82016e4c47e48d89e86492aacdcb", "e774743224c44ea5bf997a3842395d1b", "24f5ac9cf9f44be892a138892a7bd219", "7debe41bd81346eb951605e0048fadf9", "9b4a4015605142bdb6d5147dc5cbc9e2", "91f26f9100af4d3da3b2628f3ec6c3d3", "43437ff07611497eb2bf4822fb434e53", "ca455b6a71274921b8dbb789e9005449", "164ccc2485c44f4ab37b72c47e5f0898", "13a7139766e54d1e9073648c018b41bc", "ea23ac32759c4c6d899911f3bad213ee", "109c5832df954967b40487d6a51929c1", "ec16f29395b34aa6b127d0d63502a88b", "38071ab74d6e464a9b320fc2fc40bda3", "fd6a665f5a1943c2ad22b97b0c1d6999", "05b08c2cc14f4e0cae59fef46e283431", "f42a11c6eb4c449eb214287e3e61884a", "2167acd001b34219875104f2555a54ee", "f71c8ebc88cb483c8adac440b787b7b8", "ce776b04772f4795b5cf3c2279377634", "c5c491755a644c5ebd1e7c34d5a2437c", "e2c89dc3dea942aa9261a215b2a81d15", "e2e4c0a732384398bec0c89af39a1c95", "2aaa5c2a04d74d408434a43632e38d7a", "84128d9cef9f4926b3c321ccd83d761d", "6bcc169344064f15b4410261de4c16a5", "f449d843801f4f829ea75a2c6e4b67db", "9cc12d2e6b0846828d10cf8e48f1874d", "5afb9ef1a858408485139346bf4960f0", "1453657473fc45fd8cc43a3fd6d911ad", "ac0e5a1eb7394e6a9a20b50b9cb3b2d6", "d7882f9f6ff947828a9d965b07ae6308", "8fca6a1b68ee4bfa96dafdc5c438ef91", "dceb2d342e04422c9b965d6759b740d1", "92c572a2a1d74e789576318f3edb08bf", "96fee82a24e84a3ab39e8543e97fe539", "2128512dbe3e4a02a68f86d13b0f467c", "1b9d867c345843c78fdc27dd0efac5ff", "44487915e4204d6abf86766d6b17e3df", "e9a2537dd76144b395e308bdb0494321", "2ca0650dfd9e4f30a206a81a16258c2e", "13a0b8dcb53f4f50a1e253c33c197dd6", "6430bd284b4c4453ab64b88bfdba4f9b", "e91501471db04da991e73bbf28abb71d", "a6e3c391120a4171b59e69f7111d34cd", "df8651ab692b4a648a2e412b2e045bae", "2133201a6113468490d75b3a4ef942b1", "e6698265603c46aca5ce96aee64b485b", "d6680844b3e345e3af2bb27d9afc60f0", "e0dd466941a14188a40bfbe9d9f818ee", "e6a846177d7d49e6a05eecc1174dc574", "e39fed0066c14e16bfd3f9c366f0f288", "f59c4c11e1b34a4288656f28c228ec76", "f1aea0224e7246a39011c36f7ca72702", "5ae6f538d2aa4be785b7be93af885562", "4929a4a1103e47c3bd16e1db31f478e8", "3b8ccaf505604661a65efd0267c0df0a", "6bcd9a3684d0499395c21fbb617ed890", "a0bbee764501493887d45b15f58d1c84", "899c206bc1b948b7aede52ba26b0ab0e", "6c0d5a1b905743ff86613e88e151b120", "98c9c5edc84e4d70a63406fdd1e360e5", "3200b22ab9bd4ffab358630eb09d21bf", "eb0c0a2a3d814c08b8e78bd99f99048f", "366fbb8b74554e7d87329676bbe53488", "736cf1d1b23e4d6aba833d9977b77626", "001bf9ab98ec495782911cab6a91fd67", "8875242b17124440bac9eada00cbf7cc", "a2e6022491e04e5b882907db35ac5c66", "b040763ebdf646fd8d3d1d24914bcb47", "a380e1c490a84b75b60f4e50e979f707", "70382059066c41328a3348a7056032d1", "81b1f031892043bfb7bc385edf52ce48", "4f3da49fe0504640ad7c08fdf0b80113", "69ff1319e4c64d47b9925bb17b9c7c93", "3820af62fa1d498f80598e204a69e60d", "d769c7a4405f474284f9f5ca94116858", "d71c86b25cc048508a351238b4cd171f", "51c78f4a481c4982a7576b4b9c1c81fb", "811893b6b9d24f7f9976328d67eeb0c4", "78206d211645442bbdb8675c175cc233", "b1146db74d2f45859997e2173ce3350c", "62fad5fc0fc24099a0a95305206a5950", "d5ff5612fc3b448c8ad18b06fcf4f9cf", "17b95a1469594b2791a0d057cf4b3367", "8f09c3ccd53c44faac8ad55234c1487b", "4c3e55fdbc024ad2be6ac2facf1a44db", "b88c7c0bc00b4d88a1c49f11389126b5", "fd2d47e147794aa4bedb9933fc4900ce", "7a429ced76e0431ea49a795ec30d5ed2", "08d02351fc8543228e70629577344c7a", "4e38ef1829cb4c25ad84a85c3a7ab221", "25193dcc82f54db8a2993bc4f828fa05", "cfa2f244cd8740798360a2c4722d5748", "ed353a8c04d6417a80ed5ddc1bcebd80", "2252979760814e9da9e455fd34cce955", "8414aeb74f4b4a03b66544867bbe5e8e", "acb7fa2daf5646dca3eb9c9a0d29dc06", "efaf4f813c8d4385a4308a8c7db4958a", "f860da95c8b24bc1969b887d36ad9902", "e33fb707901547d5a4045ea80e5a7c76", "d2144b82918b4dbca37e9a0c1a90c2e6", "eabe4f7af9fe47a6af5509bba81ec0cd", "fead1396f54f4d86a52a04acb2deeea0", "c0c405dcda684c88a6995805bbe1a714", "ae14e90e6a084d0ca79dd0376dcbaa3a", "621eb1d3097147b1991faa1474e7d966", "aa7b75f8100749519494896ba4d12b1c", "84cdf31dac71421b948474bd9e4cde53", "4f5ad2bae09e4112b6b3a95397d7fb55", "8e3980c9433447f89740624bd85a5768", "7e22fc0da31d490aa43be2153cfd7403", "a692d570c31649c28fea3bfac1aaab01", "85755c2e08a34a2daefd4555df0b4fff", "23618d84213a4ea98b9c17bf969272cd", "44690735319544eda0fbd509b92b52c0", "dbc11af4883c4f85b6805f8b52e8170c", "78aed98077a74a75bc408d6fffa05065", "941c2e43942448fba0ff0b13e4f6bf66", "e32858052ad54f4f859a599dce4d13d6", "19c4c142c6224c09ad29b13d0f9af2fe", "6899eab30b4044a8a6493cd5acd48e7a", "a23b211aefd54270821aea17dd96f88a", "17af3f9a527d419ca4b369db183ae659", "756215384ee8444ab59695bf05d29b55", "d4e4b2592d54418f9744e8d3bfced6e5", "bcfe68ab1f46499baf4e29bc9628ef64", "f5145a7dbec343c39ea80958b5331caa", "aed96231191b47a3adeac49c4121f814", "4aad22e983004b1e9e4cebae0fe2897a", "71782ed9d2f94f55be80f7c6dc1b0a59", "945afa9b582e4431a71c8a39f4d7cacc", "d723ac8b9b314ae8ae9861570d192bdf", "11bed5e93c8b46c7a21de82c10e5cfd4", "b2d11d577ed54cb3bcd49f766cc5013a"]} id="9WQD4cuwrMOP" outputId="590f75dd-064b-4bcc-b504-167bf2ad6cfb" +trainer.fit(model=task, datamodule=datamodule) diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb deleted file mode 100644 index 5de3937a026..00000000000 --- a/docs/tutorials/trainers.ipynb +++ /dev/null @@ -1,336 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b13c2251", - "metadata": { - "id": "b13c2251" - }, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "id": "e563313d", - "metadata": { - "id": "e563313d" - }, - "source": [ - "# Lightning Trainers\n", - "\n", - "In this tutorial, we demonstrate TorchGeo trainers to train and test a model. We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. We will train models to predict land cover classes.\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "id": "8c1f4156", - "metadata": { - "id": "8c1f4156" - }, - "source": [ - "## Setup\n", - "\n", - "First, we install TorchGeo and TensorBoard." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f0d31a8", - "metadata": { - "id": "3f0d31a8" - }, - "outputs": [], - "source": [ - "%pip install torchgeo tensorboard" - ] - }, - { - "cell_type": "markdown", - "id": "c90c94c7", - "metadata": { - "id": "c90c94c7" - }, - "source": [ - "## Imports\n", - "\n", - "Next, we import TorchGeo and any other libraries we need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bd39f485", - "metadata": { - "id": "bd39f485" - }, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "%load_ext tensorboard\n", - "\n", - "import os\n", - "import tempfile\n", - "\n", - "import torch\n", - "from lightning.pytorch import Trainer\n", - "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", - "from lightning.pytorch.loggers import TensorBoardLogger\n", - "\n", - "from torchgeo.datamodules import EuroSAT100DataModule\n", - "from torchgeo.models import ResNet18_Weights\n", - "from torchgeo.trainers import ClassificationTask" - ] - }, - { - "cell_type": "markdown", - "id": "e6e1d9b6", - "metadata": { - "id": "e6e1d9b6" - }, - "source": [ - "## Lightning modules\n", - "\n", - "Our trainers use [Lightning](https://lightning.ai/docs/pytorch/stable/) to organize both the training code, and the dataloader setup code. This makes it easy to create and share reproducible experiments and results.\n", - "\n", - "First we'll create a `EuroSAT100DataModule` object which is simply a wrapper around the [EuroSAT100](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#eurosat) dataset. This object 1.) ensures that the data is downloaded, 2.) sets up PyTorch `DataLoader` objects for the train, validation, and test splits, and 3.) ensures that data from the same region **is not** shared between the training and validation sets so that you can properly evaluate the generalization performance of your model." - ] - }, - { - "cell_type": "markdown", - "id": "9f2daa0d", - "metadata": { - "id": "9f2daa0d" - }, - "source": [ - "The following variables can be modified to control training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8e100f8b", - "metadata": { - "id": "8e100f8b", - "nbmake": { - "mock": { - "batch_size": 1, - "fast_dev_run": true, - "max_epochs": 1, - "num_workers": 0 - } - } - }, - "outputs": [], - "source": [ - "batch_size = 10\n", - "num_workers = 2\n", - "max_epochs = 50\n", - "fast_dev_run = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f2a04c7", - "metadata": { - "id": "0f2a04c7" - }, - "outputs": [], - "source": [ - "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", - "datamodule = EuroSAT100DataModule(\n", - " root=root, batch_size=batch_size, num_workers=num_workers, download=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "056b7b4c", - "metadata": { - "id": "056b7b4c" - }, - "source": [ - "Next, we create a `ClassificationTask` object that holds the model object, optimizer object, and training logic. We will use a ResNet-18 model that has been pre-trained on Sentinel-2 imagery." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba5c5442", - "metadata": { - "id": "ba5c5442" - }, - "outputs": [], - "source": [ - "task = ClassificationTask(\n", - " loss='ce',\n", - " model='resnet18',\n", - " weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,\n", - " in_channels=13,\n", - " num_classes=10,\n", - " lr=0.1,\n", - " patience=5,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "d4b67f3e", - "metadata": { - "id": "d4b67f3e" - }, - "source": [ - "## Training\n", - "\n", - "Now that we have the Lightning modules set up, we can use a Lightning [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to run the training and evaluation loops. There are many useful pieces of configuration that can be set in the `Trainer` -- below we set up model checkpointing based on the validation loss, early stopping based on the validation loss, and a TensorBoard based logger. We encourage you to see the [Lightning docs](https://lightning.ai/docs/pytorch/stable/) for other options that can be set here, e.g. CSV logging, automatically selecting your optimizer's learning rate, and easy multi-GPU training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ffe26e5c", - "metadata": { - "id": "ffe26e5c" - }, - "outputs": [], - "source": [ - "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", - "default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')\n", - "checkpoint_callback = ModelCheckpoint(\n", - " monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True\n", - ")\n", - "early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=10)\n", - "logger = TensorBoardLogger(save_dir=default_root_dir, name='tutorial_logs')" - ] - }, - { - "cell_type": "markdown", - "id": "06afd8c7", - "metadata": { - "id": "06afd8c7" - }, - "source": [ - "For tutorial purposes we deliberately lower the maximum number of training epochs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "225a6d36", - "metadata": { - "id": "225a6d36" - }, - "outputs": [], - "source": [ - "trainer = Trainer(\n", - " accelerator=accelerator,\n", - " callbacks=[checkpoint_callback, early_stopping_callback],\n", - " fast_dev_run=fast_dev_run,\n", - " log_every_n_steps=1,\n", - " logger=logger,\n", - " min_epochs=1,\n", - " max_epochs=max_epochs,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "44d71e8f", - "metadata": { - "id": "44d71e8f" - }, - "source": [ - "When we first call `.fit(...)` the dataset will be downloaded and checksummed (if it hasn't already). After this, the training process will kick off, and results will be saved so that TensorBoard can read them." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00e08790", - "metadata": { - "id": "00e08790" - }, - "outputs": [], - "source": [ - "trainer.fit(model=task, datamodule=datamodule)" - ] - }, - { - "cell_type": "markdown", - "id": "73700fb5", - "metadata": { - "id": "73700fb5" - }, - "source": [ - "We launch TensorBoard to visualize various performance metrics across training and validation epochs. We can see that our model is just starting to converge, and would probably benefit from additional training time and a lower initial learning rate." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e95ee0a", - "metadata": {}, - "outputs": [], - "source": [ - "%tensorboard --logdir \"$default_root_dir\"" - ] - }, - { - "cell_type": "markdown", - "id": "04cfc7a8", - "metadata": { - "id": "04cfc7a8" - }, - "source": [ - "Finally, after the model has been trained, we can easily evaluate it on the test set." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "604a3b2f", - "metadata": { - "id": "604a3b2f" - }, - "outputs": [], - "source": [ - "trainer.test(model=task, datamodule=datamodule)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/trainers.py b/docs/tutorials/trainers.py new file mode 100644 index 00000000000..c34510d83d0 --- /dev/null +++ b/docs/tutorials/trainers.py @@ -0,0 +1,136 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + [markdown] id="b13c2251" +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="e563313d" +# # Lightning Trainers +# +# In this tutorial, we demonstrate TorchGeo trainers to train and test a model. We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. We will train models to predict land cover classes. +# +# It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started. + +# + [markdown] id="8c1f4156" +# ## Setup +# +# First, we install TorchGeo and TensorBoard. + +# + id="3f0d31a8" +# %pip install torchgeo tensorboard + +# + [markdown] id="c90c94c7" +# ## Imports +# +# Next, we import TorchGeo and any other libraries we need. + +# + id="bd39f485" +# %matplotlib inline +# %load_ext tensorboard + +import os +import tempfile + +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + +from torchgeo.datamodules import EuroSAT100DataModule +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import ClassificationTask + +# + [markdown] id="e6e1d9b6" +# ## Lightning modules +# +# Our trainers use [Lightning](https://lightning.ai/docs/pytorch/stable/) to organize both the training code, and the dataloader setup code. This makes it easy to create and share reproducible experiments and results. +# +# First we'll create a `EuroSAT100DataModule` object which is simply a wrapper around the [EuroSAT100](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#eurosat) dataset. This object 1.) ensures that the data is downloaded, 2.) sets up PyTorch `DataLoader` objects for the train, validation, and test splits, and 3.) ensures that data from the same region **is not** shared between the training and validation sets so that you can properly evaluate the generalization performance of your model. + +# + [markdown] id="9f2daa0d" +# The following variables can be modified to control training. + +# + id="8e100f8b" nbmake={"mock": {"batch_size": 1, "fast_dev_run": true, "max_epochs": 1, "num_workers": 0}} +batch_size = 10 +num_workers = 2 +max_epochs = 50 +fast_dev_run = False + +# + id="0f2a04c7" +root = os.path.join(tempfile.gettempdir(), 'eurosat100') +datamodule = EuroSAT100DataModule( + root=root, batch_size=batch_size, num_workers=num_workers, download=True +) + +# + [markdown] id="056b7b4c" +# Next, we create a `ClassificationTask` object that holds the model object, optimizer object, and training logic. We will use a ResNet-18 model that has been pre-trained on Sentinel-2 imagery. + +# + id="ba5c5442" +task = ClassificationTask( + loss='ce', + model='resnet18', + weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, + in_channels=13, + num_classes=10, + lr=0.1, + patience=5, +) + +# + [markdown] id="d4b67f3e" +# ## Training +# +# Now that we have the Lightning modules set up, we can use a Lightning [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to run the training and evaluation loops. There are many useful pieces of configuration that can be set in the `Trainer` -- below we set up model checkpointing based on the validation loss, early stopping based on the validation loss, and a TensorBoard based logger. We encourage you to see the [Lightning docs](https://lightning.ai/docs/pytorch/stable/) for other options that can be set here, e.g. CSV logging, automatically selecting your optimizer's learning rate, and easy multi-GPU training. + +# + id="ffe26e5c" +accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' +default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments') +checkpoint_callback = ModelCheckpoint( + monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True +) +early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=10) +logger = TensorBoardLogger(save_dir=default_root_dir, name='tutorial_logs') + +# + [markdown] id="06afd8c7" +# For tutorial purposes we deliberately lower the maximum number of training epochs. + +# + id="225a6d36" +trainer = Trainer( + accelerator=accelerator, + callbacks=[checkpoint_callback, early_stopping_callback], + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + logger=logger, + min_epochs=1, + max_epochs=max_epochs, +) + +# + [markdown] id="44d71e8f" +# When we first call `.fit(...)` the dataset will be downloaded and checksummed (if it hasn't already). After this, the training process will kick off, and results will be saved so that TensorBoard can read them. + +# + id="00e08790" +trainer.fit(model=task, datamodule=datamodule) + +# + [markdown] id="73700fb5" +# We launch TensorBoard to visualize various performance metrics across training and validation epochs. We can see that our model is just starting to converge, and would probably benefit from additional training time and a lower initial learning rate. +# - + +# %tensorboard --logdir "$default_root_dir" + +# + [markdown] id="04cfc7a8" +# Finally, after the model has been trained, we can easily evaluate it on the test set. + +# + id="604a3b2f" +trainer.test(model=task, datamodule=datamodule) diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb deleted file mode 100644 index a7de9f32c69..00000000000 --- a/docs/tutorials/transforms.ipynb +++ /dev/null @@ -1,725 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "DYndcZst_kdr" - }, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZKIkyiLScf9P" - }, - "source": [ - "# Transforms" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PevsPoE4cY0j" - }, - "source": [ - "In this tutorial, we demonstrate how to use TorchGeo's data augmentation transforms and provide examples of how to utilize them in your experiments with multispectral imagery.\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fsOYw-p2ccka" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VqdMMzvacOF8" - }, - "source": [ - "Install TorchGeo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wOwsb8KT_uXR", - "outputId": "e3e5f561-81a8-447b-f149-3e0e8305c598" - }, - "outputs": [], - "source": [ - "%pip install torchgeo" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u2f5_f4X_-vV" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cvPMr76K_9uk" - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "import kornia.augmentation as K\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision.transforms as T\n", - "from PIL import Image\n", - "from torch import Tensor\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from torchgeo.datasets import EuroSAT100\n", - "from torchgeo.transforms import AugmentationSequential, indices" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oR3BCeV2AAop" - }, - "source": [ - "## Custom Transforms" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oVgqhF2udp4z" - }, - "source": [ - "Here we create a transform to show an example of how you can chain custom operations along with TorchGeo and Kornia transforms/augmentations. Note how our transform takes as input a Dict of Tensors. We specify our data by the keys [\"image\", \"mask\", \"label\", etc.] and follow this standard across TorchGeo datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3mixIK7mAC9G" - }, - "outputs": [], - "source": [ - "class MinMaxNormalize(K.IntensityAugmentationBase2D):\n", - " \"\"\"Normalize channels to the range [0, 1] using min/max values.\"\"\"\n", - "\n", - " def __init__(self, mins: Tensor, maxs: Tensor) -> None:\n", - " super().__init__(p=1)\n", - " self.flags = {'mins': mins.view(1, -1, 1, 1), 'maxs': maxs.view(1, -1, 1, 1)}\n", - "\n", - " def apply_transform(\n", - " self,\n", - " input: Tensor,\n", - " params: dict[str, Tensor],\n", - " flags: dict[str, int],\n", - " transform: Tensor | None = None,\n", - " ) -> Tensor:\n", - " return (input - flags['mins']) / (flags['maxs'] - flags['mins'] + 1e-10)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2ESh5W05AE3Y" - }, - "source": [ - "## Dataset Bands and Statistics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WFTBPWUo9b5o" - }, - "source": [ - "Below we have min/max values calculated across the dataset per band. The values were clipped to the interval [0, 98] to stretch the band values and avoid outliers influencing the band histograms." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vRnMovSrAHgU" - }, - "outputs": [], - "source": [ - "mins = torch.tensor(\n", - " [\n", - " 1013.0,\n", - " 676.0,\n", - " 448.0,\n", - " 247.0,\n", - " 269.0,\n", - " 253.0,\n", - " 243.0,\n", - " 189.0,\n", - " 61.0,\n", - " 4.0,\n", - " 33.0,\n", - " 11.0,\n", - " 186.0,\n", - " ]\n", - ")\n", - "maxs = torch.tensor(\n", - " [\n", - " 2309.0,\n", - " 4543.05,\n", - " 4720.2,\n", - " 5293.05,\n", - " 3902.05,\n", - " 4473.0,\n", - " 5447.0,\n", - " 5948.05,\n", - " 1829.0,\n", - " 23.0,\n", - " 4894.05,\n", - " 4076.05,\n", - " 5846.0,\n", - " ]\n", - ")\n", - "bands = {\n", - " 'B01': 'Coastal Aerosol',\n", - " 'B02': 'Blue',\n", - " 'B03': 'Green',\n", - " 'B04': 'Red',\n", - " 'B05': 'Vegetation Red Edge 1',\n", - " 'B06': 'Vegetation Red Edge 2',\n", - " 'B07': 'Vegetation Red Edge 3',\n", - " 'B08': 'NIR 1',\n", - " 'B8A': 'NIR 2',\n", - " 'B09': 'Water Vapour',\n", - " 'B10': 'SWIR 1',\n", - " 'B11': 'SWIR 2',\n", - " 'B12': 'SWIR 3',\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qQ40TIOG9qVZ" - }, - "source": [ - "The following variables can be used to control the dataloader." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mKGubYto9qVZ", - "nbmake": { - "mock": { - "batch_size": 1, - "num_workers": 0 - } - } - }, - "outputs": [], - "source": [ - "batch_size = 4\n", - "num_workers = 2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hktYHfQHAJbs" - }, - "source": [ - "## Load the EuroSat MS dataset and dataloader" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sUavkZSxeqCA" - }, - "source": [ - "We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 269, - "referenced_widgets": [ - "4a45768962f84bfea3828b331f67dbf5", - "b87d2d75851848c1b33a00e9cee8bee3", - "7e8eb99543654b02b879432d5e286d29", - "3858c8543b2349d1aa06410d7f48d49f", - "61259e090b754bf4bb53e6be6e3f2233", - "b7e748408e824ba8829624c929cae8a9", - "d932459633764ddb85dd7fcc3f92c8e2", - "77c4e614779f41d39a1e5e23e2a0cb7e", - "47b2ca53cd7e4bdeae64192d747eba93", - "cc74eb4a618b414aae643b63e316e8be", - "f0cdff4785c54dfc9acbbb7052ce9311", - "b62e3f13bb584269a8388ebee3de3078", - "96e238e381034669b39fbf9d5483e1bb", - "f541d671cfb941b0972f622e06eb097d", - "bdf19d6aab3b4ec4aa34babe3c9701ce", - "f1f2301bee4448cea3e95846f53b7d6f", - "a3e02ee9526b4d02b1b47d3705aacaa8", - "d1a8127d743741fba825d3378aeb5062", - "d3d7851f634c4f3eb7ff0b5f7d433b10", - "8aefd08090d540e6ae2b7ec96a91dde0", - "c59ae4a8e136486793e1f5aef4b17fb9", - "2547bdf147874b9d8c636e0368839149", - "c11032f873ea4b4c8edac43bd3caa46c", - "9728941761214d74941716810364d0ec", - "9eca059d8a7d45178bec2a740c1134a1", - "f55f29996c2b466590402d490b8545f6", - "99c41c0bfda24961b461844314206525", - "37c8dd6f0525490281ac762b8d8469e4", - "399a9ea7fed342b7be22c508efe1fd28", - "4df7f51c9c4241c29b4857e1c110fc8f", - "5d0a21b260d14ea88455d39bebbc6a87", - "3f56e822e3a349458b0859b530ec890a", - "56c35731ea54485284a5396b89039aba", - "8dcaee61833c42ce896b606572ca5ebe", - "56d125f8e4ff4923ac5c15a1b803529a", - "0914f9a57f914351bc03d9bd6babdc27", - "60bf416d480b45c5993d816c46ce19f8", - "009688664acf46449e72fd92879f4268", - "1c1ff62fb40a41199e82f93f374d7b0b", - "11bd6ccbf258495eb177dbed11ecac1a", - "2095235bf0024396b23886fc877ca322", - "ab01d70248a1450bab0895b4447b1deb", - "77595a1376d340aa920a7ac27ab19fc4", - "82879fe9b81a47f3899f08513abb42be" - ] - }, - "id": "VHVgiNA4t5Tl", - "outputId": "abe91979-9f23-4eed-9589-daa0d69f5458" - }, - "outputs": [], - "source": [ - "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", - "dataset = EuroSAT100(root, download=True)\n", - "dataloader = DataLoader(\n", - " dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers\n", - ")\n", - "dataloader = iter(dataloader)\n", - "print(f'Number of images in dataset: {len(dataset)}')\n", - "print(f'Dataset Classes: {dataset.classes}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ovckKTXpA78o" - }, - "source": [ - "## Load a sample and batch of images and labels" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BKYU2A3weY82" - }, - "source": [ - "Here we test our dataset by loading a single image and label. Note how the image is of shape (13, 64, 64) containing a 64x64 shape with 13 multispectral bands." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3lhG1yM_v7Mi", - "outputId": "c689890e-80ac-47f9-8779-62f187a6e761" - }, - "outputs": [], - "source": [ - "sample = dataset[0]\n", - "x, y = sample['image'], sample['label']\n", - "print(x.shape, x.dtype, x.min(), x.max())\n", - "print(y, dataset.classes[y])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uw8F17tcAKPY" - }, - "source": [ - "Here we test our dataloader by loading a single batch of images and labels. Note how the image is of shape (4, 13, 64, 64) containing 4 samples due to our batch_size." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0faJA5UiAJmK", - "outputId": "7448880b-fb51-4c01-c335-767b93868257" - }, - "outputs": [], - "source": [ - "batch = next(dataloader)\n", - "x, y = batch['image'], batch['label']\n", - "print(x.shape, x.dtype, x.min(), x.max())\n", - "print(y, [dataset.classes[i] for i in y])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x8-uLsPdfz0o" - }, - "source": [ - "## Transforms Usage" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p28C8cTGE3dP" - }, - "source": [ - "Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using `nn.Sequential`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pJXUycffEjNX", - "outputId": "d029826c-a546-4c8e-e254-db680c5045e8" - }, - "outputs": [], - "source": [ - "transform = MinMaxNormalize(mins, maxs)\n", - "print(x.shape)\n", - "x = transform(x)\n", - "print(x.dtype, x.min(), x.max())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KRjb-u0EEmDf" - }, - "source": [ - "Indices can also be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HaG-1tvi9RKS", - "outputId": "8cbf5fc7-0e34-4670-bf03-700270a041c8" - }, - "outputs": [], - "source": [ - "transform = indices.AppendNDVI(index_nir=7, index_red=3)\n", - "batch = next(dataloader)\n", - "x = batch['image']\n", - "print(x.shape)\n", - "x = transform(x)\n", - "print(x.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q6WFG8UuGcF8" - }, - "source": [ - "This makes it incredibly easy to add indices as additional features during training by chaining multiple Appends together." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "H_EaAyfnGblR", - "outputId": "b3c7c8c9-1e8b-4125-bf72-69f4973878da" - }, - "outputs": [], - "source": [ - "transforms = nn.Sequential(\n", - " MinMaxNormalize(mins, maxs),\n", - " indices.AppendNDBI(index_swir=11, index_nir=7),\n", - " indices.AppendNDSI(index_green=3, index_swir=11),\n", - " indices.AppendNDVI(index_nir=7, index_red=3),\n", - " indices.AppendNDWI(index_green=2, index_nir=7),\n", - ")\n", - "\n", - "batch = next(dataloader)\n", - "x = batch['image']\n", - "print(x.shape)\n", - "x = transforms(x)\n", - "print(x.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w4ZbjxPyHoiB" - }, - "source": [ - "It's even possible to chain indices along with augmentations from Kornia for a single callable during training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ZKEDgnX0Hn-d", - "outputId": "129a7706-70b8-4d12-8d8c-ff60dc8d44e3" - }, - "outputs": [], - "source": [ - "transforms = AugmentationSequential(\n", - " MinMaxNormalize(mins, maxs),\n", - " indices.AppendNDBI(index_swir=11, index_nir=7),\n", - " indices.AppendNDSI(index_green=3, index_swir=11),\n", - " indices.AppendNDVI(index_nir=7, index_red=3),\n", - " indices.AppendNDWI(index_green=2, index_nir=7),\n", - " K.RandomHorizontalFlip(p=0.5),\n", - " K.RandomVerticalFlip(p=0.5),\n", - " data_keys=['image'],\n", - ")\n", - "\n", - "batch = next(dataloader)\n", - "print(batch['image'].shape)\n", - "batch = transforms(batch)\n", - "print(batch['image'].shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IhKin8a2GPoI" - }, - "source": [ - "All of our transforms are `nn.Modules`. This allows us to push them and the data to the GPU to see significant gains for large scale operations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4QhMOtYzLmVK", - "outputId": "94b8b24a-80a2-4300-df37-aa833a6dde1c", - "tags": [ - "raises-exception" - ] - }, - "outputs": [], - "source": [ - "!nvidia-smi" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4zokGELhGPF8" - }, - "outputs": [], - "source": [ - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "\n", - "transforms = AugmentationSequential(\n", - " MinMaxNormalize(mins, maxs),\n", - " indices.AppendNDBI(index_swir=11, index_nir=7),\n", - " indices.AppendNDSI(index_green=3, index_swir=11),\n", - " indices.AppendNDVI(index_nir=7, index_red=3),\n", - " indices.AppendNDWI(index_green=2, index_nir=7),\n", - " K.RandomHorizontalFlip(p=0.5),\n", - " K.RandomVerticalFlip(p=0.5),\n", - " K.RandomAffine(degrees=(0, 90), p=0.25),\n", - " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", - " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", - ")\n", - "\n", - "transforms_gpu = AugmentationSequential(\n", - " MinMaxNormalize(mins.to(device), maxs.to(device)),\n", - " indices.AppendNDBI(index_swir=11, index_nir=7),\n", - " indices.AppendNDSI(index_green=3, index_swir=11),\n", - " indices.AppendNDVI(index_nir=7, index_red=3),\n", - " indices.AppendNDWI(index_green=2, index_nir=7),\n", - " K.RandomHorizontalFlip(p=0.5),\n", - " K.RandomVerticalFlip(p=0.5),\n", - " K.RandomAffine(degrees=(0, 90), p=0.25),\n", - " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", - " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", - ").to(device)\n", - "\n", - "\n", - "def get_batch_cpu():\n", - " return dict(image=torch.randn(64, 13, 512, 512).to('cpu'))\n", - "\n", - "\n", - "def get_batch_gpu():\n", - " return dict(image=torch.randn(64, 13, 512, 512).to(device))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vo43CqJ4IIXE", - "outputId": "75e438f7-5ab1-47f4-9de9-b444b5b759f6" - }, - "outputs": [], - "source": [ - "%%timeit -n 1 -r 5\n", - "_ = transforms(get_batch_cpu())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ICKXYZYrJCeh", - "outputId": "9335cd58-90a6-4b8f-d27c-8bc833e76600" - }, - "outputs": [], - "source": [ - "%%timeit -n 1 -r 5\n", - "_ = transforms_gpu(get_batch_gpu())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nkGy_g6tBAtF" - }, - "source": [ - "## Visualize Images and Labels" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3k4W98v27NtL" - }, - "source": [ - "This is a Google Colab browser for the EuroSAT dataset. Adjust the slider to visualize images in the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O_6k7tcxz17x" - }, - "outputs": [], - "source": [ - "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n", - "dataset = EuroSAT100(root, transforms=transforms)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 290 - }, - "id": "Uw8xDeg3BY-u", - "outputId": "2b5f94c4-3aa8-4f30-a38d-b701b2332967" - }, - "outputs": [], - "source": [ - "# @title EuroSat Multispectral (MS) Browser { run: \"auto\", vertical-output: true }\n", - "idx = 21 # @param {type:\"slider\", min:0, max:59, step:1}\n", - "sample = dataset[idx]\n", - "rgb = sample['image'][0, 1:4]\n", - "image = T.ToPILImage()(rgb)\n", - "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n", - "image.resize((256, 256), resample=Image.BILINEAR)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "interpreter": { - "hash": "6e850ee5f92358dcfdbb90dda05d686956eb0825584ddd5eff31b34875ddfee0" - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/tutorials/transforms.py b/docs/tutorials/transforms.py new file mode 100644 index 00000000000..1a039632243 --- /dev/null +++ b/docs/tutorials/transforms.py @@ -0,0 +1,313 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# + [markdown] id="DYndcZst_kdr" +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# Licensed under the MIT License. + +# + [markdown] id="ZKIkyiLScf9P" +# # Transforms + +# + [markdown] id="PevsPoE4cY0j" +# In this tutorial, we demonstrate how to use TorchGeo's data augmentation transforms and provide examples of how to utilize them in your experiments with multispectral imagery. +# +# It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started. + +# + [markdown] id="fsOYw-p2ccka" +# ## Setup + +# + [markdown] id="VqdMMzvacOF8" +# Install TorchGeo + +# + colab={"base_uri": "https://localhost:8080/"} id="wOwsb8KT_uXR" outputId="e3e5f561-81a8-447b-f149-3e0e8305c598" +# %pip install torchgeo + +# + [markdown] id="u2f5_f4X_-vV" +# ## Imports + +# + id="cvPMr76K_9uk" +import os +import tempfile + +import kornia.augmentation as K +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from torch import Tensor +from torch.utils.data import DataLoader + +from torchgeo.datasets import EuroSAT100 +from torchgeo.transforms import AugmentationSequential, indices + + +# + [markdown] id="oR3BCeV2AAop" +# ## Custom Transforms + +# + [markdown] id="oVgqhF2udp4z" +# Here we create a transform to show an example of how you can chain custom operations along with TorchGeo and Kornia transforms/augmentations. Note how our transform takes as input a Dict of Tensors. We specify our data by the keys ["image", "mask", "label", etc.] and follow this standard across TorchGeo datasets. + +# + id="3mixIK7mAC9G" +class MinMaxNormalize(K.IntensityAugmentationBase2D): + """Normalize channels to the range [0, 1] using min/max values.""" + + def __init__(self, mins: Tensor, maxs: Tensor) -> None: + super().__init__(p=1) + self.flags = {'mins': mins.view(1, -1, 1, 1), 'maxs': maxs.view(1, -1, 1, 1)} + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, int], + transform: Tensor | None = None, + ) -> Tensor: + return (input - flags['mins']) / (flags['maxs'] - flags['mins'] + 1e-10) + + +# + [markdown] id="2ESh5W05AE3Y" +# ## Dataset Bands and Statistics + +# + [markdown] id="WFTBPWUo9b5o" +# Below we have min/max values calculated across the dataset per band. The values were clipped to the interval [0, 98] to stretch the band values and avoid outliers influencing the band histograms. + +# + id="vRnMovSrAHgU" +mins = torch.tensor( + [ + 1013.0, + 676.0, + 448.0, + 247.0, + 269.0, + 253.0, + 243.0, + 189.0, + 61.0, + 4.0, + 33.0, + 11.0, + 186.0, + ] +) +maxs = torch.tensor( + [ + 2309.0, + 4543.05, + 4720.2, + 5293.05, + 3902.05, + 4473.0, + 5447.0, + 5948.05, + 1829.0, + 23.0, + 4894.05, + 4076.05, + 5846.0, + ] +) +bands = { + 'B01': 'Coastal Aerosol', + 'B02': 'Blue', + 'B03': 'Green', + 'B04': 'Red', + 'B05': 'Vegetation Red Edge 1', + 'B06': 'Vegetation Red Edge 2', + 'B07': 'Vegetation Red Edge 3', + 'B08': 'NIR 1', + 'B8A': 'NIR 2', + 'B09': 'Water Vapour', + 'B10': 'SWIR 1', + 'B11': 'SWIR 2', + 'B12': 'SWIR 3', +} + +# + [markdown] id="qQ40TIOG9qVZ" +# The following variables can be used to control the dataloader. + +# + id="mKGubYto9qVZ" nbmake={"mock": {"batch_size": 1, "num_workers": 0}} +batch_size = 4 +num_workers = 2 + +# + [markdown] id="hktYHfQHAJbs" +# ## Load the EuroSat MS dataset and dataloader + +# + [markdown] id="sUavkZSxeqCA" +# We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. + +# + colab={"base_uri": "https://localhost:8080/", "height": 269, "referenced_widgets": ["4a45768962f84bfea3828b331f67dbf5", "b87d2d75851848c1b33a00e9cee8bee3", "7e8eb99543654b02b879432d5e286d29", "3858c8543b2349d1aa06410d7f48d49f", "61259e090b754bf4bb53e6be6e3f2233", "b7e748408e824ba8829624c929cae8a9", "d932459633764ddb85dd7fcc3f92c8e2", "77c4e614779f41d39a1e5e23e2a0cb7e", "47b2ca53cd7e4bdeae64192d747eba93", "cc74eb4a618b414aae643b63e316e8be", "f0cdff4785c54dfc9acbbb7052ce9311", "b62e3f13bb584269a8388ebee3de3078", "96e238e381034669b39fbf9d5483e1bb", "f541d671cfb941b0972f622e06eb097d", "bdf19d6aab3b4ec4aa34babe3c9701ce", "f1f2301bee4448cea3e95846f53b7d6f", "a3e02ee9526b4d02b1b47d3705aacaa8", "d1a8127d743741fba825d3378aeb5062", "d3d7851f634c4f3eb7ff0b5f7d433b10", "8aefd08090d540e6ae2b7ec96a91dde0", "c59ae4a8e136486793e1f5aef4b17fb9", "2547bdf147874b9d8c636e0368839149", "c11032f873ea4b4c8edac43bd3caa46c", "9728941761214d74941716810364d0ec", "9eca059d8a7d45178bec2a740c1134a1", "f55f29996c2b466590402d490b8545f6", "99c41c0bfda24961b461844314206525", "37c8dd6f0525490281ac762b8d8469e4", "399a9ea7fed342b7be22c508efe1fd28", "4df7f51c9c4241c29b4857e1c110fc8f", "5d0a21b260d14ea88455d39bebbc6a87", "3f56e822e3a349458b0859b530ec890a", "56c35731ea54485284a5396b89039aba", "8dcaee61833c42ce896b606572ca5ebe", "56d125f8e4ff4923ac5c15a1b803529a", "0914f9a57f914351bc03d9bd6babdc27", "60bf416d480b45c5993d816c46ce19f8", "009688664acf46449e72fd92879f4268", "1c1ff62fb40a41199e82f93f374d7b0b", "11bd6ccbf258495eb177dbed11ecac1a", "2095235bf0024396b23886fc877ca322", "ab01d70248a1450bab0895b4447b1deb", "77595a1376d340aa920a7ac27ab19fc4", "82879fe9b81a47f3899f08513abb42be"]} id="VHVgiNA4t5Tl" outputId="abe91979-9f23-4eed-9589-daa0d69f5458" +root = os.path.join(tempfile.gettempdir(), 'eurosat100') +dataset = EuroSAT100(root, download=True) +dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers +) +dataloader = iter(dataloader) +print(f'Number of images in dataset: {len(dataset)}') +print(f'Dataset Classes: {dataset.classes}') + +# + [markdown] id="ovckKTXpA78o" +# ## Load a sample and batch of images and labels + +# + [markdown] id="BKYU2A3weY82" +# Here we test our dataset by loading a single image and label. Note how the image is of shape (13, 64, 64) containing a 64x64 shape with 13 multispectral bands. + +# + colab={"base_uri": "https://localhost:8080/"} id="3lhG1yM_v7Mi" outputId="c689890e-80ac-47f9-8779-62f187a6e761" +sample = dataset[0] +x, y = sample['image'], sample['label'] +print(x.shape, x.dtype, x.min(), x.max()) +print(y, dataset.classes[y]) + +# + [markdown] id="uw8F17tcAKPY" +# Here we test our dataloader by loading a single batch of images and labels. Note how the image is of shape (4, 13, 64, 64) containing 4 samples due to our batch_size. + +# + colab={"base_uri": "https://localhost:8080/"} id="0faJA5UiAJmK" outputId="7448880b-fb51-4c01-c335-767b93868257" +batch = next(dataloader) +x, y = batch['image'], batch['label'] +print(x.shape, x.dtype, x.min(), x.max()) +print(y, [dataset.classes[i] for i in y]) + +# + [markdown] id="x8-uLsPdfz0o" +# ## Transforms Usage + +# + [markdown] id="p28C8cTGE3dP" +# Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using `nn.Sequential`. + +# + colab={"base_uri": "https://localhost:8080/"} id="pJXUycffEjNX" outputId="d029826c-a546-4c8e-e254-db680c5045e8" +transform = MinMaxNormalize(mins, maxs) +print(x.shape) +x = transform(x) +print(x.dtype, x.min(), x.max()) + +# + [markdown] id="KRjb-u0EEmDf" +# Indices can also be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14. + +# + colab={"base_uri": "https://localhost:8080/"} id="HaG-1tvi9RKS" outputId="8cbf5fc7-0e34-4670-bf03-700270a041c8" +transform = indices.AppendNDVI(index_nir=7, index_red=3) +batch = next(dataloader) +x = batch['image'] +print(x.shape) +x = transform(x) +print(x.shape) + +# + [markdown] id="q6WFG8UuGcF8" +# This makes it incredibly easy to add indices as additional features during training by chaining multiple Appends together. + +# + colab={"base_uri": "https://localhost:8080/"} id="H_EaAyfnGblR" outputId="b3c7c8c9-1e8b-4125-bf72-69f4973878da" +transforms = nn.Sequential( + MinMaxNormalize(mins, maxs), + indices.AppendNDBI(index_swir=11, index_nir=7), + indices.AppendNDSI(index_green=3, index_swir=11), + indices.AppendNDVI(index_nir=7, index_red=3), + indices.AppendNDWI(index_green=2, index_nir=7), +) + +batch = next(dataloader) +x = batch['image'] +print(x.shape) +x = transforms(x) +print(x.shape) + +# + [markdown] id="w4ZbjxPyHoiB" +# It's even possible to chain indices along with augmentations from Kornia for a single callable during training. + +# + colab={"base_uri": "https://localhost:8080/"} id="ZKEDgnX0Hn-d" outputId="129a7706-70b8-4d12-8d8c-ff60dc8d44e3" +transforms = AugmentationSequential( + MinMaxNormalize(mins, maxs), + indices.AppendNDBI(index_swir=11, index_nir=7), + indices.AppendNDSI(index_green=3, index_swir=11), + indices.AppendNDVI(index_nir=7, index_red=3), + indices.AppendNDWI(index_green=2, index_nir=7), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=['image'], +) + +batch = next(dataloader) +print(batch['image'].shape) +batch = transforms(batch) +print(batch['image'].shape) + +# + [markdown] id="IhKin8a2GPoI" +# All of our transforms are `nn.Modules`. This allows us to push them and the data to the GPU to see significant gains for large scale operations. + +# + colab={"base_uri": "https://localhost:8080/"} id="4QhMOtYzLmVK" outputId="94b8b24a-80a2-4300-df37-aa833a6dde1c" tags=["raises-exception"] +# !nvidia-smi + +# + id="4zokGELhGPF8" +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +transforms = AugmentationSequential( + MinMaxNormalize(mins, maxs), + indices.AppendNDBI(index_swir=11, index_nir=7), + indices.AppendNDSI(index_green=3, index_swir=11), + indices.AppendNDVI(index_nir=7, index_red=3), + indices.AppendNDWI(index_green=2, index_nir=7), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomAffine(degrees=(0, 90), p=0.25), + K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25), + K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25), + data_keys=['image'], +) + +transforms_gpu = AugmentationSequential( + MinMaxNormalize(mins.to(device), maxs.to(device)), + indices.AppendNDBI(index_swir=11, index_nir=7), + indices.AppendNDSI(index_green=3, index_swir=11), + indices.AppendNDVI(index_nir=7, index_red=3), + indices.AppendNDWI(index_green=2, index_nir=7), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomAffine(degrees=(0, 90), p=0.25), + K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25), + K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25), + data_keys=['image'], +).to(device) + + +def get_batch_cpu(): + return dict(image=torch.randn(64, 13, 512, 512).to('cpu')) + + +def get_batch_gpu(): + return dict(image=torch.randn(64, 13, 512, 512).to(device)) + + +# + colab={"base_uri": "https://localhost:8080/"} id="vo43CqJ4IIXE" outputId="75e438f7-5ab1-47f4-9de9-b444b5b759f6" +# %%timeit -n 1 -r 5 +_ = transforms(get_batch_cpu()) + +# + colab={"base_uri": "https://localhost:8080/"} id="ICKXYZYrJCeh" outputId="9335cd58-90a6-4b8f-d27c-8bc833e76600" +# %%timeit -n 1 -r 5 +_ = transforms_gpu(get_batch_gpu()) + +# + [markdown] id="nkGy_g6tBAtF" +# ## Visualize Images and Labels + +# + [markdown] id="3k4W98v27NtL" +# This is a Google Colab browser for the EuroSAT dataset. Adjust the slider to visualize images in the dataset. + +# + id="O_6k7tcxz17x" +transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image']) +dataset = EuroSAT100(root, transforms=transforms) + +# + colab={"base_uri": "https://localhost:8080/", "height": 290} id="Uw8xDeg3BY-u" outputId="2b5f94c4-3aa8-4f30-a38d-b701b2332967" +# @title EuroSat Multispectral (MS) Browser { run: "auto", vertical-output: true } +idx = 21 # @param {type:"slider", min:0, max:59, step:1} +sample = dataset[idx] +rgb = sample['image'][0, 1:4] +image = T.ToPILImage()(rgb) +print(f"Class Label: {dataset.classes[sample['label']]}") +image.resize((256, 256), resample=Image.BILINEAR) From f343a06e299c08a0287c16285b936963b83a60c8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 19 May 2024 22:07:12 +0200 Subject: [PATCH 2/4] Add jupytext dependency --- pyproject.toml | 2 ++ requirements/docs.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9707f472d8e..a688ab757c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,8 @@ datasets = [ docs = [ # ipywidgets 7+ required by nbsphinx "ipywidgets>=7", + # TODO: find min version + "jupytext", # nbsphinx 0.8.5 fixes bug with nbformat attributes "nbsphinx>=0.8.5", # release versions missing files, must install from master diff --git a/requirements/docs.txt b/requirements/docs.txt index 6a996a311cc..956c751a5fb 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,5 @@ # docs ipywidgets==8.1.2 +jupytext==1.16.2 nbsphinx==0.9.4 sphinx==5.3.0 From 8d253b420a3d5125219fd8c8309977ed655b72da Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 19 May 2024 22:09:35 +0200 Subject: [PATCH 3/4] Ruff --- docs/conf.py | 4 +--- docs/tutorials/custom_raster_dataset.py | 4 ++++ docs/tutorials/transforms.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index b1321831812..3ab20f339f4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -126,9 +126,7 @@ } # nbsphinx -nbsphinx_custom_formats = { - '.py': 'jupytext.reads', -} +nbsphinx_custom_formats = {'.py': 'jupytext.reads'} nbsphinx_execute = 'never' # TODO: branch/tag should change depending on which version of docs you look at # TODO: width option of image directive is broken, see: diff --git a/docs/tutorials/custom_raster_dataset.py b/docs/tutorials/custom_raster_dataset.py index ccea38734ce..af042c9d20e 100644 --- a/docs/tutorials/custom_raster_dataset.py +++ b/docs/tutorials/custom_raster_dataset.py @@ -167,6 +167,7 @@ # # Putting this all together into a single class, we get: + # + id="8sFb8BTTTxZD" class Sentinel2(RasterDataset): filename_glob = 'T*_B02_10m.tif' @@ -194,6 +195,7 @@ class Sentinel2(RasterDataset): # # A great test to make sure that the dataset works correctly is to try to plot an image. We'll add a plot function to our dataset to help visualize it. First, we need to modify the image so that it only contains the RGB bands, and ensure that they are in the correct order. We also need to ensure that the image is in the range 0.0 to 1.0 (or 0 to 255). Finally, we'll create a plot using matplotlib. + # + id="7PNFOy9mYq6K" class Sentinel2(RasterDataset): filename_glob = 'T*_B02_10m.tif' @@ -246,6 +248,7 @@ def plot(self, sample): # # If you want to add custom parameters to the class, you can override the `__init__` method. For example, let's say you have imagery that can be automatically downloaded. The `RasterDataset` base class doesn't support this, but you could add support in your subclass. Simply copy the parameters from the base class and add a new `download` parameter. + # + id="TxODAvIHFKNt" class Downloadable(RasterDataset): def __init__(self, root, crs, res, transforms, cache, download=False): @@ -255,6 +258,7 @@ def __init__(self, root, crs, res, transforms, cache, download=False): # download the dataset ... + # + [markdown] id="cI43f8DMF3iR" # ## Contributing # diff --git a/docs/tutorials/transforms.py b/docs/tutorials/transforms.py index 1a039632243..7b38a9dd4bb 100644 --- a/docs/tutorials/transforms.py +++ b/docs/tutorials/transforms.py @@ -52,13 +52,13 @@ from torchgeo.datasets import EuroSAT100 from torchgeo.transforms import AugmentationSequential, indices - # + [markdown] id="oR3BCeV2AAop" # ## Custom Transforms # + [markdown] id="oVgqhF2udp4z" # Here we create a transform to show an example of how you can chain custom operations along with TorchGeo and Kornia transforms/augmentations. Note how our transform takes as input a Dict of Tensors. We specify our data by the keys ["image", "mask", "label", etc.] and follow this standard across TorchGeo datasets. + # + id="3mixIK7mAC9G" class MinMaxNormalize(K.IntensityAugmentationBase2D): """Normalize channels to the range [0, 1] using min/max values.""" @@ -182,7 +182,7 @@ def apply_transform( # ## Transforms Usage # + [markdown] id="p28C8cTGE3dP" -# Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using `nn.Sequential`. +# Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using `nn.Sequential`. # + colab={"base_uri": "https://localhost:8080/"} id="pJXUycffEjNX" outputId="d029826c-a546-4c8e-e254-db680c5045e8" transform = MinMaxNormalize(mins, maxs) From 334de54ecfb4c8b6cfd3fa3fde614d478db8b1d1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 20 May 2024 00:20:38 +0200 Subject: [PATCH 4/4] Temporary best effort builds --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2c11def62d9..e75d57c473c 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -20,4 +20,4 @@ python: # Configuration for Sphinx documentation sphinx: configuration: docs/conf.py - fail_on_warning: true + # fail_on_warning: true