From feb239b75241fccbe5bd1dd1de56af4ffc6db5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 2 May 2024 02:41:32 +0200 Subject: [PATCH 1/9] Convert CLI of train, inference, download-planet and setuip-raw --- README.md | 32 +++ pyproject.toml | 16 +- src/thaw_slump_segmentation/__main__.py | 3 + src/thaw_slump_segmentation/main.py | 29 ++ .../download_s2_4band_planet_format.py | 48 +++- .../scripts/inference.py | 266 ++++++++++++------ .../scripts/setup_raw_data.py | 132 ++++++--- src/thaw_slump_segmentation/scripts/train.py | 133 ++++++--- 8 files changed, 485 insertions(+), 174 deletions(-) create mode 100644 src/thaw_slump_segmentation/__main__.py create mode 100644 src/thaw_slump_segmentation/main.py diff --git a/README.md b/README.md index 3f235bf..717c4c1 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,38 @@ gdal_path: '%CONDA_PREFIX%\Scripts' # must be single quote gdal_bin: '%CONDA_PREFIX%\Library\bin' # must be single quote ``` +## CLI + +Run in dev: + +```sh +$ rye run thaw-slump-segmentation hello tobi +Hello tobi +``` + +or run as python module: + +```sh +$ rye run python -m thaw_slump_segmentation hello tobi +Hello tobi +``` + +With activated env, e.g. after installation, just remove the `rye run`: + +```sh +$ source .venv/bin/activate +$ thaw-slump-segmentation hello tobi +Hello tobi +``` + +or + +```sh +$ source .venv/bin/activate +$ python -m thaw_slump_segmentation hello tobi +Hello tobi +``` + ## Data Processing ### Data Preprocessing for Planet data diff --git a/pyproject.toml b/pyproject.toml index f8240b2..2298396 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.10.1" description = "Add your description here" authors = [ { name = "Ingmar Nitze", email = "ingmar.nitze@awi.de" }, - { name = "Konrad Heidler", email = "k.heidler@tum.de" } + { name = "Konrad Heidler", email = "k.heidler@tum.de" }, ] dependencies = [ "torch==2.2.0", @@ -42,7 +42,11 @@ dependencies = [ "opencv-python>=4.9.0.80", "swifter>=1.4.0", "mkdocs-awesome-pages-plugin>=2.9.2", - "rich" + "ruff>=0.4.2", + "ipykernel>=6.29.4", + "rich>=13.7.1", + "torchsummary>=1.5.1", + "typer>=0.12.3", ] readme = "README.md" requires-python = ">= 3.10" @@ -56,6 +60,7 @@ process_02_inference = "thaw_slump_segmentation.scripts.process_02_inference:mai process_03_ensemble = "thaw_slump_segmentation.scripts.process_03_ensemble:main" setup_raw_data = "thaw_slump_segmentation.scripts.setup_raw_data:main" train = "thaw_slump_segmentation.scripts.train:main" +thaw-slump-segmentation = "thaw_slump_segmentation.main:cli" [build-system] requires = ["hatchling"] @@ -79,3 +84,10 @@ packages = ["src/thaw_slump_segmentation"] # [[tool.rye.sources]] # name = "pytorch" # url = "https://download.pytorch.org/whl/cu118" + +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = "single" +docstring-code-format = true diff --git a/src/thaw_slump_segmentation/__main__.py b/src/thaw_slump_segmentation/__main__.py new file mode 100644 index 0000000..82a5110 --- /dev/null +++ b/src/thaw_slump_segmentation/__main__.py @@ -0,0 +1,3 @@ +from thaw_slump_segmentation.main import cli + +cli() diff --git a/src/thaw_slump_segmentation/main.py b/src/thaw_slump_segmentation/main.py new file mode 100644 index 0000000..61dd655 --- /dev/null +++ b/src/thaw_slump_segmentation/main.py @@ -0,0 +1,29 @@ +import typer + +from thaw_slump_segmentation.scripts.download_s2_4band_planet_format import download_s2_4band_planet_format +from thaw_slump_segmentation.scripts.inference import inference +from thaw_slump_segmentation.scripts.setup_raw_data import setup_raw_data +from thaw_slump_segmentation.scripts.train import train + +cli = typer.Typer() + +cli.command()(train) +cli.command()(inference) + + +@cli.command() +def hello(name: str): + typer.echo(f'Hello {name}') + + +@cli.command() +def goodbye(name: str): + typer.echo(f'Goodbye {name}') + + +data_cli = typer.Typer() + +data_cli.command('download-planet')(download_s2_4band_planet_format) +data_cli.command('setup-raw')(setup_raw_data) + +cli.add_typer(data_cli, name='data') diff --git a/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py b/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py index f076fbc..fc06109 100644 --- a/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py +++ b/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py @@ -1,19 +1,18 @@ -import rasterio -import os +import argparse from pathlib import Path -import numpy as np +from typing import List + import ee -import eemont import geemap -ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com') +import typer +from typing_extensions import Annotated -import argparse +ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com') def download_S2image_preprocessed(s2_image_id, outfile, outbands=['B2', 'B3', 'B4', 'B8'], factor=1e4): - ic = ee.ImageCollection(ee.Image(f'COPERNICUS/S2_SR_HARMONIZED/{s2_image_id}')) # load basic image and preprocess (maks clouds, scale and offset) - image = ee.Image(f'COPERNICUS/S2_SR_HARMONIZED/{s2_image_id}').preprocess()#.spectralIndices(['NDVI']) + image = ee.Image(f'COPERNICUS/S2_SR_HARMONIZED/{s2_image_id}').preprocess() # .spectralIndices(['NDVI']) # select corresponding bands image_4Band = image.select(outbands) # scale by 10k and convert to uint16 @@ -24,20 +23,41 @@ def download_S2image_preprocessed(s2_image_id, outfile, outbands=['B2', 'B3', 'B geemap.download_ee_image(ee.Image(image_out), outfile) return 0 -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Download preprocessed S2 image.', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +def download_s2_4band_planet_format( + data_dir: Annotated[Path, typer.Argument(help='Output directory')], + s2ids: Annotated[List[str], typer.Argument(help='S2 image ID, you can use several separated by space')], +): + """Download preprocessed S2 image.""" + for s2id in s2ids: + # Call the function with the provided s2id + outfile = data_dir / s2id / f'{s2id}_SR.tif' + if not data_dir.exists(): + print('Creating output directory', data_dir) + data_dir.mkdir() + download_S2image_preprocessed(s2id, outfile) + + +# ! Moving legacy argparse cli to main to maintain compatibility with the original script +def main(): + parser = argparse.ArgumentParser( + description='Download preprocessed S2 image.', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument('--s2id', type=str, nargs='+', help='S2 image ID, you can use several separated by space') parser.add_argument('--data_dir', type=str, help='Output directory') args = parser.parse_args() - + outdir = Path(args.data_dir) s2id = args.s2id - + for s2id in args.s2id: # Call the function with the provided s2id - outfile = outdir/ s2id / f'{s2id}_SR.tif' + outfile = outdir / s2id / f'{s2id}_SR.tif' if not outdir.exists(): print('Creating output directory', outdir) outdir.mkdir() download_S2image_preprocessed(s2id, outfile) + + +if __name__ == '__main__': + main() diff --git a/src/thaw_slump_segmentation/scripts/inference.py b/src/thaw_slump_segmentation/scripts/inference.py index a610b5a..e2fff7a 100644 --- a/src/thaw_slump_segmentation/scripts/inference.py +++ b/src/thaw_slump_segmentation/scripts/inference.py @@ -10,26 +10,28 @@ """ import argparse +import os +from collections import namedtuple +from datetime import datetime from pathlib import Path +from typing import List -import rasterio as rio -import numpy as np import matplotlib.pyplot as plt -import os +import numpy as np +import rasterio as rio import torch import torch.nn as nn +import typer +import yaml from tqdm import tqdm -from datetime import datetime +from typing_extensions import Annotated -from ..models import create_model -from ..utils.plot_info import flatui_cmap -from ..utils import init_logging, get_logger, log_run +from ..data_loading import DataSources from ..data_pre_processing import gdal - +from ..models import create_model from ..scripts.setup_raw_data import preprocess_directory -from ..data_loading import DataSources - -import yaml +from ..utils import get_logger, init_logging, log_run +from ..utils.plot_info import flatui_cmap cmap_prob = flatui_cmap('Midnight Blue', 'Alizarin') cmap_dem = flatui_cmap('Alizarin', 'Clouds', 'Peter River') @@ -38,53 +40,33 @@ FIGSIZE_MAX = 20 -parser = argparse.ArgumentParser(description='Inference Script', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--gdal_bin", default='', help="Path to gdal binaries") -parser.add_argument("--gdal_path", default='', help="Path to gdal scripts") -parser.add_argument("--n_jobs", default=-1, type=int, help="number of parallel joblib jobs") -parser.add_argument("--ckpt", default='latest', type=str, help="Checkpoint to use") -parser.add_argument("--data_dir", default='data', type=Path, help="Path to data processing dir") -parser.add_argument("--log_dir", default='logs', type=Path, help="Path to log dir") -parser.add_argument("--inference_dir", default='inference', type=Path, help="Main inference directory") -parser.add_argument("-n", "--name", default=None, type=str, help="Name of inference run, data will be stored in subdirectory") -parser.add_argument("-m", "--margin_size", default=256, type=int, help="Size of patch overlap") -parser.add_argument("-p", "--patch_size", default=1024, type=int, help="Size of patches") -parser.add_argument("model_path", type=str, help="path to model, use the model base path") -parser.add_argument("tile_to_predict", type=str, help="path to model", nargs='+') - -args = parser.parse_args() -gdal.initialize(args) - - -def predict(model, imagery, device='cpu'): + +def predict(model, imagery, patch_size, margin_size, device='cpu'): prediction = torch.zeros(1, *imagery.shape[2:]) weights = torch.zeros(1, *imagery.shape[2:]) - PS = args.patch_size - MARGIN = args.margin_size - - margin_ramp = torch.cat([ - torch.linspace(0, 1, MARGIN), - torch.ones(PS - 2 * MARGIN), - torch.linspace(1, 0, MARGIN), - ]) - - soft_margin = margin_ramp.reshape(1, 1, PS) * \ - margin_ramp.reshape(1, PS, 1) - - for y in np.arange(0, imagery.shape[2], (PS - MARGIN)): - for x in np.arange(0, imagery.shape[3], (PS - MARGIN)): - if y + PS > imagery.shape[2]: - y = imagery.shape[2] - PS - if x + PS > imagery.shape[3]: - x = imagery.shape[3] - PS - patch_imagery = imagery[:, :, y:y + PS, x:x + PS] + margin_ramp = torch.cat( + [ + torch.linspace(0, 1, margin_size), + torch.ones(patch_size - 2 * margin_size), + torch.linspace(1, 0, margin_size), + ] + ) + + soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1) + + for y in np.arange(0, imagery.shape[2], (patch_size - margin_size)): + for x in np.arange(0, imagery.shape[3], (patch_size - margin_size)): + if y + patch_size > imagery.shape[2]: + y = imagery.shape[2] - patch_size + if x + patch_size > imagery.shape[3]: + x = imagery.shape[3] - patch_size + patch_imagery = imagery[:, :, y : y + patch_size, x : x + patch_size] patch_pred = torch.sigmoid(model(patch_imagery.to(device))[0].cpu()) # Essentially premultiplied alpha blending - prediction[:, y:y + PS, x:x + PS] += patch_pred * soft_margin - weights[:, y:y + PS, x:x + PS] += soft_margin + prediction[:, y : y + patch_size, x : x + patch_size] += patch_pred * soft_margin + weights[:, y : y + patch_size, x : x + patch_size] += soft_margin # Avoid division by zero weights = torch.where(weights == 0, torch.ones_like(weights), weights) @@ -96,30 +78,35 @@ def flush_rio(filepath): a file after finishing a `with rio.open(...) as ...:` block Trying to open the file for reading seems to force a flush""" - with rio.open(filepath) as f: + with rio.open(filepath) as _: pass -def do_inference(tilename, sources, model, dev, logger, args=None, log_path=None): +def do_inference( + tilename, sources, model, dev, logger, name, data_dir, inference_dir, patch_size, margin_size, log_path=None +): tile_logger = get_logger(f'inference.{tilename}') # ===== PREPARE THE DATA ===== - DATA_ROOT = args.data_dir - INFERENCE_ROOT = args.inference_dir + DATA_ROOT = data_dir + INFERENCE_ROOT = inference_dir data_directory = DATA_ROOT / 'tiles' / tilename if not data_directory.exists(): logger.info(f'Preprocessing directory {tilename}') raw_directory = DATA_ROOT / 'input' / tilename if not raw_directory.exists(): - logger.error(f"Couldn't find tile '{tilename}' in {DATA_ROOT}/tiles or {DATA_ROOT}/input. Skipping this tile") + logger.error( + f"Couldn't find tile '{tilename}' in {DATA_ROOT}/tiles or {DATA_ROOT}/input. Skipping this tile" + ) return - preprocess_directory(raw_directory, args, log_path, label_required=False) + # TODO: The arguments don't match the function signature -> Invest how to resolve + preprocess_directory(raw_directory, log_path, label_required=False) # After this, data_directory should contain all the stuff that we need. - - if args.name: - output_directory = INFERENCE_ROOT / args.name / tilename + + if name: + output_directory = INFERENCE_ROOT / name / tilename else: output_directory = INFERENCE_ROOT / tilename - + output_directory.mkdir(exist_ok=True, parents=True) planet_imagery_path = next(data_directory.glob('*_SR.tif')) @@ -172,13 +159,12 @@ def plot_results(image, outfile): plt.savefig(outfile, bbox_inches='tight', pad_inches=0) plt.close() - full_data = np.concatenate(data, axis=0) nodata = np.all(full_data == 0, axis=0, keepdims=True) full_data = torch.from_numpy(full_data) full_data = full_data.unsqueeze(0) # Pretend this is a batch of size 1 - res = predict(model, full_data, dev).numpy() + res = predict(model, full_data, patch_size, margin_size, dev).numpy() del full_data res[nodata] = np.nan @@ -200,17 +186,14 @@ def plot_results(image, outfile): count=1, compress='lzw', driver='COG', - #tiled=True + # tiled=True ) with rio.open(out_path_proba, 'w', **profile) as output_raster: output_raster.write(res.astype(np.float32)) flush_rio(out_path_proba) - profile.update( - dtype=rio.uint8, - nodata=255 - ) + profile.update(dtype=rio.uint8, nodata=255) with rio.open(out_path_label, 'w', **profile) as output_raster: output_raster.write(binarized) flush_rio(out_path_label) @@ -220,9 +203,14 @@ def plot_results(image, outfile): flush_rio(out_path_pre_poly) # create vectors - log_run(f'{gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "ESRI Shapefile" {out_path_shp}', tile_logger) - log_run(f'{gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "GPKG" {out_path_gpkg}', tile_logger) - #log_run(f'python {gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "ESRI Shapefile" {out_path_shp}', tile_logger) + log_run( + f'{gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "ESRI Shapefile" {out_path_shp}', + tile_logger, + ) + log_run( + f'{gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "GPKG" {out_path_gpkg}', tile_logger + ) + # log_run(f'python {gdal.polygonize} {out_path_pre_poly} -q -mask {out_path_pre_poly} -f "ESRI Shapefile" {out_path_shp}', tile_logger) out_path_pre_poly.unlink() h, w = res.shape[1:] @@ -239,7 +227,7 @@ def plot_results(image, outfile): kwargs = dict(colorbar=True, cmap=cmap_dem, vmin=0, vmax=1) elif src.name == 'slope': kwargs = dict(colorbar=True, cmap=cmap_slope, vmin=0, vmax=0.5) - make_img(f'{src.name}.jpg', src, mask=nodata[0],**kwargs) + make_img(f'{src.name}.jpg', src, mask=nodata[0], **kwargs) outpath = output_directory / 'pred_probability.jpg' plot_results(np.ma.masked_where(nodata[0], res[0]), outpath) @@ -247,17 +235,127 @@ def plot_results(image, outfile): outpath = output_directory / 'pred_binarized.jpg' plot_results(np.ma.masked_where(nodata[0], binarized[0]), outpath) + +def inference( + name: Annotated[ + str, typer.Option('--name', '-n', help='Name of inference run, data will be stored in subdirectory') + ], + model_path: Annotated[str, typer.Argument(help='path to model, use the model base path')], + tile_to_predict: Annotated[List[str], typer.Argument(help='path to model')], + gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries')] = '', + gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts')] = '', + n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, + ckpt: Annotated[str, typer.Option(help='Checkpoint to use')] = 'latest', + data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + inference_dir: Annotated[Path, typer.Option(help='Main inference directory')] = Path('inference'), + margin_size: Annotated[int, typer.Option('--margin_size', '-n', help='Size of patch overlap')] = 256, + patch_size: Annotated[int, typer.Option('--patch_size', '-p', help='Size of patches')] = 1024, +): + """Inference Script""" + + # TODO: let gdal.initialize take each argument separately + # Mock old args object + ARGS = namedtuple('gdalargs', ['gdal_bin', 'gdal_path']) + gdalargs = ARGS(gdal_bin, gdal_path) + gdal.initialize(gdalargs) + + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + log_path = Path(log_dir) / f'inference-{timestamp}.log' + if not Path(log_dir).exists(): + os.mkdir(Path(log_dir)) + init_logging(log_path) + logger = get_logger('inference') + + # ===== LOAD THE MODEL ===== + cuda = True if torch.cuda.is_available() else False + dev = torch.device('cpu') if not cuda else torch.device('cuda') + logger.info(f'Running on {dev} device') + + if not model_path: + last_modified = 0 + last_modeldir = None + + for config_file in Path(log_dir).glob('*/config.yml'): + modified = config_file.stat().st_mtime + if modified > last_modified: + last_modified = modified + last_modeldir = config_file.parent + model_path = last_modeldir + + model_dir = Path(model_path) + config = yaml.load((model_dir / 'config.yml').open(), Loader=yaml.SafeLoader) + + m = config['model'] + # print(m['architecture'],m['encoder'], m['input_channels']) + model = create_model( + arch=m['architecture'], + encoder_name=m['encoder'], + encoder_weights=None if m['encoder_weights'] == 'random' else m['encoder_weights'], + classes=1, + in_channels=m['input_channels'], + ) + + if ckpt == 'latest': + ckpt_nums = [int(ckpt.stem) for ckpt in model_dir.glob('checkpoints/*.pt')] + last_ckpt = max(ckpt_nums) + else: + last_ckpt = int(ckpt) + ckpt = model_dir / 'checkpoints' / f'{last_ckpt:02d}.pt' + logger.info(f'Loading checkpoint {ckpt}') + + # Parallelized Model needs to be declared before loading + try: + model.load_state_dict(torch.load(ckpt, map_location=dev)) + except Exception: + model = nn.DataParallel(model) + model.load_state_dict(torch.load(ckpt, map_location=dev)) + + model = model.to(dev) + + sources = DataSources(config['data_sources']) + + torch.set_grad_enabled(False) + + for tilename in tqdm(tile_to_predict): + do_inference( + tilename, sources, model, dev, logger, name, data_dir, inference_dir, patch_size, margin_size, log_path + ) + + +# ! Moving legacy argparse cli to main to maintain compatibility with the original script def main(): + parser = argparse.ArgumentParser( + description='Inference Script', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--gdal_bin', default='', help='Path to gdal binaries') + parser.add_argument('--gdal_path', default='', help='Path to gdal scripts') + parser.add_argument('--n_jobs', default=-1, type=int, help='number of parallel joblib jobs') + parser.add_argument('--ckpt', default='latest', type=str, help='Checkpoint to use') + parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir') + parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir') + parser.add_argument('--inference_dir', default='inference', type=Path, help='Main inference directory') + parser.add_argument( + '-n', '--name', default=None, type=str, help='Name of inference run, data will be stored in subdirectory' + ) + parser.add_argument('-m', '--margin_size', default=256, type=int, help='Size of patch overlap') + parser.add_argument('-p', '--patch_size', default=1024, type=int, help='Size of patches') + parser.add_argument('model_path', type=str, help='path to model, use the model base path') + parser.add_argument('tile_to_predict', type=str, help='path to model', nargs='+') + + args = parser.parse_args() + gdal.initialize(args) + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_path = Path(args.log_dir) / f'inference-{timestamp}.log' if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) + os.mkdir(Path(args.log_dir)) init_logging(log_path) logger = get_logger('inference') # ===== LOAD THE MODEL ===== cuda = True if torch.cuda.is_available() else False - dev = torch.device("cpu") if not cuda else torch.device("cuda") + dev = torch.device('cpu') if not cuda else torch.device('cuda') logger.info(f'Running on {dev} device') if not args.model_path: @@ -272,17 +370,16 @@ def main(): args.model_path = last_modeldir model_dir = Path(args.model_path) - model_name = model_dir.name config = yaml.load((model_dir / 'config.yml').open(), Loader=yaml.SafeLoader) m = config['model'] - #print(m['architecture'],m['encoder'], m['input_channels']) + # print(m['architecture'],m['encoder'], m['input_channels']) model = create_model( arch=m['architecture'], encoder_name=m['encoder'], encoder_weights=None if m['encoder_weights'] == 'random' else m['encoder_weights'], classes=1, - in_channels=m['input_channels'] + in_channels=m['input_channels'], ) if args.ckpt == 'latest': @@ -291,15 +388,15 @@ def main(): else: last_ckpt = int(args.ckpt) ckpt = model_dir / 'checkpoints' / f'{last_ckpt:02d}.pt' - logger.info(f"Loading checkpoint {ckpt}") - + logger.info(f'Loading checkpoint {ckpt}') + # Parallelized Model needs to be declared before loading try: model.load_state_dict(torch.load(ckpt, map_location=dev)) - except: + except Exception: model = nn.DataParallel(model) model.load_state_dict(torch.load(ckpt, map_location=dev)) - + model = model.to(dev) sources = DataSources(config['data_sources']) @@ -309,5 +406,6 @@ def main(): for tilename in tqdm(args.tile_to_predict): do_inference(tilename, sources, model, dev, logger, args, log_path) -if __name__ == "__main__": - main() + +if __name__ == '__main__': + main() diff --git a/src/thaw_slump_segmentation/scripts/setup_raw_data.py b/src/thaw_slump_segmentation/scripts/setup_raw_data.py index 37bb841..d678c25 100644 --- a/src/thaw_slump_segmentation/scripts/setup_raw_data.py +++ b/src/thaw_slump_segmentation/scripts/setup_raw_data.py @@ -7,25 +7,35 @@ """ Usecase 2 Data Preprocessing Script """ + import argparse -import yaml +import os +from collections import namedtuple from datetime import datetime from pathlib import Path -import os +import ee +import typer from joblib import Parallel, delayed +from typing_extensions import Annotated from .. import data_pre_processing -from ..data_pre_processing import * -from ..utils import init_logging, get_logger, yaml_custom - -parser = argparse.ArgumentParser() -parser.add_argument("--gdal_bin", default=None, help="Path to gdal binaries (ignored if --skip_gdal is passed)") -parser.add_argument("--gdal_path", default=None, help="Path to gdal scripts (ignored if --skip_gdal is passed)") -parser.add_argument("--n_jobs", default=-1, type=int, help="number of parallel joblib jobs") -parser.add_argument("--nolabel", action='store_false', help="Set flag to do preprocessing without label file") -parser.add_argument("--data_dir", default='data', type=Path, help="Path to data processing dir") -parser.add_argument("--log_dir", default='logs', type=Path, help="Path to log dir") +from ..data_pre_processing import ( + aux_data_to_tiles, + check_input_data, + gdal, + get_tcvis_from_gee, + has_projection, + make_ndvi_file, + mask_input_data, + move_files, + pre_cleanup, + rename_clip_to_standard, + vector_to_raster_mask, +) + +# from ..data_pre_processing import * +from ..utils import get_logger, init_logging is_ee_initialized = False # Module-global flag to avoid calling ee.Initialize multiple times @@ -34,8 +44,13 @@ SUCCESS_STATES = ['rename', 'label', 'ndvi', 'tcvis', 'rel_dem', 'slope', 'mask', 'move'] -def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, args, log_path, label_required=True): - gdal.initialize(args) +def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, log_path, gdal_bin, gdal_path, label_required=True): + # TODO: let gdal.initialize take each argument separately + # Mock old args object + ARGS = namedtuple('gdalargs', ['gdal_bin', 'gdal_path']) + gdalargs = ARGS(gdal_bin, gdal_path) + gdal.initialize(gdalargs) + init_logging(log_path) image_name = os.path.basename(image_dir) thread_logger = get_logger(f'setup_raw_data.{image_name}') @@ -47,7 +62,7 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, args, log_pat try: thread_logger.debug('Initializing Earth Engine') ee.Initialize() - except: + except Exception: thread_logger.warn('Initializing Earth Engine failed, trying to authenticate') ee.Authenticate() ee.Initialize() @@ -70,23 +85,18 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, args, log_pat success_state['ndvi'] = make_ndvi_file(image_dir) - ee_image_tcvis = ee.ImageCollection("users/ingmarnitze/TCTrend_SR_2000-2019_TCVIS").mosaic() - success_state['tcvis'] = get_tcvis_from_gee(image_dir, - ee_image_tcvis, - out_filename='tcvis.tif', - resolution=3) + ee_image_tcvis = ee.ImageCollection('users/ingmarnitze/TCTrend_SR_2000-2019_TCVIS').mosaic() + success_state['tcvis'] = get_tcvis_from_gee(image_dir, ee_image_tcvis, out_filename='tcvis.tif', resolution=3) - success_state['rel_dem'] = aux_data_to_tiles(image_dir, - aux_dir / 'ArcticDEM' / 'elevation.vrt', - 'relative_elevation.tif') + success_state['rel_dem'] = aux_data_to_tiles( + image_dir, aux_dir / 'ArcticDEM' / 'elevation.vrt', 'relative_elevation.tif' + ) - success_state['slope'] = aux_data_to_tiles(image_dir, - aux_dir / 'ArcticDEM' / 'slope.vrt', - 'slope.tif') + success_state['slope'] = aux_data_to_tiles(image_dir, aux_dir / 'ArcticDEM' / 'slope.vrt', 'slope.tif') success_state['mask'] = mask_input_data(image_dir, data_dir) - #backup_dir_full = os.path.join(backup_dir, os.path.basename(image_dir)) + # backup_dir_full = os.path.join(backup_dir, os.path.basename(image_dir)) backup_dir_full = backup_dir / image_dir.name success_state['move'] = move_files(image_dir, backup_dir_full) @@ -95,21 +105,65 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, args, log_pat return success_state +def setup_raw_data( + gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries')] = '', + gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts')] = '', + n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, + label: Annotated[bool, typer.Option(help='Set flag to do preprocessing with label file')] = False, + data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), +): + INPUT_DATA_DIR = data_dir / 'input' + BACKUP_DIR = data_dir / 'backup' + DATA_DIR = data_dir / 'tiles' + AUX_DIR = data_dir / 'auxiliary' + + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + log_path = Path(log_dir) / f'setup_raw_data-{timestamp}.log' + if not Path(log_dir).exists(): + os.mkdir(Path(log_dir)) + init_logging(log_path) + logger = get_logger('setup_raw_data') + logger.info('###########################') + logger.info('# Starting Raw Data Setup #') + logger.info('###########################') + + dir_list = check_input_data(INPUT_DATA_DIR) + if len(dir_list) > 0: + Parallel(n_jobs=n_jobs)( + delayed(preprocess_directory)( + image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, gdal_bin, gdal_path, log_path, not label + ) + for image_dir in dir_list + ) + else: + logger.error('Empty Input Data Directory! No Data available to process!') + + +# ! Moving legacy argparse cli to main to maintain compatibility with the original script def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--gdal_bin', default=None, help='Path to gdal binaries (ignored if --skip_gdal is passed)') + parser.add_argument('--gdal_path', default=None, help='Path to gdal scripts (ignored if --skip_gdal is passed)') + parser.add_argument('--n_jobs', default=-1, type=int, help='number of parallel joblib jobs') + parser.add_argument('--nolabel', action='store_false', help='Set flag to do preprocessing without label file') + parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir') + parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir') + args = parser.parse_args() - + global DATA_ROOT, INPUT_DATA_DIR, BACKUP_DIR, DATA_DIR, AUX_DIR DATA_ROOT = Path(args.data_dir) INPUT_DATA_DIR = DATA_ROOT / 'input' - BACKUP_DIR = DATA_ROOT / 'backup' - DATA_DIR = DATA_ROOT / 'tiles' - AUX_DIR = DATA_ROOT / 'auxiliary' + BACKUP_DIR = DATA_ROOT / 'backup' + DATA_DIR = DATA_ROOT / 'tiles' + AUX_DIR = DATA_ROOT / 'auxiliary' timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_path = Path(args.log_dir) / f'setup_raw_data-{timestamp}.log' if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) + os.mkdir(Path(args.log_dir)) init_logging(log_path) logger = get_logger('setup_raw_data') logger.info('###########################') @@ -118,9 +172,15 @@ def main(): dir_list = check_input_data(INPUT_DATA_DIR) if len(dir_list) > 0: - Parallel(n_jobs=args.n_jobs)(delayed(preprocess_directory)(image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, args, log_path, args.nolabel) for image_dir in dir_list) + Parallel(n_jobs=args.n_jobs)( + delayed(preprocess_directory)( + image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, args.gdal_bin, args.gdal_path, log_path, args.nolabel + ) + for image_dir in dir_list + ) else: - logger.error("Empty Input Data Directory! No Data available to process!") + logger.error('Empty Input Data Directory! No Data available to process!') + -if __name__ == "__main__": - main() +if __name__ == '__main__': + main() diff --git a/src/thaw_slump_segmentation/scripts/train.py b/src/thaw_slump_segmentation/scripts/train.py index 728f35a..bd4530d 100644 --- a/src/thaw_slump_segmentation/scripts/train.py +++ b/src/thaw_slump_segmentation/scripts/train.py @@ -18,11 +18,12 @@ import torch import torch.nn as nn +import typer +import wandb import yaml from rich import pretty, traceback from tqdm import tqdm - -import wandb +from typing_extensions import Annotated from ..data_loading import DataSources, get_loader, get_slump_loader, get_vis_loader from ..metrics import F1, Accuracy, IoU, Metrics, Precision, Recall @@ -32,41 +33,28 @@ traceback.install(show_locals=True) pretty.install() -parser = argparse.ArgumentParser(description='Training script', formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('-s', '--summary', action='store_true', help='Only print model summary and return.') -parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir') -parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir') -parser.add_argument( - '-n', '--name', default='', help='Give this run a name, so that it will be logged into logs/_.' -) -parser.add_argument('-c', '--config', default='config.yml', type=Path, help='Specify run config to use.') -parser.add_argument( - '-r', - '--resume', - default='', - help='Resume from the specified checkpoint.' - 'Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last' - 'checkpoint of that run, or a direct path to a checkpoint to be loaded.' - 'Overrides the resume option in the config file if given.', -) -parser.add_argument( - '-wp', '--wandb_project', default='thaw-slump-segmentation', help='Set a project name for weights and biases' -) -parser.add_argument('-wn', '--wandb_name', default=None, help='Set a run name for weights and biases') - class Engine: - def __init__(self): - args = parser.parse_args() - self.config = yaml.load(args.config.open(), Loader=yaml_custom.SaneYAMLLoader) - self.DATA_ROOT = args.data_dir + def __init__( + self, + config: Path, + data_dir: Path, + name: str, + log_dir: Path, + resume: str, + summary: bool, + wandb_project: str, + wandb_name: str, + ): + self.config = yaml.load(config.open(), Loader=yaml_custom.SaneYAMLLoader) + self.DATA_ROOT = data_dir # Logging setup timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - if args.name: - log_dir_name = f'{args.name}_{timestamp}' + if name: + log_dir_name = f'{name}_{timestamp}' else: log_dir_name = timestamp - self.log_dir = Path(args.log_dir) / log_dir_name + self.log_dir = Path(log_dir) / log_dir_name self.log_dir.mkdir(exist_ok=False) init_logging(self.log_dir / 'train.log') @@ -87,8 +75,8 @@ def __init__(self): # make parallel self.model = nn.DataParallel(self.model) - if args.resume: - self.config['resume'] = args.resume + if resume: + self.config['resume'] = resume if 'resume' in self.config and self.config['resume']: checkpoint = Path(self.config['resume']) @@ -115,7 +103,7 @@ def __init__(self): self.epoch = 0 self.metrics = Metrics(Accuracy, Precision, Recall, F1, IoU) - if args.summary: + if summary: from torchsummary import summary summary(self.model, [(self.config['model']['input_channels'], 256, 256)]) @@ -144,11 +132,11 @@ def __init__(self): # Metrics and Weights and Biases initialization self.trn_metrics = {} self.val_metrics = {} - print('wandb project:', args.wandb_project) - print('wandb name:', args.wandb_name) + print('wandb project:', wandb_project) + print('wandb name:', wandb_name) print('config:', self.config) print('entity:', 'ml4earth') - wandb.init(project=args.wandb_project, name=args.wandb_name, config=self.config, entity='ingmarnitze_team') + wandb.init(project=wandb_project, name=wandb_name, config=self.config, entity='ingmarnitze_team') def run(self): for phase in self.config['schedule']: @@ -359,8 +347,77 @@ def safe_append(dictionary, key, value): dictionary[key] = [value] +def train( + name: Annotated[ + str, + typer.Option( + '--name', + '-n', + prompt=True, + help='Give this run a name, so that it will be logged into logs/_.', + ), + ], + data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + config: Annotated[Path, typer.Option('--config', '-c', help='Specify run config to use.')] = Path('config.yml'), + resume: Annotated[ + str, + typer.Option( + '--resume', + '-r', + help='Resume from the specified checkpoint. Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last. Overrides the resume option in the config file if given.', + ), + ] = None, + summary: Annotated[bool, typer.Option('--summary', '-s', help='Only print model summary and return.')] = False, + wandb_project: Annotated[ + str, typer.Option('--wandb_project', '-wp', help='Set a project name for weights and biases') + ] = 'thaw-slump-segmentation', + wandb_name: Annotated[ + str, typer.Option('--wandb_name', '-wn', help='Set a run name for weights and biases') + ] = None, +): + """Training script""" + engine = Engine(config, data_dir, name, log_dir, resume, summary, wandb_project, wandb_name) + engine.run() + + +# ! Moving legacy argparse cli to main to maintain compatibility with the original script def main(): - Engine().run() + parser = argparse.ArgumentParser( + description='Training script', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('-s', '--summary', action='store_true', help='Only print model summary and return.') + parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir') + parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir') + parser.add_argument( + '-n', '--name', default='', help='Give this run a name, so that it will be logged into logs/_.' + ) + parser.add_argument('-c', '--config', default='config.yml', type=Path, help='Specify run config to use.') + parser.add_argument( + '-r', + '--resume', + default='', + help='Resume from the specified checkpoint.' + 'Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last' + 'Can be either a run-id (e.g. "2020-06-29_18-12-03") to select the last' + 'Overrides the resume option in the config file if given.', + ) + parser.add_argument( + '-wp', '--wandb_project', default='thaw-slump-segmentation', help='Set a project name for weights and biases' + ) + parser.add_argument('-wn', '--wandb_name', default=None, help='Set a run name for weights and biases') + + args = parser.parse_args() + Engine( + args.config, + args.data_dir, + args.name, + args.log_dir, + args.resume, + args.summary, + args.wandb_project, + args.wandb_name, + ).run() if __name__ == '__main__': From 962b13d0f1d5665c62cc11256d1e2f296a09b28b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 9 May 2024 16:23:22 +0200 Subject: [PATCH 2/9] Convert CLI of prepare_data and prepare_planet Also fix undeclared variables in prepare data --- .../data_pre_processing/gdal.py | 37 ++- src/thaw_slump_segmentation/main.py | 14 +- .../scripts/inference.py | 5 +- .../scripts/prepare_data.py | 278 ++++++++++++------ .../scripts/prepare_s2_4band_planet_format.py | 106 ++++--- .../scripts/setup_raw_data.py | 5 +- 6 files changed, 281 insertions(+), 164 deletions(-) diff --git a/src/thaw_slump_segmentation/data_pre_processing/gdal.py b/src/thaw_slump_segmentation/data_pre_processing/gdal.py index 13a9f63..78f5154 100644 --- a/src/thaw_slump_segmentation/data_pre_processing/gdal.py +++ b/src/thaw_slump_segmentation/data_pre_processing/gdal.py @@ -5,23 +5,28 @@ import os import sys -import yaml from pathlib import Path + +import yaml + _module = sys.modules[__name__] -def initialize(args=None): +def initialize(args=None, *, bin=None, path=None): # If command line arguments are given, use those: system_yml = Path('system.yml') - if args.gdal_bin is not None: - #print('Manually set path') + if args is not None: + # print('Manually set path') _module.gdal_path = args.gdal_path _module.gdal_bin = args.gdal_bin - + elif bin is not None and path is not None: + # print('Manually set path') + _module.gdal_path = path + _module.gdal_bin = bin # Otherwise, fall back to the ones from system.yml elif system_yml.exists(): - #print('yml file') + # print('yml file') system_config = yaml.load(system_yml.open(), Loader=yaml.SafeLoader) if 'gdal_path' in system_config: _module.gdal_path = system_config['gdal_path'] @@ -29,16 +34,16 @@ def initialize(args=None): _module.gdal_bin = system_config['gdal_bin'] else: - #print('Empty path') + # print('Empty path') _module.gdal_path = '' _module.gdal_bin = '' - - #print(_module.gdal_bin) - #print(_module.gdal_path) - _module.rasterize = os.path.join(_module.gdal_bin, 'gdal_rasterize') - _module.translate = os.path.join(_module.gdal_bin, 'gdal_translate') - _module.warp = os.path.join(_module.gdal_bin, 'gdalwarp') - - _module.merge = os.path.join(_module.gdal_path, 'gdal_merge.py') - _module.retile = os.path.join(_module.gdal_path, 'gdal_retile.py') + + # print(_module.gdal_bin) + # print(_module.gdal_path) + _module.rasterize = os.path.join(_module.gdal_bin, 'gdal_rasterize') + _module.translate = os.path.join(_module.gdal_bin, 'gdal_translate') + _module.warp = os.path.join(_module.gdal_bin, 'gdalwarp') + + _module.merge = os.path.join(_module.gdal_path, 'gdal_merge.py') + _module.retile = os.path.join(_module.gdal_path, 'gdal_retile.py') _module.polygonize = os.path.join(_module.gdal_path, 'gdal_polygonize.py') diff --git a/src/thaw_slump_segmentation/main.py b/src/thaw_slump_segmentation/main.py index 61dd655..77ae773 100644 --- a/src/thaw_slump_segmentation/main.py +++ b/src/thaw_slump_segmentation/main.py @@ -2,6 +2,8 @@ from thaw_slump_segmentation.scripts.download_s2_4band_planet_format import download_s2_4band_planet_format from thaw_slump_segmentation.scripts.inference import inference +from thaw_slump_segmentation.scripts.prepare_data import prepare_data +from thaw_slump_segmentation.scripts.prepare_s2_4band_planet_format import prepare_s2_4band_planet_format from thaw_slump_segmentation.scripts.setup_raw_data import setup_raw_data from thaw_slump_segmentation.scripts.train import train @@ -11,19 +13,11 @@ cli.command()(inference) -@cli.command() -def hello(name: str): - typer.echo(f'Hello {name}') - - -@cli.command() -def goodbye(name: str): - typer.echo(f'Goodbye {name}') - - data_cli = typer.Typer() data_cli.command('download-planet')(download_s2_4band_planet_format) +data_cli.command('prepare-planet')(prepare_s2_4band_planet_format) data_cli.command('setup-raw')(setup_raw_data) +data_cli.command('prepare')(prepare_data) cli.add_typer(data_cli, name='data') diff --git a/src/thaw_slump_segmentation/scripts/inference.py b/src/thaw_slump_segmentation/scripts/inference.py index e2fff7a..69f55cc 100644 --- a/src/thaw_slump_segmentation/scripts/inference.py +++ b/src/thaw_slump_segmentation/scripts/inference.py @@ -11,7 +11,6 @@ import argparse import os -from collections import namedtuple from datetime import datetime from pathlib import Path from typing import List @@ -256,9 +255,7 @@ def inference( # TODO: let gdal.initialize take each argument separately # Mock old args object - ARGS = namedtuple('gdalargs', ['gdal_bin', 'gdal_path']) - gdalargs = ARGS(gdal_bin, gdal_path) - gdal.initialize(gdalargs) + gdal.initialize(bin=gdal_bin, path=gdal_path) timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_path = Path(log_dir) / f'inference-{timestamp}.log' diff --git a/src/thaw_slump_segmentation/scripts/prepare_data.py b/src/thaw_slump_segmentation/scripts/prepare_data.py index 076337c..204dd38 100644 --- a/src/thaw_slump_segmentation/scripts/prepare_data.py +++ b/src/thaw_slump_segmentation/scripts/prepare_data.py @@ -8,37 +8,28 @@ """ Usecase 2 Data Preprocessing Script """ + import argparse +import os import shutil import sys from datetime import datetime from pathlib import Path -import yaml -import os import h5py import numpy as np import rasterio as rio +import typer from joblib import Parallel, delayed from skimage.io import imsave +from typing_extensions import Annotated from ..data_pre_processing import gdal -from ..utils import init_logging, get_logger, log_run, yaml_custom - -parser = argparse.ArgumentParser(description='Make data ready for training', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--data_dir", default='data', type=Path, help="Path to data processing dir") -parser.add_argument("--log_dir", default='logs', type=Path, help="Path to log dir") -parser.add_argument("--skip_gdal", action='store_true', help="Skip the Gdal conversion stage (if it has already been " - "done)") -parser.add_argument("--gdal_bin", default=None, help="Path to gdal binaries (ignored if --skip_gdal is passed)") -parser.add_argument("--gdal_path", default=None, help="Path to gdal scripts (ignored if --skip_gdal is passed)") -parser.add_argument("--n_jobs", default=-1, type=int, help="number of parallel joblib jobs") -parser.add_argument("--nodata_threshold", default=50, type=float, help="Throw away data with at least this % of " - "nodata pixels") -parser.add_argument("--tile_size", default='256x256', type=str, help="Tiling size in pixels e.g. '256x256'") -parser.add_argument("--tile_overlap", default=25, type=int, help="Overlap of the tiles in pixels") +from ..utils import get_logger, init_logging, log_run +# Paths setup +RASTERFILTER = '*_SR*.tif' +VECTORFILTER = '*.shp' def read_and_assert_imagedata(image_path): @@ -48,7 +39,7 @@ def read_and_assert_imagedata(image_path): else: data = raster.read()[:3] # Assert data can safely be coerced to int16 - assert data.max() < 2 ** 15 + assert data.max() < 2**15 return data @@ -61,7 +52,7 @@ def get_planet_product_type(img_path): pl_type = 'OrthoTile' else: pl_type = 'Scene' - + return pl_type @@ -69,18 +60,18 @@ def mask_from_img(img_path): """ Given an image path, return path for the mask """ - # change for + # change for if get_planet_product_type(img_path) == 'Scene': date, time, *block, platform, _, sr, row, col = img_path.stem.split('_') block = '_'.join(block) base = img_path.parent.parent mask_path = base / 'mask' / f'{date}_{time}_{block}_mask_{row}_{col}.tif' - + else: block, tile, date, sensor, bgrn, sr, row, col = img_path.stem.split('_') base = img_path.parent.parent mask_path = base / 'mask' / f'{block}_{tile}_{date}_{sensor}_mask_{row}_{col}.tif' - + assert mask_path.exists() return mask_path @@ -95,7 +86,7 @@ def other_from_img(img_path, other): block = '_'.join(block) else: block, tile, date, sensor, bgrn, sr, row, col = img_path.stem.split('_') - + base = img_path.parent.parent path = base / other / f'{other}_{row}_{col}.tif' @@ -104,17 +95,23 @@ def other_from_img(img_path, other): return path -def glob_file(DATASET, filter_string): +def glob_file(DATASET, filter_string, logger=None): candidates = list(DATASET.glob(f'{filter_string}')) if len(candidates) == 1: logger.debug(f'Found file: {candidates[0]}') return candidates[0] else: - raise ValueError(f'Found {len(candidates)} candidates.' - 'Please make selection more specific!') + raise ValueError(f'Found {len(candidates)} candidates.' 'Please make selection more specific!') -def do_gdal_calls(DATASET, aux_data=['ndvi', 'tcvis', 'slope', 'relative_elevation'], logger=None): +def do_gdal_calls( + DATASET, + xsize: int, + ysize: int, + overlap: int, + aux_data=['ndvi', 'tcvis', 'slope', 'relative_elevation'], + logger=None, +): maskfile = DATASET / f'{DATASET.name}_mask.tif' tile_dir_data = DATASET / 'tiles' / 'data' @@ -124,88 +121,92 @@ def do_gdal_calls(DATASET, aux_data=['ndvi', 'tcvis', 'slope', 'relative_elevati tile_dir_data.mkdir(exist_ok=True, parents=True) tile_dir_mask.mkdir(exist_ok=True) - rasterfile = glob_file(DATASET, RASTERFILTER) + rasterfile = glob_file(DATASET, RASTERFILTER, logger) # Retile data, mask - #log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_data} {rasterfile}', logger) - #log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_mask} {maskfile}', logger) - log_run(f'{gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_data} {rasterfile}', logger) - log_run(f'{gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_mask} {maskfile}', logger) + # log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_data} {rasterfile}', logger) + # log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_mask} {maskfile}', logger) + log_run(f'{gdal.retile} -ps {xsize} {ysize} -overlap {overlap} -targetDir {tile_dir_data} {rasterfile}', logger) + log_run(f'{gdal.retile} -ps {xsize} {ysize} -overlap {overlap} -targetDir {tile_dir_mask} {maskfile}', logger) # Retile additional data for aux in aux_data: auxfile = DATASET / f'{aux}.tif' tile_dir_aux = DATASET / 'tiles' / aux tile_dir_aux.mkdir(exist_ok=True) - #log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_aux} {auxfile}', logger) - log_run(f'{gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_aux} {auxfile}', logger) + # log_run(f'python {gdal.retile} -ps {XSIZE} {YSIZE} -overlap {OVERLAP} -targetDir {tile_dir_aux} {auxfile}', logger) + log_run(f'{gdal.retile} -ps {xsize} {yield} -overlap {overlap} -targetDir {tile_dir_aux} {auxfile}', logger) def make_info_picture(tile, filename): "Make overview picture" - rgbn = np.clip(tile['planet'][:4,:,:].transpose(1, 2, 0) / 3000 * 255, 0, 255).astype(np.uint8) + rgbn = np.clip(tile['planet'][:4, :, :].transpose(1, 2, 0) / 3000 * 255, 0, 255).astype(np.uint8) tcvis = np.clip(tile['tcvis'].transpose(1, 2, 0), 0, 255).astype(np.uint8) - rgb = rgbn[:,:,:3] - nir = rgbn[:,:,[3,2,1]] - mask = (tile['mask'][[0,0,0]].transpose(1, 2, 0) * 255).astype(np.uint8) + rgb = rgbn[:, :, :3] + nir = rgbn[:, :, [3, 2, 1]] + mask = (tile['mask'][[0, 0, 0]].transpose(1, 2, 0) * 255).astype(np.uint8) - img = np.concatenate([ - np.concatenate([rgb, nir], axis=1), - np.concatenate([tcvis, mask], axis=1), - ]) + img = np.concatenate( + [ + np.concatenate([rgb, nir], axis=1), + np.concatenate([tcvis, mask], axis=1), + ] + ) imsave(filename, img) -def main_function(dataset, args, log_path): - if not args.skip_gdal: - gdal.initialize(args) - +def main_function( + dataset, log_path, h5dir: Path, xsize: int, ysize: int, overlap: int, threshold: float, skip_gdal: bool +): init_logging(log_path) thread_logger = get_logger(f'prepare_data.{dataset.name}') thread_logger.info(f'Starting preparation on dataset {dataset}') - if not args.skip_gdal: + if not skip_gdal: thread_logger.info('Doing GDAL Calls') - do_gdal_calls(dataset, logger=thread_logger) + do_gdal_calls(dataset, xsize, ysize, overlap, logger=thread_logger) else: thread_logger.info('Skipping GDAL Calls') - tifs = list(sorted(dataset.glob('tiles/data/*.tif'))) if len(tifs) == 0: - logger.warning(f'No tiles found for {dataset}, skipping this directory.') + thread_logger.warning(f'No tiles found for {dataset}, skipping this directory.') return - h5_path = H5_DIR / f'{dataset.name}.h5' - info_dir = H5_DIR / dataset.name + h5_path = h5dir / f'{dataset.name}.h5' + info_dir = h5dir / dataset.name info_dir.mkdir(parents=True) thread_logger.info(f'Creating H5File at {h5_path}') - h5 = h5py.File(h5_path, 'w', - rdcc_nbytes=2 * (1 << 30), # 2 GiB - rdcc_nslots=200003, - ) + h5 = h5py.File( + h5_path, + 'w', + rdcc_nbytes=2 * (1 << 30), # 2 GiB + rdcc_nslots=200003, + ) channel_numbers = dict(planet=4, ndvi=1, tcvis=3, relative_elevation=1, slope=1) datasets = dict() for dataset_name, nchannels in channel_numbers.items(): - ds = h5.create_dataset(dataset_name, - dtype=np.float32, - shape=(len(tifs), nchannels, XSIZE, YSIZE), - maxshape=(len(tifs), nchannels, XSIZE, YSIZE), - chunks=(1, nchannels, XSIZE, YSIZE), - compression='lzf', - scaleoffset=3, - ) + ds = h5.create_dataset( + dataset_name, + dtype=np.float32, + shape=(len(tifs), nchannels, xsize, ysize), + maxshape=(len(tifs), nchannels, xsize, ysize), + chunks=(1, nchannels, xsize, ysize), + compression='lzf', + scaleoffset=3, + ) datasets[dataset_name] = ds - datasets['mask'] = h5.create_dataset("mask", - dtype=np.uint8, - shape=(len(tifs), 1, XSIZE, YSIZE), - maxshape=(len(tifs), 1, XSIZE, YSIZE), - chunks=(1, 1, XSIZE, YSIZE), - compression='lzf', - ) + datasets['mask'] = h5.create_dataset( + 'mask', + dtype=np.uint8, + shape=(len(tifs), 1, xsize, ysize), + maxshape=(len(tifs), 1, xsize, ysize), + chunks=(1, 1, xsize, ysize), + compression='lzf', + ) # Convert data to HDF5 storage for efficient data loading i = 0 @@ -215,7 +216,7 @@ def main_function(dataset, args, log_path): with rio.open(img) as raster: tile['planet'] = raster.read() - if (tile['planet'] == 0).all(axis=0).mean() > THRESHOLD: + if (tile['planet'] == 0).all(axis=0).mean() > threshold: bad_tiles += 1 continue @@ -230,14 +231,14 @@ def main_function(dataset, args, log_path): data = raster.read() if data.shape[0] > channel_numbers[other]: # This is for tcvis mostly - data = data[:channel_numbers[other]] + data = data[: channel_numbers[other]] tile[other] = data # gdal_retile leaves narrow stripes at the right and bottom border, # which are filtered out here: is_narrow = False for tensor in tile.values(): - if tensor.shape[-2:] != (XSIZE, YSIZE): + if tensor.shape[-2:] != (xsize, ysize): is_narrow = True break if is_narrow: @@ -254,62 +255,159 @@ def main_function(dataset, args, log_path): datasets[t].resize(i, axis=0) +def tile_size_callback(value: str): + try: + x, y = map(int, value.split('x')) + except ValueError: + raise typer.BadParameter('Tile size must be in the format "256x256"') + return x, y + + +def prepare_data( + data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + skip_gdal: Annotated[ + bool, typer.Option(help='Skip the Gdal conversion stage (if it has already been done)') + ] = False, + gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries (ignored if --skip_gdal is passed)')] = None, + gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts (ignored if --skip_gdal is passed)')] = None, + n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, + nodata_threshold: Annotated[float, typer.Option(help='Throw away data with at least this % of nodata pixels')] = 50, + tile_size: Annotated[ + str, typer.Option(help="Tiling size in pixels e.g. '256x256'", callback=tile_size_callback) + ] = '256x256', + tile_overlap: Annotated[int, typer.Option(help='Overlap of the tiles in pixels')] = 25, +): + """Make data ready for training""" + + # Tiling Settings + xsize, ysize = map(int, tile_size.split('x')) + overlap = tile_overlap + + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + log_path = log_dir / f'prepare_data-{timestamp}.log' + log_dir.mkdir(exist_ok=True) + init_logging(log_path) + logger = get_logger('prepare_data') + logger.info('#############################') + logger.info('# Starting Data Preparation #') + logger.info('#############################') + + threshold = nodata_threshold / 100 + + if not skip_gdal: + gdal.initialize(bin=gdal_bin, path=gdal_path) + + DATA_ROOT = data_dir + DATA_DIR = DATA_ROOT / 'tiles' + h5dir = DATA_ROOT / 'h5' + h5dir.mkdir(exist_ok=True) + + # All folders that contain the big raster (...AnalyticsML_SR.tif) are assumed to be a dataset + datasets = [raster.parent for raster in DATA_DIR.glob('*/' + RASTERFILTER)] + + overwrite_conflicts = [] + for dataset in datasets: + check_dir = h5dir / dataset.name + if check_dir.exists(): + overwrite_conflicts.append(check_dir) + + if overwrite_conflicts: + logger.warning(f"Found old data directories: {', '.join(dir.name for dir in overwrite_conflicts)}.") + decision = input('Delete and recreate them [d], skip them [s] or abort [a]? ').lower() + if decision == 'd': + logger.info('User chose to delete old directories.') + for old_dir in overwrite_conflicts: + shutil.rmtree(old_dir) + elif decision == 's': + logger.info('User chose to skip old directories.') + already_done = [d.name for d in overwrite_conflicts] + datasets = [d for d in datasets if d.name not in already_done] + else: + # When in doubt, don't overwrite/change anything to be safe + logger.error('Aborting due to conflicts with existing data directories.') + sys.exit(1) + + Parallel(n_jobs=n_jobs)( + delayed(main_function)(dataset, log_path, h5dir, xsize, ysize, overlap, threshold, skip_gdal) + for dataset in datasets + ) + + def main(): + parser = argparse.ArgumentParser( + description='Make data ready for training', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--data_dir', default='data', type=Path, help='Path to data processing dir') + parser.add_argument('--log_dir', default='logs', type=Path, help='Path to log dir') + parser.add_argument( + '--skip_gdal', action='store_true', help='Skip the Gdal conversion stage (if it has already been ' 'done)' + ) + parser.add_argument('--gdal_bin', default=None, help='Path to gdal binaries (ignored if --skip_gdal is passed)') + parser.add_argument('--gdal_path', default=None, help='Path to gdal scripts (ignored if --skip_gdal is passed)') + parser.add_argument('--n_jobs', default=-1, type=int, help='number of parallel joblib jobs') + parser.add_argument( + '--nodata_threshold', default=50, type=float, help='Throw away data with at least this % of ' 'nodata pixels' + ) + parser.add_argument('--tile_size', default='256x256', type=str, help="Tiling size in pixels e.g. '256x256'") + parser.add_argument('--tile_overlap', default=25, type=int, help='Overlap of the tiles in pixels') + args = parser.parse_args() + # Tiling Settings - XSIZE, YSIZE = map(int, args.tile_size.split('x')) - OVERLAP = args.tile_overlap + xsize, ysize = map(int, args.tile_size.split('x')) + overlap = args.tile_overlap timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_path = Path(args.log_dir) / f'prepare_data-{timestamp}.log' if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) + os.mkdir(Path(args.log_dir)) init_logging(log_path) logger = get_logger('prepare_data') logger.info('#############################') logger.info('# Starting Data Preparation #') logger.info('#############################') - # Paths setup - RASTERFILTER = '*_SR*.tif' - VECTORFILTER = '*.shp' - THRESHOLD = args.nodata_threshold / 100 + threshold = args.nodata_threshold / 100 if not args.skip_gdal: gdal.initialize(args) DATA_ROOT = Path(args.data_dir) DATA_DIR = DATA_ROOT / 'tiles' - H5_DIR = DATA_ROOT / 'h5' - H5_DIR.mkdir(exist_ok=True) + h5dir = DATA_ROOT / 'h5' + h5dir.mkdir(exist_ok=True) # All folders that contain the big raster (...AnalyticsML_SR.tif) are assumed to be a dataset datasets = [raster.parent for raster in DATA_DIR.glob('*/' + RASTERFILTER)] overwrite_conflicts = [] for dataset in datasets: - check_dir = H5_DIR / dataset.name + check_dir = h5dir / dataset.name if check_dir.exists(): overwrite_conflicts.append(check_dir) if overwrite_conflicts: logger.warning(f"Found old data directories: {', '.join(dir.name for dir in overwrite_conflicts)}.") - decision = input("Delete and recreate them [d], skip them [s] or abort [a]? ").lower() + decision = input('Delete and recreate them [d], skip them [s] or abort [a]? ').lower() if decision == 'd': - logger.info(f"User chose to delete old directories.") + logger.info('User chose to delete old directories.') for old_dir in overwrite_conflicts: shutil.rmtree(old_dir) elif decision == 's': - logger.info(f"User chose to skip old directories.") + logger.info('User chose to skip old directories.') already_done = [d.name for d in overwrite_conflicts] datasets = [d for d in datasets if d.name not in already_done] else: # When in doubt, don't overwrite/change anything to be safe - logger.error("Aborting due to conflicts with existing data directories.") + logger.error('Aborting due to conflicts with existing data directories.') sys.exit(1) - Parallel(n_jobs=args.n_jobs)(delayed(main_function)(dataset, args, log_path) for dataset in datasets) + Parallel(n_jobs=args.n_jobs)( + delayed(main_function)(dataset, log_path, h5dir, xsize, ysize, overlap, threshold, args.skip_gdal) + for dataset in datasets + ) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py b/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py index 411410e..2aa6d07 100644 --- a/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py +++ b/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py @@ -1,13 +1,18 @@ -import rasterio +import argparse import os from pathlib import Path -import numpy as np + import ee import geemap +import numpy as np +import rasterio +import typer from joblib import Parallel, delayed -import argparse -#from ..data_pre_processing.earthengine import ee_geom_from_image_bounds + +# from ..data_pre_processing.earthengine import ee_geom_from_image_bounds from rasterio.coords import BoundingBox +from typing_extensions import Annotated + ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com') @@ -16,17 +21,17 @@ def get_ndvi_from_4bandS2(image_path): ndvi_path = image_path.parent / 'ndvi.tif' if not ndvi_path.exists(): with rasterio.open(image_path) as src: - #read data + # read data data = src.read().astype(float) # calc ndvi - ndvi = (data[3]-data[2]) / (data[3]+data[2]) + ndvi = (data[3] - data[2]) / (data[3] + data[2]) # factor to correct output ndvi_out = ((np.clip(ndvi, -1, 1) + 1) * 1e4).astype(np.uint16) # get and adapt profile to match output profile = src.profile - profile.update({'dtype':'uint16', 'count':1}) + profile.update({'dtype': 'uint16', 'count': 1}) # save ndvi - + with rasterio.open(ndvi_path, 'w', **profile) as target: target.write(np.expand_dims(ndvi_out, 0)) else: @@ -39,7 +44,7 @@ def get_elevation_and_slope(image_path, rel_el_vrt, slope_vrt, parallel=True): epsg = src.crs.to_string() bounds = src.bounds xres, yres = src.res - + # setup gdal runs target_el = image_path.parent / 'relative_elevation.tif' if not target_el.exists(): @@ -55,8 +60,10 @@ def get_elevation_and_slope(image_path, rel_el_vrt, slope_vrt, parallel=True): s_slope = '' # run in console if parallel: + def execute_command(cmd): os.system(cmd) + # Parallel execution Parallel(n_jobs=2)(delayed(execute_command)(cmd) for cmd in [s_el, s_slope]) else: @@ -67,10 +74,10 @@ def execute_command(cmd): def replace_tcvis_zeronodata(infile): """ This function replaces zero values in the input raster file with 1, and writes the result to a new file. - + Parameters: infile (Path): The path to the input raster file. - + The function performs the following steps: 1. Reads the data from the input file. 2. Creates a mask of zero values. @@ -79,12 +86,12 @@ def replace_tcvis_zeronodata(infile): 5. Writes the modified data to a new file with the same profile as the input file. 6. Deletes the original file. 7. Renames the new file to have the same name as the original file. - + The function does not return any value. """ # setup outfile name tcvis_replace = infile.parent / 'tcvis_tmpfix.tif' - + with rasterio.open(infile, 'r') as src: # read data ds = src.read() @@ -94,18 +101,18 @@ def replace_tcvis_zeronodata(infile): mask_all = mask.all(axis=0) mask_any = mask.any(axis=0) replace_mask = np.logical_and(mask_any, ~mask_all) - #replace zero values with 1a - for i in [0,1,2]: + # replace zero values with 1a + for i in [0, 1, 2]: ds[i][(ds[i] == 0) & replace_mask] = 1 # get profile for new output profile = src.profile with rasterio.open(tcvis_replace, 'w', **profile) as dst: dst.write(ds) - + # delete_original infile.unlink() - # rename + # rename tcvis_replace.rename(infile) @@ -127,13 +134,12 @@ def download_tcvis(image_path): Example: download_tcvis('path/to/your/input_image.tif') """ - + with rasterio.open(image_path) as src: epsg = src.crs.to_string() crs = src.crs bounds = src.bounds xres, yres = src.res - image_shape = (src.height, src.width) # needs to cut one pixel on top fixed_bounds = BoundingBox(left=bounds.left, bottom=bounds.bottom, right=bounds.right, top=bounds.top - yres) @@ -143,65 +149,85 @@ def download_tcvis(image_path): geom = geemap.gdf_to_ee(gdf).first().geometry() # download result - ee_image_tcvis = ee.ImageCollection("users/ingmarnitze/TCTrend_SR_2000-2019_TCVIS").mosaic() - tcvis_outfile = image_path.parent /'tcvis.tif' - + ee_image_tcvis = ee.ImageCollection('users/ingmarnitze/TCTrend_SR_2000-2019_TCVIS').mosaic() + tcvis_outfile = image_path.parent / 'tcvis.tif' + if not tcvis_outfile.exists(): geemap.download_ee_image(ee_image_tcvis, filename=tcvis_outfile, region=geom, scale=xres, crs=epsg) else: print(f'TCVIS file for {image_path.parent.name} already exists!') return 0 - + # check outfile props with rasterio.open(tcvis_outfile) as src: - epsg_tcvis = src.crs.to_string() - crs_tcvis = src.crs bounds_tcvis = src.bounds xres_tcvis, yres_tcvis = src.res - shape_tcvis = (src.height, src.width) # show diff if not bounds == bounds_tcvis: print(f'Downloaded TCVIS Image for dataset {image_path.parent.name} has wrong dimensions') print('Input image:', bounds, xres, yres) print('TCVIS Image:', bounds_tcvis, xres_tcvis, yres_tcvis) - + # fix output in case of size mismatch tcvis_outfile_tmp = tcvis_outfile.parent / 'tcvis_temp.tif' tcvis_outfile.rename(tcvis_outfile_tmp) s_fix_tcvis = f'gdalwarp -te {bounds.left} {bounds.bottom} {bounds.right} {bounds.top} -tr {xres} {yres} -co COMPRESS=DEFLATE {tcvis_outfile_tmp} {tcvis_outfile}' os.system(s_fix_tcvis) tcvis_outfile_tmp.unlink() - + print('Write mask corrected TCVIS') replace_tcvis_zeronodata(tcvis_outfile) +def prepare_s2_4band_planet_format( + data_dir: Annotated[Path, typer.Argument(help='data directory (parent of download dir)')], + image_regex: Annotated[str, typer.Option(help='regex term to find image file')] = '*/*SR.tif', + n_jobs: Annotated[int, typer.Option(help='Number of parallel- images to prepare data for')] = 6, +): + input_dir = Path(data_dir) + infiles = list(input_dir.glob(image_regex)) + + base_dir = Path('/isipd/projects/p_aicore_pf/initze/processing/auxiliary/') + # make absolute paths + elevation = base_dir / 'elevation.vrt' + slope = base_dir / 'slope.vrt' + + # for image_path in infiles: + Parallel(n_jobs=n_jobs)(delayed(process_local_data)(image_path, elevation, slope) for image_path in infiles) + + for image_path in infiles: + download_tcvis(image_path) + + def main(): - parser = argparse.ArgumentParser(description='Prepare aux data for downloaded S2 images.', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + description='Prepare aux data for downloaded S2 images.', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument('--data_dir', type=str, help='data directory (parent of download dir)', required=True) parser.add_argument('--image_regex', type=str, default='*/*SR.tif', help='regex term to find image file') parser.add_argument('--n_jobs', type=int, default=6, help='Number of parallel- images to prepare data for') - parser.add_argument('--aux_dir', - default='/isipd/projects/p_aicore_pf/initze/processing/auxiliary/', - type=str, - help='parent directory of auxilliary data') + parser.add_argument( + '--aux_dir', + default='/isipd/projects/p_aicore_pf/initze/processing/auxiliary/', + type=str, + help='parent directory of auxilliary data', + ) args = parser.parse_args() - + input_dir = Path(args.data_dir) infiles = list(input_dir.glob(args.image_regex)) - + base_dir = Path('/isipd/projects/p_aicore_pf/initze/processing/auxiliary/') # make absolute paths elevation = base_dir / 'elevation.vrt' slope = base_dir / 'slope.vrt' - - #for image_path in infiles: + + # for image_path in infiles: Parallel(n_jobs=args.n_jobs)(delayed(process_local_data)(image_path, elevation, slope) for image_path in infiles) - + for image_path in infiles: download_tcvis(image_path) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/src/thaw_slump_segmentation/scripts/setup_raw_data.py b/src/thaw_slump_segmentation/scripts/setup_raw_data.py index d678c25..6d032d9 100644 --- a/src/thaw_slump_segmentation/scripts/setup_raw_data.py +++ b/src/thaw_slump_segmentation/scripts/setup_raw_data.py @@ -10,7 +10,6 @@ import argparse import os -from collections import namedtuple from datetime import datetime from pathlib import Path @@ -47,9 +46,7 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, log_path, gdal_bin, gdal_path, label_required=True): # TODO: let gdal.initialize take each argument separately # Mock old args object - ARGS = namedtuple('gdalargs', ['gdal_bin', 'gdal_path']) - gdalargs = ARGS(gdal_bin, gdal_path) - gdal.initialize(gdalargs) + gdal.initialize(bin=gdal_bin, path=gdal_path) init_logging(log_path) image_name = os.path.basename(image_dir) From b2134bb2017ca37bcfafb7473cdda4b7ca161267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 9 May 2024 16:45:06 +0200 Subject: [PATCH 3/9] Convert CLI of both process scripts --- src/thaw_slump_segmentation/main.py | 9 + .../scripts/process_02_inference.py | 256 ++++++++++------ .../scripts/process_03_ensemble.py | 281 ++++++++++++------ 3 files changed, 379 insertions(+), 167 deletions(-) diff --git a/src/thaw_slump_segmentation/main.py b/src/thaw_slump_segmentation/main.py index 77ae773..1afa96d 100644 --- a/src/thaw_slump_segmentation/main.py +++ b/src/thaw_slump_segmentation/main.py @@ -4,6 +4,8 @@ from thaw_slump_segmentation.scripts.inference import inference from thaw_slump_segmentation.scripts.prepare_data import prepare_data from thaw_slump_segmentation.scripts.prepare_s2_4band_planet_format import prepare_s2_4band_planet_format +from thaw_slump_segmentation.scripts.process_02_inference import process_02_inference +from thaw_slump_segmentation.scripts.process_03_ensemble import process_03_ensemble from thaw_slump_segmentation.scripts.setup_raw_data import setup_raw_data from thaw_slump_segmentation.scripts.train import train @@ -21,3 +23,10 @@ data_cli.command('prepare')(prepare_data) cli.add_typer(data_cli, name='data') + +process_cli = typer.Typer() + +process_cli.command('inference')(process_02_inference) +process_cli.command('ensemble')(process_03_ensemble) + +cli.add_typer(process_cli, name='process') diff --git a/src/thaw_slump_segmentation/scripts/process_02_inference.py b/src/thaw_slump_segmentation/scripts/process_02_inference.py index 20095c4..ad0a614 100644 --- a/src/thaw_slump_segmentation/scripts/process_02_inference.py +++ b/src/thaw_slump_segmentation/scripts/process_02_inference.py @@ -1,60 +1,63 @@ -from pathlib import Path -import torch -import pandas as pd +import argparse import os -import numpy as np -import tqdm -from joblib import delayed, Parallel import shutil -from tqdm import tqdm -import swifter from datetime import datetime -import argparse +from pathlib import Path +from typing import List -from ..postprocessing import * - -# ### Settings -# Add argument definitions -parser = argparse.ArgumentParser(description="Script to run auto inference for RTS", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--code_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/code/aicore_inference'), - help="Local code directory") -parser.add_argument("--raw_data_dir", type=Path, nargs='+', - default=[ - Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/scenes'), - Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles') - ], - help="Location of raw data") -parser.add_argument("--processing_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/processing'), - help="Location for data processing") -parser.add_argument("--inference_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/processed/inference'), - help="Target directory for inference results") -parser.add_argument("--model_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/models/thaw_slumps'), - help="Target directory for models") -parser.add_argument("--model", type=str, default='RTS_v6_tcvis', - help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']") -parser.add_argument("--use_gpu", nargs="+", type=int, default=[0], - help="List of GPU IDs to use, space separated") -parser.add_argument("--runs_per_gpu", type=int, default=5, - help="Number of runs per GPU") -parser.add_argument("--max_images", type=int, default=None, - help="Maximum number of images to process (optional)") -parser.add_argument("--skip_vrt", action="store_true", - help="set to skip DEM vrt creation") -parser.add_argument("--skip_vector_save", action="store_true", - help="set to skip output vector creation") - -# TODO, make flag to skip vrt -args = parser.parse_args() +import geopandas as gpd +import numpy as np +import pandas as pd +import typer +from joblib import Parallel, delayed +from tqdm import tqdm +from typing_extensions import Annotated -def main(): +from ..postprocessing import ( + copy_unprocessed_files, + get_processing_status, + load_and_parse_vector, + run_inference, + update_DEM2, +) + + +def process_02_inference( + code_dir: Annotated[Path, typer.Option(help='Local code directory')] = Path( + '/isipd/projects/p_aicore_pf/initze/code/aicore_inference' + ), + raw_data_dir: Annotated[List[Path], typer.Option(help='Location of raw data')] = [ + Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/scenes'), + Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles'), + ], + processing_dir: Annotated[Path, typer.Option(help='Location for data processing')] = Path( + '/isipd/projects/p_aicore_pf/initze/processing' + ), + inference_dir: Annotated[Path, typer.Option(help='Target directory for inference results')] = Path( + '/isipd/projects/p_aicore_pf/initze/processed/inference' + ), + model_dir: Annotated[Path, typer.Option(help='Target directory for models')] = Path( + '/isipd/projects/p_aicore_pf/initze/models/thaw_slumps' + ), + model: Annotated[ + str, typer.Option(help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']") + ] = 'RTS_v6_tcvis', + use_gpu: Annotated[List[int], typer.Option(help='List of GPU IDs to use, space separated')] = [0], + runs_per_gpu: Annotated[int, typer.Option(help='Number of runs per GPU')] = 5, + max_images: Annotated[int, typer.Option(help='Maximum number of images to process (optional)')] = None, + skip_vrt: Annotated[bool, typer.Option(help='set to skip DEM vrt creation')] = False, + skip_vector_save: Annotated[bool, typer.Option(help='set to skip output vector creation')] = False, +): + """Script to run auto inference for RTS""" # ### List all files with properties # TODO: run double for both paths - print("Checking processing status!") + print('Checking processing status!') # read processing status for raw data list # TODO: check here - produces very large output when double checking - df_processing_status_list = [get_processing_status(raw_data_dir, args.processing_dir, args.inference_dir, args.model) for raw_data_dir in args.raw_data_dir] - + df_processing_status_list = [ + get_processing_status(rdd, processing_dir, inference_dir, model) for rdd in raw_data_dir + ] + # get df for preprocessing df_final = pd.concat(df_processing_status_list).drop_duplicates() @@ -69,16 +72,16 @@ def main(): print(f'Number of images for preprocessing: {preprocessing_images}') print(f'Number of images for inference: {preprocessed_images - finished_images}') print(f'Number of finished images: {finished_images}') - + # TODO: images with processing status True but Inference False are crappy - + if total_images == finished_images: print('No processing needed: all images are already processed!') return 0 - + ## Preprocessing # #### Update Arctic DEM data - if args.skip_vrt == True: + if skip_vrt: print('Skipping Elevation VRT creation!') else: print('Updating Elevation VRTs!') @@ -86,17 +89,16 @@ def main(): vrt_target_dir = Path('/isipd/projects/p_aicore_pf/initze/processing/auxiliary/ArcticDEM') update_DEM2(dem_data_dir=dem_data_dir, vrt_target_dir=vrt_target_dir) - - # #### Copy data for Preprocessing + # #### Copy data for Preprocessing # make better documentation df_preprocess = df_final[~(df_final.preprocessed)] print(f'Number of images to preprocess: {len(df_preprocess)}') # Cleanup processing directories to avoid incomplete processing - input_dir_dslist = list((args.processing_dir / 'input').glob('*')) + input_dir_dslist = list((processing_dir / 'input').glob('*')) if len(input_dir_dslist) > 0: - print(f"Cleaning up {(args.processing_dir / 'input')}") + print(f"Cleaning up {(processing_dir / 'input')}") for d in input_dir_dslist: print('Delete', d) shutil.rmtree(d) @@ -105,76 +107,164 @@ def main(): # TODO: check for empty processing status # Copy Data - _ = df_preprocess.swifter.apply(lambda x: copy_unprocessed_files(x, args.processing_dir), axis=1) + _ = df_preprocess.swifter.apply(lambda x: copy_unprocessed_files(x, processing_dir), axis=1) - # #### Run Preprocessing + # #### Run Preprocessing import warnings + warnings.filterwarnings('ignore') - N_JOBS=40 - print(f'Preprocessing {len(df_preprocess)} images') #fix this + N_JOBS = 40 + print(f'Preprocessing {len(df_preprocess)} images') # fix this if len(df_preprocess) > 0: - pp_string = f'setup_raw_data --data_dir {args.processing_dir} --n_jobs {N_JOBS} --nolabel' + pp_string = f'setup_raw_data --data_dir {processing_dir} --n_jobs {N_JOBS} --nolabel' os.system(pp_string) # ## Processing/Inference # rerun processing status - df_processing_status2 = pd.concat([get_processing_status(raw_data_dir, args.processing_dir, args.inference_dir, args.model) for raw_data_dir in args.raw_data_dir]).drop_duplicates() + df_processing_status2 = pd.concat( + [get_processing_status(rdd, processing_dir, inference_dir, model) for rdd in raw_data_dir] + ).drop_duplicates() # Filter to images that are not preprocessed yet df_process = df_final[~df_final.inference_finished] # update overview and filter accordingly - really necessary? - df_process_final = df_process.set_index('name').join(df_processing_status2[df_processing_status2['preprocessed']][['name']].set_index('name'), how='inner').reset_index(drop=False).iloc[:args.max_images] + df_process_final = ( + df_process.set_index('name') + .join(df_processing_status2[df_processing_status2['preprocessed']][['name']].set_index('name'), how='inner') + .reset_index(drop=False) + .iloc[:max_images] + ) # validate if images are correctly preprocessed - df_process_final['preprocessing_valid'] = (df_process_final.apply(lambda x: len(list(x['path'].glob('*'))), axis=1) >= 5) + df_process_final['preprocessing_valid'] = ( + df_process_final.apply(lambda x: len(list(x['path'].glob('*'))), axis=1) >= 5 + ) # final filtering process to remove incorrectly preprocessed data df_process_final = df_process_final[df_process_final['preprocessing_valid']] # TODO: check for empty files and processing - print(f'Number of images:', len(df_process_final)) + print(f'Number of images: {len(df_process_final)}') - # #### Parallel runs + # #### Parallel runs # Make splits to distribute the processing - n_splits = len(args.use_gpu) * args.runs_per_gpu + n_splits = len(use_gpu) * runs_per_gpu df_split = np.array_split(df_process_final, n_splits) - gpu_split = args.use_gpu * args.runs_per_gpu + gpu_split = use_gpu * runs_per_gpu - #for split in df_split: + # for split in df_split: # print(f'Number of images: {len(split)}') print('Run inference!') # ### Parallel Inference execution - _ = Parallel(n_jobs=n_splits)(delayed(run_inference)(df_split[split], model=args.model, processing_dir=args.processing_dir, inference_dir=args.inference_dir, model_dir=args.model_dir, gpu=gpu_split[split], run=True) for split in range(n_splits)) + _ = Parallel(n_jobs=n_splits)( + delayed(run_inference)( + df_split[split], + model=model, + processing_dir=processing_dir, + inference_dir=inference_dir, + model_dir=model_dir, + gpu=gpu_split[split], + run=True, + ) + for split in range(n_splits) + ) # #### Merge output files - if not args.skip_vector_save: + if not skip_vector_save: if len(df_process_final) > 0: - # read all files which following the above defined threshold - flist = list((args.inference_dir / args.model).glob(f'*/*pred_binarized.shp')) + flist = list((inference_dir / model).glob('*/*pred_binarized.shp')) len(flist) - + # Save output vectors to merged file # load them in parallel out = Parallel(n_jobs=6)(delayed(load_and_parse_vector)(f) for f in tqdm(flist[:])) - + # merge them and save to geopackage file merged_gdf = gpd.pd.concat(out) - save_file = args.inference_dir / args.model / f'{args.model}_merged.gpkg' - + save_file = inference_dir / model / f'{model}_merged.gpkg' + # check if file already exists, create backup file if exists if save_file.exists(): # Get the current timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Create the backup file name - save_file_bk = args.inference_dir / args.model / f"{args.model}_merged_bk_{timestamp}.gpkg" - print (f'Creating backup of file {save_file} to {save_file_bk}') + save_file_bk = inference_dir / model / f'{model}_merged_bk_{timestamp}.gpkg' + print(f'Creating backup of file {save_file} to {save_file_bk}') shutil.move(save_file, save_file_bk) - + # save to files print(f'Saving vectors to {save_file}') merged_gdf.to_file(save_file) else: print('Skipping output vector creation!') -if __name__ == "__main__": - main() \ No newline at end of file + +def main(): + # ### Settings + # Add argument definitions + parser = argparse.ArgumentParser( + description='Script to run auto inference for RTS', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--code_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/code/aicore_inference'), + help='Local code directory', + ) + parser.add_argument( + '--raw_data_dir', + type=Path, + nargs='+', + default=[ + Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/scenes'), + Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles'), + ], + help='Location of raw data', + ) + parser.add_argument( + '--processing_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/processing'), + help='Location for data processing', + ) + parser.add_argument( + '--inference_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/processed/inference'), + help='Target directory for inference results', + ) + parser.add_argument( + '--model_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/models/thaw_slumps'), + help='Target directory for models', + ) + parser.add_argument( + '--model', type=str, default='RTS_v6_tcvis', help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']" + ) + parser.add_argument('--use_gpu', nargs='+', type=int, default=[0], help='List of GPU IDs to use, space separated') + parser.add_argument('--runs_per_gpu', type=int, default=5, help='Number of runs per GPU') + parser.add_argument('--max_images', type=int, default=None, help='Maximum number of images to process (optional)') + parser.add_argument('--skip_vrt', action='store_true', help='set to skip DEM vrt creation') + parser.add_argument('--skip_vector_save', action='store_true', help='set to skip output vector creation') + + # TODO, make flag to skip vrt + args = parser.parse_args() + + process_02_inference( + code_dir=args.code_dir, + raw_data_dir=args.raw_data_dir, + processing_dir=args.processing_dir, + inference_dir=args.inference_dir, + model_dir=args.model_dir, + model=args.model, + use_gpu=args.use_gpu, + runs_per_gpu=args.runs_per_gpu, + max_images=args.max_images, + skip_vrt=args.skip_vrt, + skip_vector_save=args.skip_vector_save, + ) + + +if __name__ == '__main__': + main() diff --git a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py index f2e7236..0c4354a 100644 --- a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py +++ b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py @@ -1,128 +1,150 @@ # # Create ensemble results from several model outputs -# ### Imports +# ### Imports +import argparse +import shutil +from datetime import datetime from pathlib import Path -import pandas as pd -from joblib import delayed, Parallel -#from tqdm.notebook import tqdm -from tqdm import tqdm -from ..postprocessing import * +from typing import List + import geopandas as gpd -from datetime import datetime -import argparse +import typer +from joblib import Parallel, delayed +from tqdm import tqdm +from typing_extensions import Annotated -# Add argument definitions -parser = argparse.ArgumentParser(description="Script to run auto inference for RTS", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--raw_data_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles'), - help="Location of raw data") -parser.add_argument("--processing_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/processing'), - help="Location for data processing") -parser.add_argument("--inference_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/processed/inference'), - help="Target directory for inference results") -parser.add_argument("--model_dir", type=Path, default=Path('/isipd/projects/p_aicore_pf/initze/models/thaw_slumps'), - help="Target directory for models") -parser.add_argument("--ensemble_name", type=str, default='RTS_v6_ensemble_v2', - help="Target directory for models") -parser.add_argument("--model_names", type=str, nargs='+', default=['RTS_v6_tcvis', 'RTS_v6_notcvis'], - help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']") -parser.add_argument("--gpu", type=int, default=0, - help="GPU IDs to use for edge cleaning") -parser.add_argument("--n_jobs", type=int, default=15, - help="number of CPU jobs for ensembling") -parser.add_argument("--n_vector_loaders", type=int, default=6, - help="number of parallel vector loaders for final merge") -parser.add_argument("--max_images", type=int, default=None, - help="Maximum number of images to process (optional)") -parser.add_argument("--vector_output_format", type=str, nargs='+', default=['gpkg', 'parquet'], - help="Output format extension of ensembled vector files") -parser.add_argument("--ensemble_thresholds", type=float, nargs='+', default=[0.4, 0.45, 0.5], - help="Thresholds for polygonized outputs of the ensemble, needs to be string, see examples") -parser.add_argument("--ensemble_border_size", type=int, default=10, - help="Number of pixels to remove around the border and no data") -parser.add_argument("--ensemble_mmu", type=int, default=32, - help="Minimum mapping unit of output objects in pixels") -parser.add_argument("--try_gpu", action="store_true", help="set to try image processing with gpu") -parser.add_argument("--force_vector_merge", action="store_true", help="force merging of output vectors even if no new ensemble tiles were processed") - -args = parser.parse_args() +# from tqdm.notebook import tqdm +from ..postprocessing import ( + create_ensemble_v2, + get_processing_status, + get_processing_status_ensemble, + load_and_parse_vector, +) -def main(): + +def process_03_ensemble( + raw_data_dir: Annotated[Path, typer.Option(help='Location of raw data')] = Path( + '/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles' + ), + processing_dir: Annotated[Path, typer.Option(help='Location for data processing')] = Path( + '/isipd/projects/p_aicore_pf/initze/processing' + ), + inference_dir: Annotated[Path, typer.Option(help='Target directory for inference results')] = Path( + '/isipd/projects/p_aicore_pf/initze/processed/inference' + ), + model_dir: Annotated[Path, typer.Option(help='Target directory for models')] = Path( + '/isipd/projects/p_aicore_pf/initze/models/thaw_slumps' + ), + ensemble_name: Annotated[str, typer.Option(help='Target directory for models')] = 'RTS_v6_ensemble_v2', + model_names: Annotated[List[str], typer.Option(help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']")] = [ + 'RTS_v6_tcvis', + 'RTS_v6_notcvis', + ], + gpu: Annotated[int, typer.Option(help='GPU IDs to use for edge cleaning')] = 0, + n_jobs: Annotated[int, typer.Option(help='number of CPU jobs for ensembling')] = 15, + n_vector_loaders: Annotated[int, typer.Option(help='number of parallel vector loaders for final merge')] = 6, + max_images: Annotated[int, typer.Option(help='Maximum number of images to process (optional)')] = None, + vector_output_format: Annotated[ + List[str], typer.Option(help='Output format extension of ensembled vector files') + ] = ['gpkg', 'parquet'], + ensemble_thresholds: Annotated[ + List[float], + typer.Option(help='Thresholds for polygonized outputs of the ensemble, needs to be string, see examples'), + ] = [0.4, 0.45, 0.5], + ensemble_border_size: Annotated[ + int, typer.Option(help='Number of pixels to remove around the border and no data') + ] = 10, + ensemble_mmu: Annotated[int, typer.Option(help='Minimum mapping unit of output objects in pixels')] = 32, + try_gpu: Annotated[bool, typer.Option(help='set to try image processing with gpu')] = False, + force_vector_merge: Annotated[ + bool, typer.Option(help='force merging of output vectors even if no new ensemble tiles were processed') + ] = False, +): ### Start # check if cucim is available try: - import cucim - if args.try_gpu: + import cucim # type: ignore # noqa: F401 + + if try_gpu: try_gpu = True - print ('Running ensembling with GPU!') + print('Running ensembling with GPU!') else: try_gpu = False - print ('Running ensembling with CPU!') - except: + print('Running ensembling with CPU!') + except Exception as e: try_gpu = False - print ('Cucim import failed') + print(f'Cucim import failed: {e}') # setup all params kwargs_ensemble = { - 'ensemblename': args.ensemble_name, - 'inference_dir': args.inference_dir, - 'modelnames': args.model_names, - 'binary_threshold': args.ensemble_thresholds, - 'border_size': args.ensemble_border_size, - 'minimum_mapping_unit': args.ensemble_mmu, + 'ensemblename': ensemble_name, + 'inference_dir': inference_dir, + 'modelnames': model_names, + 'binary_threshold': ensemble_thresholds, + 'border_size': ensemble_border_size, + 'minimum_mapping_unit': ensemble_mmu, 'delete_binary': True, - 'try_gpu': try_gpu, # currently default to CPU only - 'gpu' : args.gpu, + 'try_gpu': try_gpu, # currently default to CPU only + 'gpu': gpu, } # Check for finalized products - df_processing_status = get_processing_status(args.raw_data_dir, args.processing_dir, args.inference_dir, model=kwargs_ensemble['ensemblename']) - df_ensemble_status = get_processing_status_ensemble(args.inference_dir, model_input_names=kwargs_ensemble['modelnames'], model_ensemble_name=kwargs_ensemble['ensemblename']) + get_processing_status(raw_data_dir, processing_dir, inference_dir, model=kwargs_ensemble['ensemblename']) + df_ensemble_status = get_processing_status_ensemble( + inference_dir, + model_input_names=kwargs_ensemble['modelnames'], + model_ensemble_name=kwargs_ensemble['ensemblename'], + ) # Check which need to be process - check for already processed and invalid files process = df_ensemble_status[df_ensemble_status['process']] - n_images = len(process.iloc[:args.max_images]) + n_images = len(process.iloc[:max_images]) # #### Run Ensemble Merging if len(process) > 0: - print(f'Start running ensemble for {n_images} images with {args.n_jobs} parallel jobs!') - print(f'Target ensemble name:', kwargs_ensemble['ensemblename']) - print(f'Source model output', kwargs_ensemble['modelnames']) - _ = Parallel(n_jobs=args.n_jobs)(delayed(create_ensemble_v2)(image_id=process.iloc[row]['name'], **kwargs_ensemble) for row in tqdm(range(n_images))) + print(f'Start running ensemble for {n_images} images with {n_jobs} parallel jobs!') + print(f'Target ensemble name: {ensemble_name}') + print(f'Source model output {model_names}') + _ = Parallel(n_jobs=n_jobs)( + delayed(create_ensemble_v2)(image_id=process.iloc[row]['name'], **kwargs_ensemble) + for row in tqdm(range(n_images)) + ) else: - print(f'Skipped ensembling, all files ready for {args.ensemble_name}!') + print(f'Skipped ensembling, all files ready for {ensemble_name}!') - # # #### run parallelized batch + # # #### run parallelized batch - if (len(process) > 0) or args.force_vector_merge: + if (len(process) > 0) or force_vector_merge: # ### Merge vectors to complete dataset # set probability levels: 'class_05' means 50%, 'class_045' means 45%. This is the regex to search for vector naming - #proba_strings = args.ensemble_thresholds - # TODO: needs to be 'class_04', - proba_strings = [f'class_{thresh}'.replace('.','') for thresh in args.ensemble_thresholds] + # proba_strings = args.ensemble_thresholds + # TODO: needs to be 'class_04', + proba_strings = [f'class_{thresh}'.replace('.', '') for thresh in ensemble_thresholds] for proba_string in proba_strings: # read all files which follow the above defined threshold - flist = list((args.inference_dir / args.ensemble_name).glob(f'*/*_{proba_string}.gpkg')) + flist = list((inference_dir / ensemble_name).glob(f'*/*_{proba_string}.gpkg')) len(flist) # load them in parallel - print (f'Loading results {proba_string}') - out = Parallel(n_jobs=6)(delayed(load_and_parse_vector)(f) for f in tqdm(flist[:args.max_images])) + print(f'Loading results {proba_string}') + out = Parallel(n_jobs=6)(delayed(load_and_parse_vector)(f) for f in tqdm(flist[:max_images])) # merge them and save to geopackage file - print ('Merging results') + print('Merging results') merged_gdf = gpd.pd.concat(out) - for vector_output_format in args.vector_output_format: + for vector_output_format in vector_output_format: # Save output to vector - save_file = args.inference_dir / args.ensemble_name / f'merged_{proba_string}.{vector_output_format}' - + save_file = inference_dir / ensemble_name / f'merged_{proba_string}.{vector_output_format}' + # make file backup if necessary if save_file.exists(): # Get the current timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Create the backup file name - save_file_bk = args.inference_dir / args.ensemble_name / f'merged_{proba_string}_bk_{timestamp}.{vector_output_format}' - print (f'Creating backup of file {save_file} to {save_file_bk}') + save_file_bk = ( + inference_dir / ensemble_name / f'merged_{proba_string}_bk_{timestamp}.{vector_output_format}' + ) + print(f'Creating backup of file {save_file} to {save_file_bk}') shutil.move(save_file, save_file_bk) - + # save to files print(f'Saving vectors to {save_file}') if vector_output_format in ['shp', 'gpkg']: @@ -132,5 +154,96 @@ def main(): else: print(f'Unknown format {vector_output_format}!') -if __name__ == "__main__": - main() \ No newline at end of file + +def main(): + # Add argument definitions + parser = argparse.ArgumentParser( + description='Script to run auto inference for RTS', formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--raw_data_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles'), + help='Location of raw data', + ) + parser.add_argument( + '--processing_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/processing'), + help='Location for data processing', + ) + parser.add_argument( + '--inference_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/processed/inference'), + help='Target directory for inference results', + ) + parser.add_argument( + '--model_dir', + type=Path, + default=Path('/isipd/projects/p_aicore_pf/initze/models/thaw_slumps'), + help='Target directory for models', + ) + parser.add_argument('--ensemble_name', type=str, default='RTS_v6_ensemble_v2', help='Target directory for models') + parser.add_argument( + '--model_names', + type=str, + nargs='+', + default=['RTS_v6_tcvis', 'RTS_v6_notcvis'], + help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']", + ) + parser.add_argument('--gpu', type=int, default=0, help='GPU IDs to use for edge cleaning') + parser.add_argument('--n_jobs', type=int, default=15, help='number of CPU jobs for ensembling') + parser.add_argument( + '--n_vector_loaders', type=int, default=6, help='number of parallel vector loaders for final merge' + ) + parser.add_argument('--max_images', type=int, default=None, help='Maximum number of images to process (optional)') + parser.add_argument( + '--vector_output_format', + type=str, + nargs='+', + default=['gpkg', 'parquet'], + help='Output format extension of ensembled vector files', + ) + parser.add_argument( + '--ensemble_thresholds', + type=float, + nargs='+', + default=[0.4, 0.45, 0.5], + help='Thresholds for polygonized outputs of the ensemble, needs to be string, see examples', + ) + parser.add_argument( + '--ensemble_border_size', type=int, default=10, help='Number of pixels to remove around the border and no data' + ) + parser.add_argument('--ensemble_mmu', type=int, default=32, help='Minimum mapping unit of output objects in pixels') + parser.add_argument('--try_gpu', action='store_true', help='set to try image processing with gpu') + parser.add_argument( + '--force_vector_merge', + action='store_true', + help='force merging of output vectors even if no new ensemble tiles were processed', + ) + + args = parser.parse_args() + + process_03_ensemble( + raw_data_dir=args.raw_data_dir, + processing_dir=args.processing_dir, + inference_dir=args.inference_dir, + model_dir=args.model_dir, + ensemble_name=args.ensemble_name, + model_names=args.model_names, + gpu=args.gpu, + n_jobs=args.n_jobs, + n_vector_loaders=args.n_vector_loaders, + max_images=args.max_images, + vector_output_format=args.vector_output_format, + ensemble_thresholds=args.ensemble_thresholds, + ensemble_border_size=args.ensemble_border_size, + ensemble_mmu=args.ensemble_mmu, + try_gpu=args.try_gpu, + force_vector_merge=args.force_vector_merge, + ) + + +if __name__ == '__main__': + main() From 3b0a708a6488387fbb63a6dfd6e0a8cd39be92ea Mon Sep 17 00:00:00 2001 From: Ingmar Nitze Date: Thu, 23 May 2024 14:46:58 +0200 Subject: [PATCH 4/9] added option to save probabilities for ensemble --- src/thaw_slump_segmentation/postprocessing.py | 19 ++++++++-- .../scripts/process_03_ensemble.py | 37 +++++++++++++++---- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/thaw_slump_segmentation/postprocessing.py b/src/thaw_slump_segmentation/postprocessing.py index efab578..32e02b6 100644 --- a/src/thaw_slump_segmentation/postprocessing.py +++ b/src/thaw_slump_segmentation/postprocessing.py @@ -294,7 +294,8 @@ def create_ensemble_v2(inference_dir: Path, binary_threshold: list=[0.5], border_size: int=10, minimum_mapping_unit: int=32, - delete_binary: bool=True, + save_binary: bool=False, + save_probability: bool=False, try_gpu: bool=True, gpu: int=0): """ @@ -326,6 +327,9 @@ def create_ensemble_v2(inference_dir: Path, delete_binary : bool, optional Whether to delete the binary file after processing, by default True. + save_probability : bool, optional + Whether to save the probability file after processing, by default False. + Returns: ------------ None @@ -346,7 +350,7 @@ def calculate_mean_image(images): ctr += 1 mean_image = np.mean(list_data, axis=0) - return mean_image, out_meta_binary + return mean_image, out_meta_binary, out_meta def dilate_data_mask(mask, size=10): selem = disk(size) @@ -371,11 +375,18 @@ def mask_edges(input_mask, size=10): return None try: - mean_image, out_meta_binary = calculate_mean_image(images) + mean_image, out_meta_binary, out_meta_probability = calculate_mean_image(images) except: print(f'Read error of files {images}') return None + # save probability layer, if specified + if save_probability: + outpath_proba = outpath = inference_dir / ensemblename / image_id / f'{image_id}_{ensemblename}_probability.tif' + os.makedirs(outpath_proba.parent, exist_ok=True) + with rasterio.open(outpath_proba, 'w', **out_meta_probability) as target: + target.write(mean_image) + for threshold in binary_threshold: # get binary object mask @@ -423,7 +434,7 @@ def mask_edges(input_mask, size=10): s_polygonize = f'gdal_polygonize.py {outpath} -q -mask {outpath} -f "GPKG" {outpath_shp}' os.system(s_polygonize) - if delete_binary: + if not save_binary: os.remove(outpath) diff --git a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py index 0c4354a..00207a9 100644 --- a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py +++ b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py @@ -26,39 +26,59 @@ def process_03_ensemble( raw_data_dir: Annotated[Path, typer.Option(help='Location of raw data')] = Path( '/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles' ), + processing_dir: Annotated[Path, typer.Option(help='Location for data processing')] = Path( '/isipd/projects/p_aicore_pf/initze/processing' ), + inference_dir: Annotated[Path, typer.Option(help='Target directory for inference results')] = Path( '/isipd/projects/p_aicore_pf/initze/processed/inference' ), + model_dir: Annotated[Path, typer.Option(help='Target directory for models')] = Path( '/isipd/projects/p_aicore_pf/initze/models/thaw_slumps' ), + ensemble_name: Annotated[str, typer.Option(help='Target directory for models')] = 'RTS_v6_ensemble_v2', + model_names: Annotated[List[str], typer.Option(help="Model name, examples ['RTS_v6_tcvis', 'RTS_v6_notcvis']")] = [ 'RTS_v6_tcvis', 'RTS_v6_notcvis', ], + gpu: Annotated[int, typer.Option(help='GPU IDs to use for edge cleaning')] = 0, + n_jobs: Annotated[int, typer.Option(help='number of CPU jobs for ensembling')] = 15, + n_vector_loaders: Annotated[int, typer.Option(help='number of parallel vector loaders for final merge')] = 6, + max_images: Annotated[int, typer.Option(help='Maximum number of images to process (optional)')] = None, + vector_output_format: Annotated[ List[str], typer.Option(help='Output format extension of ensembled vector files') ] = ['gpkg', 'parquet'], + ensemble_thresholds: Annotated[ List[float], typer.Option(help='Thresholds for polygonized outputs of the ensemble, needs to be string, see examples'), ] = [0.4, 0.45, 0.5], + ensemble_border_size: Annotated[ int, typer.Option(help='Number of pixels to remove around the border and no data') ] = 10, + ensemble_mmu: Annotated[int, typer.Option(help='Minimum mapping unit of output objects in pixels')] = 32, + try_gpu: Annotated[bool, typer.Option(help='set to try image processing with gpu')] = False, + force_vector_merge: Annotated[ bool, typer.Option(help='force merging of output vectors even if no new ensemble tiles were processed') ] = False, + + save_binary: Annotated[bool, typer.Option(help='set to keep intermediate binary rasters')] = False, + + save_probability: Annotated[bool, typer.Option(help='set to keep intermediate probility rasters')] = False, + ): ### Start # check if cucim is available @@ -83,7 +103,8 @@ def process_03_ensemble( 'binary_threshold': ensemble_thresholds, 'border_size': ensemble_border_size, 'minimum_mapping_unit': ensemble_mmu, - 'delete_binary': True, + 'save_binary': save_binary, + 'save_probability': save_probability, 'try_gpu': try_gpu, # currently default to CPU only 'gpu': gpu, } @@ -130,9 +151,9 @@ def process_03_ensemble( print('Merging results') merged_gdf = gpd.pd.concat(out) - for vector_output_format in vector_output_format: + for vector_format in vector_output_format: # Save output to vector - save_file = inference_dir / ensemble_name / f'merged_{proba_string}.{vector_output_format}' + save_file = inference_dir / ensemble_name / f'merged_{proba_string}.{vector_format}' # make file backup if necessary if save_file.exists(): @@ -140,19 +161,19 @@ def process_03_ensemble( timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Create the backup file name save_file_bk = ( - inference_dir / ensemble_name / f'merged_{proba_string}_bk_{timestamp}.{vector_output_format}' + inference_dir / ensemble_name / f'merged_{proba_string}_bk_{timestamp}.{vector_format}' ) print(f'Creating backup of file {save_file} to {save_file_bk}') shutil.move(save_file, save_file_bk) # save to files print(f'Saving vectors to {save_file}') - if vector_output_format in ['shp', 'gpkg']: + if vector_format in ['shp', 'gpkg']: merged_gdf.to_file(save_file) - elif vector_output_format in ['parquet']: + elif vector_format in ['parquet']: merged_gdf.to_parquet(save_file) else: - print(f'Unknown format {vector_output_format}!') + print(f'Unknown format {vector_format}!') def main(): @@ -216,6 +237,8 @@ def main(): '--ensemble_border_size', type=int, default=10, help='Number of pixels to remove around the border and no data' ) parser.add_argument('--ensemble_mmu', type=int, default=32, help='Minimum mapping unit of output objects in pixels') + parser.add_argument('--save_binary', action='store_true', help='set to keep intermediate binary rasters') + parser.add_argument('--save_probability', action='store_true', help='set to keep intermediate probability rasters') parser.add_argument('--try_gpu', action='store_true', help='set to try image processing with gpu') parser.add_argument( '--force_vector_merge', From edc6446f99fadd6c917064ec2e2f9ce7af5af365 Mon Sep 17 00:00:00 2001 From: Ingmar Nitze Date: Thu, 23 May 2024 14:50:03 +0200 Subject: [PATCH 5/9] added argparse to typer args --- src/thaw_slump_segmentation/scripts/process_03_ensemble.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py index 00207a9..c0e5a39 100644 --- a/src/thaw_slump_segmentation/scripts/process_03_ensemble.py +++ b/src/thaw_slump_segmentation/scripts/process_03_ensemble.py @@ -265,6 +265,8 @@ def main(): ensemble_mmu=args.ensemble_mmu, try_gpu=args.try_gpu, force_vector_merge=args.force_vector_merge, + save_binary=args.save_binary, + save_probability=args.save_probability, ) From 20f0f48359b38e11d494cd0437890d9db00681a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 26 May 2024 15:05:27 +0200 Subject: [PATCH 6/9] Fix last CLI bugs --- .gitignore | 3 +- src/thaw_slump_segmentation/main.py | 19 +++- .../download_s2_4band_planet_format.py | 21 +--- .../scripts/inference.py | 95 +++++------------- .../scripts/prepare_data.py | 97 ++++++------------- .../scripts/prepare_s2_4band_planet_format.py | 36 +++---- .../scripts/setup_raw_data.py | 62 ++++-------- src/thaw_slump_segmentation/scripts/train.py | 10 +- 8 files changed, 114 insertions(+), 229 deletions(-) diff --git a/.gitignore b/.gitignore index c290862..79f38aa 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ dist/ .python-version requirements.lock requirements-dev.lock -wandb \ No newline at end of file +wandb +models \ No newline at end of file diff --git a/src/thaw_slump_segmentation/main.py b/src/thaw_slump_segmentation/main.py index 1afa96d..253e744 100644 --- a/src/thaw_slump_segmentation/main.py +++ b/src/thaw_slump_segmentation/main.py @@ -9,7 +9,24 @@ from thaw_slump_segmentation.scripts.setup_raw_data import setup_raw_data from thaw_slump_segmentation.scripts.train import train -cli = typer.Typer() +# TODO: Move this comment to docs +# GEE is used for: +# - Setup of raw data (init call inside function) +# - download of S2 images (init call global at module level) +# - prepare of S2 images (init call global at module level) +# GDAL is used for: +# - Setup of raw data (in threaded function) +# - prepare data (in main function) +# - inference (in main function) +# - prepare of S2 images (but its not implemented via gdal module but hardcoded) + +cli = typer.Typer(pretty_exceptions_show_locals=False) + + +@cli.command() +def hello(name: str): + typer.echo(f'Hello {name}') + cli.command()(train) cli.command()(inference) diff --git a/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py b/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py index fc06109..afb19d4 100644 --- a/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py +++ b/src/thaw_slump_segmentation/scripts/download_s2_4band_planet_format.py @@ -25,8 +25,8 @@ def download_S2image_preprocessed(s2_image_id, outfile, outbands=['B2', 'B3', 'B def download_s2_4band_planet_format( - data_dir: Annotated[Path, typer.Argument(help='Output directory')], - s2ids: Annotated[List[str], typer.Argument(help='S2 image ID, you can use several separated by space')], + data_dir: Annotated[Path, typer.Option('--data_dir', help='Output directory')], + s2ids: Annotated[List[str], typer.Option(help='S2 image ID, you can use several separated by space')], ): """Download preprocessed S2 image.""" for s2id in s2ids: @@ -39,7 +39,7 @@ def download_s2_4band_planet_format( # ! Moving legacy argparse cli to main to maintain compatibility with the original script -def main(): +if __name__ == '__main__': parser = argparse.ArgumentParser( description='Download preprocessed S2 image.', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -47,17 +47,4 @@ def main(): parser.add_argument('--data_dir', type=str, help='Output directory') args = parser.parse_args() - outdir = Path(args.data_dir) - s2id = args.s2id - - for s2id in args.s2id: - # Call the function with the provided s2id - outfile = outdir / s2id / f'{s2id}_SR.tif' - if not outdir.exists(): - print('Creating output directory', outdir) - outdir.mkdir() - download_S2image_preprocessed(s2id, outfile) - - -if __name__ == '__main__': - main() + download_s2_4band_planet_format(args.data_dir, args.s2id) diff --git a/src/thaw_slump_segmentation/scripts/inference.py b/src/thaw_slump_segmentation/scripts/inference.py index 69f55cc..efca707 100644 --- a/src/thaw_slump_segmentation/scripts/inference.py +++ b/src/thaw_slump_segmentation/scripts/inference.py @@ -241,19 +241,22 @@ def inference( ], model_path: Annotated[str, typer.Argument(help='path to model, use the model base path')], tile_to_predict: Annotated[List[str], typer.Argument(help='path to model')], - gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries')] = '', - gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts')] = '', - n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, + gdal_bin: Annotated[str, typer.Option('--gdal_bin', help='Path to gdal binaries', envvar='GDAL_BIN')] = '/usr/bin', + gdal_path: Annotated[ + str, typer.Option('--gdal_path', help='Path to gdal scripts', envvar='GDAL_PATH') + ] = '/usr/bin', + n_jobs: Annotated[int, typer.Option('--n_jobs', help='number of parallel joblib jobs')] = -1, ckpt: Annotated[str, typer.Option(help='Checkpoint to use')] = 'latest', - data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), - log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), - inference_dir: Annotated[Path, typer.Option(help='Main inference directory')] = Path('inference'), + data_dir: Annotated[Path, typer.Option('--data_dir', help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option('--log_dir', help='Path to log dir')] = Path('logs'), + inference_dir: Annotated[Path, typer.Option('--inference_dir', help='Main inference directory')] = Path( + 'inference' + ), margin_size: Annotated[int, typer.Option('--margin_size', '-n', help='Size of patch overlap')] = 256, patch_size: Annotated[int, typer.Option('--patch_size', '-p', help='Size of patches')] = 1024, ): """Inference Script""" - # TODO: let gdal.initialize take each argument separately # Mock old args object gdal.initialize(bin=gdal_bin, path=gdal_path) @@ -321,7 +324,7 @@ def inference( # ! Moving legacy argparse cli to main to maintain compatibility with the original script -def main(): +if __name__ == '__main__': parser = argparse.ArgumentParser( description='Inference Script', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -341,68 +344,18 @@ def main(): parser.add_argument('tile_to_predict', type=str, help='path to model', nargs='+') args = parser.parse_args() - gdal.initialize(args) - - timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - log_path = Path(args.log_dir) / f'inference-{timestamp}.log' - if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) - init_logging(log_path) - logger = get_logger('inference') - - # ===== LOAD THE MODEL ===== - cuda = True if torch.cuda.is_available() else False - dev = torch.device('cpu') if not cuda else torch.device('cuda') - logger.info(f'Running on {dev} device') - - if not args.model_path: - last_modified = 0 - last_modeldir = None - for config_file in Path(args.log_dir).glob('*/config.yml'): - modified = config_file.stat().st_mtime - if modified > last_modified: - last_modified = modified - last_modeldir = config_file.parent - args.model_path = last_modeldir - - model_dir = Path(args.model_path) - config = yaml.load((model_dir / 'config.yml').open(), Loader=yaml.SafeLoader) - - m = config['model'] - # print(m['architecture'],m['encoder'], m['input_channels']) - model = create_model( - arch=m['architecture'], - encoder_name=m['encoder'], - encoder_weights=None if m['encoder_weights'] == 'random' else m['encoder_weights'], - classes=1, - in_channels=m['input_channels'], + inference( + name=args.name, + model_path=args.model_path, + tile_to_predict=args.tile_to_predict, + gdal_bin=args.gdal_bin, + gdal_path=args.gdal_path, + n_jobs=args.n_jobs, + ckpt=args.ckpt, + data_dir=args.data_dir, + log_dir=args.log_dir, + inference_dir=args.inference_dir, + margin_size=args.margin_size, + patch_size=args.patch_size, ) - - if args.ckpt == 'latest': - ckpt_nums = [int(ckpt.stem) for ckpt in model_dir.glob('checkpoints/*.pt')] - last_ckpt = max(ckpt_nums) - else: - last_ckpt = int(args.ckpt) - ckpt = model_dir / 'checkpoints' / f'{last_ckpt:02d}.pt' - logger.info(f'Loading checkpoint {ckpt}') - - # Parallelized Model needs to be declared before loading - try: - model.load_state_dict(torch.load(ckpt, map_location=dev)) - except Exception: - model = nn.DataParallel(model) - model.load_state_dict(torch.load(ckpt, map_location=dev)) - - model = model.to(dev) - - sources = DataSources(config['data_sources']) - - torch.set_grad_enabled(False) - - for tilename in tqdm(args.tile_to_predict): - do_inference(tilename, sources, model, dev, logger, args, log_path) - - -if __name__ == '__main__': - main() diff --git a/src/thaw_slump_segmentation/scripts/prepare_data.py b/src/thaw_slump_segmentation/scripts/prepare_data.py index 204dd38..4835612 100644 --- a/src/thaw_slump_segmentation/scripts/prepare_data.py +++ b/src/thaw_slump_segmentation/scripts/prepare_data.py @@ -10,7 +10,6 @@ """ import argparse -import os import shutil import sys from datetime import datetime @@ -264,19 +263,28 @@ def tile_size_callback(value: str): def prepare_data( - data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), - log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + data_dir: Annotated[Path, typer.Option('--data_dir', help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option('--log_dir', help='Path to log dir')] = Path('logs'), skip_gdal: Annotated[ - bool, typer.Option(help='Skip the Gdal conversion stage (if it has already been done)') + bool, + typer.Option('--skip_gdal/--use_gdal', help='Skip the Gdal conversion stage (if it has already been done)'), ] = False, - gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries (ignored if --skip_gdal is passed)')] = None, - gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts (ignored if --skip_gdal is passed)')] = None, - n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, - nodata_threshold: Annotated[float, typer.Option(help='Throw away data with at least this % of nodata pixels')] = 50, + gdal_bin: Annotated[ + str, + typer.Option('--gdal_bin', help='Path to gdal binaries (ignored if --skip_gdal is passed)', envvar='GDAL_BIN'), + ] = '/usr/bin', + gdal_path: Annotated[ + str, + typer.Option('--gdal_path', help='Path to gdal scripts (ignored if --skip_gdal is passed)', envvar='GDAL_PATH'), + ] = '/usr/bin', + n_jobs: Annotated[int, typer.Option('--n_jobs', help='number of parallel joblib jobs')] = -1, + nodata_threshold: Annotated[ + float, typer.Option('--nodata_threshold', help='Throw away data with at least this % of nodata pixels') + ] = 50, tile_size: Annotated[ - str, typer.Option(help="Tiling size in pixels e.g. '256x256'", callback=tile_size_callback) + str, typer.Option('--tile_size', help="Tiling size in pixels e.g. '256x256'", callback=tile_size_callback) ] = '256x256', - tile_overlap: Annotated[int, typer.Option(help='Overlap of the tiles in pixels')] = 25, + tile_overlap: Annotated[int, typer.Option('--tile_overlap', help='Overlap of the tiles in pixels')] = 25, ): """Make data ready for training""" @@ -334,7 +342,8 @@ def prepare_data( ) -def main(): +# ! Moving legacy argparse cli to main to maintain compatibility with the original script +if __name__ == '__main__': parser = argparse.ArgumentParser( description='Make data ready for training', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -354,60 +363,14 @@ def main(): args = parser.parse_args() - # Tiling Settings - xsize, ysize = map(int, args.tile_size.split('x')) - overlap = args.tile_overlap - - timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - log_path = Path(args.log_dir) / f'prepare_data-{timestamp}.log' - if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) - init_logging(log_path) - logger = get_logger('prepare_data') - logger.info('#############################') - logger.info('# Starting Data Preparation #') - logger.info('#############################') - - threshold = args.nodata_threshold / 100 - - if not args.skip_gdal: - gdal.initialize(args) - - DATA_ROOT = Path(args.data_dir) - DATA_DIR = DATA_ROOT / 'tiles' - h5dir = DATA_ROOT / 'h5' - h5dir.mkdir(exist_ok=True) - - # All folders that contain the big raster (...AnalyticsML_SR.tif) are assumed to be a dataset - datasets = [raster.parent for raster in DATA_DIR.glob('*/' + RASTERFILTER)] - - overwrite_conflicts = [] - for dataset in datasets: - check_dir = h5dir / dataset.name - if check_dir.exists(): - overwrite_conflicts.append(check_dir) - - if overwrite_conflicts: - logger.warning(f"Found old data directories: {', '.join(dir.name for dir in overwrite_conflicts)}.") - decision = input('Delete and recreate them [d], skip them [s] or abort [a]? ').lower() - if decision == 'd': - logger.info('User chose to delete old directories.') - for old_dir in overwrite_conflicts: - shutil.rmtree(old_dir) - elif decision == 's': - logger.info('User chose to skip old directories.') - already_done = [d.name for d in overwrite_conflicts] - datasets = [d for d in datasets if d.name not in already_done] - else: - # When in doubt, don't overwrite/change anything to be safe - logger.error('Aborting due to conflicts with existing data directories.') - sys.exit(1) - - Parallel(n_jobs=args.n_jobs)( - delayed(main_function)(dataset, log_path, h5dir, xsize, ysize, overlap, threshold, args.skip_gdal) - for dataset in datasets + prepare_data( + data_dir=args.data_dir, + log_dir=args.log_dir, + skip_gdal=args.skip_gdal, + gdal_bin=args.gdal_bin, + gdal_path=args.gdal_path, + n_jobs=args.n_jobs, + nodata_threshold=args.nodata_threshold, + tile_size=args.tile_size, + tile_overlap=args.tile_overlap, ) - - -if __name__ == '__main__': - main() diff --git a/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py b/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py index 2aa6d07..32a8a56 100644 --- a/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py +++ b/src/thaw_slump_segmentation/scripts/prepare_s2_4band_planet_format.py @@ -180,17 +180,19 @@ def download_tcvis(image_path): def prepare_s2_4band_planet_format( - data_dir: Annotated[Path, typer.Argument(help='data directory (parent of download dir)')], - image_regex: Annotated[str, typer.Option(help='regex term to find image file')] = '*/*SR.tif', - n_jobs: Annotated[int, typer.Option(help='Number of parallel- images to prepare data for')] = 6, + data_dir: Annotated[Path, typer.Option('--data_dir', help='data directory (parent of download dir)')], + image_regex: Annotated[str, typer.Option('--image_regex', help='regex term to find image file')] = '*/*SR.tif', + n_jobs: Annotated[int, typer.Option('--n_jobs', help='Number of parallel- images to prepare data for')] = 6, + aux_dir: Annotated[ + Path, typer.Option('--aux_dir', help='parent directory of auxilliary data') + ] = '/isipd/projects/p_aicore_pf/initze/processing/auxiliary/', ): input_dir = Path(data_dir) infiles = list(input_dir.glob(image_regex)) - base_dir = Path('/isipd/projects/p_aicore_pf/initze/processing/auxiliary/') # make absolute paths - elevation = base_dir / 'elevation.vrt' - slope = base_dir / 'slope.vrt' + elevation = aux_dir / 'elevation.vrt' + slope = aux_dir / 'slope.vrt' # for image_path in infiles: Parallel(n_jobs=n_jobs)(delayed(process_local_data)(image_path, elevation, slope) for image_path in infiles) @@ -199,7 +201,7 @@ def prepare_s2_4band_planet_format( download_tcvis(image_path) -def main(): +if __name__ == '__main__': parser = argparse.ArgumentParser( description='Prepare aux data for downloaded S2 images.', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -214,20 +216,6 @@ def main(): ) args = parser.parse_args() - input_dir = Path(args.data_dir) - infiles = list(input_dir.glob(args.image_regex)) - - base_dir = Path('/isipd/projects/p_aicore_pf/initze/processing/auxiliary/') - # make absolute paths - elevation = base_dir / 'elevation.vrt' - slope = base_dir / 'slope.vrt' - - # for image_path in infiles: - Parallel(n_jobs=args.n_jobs)(delayed(process_local_data)(image_path, elevation, slope) for image_path in infiles) - - for image_path in infiles: - download_tcvis(image_path) - - -if __name__ == '__main__': - main() + prepare_s2_4band_planet_format( + data_dir=args.data_dir, image_regex=args.image_regex, n_jobs=args.n_jobs, aux_dir=args.aux_dir + ) diff --git a/src/thaw_slump_segmentation/scripts/setup_raw_data.py b/src/thaw_slump_segmentation/scripts/setup_raw_data.py index 6d032d9..8511093 100644 --- a/src/thaw_slump_segmentation/scripts/setup_raw_data.py +++ b/src/thaw_slump_segmentation/scripts/setup_raw_data.py @@ -44,7 +44,6 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, log_path, gdal_bin, gdal_path, label_required=True): - # TODO: let gdal.initialize take each argument separately # Mock old args object gdal.initialize(bin=gdal_bin, path=gdal_path) @@ -103,12 +102,16 @@ def preprocess_directory(image_dir, data_dir, aux_dir, backup_dir, log_path, gda def setup_raw_data( - gdal_bin: Annotated[str, typer.Option(help='Path to gdal binaries')] = '', - gdal_path: Annotated[str, typer.Option(help='Path to gdal scripts')] = '', - n_jobs: Annotated[int, typer.Option(help='number of parallel joblib jobs')] = -1, - label: Annotated[bool, typer.Option(help='Set flag to do preprocessing with label file')] = False, - data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), - log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + gdal_bin: Annotated[str, typer.Option('--gdal_bin', help='Path to gdal binaries', envvar='GDAL_BIN')] = '/usr/bin', + gdal_path: Annotated[ + str, typer.Option('--gdal_path', help='Path to gdal scripts', envvar='GDAL_PATH') + ] = '/usr/bin', + n_jobs: Annotated[int, typer.Option('--n_jobs', help='number of parallel joblib jobs')] = -1, + label: Annotated[ + bool, typer.Option('--label/--nolabel', help='Set flag to do preprocessing with label file') + ] = False, + data_dir: Annotated[Path, typer.Option('--data_dir', help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option('--log_dir', help='Path to log dir')] = Path('logs'), ): INPUT_DATA_DIR = data_dir / 'input' BACKUP_DIR = data_dir / 'backup' @@ -126,10 +129,11 @@ def setup_raw_data( logger.info('###########################') dir_list = check_input_data(INPUT_DATA_DIR) + print(dir_list) if len(dir_list) > 0: Parallel(n_jobs=n_jobs)( delayed(preprocess_directory)( - image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, gdal_bin, gdal_path, log_path, not label + image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, log_path, gdal_bin, gdal_path, label ) for image_dir in dir_list ) @@ -138,7 +142,7 @@ def setup_raw_data( # ! Moving legacy argparse cli to main to maintain compatibility with the original script -def main(): +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--gdal_bin', default=None, help='Path to gdal binaries (ignored if --skip_gdal is passed)') parser.add_argument('--gdal_path', default=None, help='Path to gdal scripts (ignored if --skip_gdal is passed)') @@ -149,35 +153,11 @@ def main(): args = parser.parse_args() - global DATA_ROOT, INPUT_DATA_DIR, BACKUP_DIR, DATA_DIR, AUX_DIR - - DATA_ROOT = Path(args.data_dir) - INPUT_DATA_DIR = DATA_ROOT / 'input' - BACKUP_DIR = DATA_ROOT / 'backup' - DATA_DIR = DATA_ROOT / 'tiles' - AUX_DIR = DATA_ROOT / 'auxiliary' - - timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - log_path = Path(args.log_dir) / f'setup_raw_data-{timestamp}.log' - if not Path(args.log_dir).exists(): - os.mkdir(Path(args.log_dir)) - init_logging(log_path) - logger = get_logger('setup_raw_data') - logger.info('###########################') - logger.info('# Starting Raw Data Setup #') - logger.info('###########################') - - dir_list = check_input_data(INPUT_DATA_DIR) - if len(dir_list) > 0: - Parallel(n_jobs=args.n_jobs)( - delayed(preprocess_directory)( - image_dir, DATA_DIR, AUX_DIR, BACKUP_DIR, args.gdal_bin, args.gdal_path, log_path, args.nolabel - ) - for image_dir in dir_list - ) - else: - logger.error('Empty Input Data Directory! No Data available to process!') - - -if __name__ == '__main__': - main() + setup_raw_data( + gdal_bin=args.gdal_bin, + gdal_path=args.gdal_path, + n_jobs=args.n_jobs, + label=args.nolabel, + data_dir=args.data_dir, + log_dir=args.log_dir, + ) diff --git a/src/thaw_slump_segmentation/scripts/train.py b/src/thaw_slump_segmentation/scripts/train.py index bd4530d..c85bbf9 100644 --- a/src/thaw_slump_segmentation/scripts/train.py +++ b/src/thaw_slump_segmentation/scripts/train.py @@ -357,8 +357,8 @@ def train( help='Give this run a name, so that it will be logged into logs/_.', ), ], - data_dir: Annotated[Path, typer.Option(help='Path to data processing dir')] = Path('data'), - log_dir: Annotated[Path, typer.Option(help='Path to log dir')] = Path('logs'), + data_dir: Annotated[Path, typer.Option('--data_dir', help='Path to data processing dir')] = Path('data'), + log_dir: Annotated[Path, typer.Option('--log_dir', help='Path to log dir')] = Path('logs'), config: Annotated[Path, typer.Option('--config', '-c', help='Specify run config to use.')] = Path('config.yml'), resume: Annotated[ str, @@ -382,7 +382,7 @@ def train( # ! Moving legacy argparse cli to main to maintain compatibility with the original script -def main(): +if __name__ == '__main__': parser = argparse.ArgumentParser( description='Training script', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -418,7 +418,3 @@ def main(): args.wandb_project, args.wandb_name, ).run() - - -if __name__ == '__main__': - main() From cd168575558fe046b545ec5111ff3844a327b7e3 Mon Sep 17 00:00:00 2001 From: Ingmar Nitze Date: Mon, 27 May 2024 15:13:26 +0200 Subject: [PATCH 7/9] fixed bugs --- src/thaw_slump_segmentation/postprocessing.py | 18 +++++++++++------- .../scripts/process_02_inference.py | 15 ++++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/thaw_slump_segmentation/postprocessing.py b/src/thaw_slump_segmentation/postprocessing.py index 32e02b6..ef9e437 100644 --- a/src/thaw_slump_segmentation/postprocessing.py +++ b/src/thaw_slump_segmentation/postprocessing.py @@ -65,16 +65,19 @@ def get_date_from_PSfilename(name): return date -# TODO: create empty dataframe if no files found def get_datasets(path, depth=0, preprocessed=False): dirs = listdirs2(path, depth=depth) df = pd.DataFrame(data=dirs, columns=['path']) - df['name'] = df.apply(lambda x: x['path'].name, axis=1) - df['preprocessed'] = preprocessed - df['PS_product_type'] = df.apply(lambda x: get_PS_products_type(x['name']), axis=1) - df['image_date'] = df.apply(lambda x: get_date_from_PSfilename(x['name']), axis=1) - df['tile_id'] = df.apply(lambda x: x['name'].split('_')[1], axis=1) - return df + if len(df) > 0: + df['name'] = df.apply(lambda x: x['path'].name, axis=1) + df['preprocessed'] = preprocessed + df['PS_product_type'] = df.apply(lambda x: get_PS_products_type(x['name']), axis=1) + df['image_date'] = df.apply(lambda x: get_date_from_PSfilename(x['name']), axis=1) + df['tile_id'] = df.apply(lambda x: x['name'].split('_')[1], axis=1) + return df + else: + return pd.DataFrame(columns=['path', 'name', 'preprocessed', 'PS_product_type', 'image_date', 'tile_id']) + def copy_unprocessed_files(row, processing_dir, quiet=True): """ @@ -157,6 +160,7 @@ def get_processing_status(raw_data_dir, processing_dir, inference_dir, model, re # get processed # TODO: make validation steps if files are alright df_processed = get_datasets(processing_dir / 'tiles', depth=0, preprocessed=True) + # check if all files are available df_processed = df_processed[df_processed.apply(lambda x: len(list(x['path'].glob('*')))>=5, axis=1)] diff --git a/src/thaw_slump_segmentation/scripts/process_02_inference.py b/src/thaw_slump_segmentation/scripts/process_02_inference.py index ad0a614..f916212 100644 --- a/src/thaw_slump_segmentation/scripts/process_02_inference.py +++ b/src/thaw_slump_segmentation/scripts/process_02_inference.py @@ -12,6 +12,7 @@ from joblib import Parallel, delayed from tqdm import tqdm from typing_extensions import Annotated +import swifter from ..postprocessing import ( copy_unprocessed_files, @@ -63,10 +64,10 @@ def process_02_inference( # TODO: move to function # print basic information - total_images = len(df_final) - preprocessed_images = df_final.preprocessed.sum() - preprocessing_images = total_images - preprocessed_images - finished_images = df_final.inference_finished.sum() + total_images = int(len(df_final)) + preprocessed_images = int(df_final.preprocessed.sum()) + preprocessing_images = int(total_images - preprocessed_images) + finished_images = int(df_final.inference_finished.sum()) print(f'Number of images: {total_images}') print(f'Number of preprocessed images: {preprocessed_images}') print(f'Number of images for preprocessing: {preprocessing_images}') @@ -92,7 +93,7 @@ def process_02_inference( # #### Copy data for Preprocessing # make better documentation - df_preprocess = df_final[~(df_final.preprocessed)] + df_preprocess = df_final[df_final.preprocessed == False] print(f'Number of images to preprocess: {len(df_preprocess)}') # Cleanup processing directories to avoid incomplete processing @@ -105,7 +106,6 @@ def process_02_inference( else: print('Processing directory is ready, nothing to do!') - # TODO: check for empty processing status # Copy Data _ = df_preprocess.swifter.apply(lambda x: copy_unprocessed_files(x, processing_dir), axis=1) @@ -126,7 +126,7 @@ def process_02_inference( [get_processing_status(rdd, processing_dir, inference_dir, model) for rdd in raw_data_dir] ).drop_duplicates() # Filter to images that are not preprocessed yet - df_process = df_final[~df_final.inference_finished] + df_process = df_final[df_final.inference_finished == False] # update overview and filter accordingly - really necessary? df_process_final = ( df_process.set_index('name') @@ -201,6 +201,7 @@ def process_02_inference( def main(): # ### Settings + # Add argument definitions parser = argparse.ArgumentParser( description='Script to run auto inference for RTS', formatter_class=argparse.ArgumentDefaultsHelpFormatter From e5b0b4189bd34bdf581a51be7b2fda2ef787b1ee Mon Sep 17 00:00:00 2001 From: Ingmar Nitze Date: Mon, 27 May 2024 17:25:43 +0200 Subject: [PATCH 8/9] fixed issues with udm, now using only udm2 --- .../data_pre_processing/udm.py | 68 +++++++++++++------ .../data_pre_processing/utils.py | 5 +- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/src/thaw_slump_segmentation/data_pre_processing/udm.py b/src/thaw_slump_segmentation/data_pre_processing/udm.py index 9407766..35fea18 100644 --- a/src/thaw_slump_segmentation/data_pre_processing/udm.py +++ b/src/thaw_slump_segmentation/data_pre_processing/udm.py @@ -10,15 +10,50 @@ def get_mask_from_udm2(infile, bands=[1,2,4,5], nodata=1): + """ + return masked array, where True = NoData + """ with rio.open(infile) as src: a = src.read() - mask = a[bands].max(axis=0) < nodata + mask = a[bands].max(axis=0) == nodata return mask -def get_mask_from_udm2_v2(infile, band=1): + +def get_mask_from_udm2_v2(infile): + """ + Create a data mask from a UDM2 V2 file. + + The data mask is a boolean array where True values represent good data, + and False values represent no data or unusable data. + + Args: + infile (str): Path to the input UDM2 V2 file. + + Returns: + numpy.ndarray: A boolean array representing the data mask. + + Notes: + - The function assumes that the input file is a UDM2 V2 file with + specific band meanings: + - Band 0: Clear data + - Bands 1, 2, 4, 5: Unusable data (if any of these bands has a value of 1) + - Band 7: No data + + - The data mask is created by combining the following conditions: + - Clear data (band 0 == 1) + - Not no data (band 7 != 1) + - Not unusable data (maximum of bands 1, 2, 4, 5 != 1) + + - The function uses the rasterio library to read the input file. + """ with rio.open(infile) as src: - a = src.read(1) - return a + a = src.read() + unusable_data = a[[1,2,4,5]].max(axis=0) == 1 + clear_data = a[[0]] == 1 + no_data = a[[7]] == 1 + # final data mask: 0 = no or crappy data, 1 = good data + data_mask = np.logical_and(clear_data, ~np.logical_or(no_data, unusable_data)) + return data_mask def get_mask_from_udm(infile, nodata=1): @@ -27,30 +62,19 @@ def get_mask_from_udm(infile, nodata=1): return np.array(mask, dtype=np.uint8) +# TODO: inconsistent application what the mask value is def burn_mask(file_src, file_dst, file_udm, file_udm2=None, mask_value=0): - with rio.Env(): - # checks - - if file_udm: - mask_udm = get_mask_from_udm(file_udm) - + """" + """ + with rio.Env(): if file_udm2: - mask_udm2 = get_mask_from_udm2(file_udm2) - #mask_udm2 = get_mask_from_udm2_v2(file_udm2) - - # merge masks if both - if ((file_udm is not None) & (file_udm2 is not None)): - clear_mask = np.array([mask_udm, mask_udm2]).max(axis=0) == mask_value - # only use one of udm or udm if only one exists - elif file_udm is not None: - clear_mask = mask_udm - elif file_udm2 is not None: - clear_mask = mask_udm2 + mask_udm2 = get_mask_from_udm2_v2(file_udm2) else: raise ValueError with rio.open(file_src) as ds_src: - data = ds_src.read()*clear_mask + # apply data mask (multiply) + data = ds_src.read() * mask_udm2 profile = ds_src.profile with rio.open(file_dst, 'w', **profile) as ds_dst: ds_dst.write(data) diff --git a/src/thaw_slump_segmentation/data_pre_processing/utils.py b/src/thaw_slump_segmentation/data_pre_processing/utils.py index 2ff98af..c2da40d 100755 --- a/src/thaw_slump_segmentation/data_pre_processing/utils.py +++ b/src/thaw_slump_segmentation/data_pre_processing/utils.py @@ -85,6 +85,7 @@ def get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_S image_files = [] for im in images: image_files.extend([f for f in flist if im in f]) + # check which udms are available, if not then set to None try: udm_file = [f for f in flist if udm in f][0] @@ -94,6 +95,7 @@ def get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_S udm2_file = [f for f in flist if udm2 in f][0] except: udm2_file = None + # raise error if no udms available if (udm_file == None) & (udm2_file == None): raise ValueError(f'There are no udm or udm2 files for image {image_directory.name}!') @@ -161,13 +163,14 @@ def aux_data_to_tiles(image_directory, aux_data, outfile): # load template and get props images = get_mask_images(image_directory, udm='udm.tif', udm2='udm2.tif', images=['_SR.tif']) image = images['images'][0] + # prepare gdalwarp call xmin, xmax, ymin, ymax = geom_from_image_bounds(image) crs = crs_from_image(image) xres, yres = resolution_from_image(image) + # run gdalwarp call outfile = f'{image_directory}/{outfile}'#os.path.join(image_directory,outfile) s_run = f'{gdal.warp} -te {xmin} {ymin} {xmax} {ymax} -tr {xres} {yres} -r cubic -t_srs {crs} -co COMPRESS=DEFLATE {aux_data} {outfile}' - #s_run = f'{gdal.warp} -te {xmin} {ymin} {xmax} {ymax} -tr 3 3 -r cubic -t_srs {crs} -co COMPRESS=DEFLATE {aux_data} {outfile}' log_run(s_run, _logger) return 1 From f97c9633d64f8715fa200aa861d8c167192a841b Mon Sep 17 00:00:00 2001 From: Ingmar Nitze Date: Mon, 27 May 2024 17:33:19 +0200 Subject: [PATCH 9/9] added docstring --- .../data_pre_processing/udm.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/thaw_slump_segmentation/data_pre_processing/udm.py b/src/thaw_slump_segmentation/data_pre_processing/udm.py index 35fea18..deab524 100644 --- a/src/thaw_slump_segmentation/data_pre_processing/udm.py +++ b/src/thaw_slump_segmentation/data_pre_processing/udm.py @@ -64,7 +64,32 @@ def get_mask_from_udm(infile, nodata=1): # TODO: inconsistent application what the mask value is def burn_mask(file_src, file_dst, file_udm, file_udm2=None, mask_value=0): - """" + """ + Apply a data mask to a raster file and save the result to a new file. + + Args: + file_src (str): Path to the input raster file. + file_dst (str): Path to the output raster file. + file_udm (str): Path to the UDM file (not used in this function). + file_udm2 (str, optional): Path to the UDM2 V2 file. If provided, the data mask + will be derived from this file using the `get_mask_from_udm2_v2` function. + mask_value (int, optional): Value to use for masked (invalid) pixels in the output file. + Default is 0. + + Returns: + int: Always returns 1 (for successful execution). + + Raises: + ValueError: If `file_udm2` is not provided. + + Notes: + - The function reads the input raster file using rasterio and applies the data mask + by multiplying the raster data with the mask. + - The masked raster data is then written to the output file with the same metadata + as the input file. + - If `file_udm2` is not provided, a `ValueError` is raised. + - The `file_udm` parameter is not used in this function. + - The function uses the rasterio library for reading and writing raster files. """ with rio.Env(): if file_udm2: