Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TileDatasets #1353

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions conf/l7irishtile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: true
learning_rate: 1e-4
learning_rate_schedule_patience: 6
in_channels: 8
num_classes: 5
num_filters: 64
ignore_index: 0
weight_decay: 0

datamodule:
_target_: torchgeo.datamodules.L7IrishTileDataModule
root: "/home/calebrobinson/ssdprivate/torchgeo/data/L7IrishSimple/"
batch_size: 32
patch_size: 256
train_batches_per_epoch: 2000
val_batches_per_epoch: 200
num_workers: 6

trainer:
_target_: lightning.pytorch.Trainer
accelerator: gpu
devices:
- 3
min_epochs: 50
max_epochs: 100

program:
seed: 0
output_dir: output/l7irish/
log_dir: logs/l7irish/
overwrite: True
experiment_name: unet_imagenet_lr1e-4_wd0
81 changes: 81 additions & 0 deletions experiments/ssl4eo/run_l7irish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue

# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0, 0, 1, 1, 2, 2, 3, 3]
DRY_RUN = False # if False then print out the commands to be run, if True then run

# Hyperparameter options
model_options = ["unet", "fcn"]
backbone_options = ["resnet18"]
lr_options = [0.001, 0.0003, 0.0001, 0.00003]
loss_options = ["ce"]
wd_options = [0, 0.1, 0.01]
weight_options = [True, False]
seed_options = [0, 1, 2]


def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True


if __name__ == "__main__":
work: "Queue[str]" = Queue()

for model, backbone, lr, loss, wd, weights, seed in itertools.product(
model_options,
backbone_options,
lr_options,
loss_options,
wd_options,
weight_options,
seed_options,
):
if model == "fcn" and not weights:
continue

if model != "unet":
experiment_name = f"{model}_{backbone}_{lr}_{loss}_{wd}_{weights}_{seed}"
else:
experiment_name = f"{model}_{lr}_{loss}_{wd}_{weights}_{seed}"

config_file = os.path.join("conf", "l7irishtile.yaml")
Copy link
Collaborator

@nilsleh nilsleh May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make one run_{downstream_task}.py script that has the config file name as an additional variable at the beginning of the file? Then we can use this one script for all the different downstream tasks, or are there some differences beyond the config files that we need to account for?


command = (
"python train.py"
+ f" config_file={config_file}"
+ f" module.model={model}"
+ f" module.backbone={backbone}"
+ f" module.learning_rate={lr}"
+ f" module.loss={loss}"
+ f" module.weight_decay={wd}"
+ f" module.weights={weights}"
+ f" program.seed={seed}"
+ f" program.experiment_name={experiment_name}"
+ " trainer.devices=[GPU]"
)
command = command.strip()

work.put(command)

processes = []
for gpu_idx in GPUS:
p = Process(target=do_work, args=(work, gpu_idx))
processes.append(p)
p.start()
for p in processes:
p.join()
3 changes: 2 additions & 1 deletion torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .l7irish import L7IrishDataModule
from .l7irish import L7IrishDataModule, L7IrishTileDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
Expand All @@ -41,6 +41,7 @@
# GeoDataset
"ChesapeakeCVPRDataModule",
"L7IrishDataModule",
"L7IrishTileDataModule",
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
# NonGeoDataset
Expand Down
89 changes: 87 additions & 2 deletions torchgeo/datamodules/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from typing import Any, Optional, Union

from lightning.pytorch import LightningDataModule
import torch
from torch.utils.data import DataLoader

from ..datasets import L7Irish, random_bbox_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..datasets import L7Irish, random_bbox_assignment, TileDataset
from ..samplers import GridGeoSampler, RandomBatchGeoSampler, RandomTileGeoSampler, GridTileGeoSampler
from .geo import GeoDataModule


Expand Down Expand Up @@ -74,3 +76,86 @@ def setup(self, stage: str) -> None:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)


class L7IrishTileDataModule(LightningDataModule):

@staticmethod
def preprocess(sample):
sample["image"] = sample["image"] / 255.0

mask_mapping = {64: 1, 128: 2, 191: 3, 255: 4}
if "mask" in sample:
mask = sample["mask"].squeeze()
for k, v in mask_mapping.items():
mask[mask == k] = v
sample["mask"] = mask
return sample

def _get_all_the_fns(self, root):
import os
areas = L7Irish.md5s.keys()
image_fns = []
mask_fns = []
for area in areas:
for path_row in os.listdir(os.path.join(root,area)):
if path_row == "p46_r14":
continue
path, row = path_row.split("_")[:2]
image_fns.append(os.path.join(root,area,path_row,f"L7_{path}_{row}_stacked.TIF"))
mask_fns.append(os.path.join(root,area,path_row,f"L7_{path}_{row}_newmask2015.TIF"))
return image_fns, mask_fns

def __init__(self, root, batch_size=1, patch_size=32, train_batches_per_epoch=None, val_batches_per_epoch=None, num_workers=0, seed=0):
super().__init__()
self.image_fns, self.mask_fns = self._get_all_the_fns(root)
self.batch_size = batch_size
self.patch_size = patch_size
self.train_batches_per_epoch = train_batches_per_epoch
self.val_batches_per_epoch = val_batches_per_epoch
self.num_workers = num_workers

generator = torch.Generator().manual_seed(seed)

idxs = torch.randperm(len(self.image_fns), generator=generator)
train_idxs = idxs[:int(len(idxs)*0.6)]
val_idxs = idxs[int(len(idxs)*0.6):int(len(idxs)*0.8)]
test_idxs = idxs[int(len(idxs)*0.8):]

self.train_image_fns = [self.image_fns[i] for i in train_idxs]
self.train_mask_fns = [self.mask_fns[i] for i in train_idxs]
self.val_image_fns = [self.image_fns[i] for i in val_idxs]
self.val_mask_fns = [self.mask_fns[i] for i in val_idxs]
self.test_image_fns = [self.image_fns[i] for i in test_idxs]
self.test_mask_fns = [self.mask_fns[i] for i in test_idxs]

def setup(self, stage):
self.train_dataset = TileDataset(self.train_image_fns, self.train_mask_fns, transforms=L7IrishTileDataModule.preprocess)
self.val_dataset = TileDataset(self.val_image_fns, self.val_mask_fns, transforms=L7IrishTileDataModule.preprocess)
self.test_dataset = TileDataset(self.test_image_fns, self.test_mask_fns, transforms=L7IrishTileDataModule.preprocess)

# def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
# return super().on_after_batch_transfer(batch, dataloader_idx)

def train_dataloader(self):
sampler = RandomTileGeoSampler(self.train_dataset, self.patch_size, self.batch_size * self.train_batches_per_epoch)
return DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)

def val_dataloader(self):
sampler = RandomTileGeoSampler(self.val_dataset, self.patch_size, self.batch_size * self.val_batches_per_epoch)
return DataLoader(self.val_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)

def test_dataloader(self):
sampler = GridTileGeoSampler(self.test_dataset, self.patch_size, self.patch_size)
return DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers)

def plot(self, sample):
import matplotlib.pyplot as plt
image = sample["image"].permute(1,2,0).numpy()
mask = sample["mask"].numpy().squeeze()
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(image[:,:,[2,1,0]])
axs[0].axis("off")
axs[1].imshow(mask, vmin=0, vmax=4)
axs[1].axis("off")
return fig
3 changes: 3 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
)
from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12
from .sustainbench_crop_yield import SustainBenchCropYield
from .tile import TileDataset
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
Expand Down Expand Up @@ -241,4 +242,6 @@
"random_grid_cell_assignment",
"roi_split",
"time_series_split",
# TileDataset
"TileDataset",
)
59 changes: 59 additions & 0 deletions torchgeo/datasets/tile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import rasterio
import rasterio.io
import rasterio.merge
import rasterio.windows
import torch
from torch.utils.data import Dataset


class TileDataset(Dataset):

def __init__(self, image_fns, mask_fns=None, transforms=None, sanity_check=False):
super().__init__()
self.image_fns = image_fns
self.mask_fns = mask_fns
if self.mask_fns is not None:
assert len(image_fns) == len(mask_fns)

if sanity_check and mask_fns is not None:
for image_fn, mask_fn in zip(image_fns, mask_fns):
with rasterio.open(image_fn) as f:
image_height, image_width = f.shape
with rasterio.open(mask_fn) as f:
mask_height, mask_width = f.shape
assert image_height == mask_height
assert image_width == mask_width

self.transforms = transforms

def __len__(self):
return len(self.image_fns)

def __getitem__(self, index):
i, y, x, patch_size = index
assert 0 <= i < len(self.image_fns)

sample = {
"y": y,
"x": x,
}

window = rasterio.windows.Window(
x, y, patch_size, patch_size
)

image_fn = self.image_fns[i]
with rasterio.open(image_fn) as f:
image = f.read(window=window)
sample["image"] = torch.from_numpy(image).float()

if self.mask_fns is not None:
mask_fn = self.mask_fns[i]
with rasterio.open(mask_fn) as f:
mask = f.read(window=window)
sample["mask"] = torch.from_numpy(mask).long()

if self.transforms is not None:
sample = self.transforms(sample)

return sample
3 changes: 3 additions & 0 deletions torchgeo/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from .constants import Units
from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler
from .utils import get_random_bounding_box, tile_to_chips
from .tile import RandomTileGeoSampler, GridTileGeoSampler

__all__ = (
# Samplers
"GridGeoSampler",
"GridTileGeoSampler",
"PreChippedGeoSampler",
"RandomGeoSampler",
# Batch samplers
"RandomBatchGeoSampler",
"RandomTileGeoSampler",
# Base classes
"GeoSampler",
"BatchGeoSampler",
Expand Down
Loading