Skip to content

Commit

Permalink
Merge pull request #84 from initze/augmentation
Browse files Browse the repository at this point in the history
Augmentation
  • Loading branch information
initze authored Jan 3, 2024
2 parents 5345512 + 143b5d9 commit 8ec9977
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 137 deletions.
20 changes: 12 additions & 8 deletions config_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
model:
# Model Architecture. Available:
# Unet, UnetPlusPlus, Unet3Plus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
architecture: Unet
architecture: UnetPlusPlus
# Model Encoder. Examples:
# resnet18, resnet34, resnet50, resnet101, resnet152
# Check https://github.com/qubvel/segmentation_models.pytorch#encoders for the full list of available encoders
Expand All @@ -24,29 +24,33 @@ tile_size: 512 # tile size in pixels
sampling_mode: deterministic
data_sources: # Enabled input features
- PlanetScope
- NDVI
- TCVIS
#- RelativeElevation
- RelativeElevation
- AbsoluteElevation
- Slope
- Hillshade
datasets:
train:
normalize: true
augment: true
augment_types:
- HorizontalFlip
- VerticalFlip
- Blur
- RandomRotate90
- RandomBrightnessContrast
- MultiplicativeNoise
- RandomHorizontalFlip
- RandomVerticalFlip
- RandomRotation
- RandomResizedCrop
- GaussianBlur
shuffle: true
scenes:
- 4694518_0870514_2021-07-15_1064
val:
normalize: true
augment: false
shuffle: false
scenes:
- 20200722_081437_1032
test:
normalize: true
augment: false
shuffle: false
scenes:
Expand Down
61 changes: 32 additions & 29 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,41 +1,44 @@
name: aicore
name: thaw_slump_segmentation
channels:
- pytorch
- conda-forge
- nvidia
- defaults
dependencies:
- earthengine-api=0.1.227
- albumentations
- earthengine-api
- einops
- efficientnet-pytorch=0.6.3
- efficientnet-pytorch==0.7.1
- fsspec
- geedim=1.7.0
- gdal
- geedim
- geemap
- geopandas
- h5netcdf=1.1.0
- joblib=1.0.1
- matplotlib=3.2.2
- numpy=1.18.5
- numpy-base=1.18.5
- opencv=4.8.1
- pandas=1.0.5
- pillow=9.4
- pip=20.1.1
- pretrainedmodels=0.7.4
- pyproj=2.6.1.post1
- python=3.7
- pytorch=1.13.1 #
- h5netcdf
- h5py
- joblib=1.3
- matplotlib
- numpy
- pandas
- pillow
- pip
- pretrainedmodels
- pyproj
- python=3.10
- pytorch=2.1.1
- pytorch-cuda=11.7
- pyyaml=5.3.1
- rasterio=1.1.5
- requests=2.24.0
- rioxarray=0.9.1
- scikit-image=0.19.3
- tensorboard=2.2.1
- tqdm=4.65.0
- torchvision=0.14.1 #
- timm=0.4.12
- wandb=0.15.3
- xarray=2023.4.2
- pyyaml
- rasterio
- requests
- rioxarray
- rtree
- scikit-image
- tensorboard=
- tqdm
- torchvision
- timm
- wandb=0.15.3
- xarray
- yacs
- pip:
- torchsummary=1.5.1
- torchsummary
2 changes: 2 additions & 0 deletions lib/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,5 @@ def create_encoding_dict(xr_dataset, encoding_dict={"compression": "gzip", "comp
return encoding




27 changes: 17 additions & 10 deletions lib/data/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import torch
import numpy as np
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from ..utils import Augment, Normalize
from ..utils.data import Augment_A2, Augment_TV, Normalize
from math import ceil
from einops import rearrange
from tqdm import tqdm
from pathlib import Path
from .base import _LAYER_REGISTRY
from torchvision.transforms import v2
from skimage.measure import find_contours


class NCDataset(Dataset):
Expand Down Expand Up @@ -48,6 +49,7 @@ def __getitem__(self, idx):
y0 = int(torch.randint(0, self.H - self.tile_size, ()))
x0 = int(torch.randint(0, self.W - self.tile_size, ()))
elif self.sampling_mode == 'targets_only':
# breaks when len = 0
bbox_idx = int(torch.randint(0, len(self.bboxes), ()))
ymin, xmin, ymax, xmax = self.bboxes[bbox_idx]

Expand Down Expand Up @@ -84,7 +86,7 @@ def __getitem__(self, idx):
if 'Mask' in tile:
return (
np.concatenate([tile[k] for k in tile if k != 'Mask'], axis=0),
tile['Mask'],
tile['Mask'].squeeze(),
metadata
)
else:
Expand All @@ -94,14 +96,18 @@ def __getitem__(self, idx):
)

def __len__(self):
# skip dataset if empty
if (self.sampling_mode == 'targets_only'):
if (len(self.bboxes) == 0):
return 0
return self.H_tile * self.W_tile


def single_tile_loader(tile_path, config):
data = NCDataset(tile_path, config)

return DataLoader(
all_data,
data,
shuffle = False,
batch_size = config['batch_size'],
num_workers=config['num_workers'],
Expand All @@ -128,26 +134,27 @@ def get_loader(config):
scene_names = config['scenes']
scenes = [NCDataset(f'{root}/{scene}.nc', config) for scene in scene_names]
all_data = ConcatDataset(scenes)


if config['augment']:
# add loading from config
# check if validation also gets augmented
print(config['augment_types'])
if config['augment_types'] is not None:
all_data = Augment(all_data, augment_types=config['augment_types'])

# TODO: test if normalization step can be used here
all_data = Normalize(all_data)
# Torchvision
all_data = Augment_TV(all_data, augment_types=config['augment_types'], tile_size=config['tile_size'])
# albumentations
#all_data = Augment_A2(all_data, augment_types=config['augment_types'], tile_size=config['tile_size'])
# moving it one level lower breaks validation
if config['normalize']:
all_data = Normalize(all_data)

return DataLoader(
all_data,
shuffle = (config['sampling_mode'] != 'deterministic'),
batch_size=config['batch_size'],
num_workers=config['num_workers'],
persistent_workers=True,
pin_memory=True,

pin_memory=True
)


Expand Down
20 changes: 19 additions & 1 deletion lib/data/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
from .base import TileSource, Scene, cache_path
from lib.data_pre_processing import udm

def scale_array_ignore_zero_add_constant_per_band_3D(arr, constant=1e-5, perc_min=2, perc_max=98):
if (arr == 0).all():
return arr

non_zero_mask = arr != 0
min_vals = np.percentile(arr[non_zero_mask], perc_min, axis=0)
max_vals = np.percentile(arr[non_zero_mask], perc_max, axis=0)
if max_vals == min_vals:
max_vals += 1e-4

out = np.clip(np.where(non_zero_mask, (arr - min_vals) / (max_vals - min_vals), 0), constant, 1)
out[~non_zero_mask] = 0
return out

class PlanetScope(TileSource):
def __init__(self, tile_path: Union[str, Path]):
Expand Down Expand Up @@ -65,4 +78,9 @@ def __repr__(self):

@staticmethod
def normalize(tile):
return np.clip((tile / 5000), 0, 1)
clipped = np.clip((tile / 3000), 0, 1)
#np.clip(scale_array_ignore_zero_add_constant(clipped), 0, 1)
return clipped
#return scale_array_ignore_zero_add_constant_per_band_3D(clipped)
#return np.clip((tile / 3000), 0, 1)

3 changes: 2 additions & 1 deletion lib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

from .plot_info import *
from .logging import init_logging, get_logger, log_run
from .data import Augment, Transformed, Scaling, Normalize
from .data import Transformed, Scaling, Normalize, Augment_A2, Augment_TV
from math import ceil
from .images import extract_patches, Compositor, extract_contours
Loading

0 comments on commit 8ec9977

Please sign in to comment.