Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ensemble proba #135

Merged
merged 11 commits into from
Jun 3, 2024
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ dist/
.python-version
requirements.lock
requirements-dev.lock
wandb
wandb
models
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,38 @@ gdal_path: '%CONDA_PREFIX%\Scripts' # must be single quote
gdal_bin: '%CONDA_PREFIX%\Library\bin' # must be single quote
```

## CLI

Run in dev:

```sh
$ rye run thaw-slump-segmentation hello tobi
Hello tobi
```

or run as python module:

```sh
$ rye run python -m thaw_slump_segmentation hello tobi
Hello tobi
```

With activated env, e.g. after installation, just remove the `rye run`:

```sh
$ source .venv/bin/activate
$ thaw-slump-segmentation hello tobi
Hello tobi
```

or

```sh
$ source .venv/bin/activate
$ python -m thaw_slump_segmentation hello tobi
Hello tobi
```

## Data Processing

### Data Preprocessing for Planet data
Expand Down
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.10.2"
description = "Thaw slump segmentation workflow using PlanetScope data and pytorch"
authors = [
{ name = "Ingmar Nitze", email = "[email protected]" },
{ name = "Konrad Heidler", email = "[email protected]" }
{ name = "Konrad Heidler", email = "[email protected]" },
]
dependencies = [
"torch==2.2.0",
Expand Down Expand Up @@ -42,7 +42,11 @@ dependencies = [
"opencv-python>=4.9.0.80",
"swifter>=1.4.0",
"mkdocs-awesome-pages-plugin>=2.9.2",
"rich"
"ruff>=0.4.2",
"ipykernel>=6.29.4",
"rich>=13.7.1",
"torchsummary>=1.5.1",
"typer>=0.12.3",
]
readme = "README.md"
requires-python = ">= 3.10"
Expand All @@ -56,6 +60,7 @@ process_02_inference = "thaw_slump_segmentation.scripts.process_02_inference:mai
process_03_ensemble = "thaw_slump_segmentation.scripts.process_03_ensemble:main"
setup_raw_data = "thaw_slump_segmentation.scripts.setup_raw_data:main"
train = "thaw_slump_segmentation.scripts.train:main"
thaw-slump-segmentation = "thaw_slump_segmentation.main:cli"

[build-system]
requires = ["hatchling"]
Expand All @@ -79,3 +84,10 @@ packages = ["src/thaw_slump_segmentation"]
# [[tool.rye.sources]]
# name = "pytorch"
# url = "https://download.pytorch.org/whl/cu118"

[tool.ruff]
line-length = 120

[tool.ruff.format]
quote-style = "single"
docstring-code-format = true
3 changes: 3 additions & 0 deletions src/thaw_slump_segmentation/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from thaw_slump_segmentation.main import cli

cli()
37 changes: 21 additions & 16 deletions src/thaw_slump_segmentation/data_pre_processing/gdal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,45 @@

import os
import sys
import yaml
from pathlib import Path

import yaml

_module = sys.modules[__name__]


def initialize(args=None):
def initialize(args=None, *, bin=None, path=None):
# If command line arguments are given, use those:
system_yml = Path('system.yml')

if args.gdal_bin is not None:
#print('Manually set path')
if args is not None:
# print('Manually set path')
_module.gdal_path = args.gdal_path
_module.gdal_bin = args.gdal_bin

elif bin is not None and path is not None:
# print('Manually set path')
_module.gdal_path = path
_module.gdal_bin = bin
# Otherwise, fall back to the ones from system.yml
elif system_yml.exists():
#print('yml file')
# print('yml file')
system_config = yaml.load(system_yml.open(), Loader=yaml.SafeLoader)
if 'gdal_path' in system_config:
_module.gdal_path = system_config['gdal_path']
if 'gdal_bin' in system_config:
_module.gdal_bin = system_config['gdal_bin']

else:
#print('Empty path')
# print('Empty path')
_module.gdal_path = ''
_module.gdal_bin = ''
#print(_module.gdal_bin)
#print(_module.gdal_path)
_module.rasterize = os.path.join(_module.gdal_bin, 'gdal_rasterize')
_module.translate = os.path.join(_module.gdal_bin, 'gdal_translate')
_module.warp = os.path.join(_module.gdal_bin, 'gdalwarp')

_module.merge = os.path.join(_module.gdal_path, 'gdal_merge.py')
_module.retile = os.path.join(_module.gdal_path, 'gdal_retile.py')

# print(_module.gdal_bin)
# print(_module.gdal_path)
_module.rasterize = os.path.join(_module.gdal_bin, 'gdal_rasterize')
_module.translate = os.path.join(_module.gdal_bin, 'gdal_translate')
_module.warp = os.path.join(_module.gdal_bin, 'gdalwarp')

_module.merge = os.path.join(_module.gdal_path, 'gdal_merge.py')
_module.retile = os.path.join(_module.gdal_path, 'gdal_retile.py')
_module.polygonize = os.path.join(_module.gdal_path, 'gdal_polygonize.py')
93 changes: 71 additions & 22 deletions src/thaw_slump_segmentation/data_pre_processing/udm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,50 @@


def get_mask_from_udm2(infile, bands=[1,2,4,5], nodata=1):
"""
return masked array, where True = NoData
"""
with rio.open(infile) as src:
a = src.read()
mask = a[bands].max(axis=0) < nodata
mask = a[bands].max(axis=0) == nodata
return mask

def get_mask_from_udm2_v2(infile, band=1):

def get_mask_from_udm2_v2(infile):
"""
Create a data mask from a UDM2 V2 file.

The data mask is a boolean array where True values represent good data,
and False values represent no data or unusable data.

Args:
infile (str): Path to the input UDM2 V2 file.

Returns:
numpy.ndarray: A boolean array representing the data mask.

Notes:
- The function assumes that the input file is a UDM2 V2 file with
specific band meanings:
- Band 0: Clear data
- Bands 1, 2, 4, 5: Unusable data (if any of these bands has a value of 1)
- Band 7: No data

- The data mask is created by combining the following conditions:
- Clear data (band 0 == 1)
- Not no data (band 7 != 1)
- Not unusable data (maximum of bands 1, 2, 4, 5 != 1)

- The function uses the rasterio library to read the input file.
"""
with rio.open(infile) as src:
a = src.read(1)
return a
a = src.read()
unusable_data = a[[1,2,4,5]].max(axis=0) == 1
clear_data = a[[0]] == 1
no_data = a[[7]] == 1
# final data mask: 0 = no or crappy data, 1 = good data
data_mask = np.logical_and(clear_data, ~np.logical_or(no_data, unusable_data))
return data_mask


def get_mask_from_udm(infile, nodata=1):
Expand All @@ -27,30 +62,44 @@ def get_mask_from_udm(infile, nodata=1):
return np.array(mask, dtype=np.uint8)


# TODO: inconsistent application what the mask value is
def burn_mask(file_src, file_dst, file_udm, file_udm2=None, mask_value=0):
with rio.Env():
# checks

if file_udm:
mask_udm = get_mask_from_udm(file_udm)

"""
Apply a data mask to a raster file and save the result to a new file.

Args:
file_src (str): Path to the input raster file.
file_dst (str): Path to the output raster file.
file_udm (str): Path to the UDM file (not used in this function).
file_udm2 (str, optional): Path to the UDM2 V2 file. If provided, the data mask
will be derived from this file using the `get_mask_from_udm2_v2` function.
mask_value (int, optional): Value to use for masked (invalid) pixels in the output file.
Default is 0.

Returns:
int: Always returns 1 (for successful execution).

Raises:
ValueError: If `file_udm2` is not provided.

Notes:
- The function reads the input raster file using rasterio and applies the data mask
by multiplying the raster data with the mask.
- The masked raster data is then written to the output file with the same metadata
as the input file.
- If `file_udm2` is not provided, a `ValueError` is raised.
- The `file_udm` parameter is not used in this function.
- The function uses the rasterio library for reading and writing raster files.
"""
with rio.Env():
if file_udm2:
mask_udm2 = get_mask_from_udm2(file_udm2)
#mask_udm2 = get_mask_from_udm2_v2(file_udm2)

# merge masks if both
if ((file_udm is not None) & (file_udm2 is not None)):
clear_mask = np.array([mask_udm, mask_udm2]).max(axis=0) == mask_value
# only use one of udm or udm if only one exists
elif file_udm is not None:
clear_mask = mask_udm
elif file_udm2 is not None:
clear_mask = mask_udm2
mask_udm2 = get_mask_from_udm2_v2(file_udm2)
else:
raise ValueError

with rio.open(file_src) as ds_src:
data = ds_src.read()*clear_mask
# apply data mask (multiply)
data = ds_src.read() * mask_udm2
profile = ds_src.profile
with rio.open(file_dst, 'w', **profile) as ds_dst:
ds_dst.write(data)
Expand Down
5 changes: 4 additions & 1 deletion src/thaw_slump_segmentation/data_pre_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_S
image_files = []
for im in images:
image_files.extend([f for f in flist if im in f])

# check which udms are available, if not then set to None
try:
udm_file = [f for f in flist if udm in f][0]
Expand All @@ -94,6 +95,7 @@ def get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_S
udm2_file = [f for f in flist if udm2 in f][0]
except:
udm2_file = None

# raise error if no udms available
if (udm_file == None) & (udm2_file == None):
raise ValueError(f'There are no udm or udm2 files for image {image_directory.name}!')
Expand Down Expand Up @@ -161,13 +163,14 @@ def aux_data_to_tiles(image_directory, aux_data, outfile):
# load template and get props
images = get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_SR.tif'])
image = images['images'][0]

# prepare gdalwarp call
xmin, xmax, ymin, ymax = geom_from_image_bounds(image)
crs = crs_from_image(image)
xres, yres = resolution_from_image(image)

# run gdalwarp call
outfile = f'{image_directory}/{outfile}'#os.path.join(image_directory,outfile)
s_run = f'{gdal.warp} -te {xmin} {ymin} {xmax} {ymax} -tr {xres} {yres} -r cubic -t_srs {crs} -co COMPRESS=DEFLATE {aux_data} {outfile}'
#s_run = f'{gdal.warp} -te {xmin} {ymin} {xmax} {ymax} -tr 3 3 -r cubic -t_srs {crs} -co COMPRESS=DEFLATE {aux_data} {outfile}'
log_run(s_run, _logger)
return 1
49 changes: 49 additions & 0 deletions src/thaw_slump_segmentation/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import typer

from thaw_slump_segmentation.scripts.download_s2_4band_planet_format import download_s2_4band_planet_format
from thaw_slump_segmentation.scripts.inference import inference
from thaw_slump_segmentation.scripts.prepare_data import prepare_data
from thaw_slump_segmentation.scripts.prepare_s2_4band_planet_format import prepare_s2_4band_planet_format
from thaw_slump_segmentation.scripts.process_02_inference import process_02_inference
from thaw_slump_segmentation.scripts.process_03_ensemble import process_03_ensemble
from thaw_slump_segmentation.scripts.setup_raw_data import setup_raw_data
from thaw_slump_segmentation.scripts.train import train

# TODO: Move this comment to docs
# GEE is used for:
# - Setup of raw data (init call inside function)
# - download of S2 images (init call global at module level)
# - prepare of S2 images (init call global at module level)
# GDAL is used for:
# - Setup of raw data (in threaded function)
# - prepare data (in main function)
# - inference (in main function)
# - prepare of S2 images (but its not implemented via gdal module but hardcoded)

cli = typer.Typer(pretty_exceptions_show_locals=False)


@cli.command()
def hello(name: str):
typer.echo(f'Hello {name}')


cli.command()(train)
cli.command()(inference)


data_cli = typer.Typer()

data_cli.command('download-planet')(download_s2_4band_planet_format)
data_cli.command('prepare-planet')(prepare_s2_4band_planet_format)
data_cli.command('setup-raw')(setup_raw_data)
data_cli.command('prepare')(prepare_data)

cli.add_typer(data_cli, name='data')

process_cli = typer.Typer()

process_cli.command('inference')(process_02_inference)
process_cli.command('ensemble')(process_03_ensemble)

cli.add_typer(process_cli, name='process')
Loading