Skip to content

Commit

Permalink
Merge pull request #141 from initze/fix/list_datasets
Browse files Browse the repository at this point in the history
Fix/list datasets
  • Loading branch information
initze authored Jun 13, 2024
2 parents 4b060f0 + b7a6c27 commit 7115fc1
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 18 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
We recommend using a new conda environment from scratch

```bash
conda create -n thaw_slump_segmentation python=3.10 mamba -c conda-forge
conda create -n thaw_slump_segmentation python=3.11 mamba -c conda-forge
conda activate thaw_slump_segmentation
```

Expand All @@ -24,7 +24,6 @@ mamba install gdal>=3.6 -c conda-forge
### Package Installation

* Latest development version: `pip install git+https://github.com/initze/thaw-slump-segmentation`
* Latest release: `pip install https://github.com/initze/thaw-slump-segmentation/releases/download/untagged-f6739f56e0ee4c2c64fe/thaw_slump_segmentation-0.10.0-py3-none-any.whl`

This will pull the CUDA 12 version of pytorch. If you are running CUDA 11, you need to manually switch to the corresponding Pytorch package afterwards by running `pip3 install torch==2.2.0+cu118 torchvision==0.17.0+cu118 --index-url https://download.pytorch.org/whl/cu118`

Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"rasterio>=1.3.10",
"rioxarray>=0.15.5",
"pyproj>=3.6.1",
"earthengine-api>=0.1.381",
"earthengine-api==0.1.381",
"geedim>=1.7.2",
"geemap==0.29.6",
"eemont==0.3.6",
Expand Down Expand Up @@ -50,15 +50,15 @@ dependencies = [
"torchmetrics @ git+https://github.com/Lightning-AI/torchmetrics",
]
readme = "README.md"
requires-python = ">= 3.10"
requires-python = "== 3.11"

[project.scripts]
download_s2_4band_planet_format = "thaw_slump_segmentation.scripts.download_s2_4band_planet_format:main"
inference = "thaw_slump_segmentation.scripts.inference:main"
prepare_data = "thaw_slump_segmentation.scripts.prepare_data:main"
prepare_s2_4band_planet_format = "thaw_slump_segmentation.scripts.prepare_s2_4band_planet_format:main"
# process_02_inference = "thaw_slump_segmentation.scripts.process_02_inference:main"
# process_03_ensemble = "thaw_slump_segmentation.scripts.process_03_ensemble:main"
process_02_inference = "thaw_slump_segmentation.scripts.process_02_inference:main"
process_03_ensemble = "thaw_slump_segmentation.scripts.process_03_ensemble:main"
setup_raw_data = "thaw_slump_segmentation.scripts.setup_raw_data:main"
train = "thaw_slump_segmentation.scripts.train:main"
thaw-slump-segmentation = "thaw_slump_segmentation.main:cli"
Expand Down
72 changes: 66 additions & 6 deletions src/thaw_slump_segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,42 @@ def update_DEM2(dem_data_dir, vrt_target_dir):


def get_processing_status(raw_data_dir, processing_dir, inference_dir, model, reduce_to_raw=False):
# get raw tiles
try:
"""
Get the processing status of raw data, intermediate data, and inference results.
Args:
raw_data_dir (Path): Path to the directory containing raw input data (tiles or scenes).
processing_dir (Path): Path to the directory containing processed intermediate data.
inference_dir (Path): Path to the directory containing inference results.
model (str): Name of the model used for inference.
reduce_to_raw (bool, optional): If True, return only the raw data that hasn't been processed yet.
Default is False.
Returns:
pandas.DataFrame: A DataFrame containing the processing status of each dataset, with columns:
- 'name': Name of the dataset.
- 'path': Path to the dataset.
- 'inference_finished': Boolean indicating if inference has been completed for the dataset.
Raises:
ValueError: If the provided raw_data_dir is not 'tiles' or 'scenes'.
Notes:
- The function assumes that the processed intermediate data is located in `processing_dir/tiles`.
- The function checks if at least 5 files are available for each processed dataset.
- The function assumes that the inference results are located in `inference_dir/model/*`.
- There are two TODO comments in the code that need to be addressed.
"""
# get processing status for raw input data
if raw_data_dir.name == 'tiles':
df_raw = get_datasets(raw_data_dir, depth=1)
except:
elif raw_data_dir.name == 'scenes':
df_raw = get_datasets(raw_data_dir, depth=0)
else:
raise ValueError('Please point to tiles or scenes path!')
# get processed
# TODO: make validation steps if files are alright

# get processing status for intermediate data
df_processed = get_datasets(processing_dir / 'tiles', depth=0, preprocessed=True)

# check if all files are available
Expand All @@ -173,7 +202,7 @@ def get_processing_status(raw_data_dir, processing_dir, inference_dir, model, re
else:
df_merged = pd.concat([df_processed, diff]).reset_index()


# make a dataframe with checks for processing status
products_list = [prod.name for prod in list((inference_dir / model).glob('*'))]
df_merged['inference_finished'] = df_merged.apply(lambda x: x['name'] in (products_list), axis=1)

Expand Down Expand Up @@ -595,4 +624,35 @@ def create_ensemble_with_negative(inference_dir: Path,
return out_binary

except:
return None
return None


def print_processing_stats(df_final):
"""
Print the processing statistics for a given DataFrame.
Args:
df_final (pandas.DataFrame): A DataFrame containing the processing status of each dataset,
with columns 'preprocessed' and 'inference_finished'.
Returns:
None
Prints:
- Number of total images
- Number of preprocessed images
- Number of images for preprocessing
- Number of images for inference
- Number of finished images
"""
total_images = int(len(df_final))
preprocessed_images = int(df_final.preprocessed.sum())
preprocessing_images = int(total_images - preprocessed_images)
finished_images = int(df_final.inference_finished.sum())

print(f'Number of images: {total_images}')
print(f'Number of preprocessed images: {preprocessed_images}')
print(f'Number of images for preprocessing: {preprocessing_images}')
print(f'Number of images for inference: {preprocessed_images - finished_images}')
print(f'Number of finished images: {finished_images}')
return total_images, preprocessed_images, preprocessing_images, finished_images
7 changes: 5 additions & 2 deletions src/thaw_slump_segmentation/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def inference(
)


# ! Moving legacy argparse cli to main to maintain compatibility with the original script
if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser(
description='Inference Script', formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
Expand Down Expand Up @@ -359,3 +358,7 @@ def inference(
margin_size=args.margin_size,
patch_size=args.patch_size,
)

# ! Moving legacy argparse cli to main to maintain compatibility with the original script
if __name__ == '__main__':
main()
6 changes: 4 additions & 2 deletions src/thaw_slump_segmentation/scripts/process_02_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ def process_02_inference(
df_preprocess = df_final[df_final.preprocessed == False]
print(f'Number of images to preprocess: {len(df_preprocess)}')

# TODO make better check
# Cleanup processing directories to avoid incomplete processing
input_dir_dslist = list((processing_dir / 'input').glob('*'))
if len(input_dir_dslist) > 0:
print(f"Cleaning up {(processing_dir / 'input')}")
for d in input_dir_dslist:
print('Delete', d)
shutil.rmtree(d)
if len(list(d.glob('*'))) < 4:
print('Delete', d)
shutil.rmtree(d)
else:
print('Processing directory is ready, nothing to do!')

Expand Down
7 changes: 5 additions & 2 deletions src/thaw_slump_segmentation/scripts/setup_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def setup_raw_data(
logger.error('Empty Input Data Directory! No Data available to process!')


# ! Moving legacy argparse cli to main to maintain compatibility with the original script
if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gdal_bin', default=None, help='Path to gdal binaries (ignored if --skip_gdal is passed)')
parser.add_argument('--gdal_path', default=None, help='Path to gdal scripts (ignored if --skip_gdal is passed)')
Expand All @@ -161,3 +160,7 @@ def setup_raw_data(
data_dir=args.data_dir,
log_dir=args.log_dir,
)

# ! Moving legacy argparse cli to main to maintain compatibility with the original script
if __name__ == '__main__':
main()

0 comments on commit 7115fc1

Please sign in to comment.