Skip to content

Commit

Permalink
Merge pull request #109 from initze/108-cant-find-inferecepy
Browse files Browse the repository at this point in the history
108 cant find inferecepy
  • Loading branch information
initze authored Apr 25, 2024
2 parents 19a1913 + 1754afc commit 4f27986
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 29 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
* Latest development version: `pip install git+https://github.com/initze/thaw-slump-segmentation`
* Latest release: `pip install https://github.com/initze/thaw-slump-segmentation/releases/download/untagged-f6739f56e0ee4c2c64fe/thaw_slump_segmentation-0.10.0-py3-none-any.whl`

This will pull the CUDA 12 version of pytorch. If you are running CUDA 11, you need to manually switch to the corresponding Pytorch package afterwards by running `pip3 install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118`

This will pull the CUDA 12 version of pytorch. If you are running CUDA 11, you need to manually switch to the corresponding Pytorch package afterwards by running `pip3 install torch==2.2.0+cu118 torchvision==0.17.0+cu118 --index-url https://download.pytorch.org/whl/cu118`

gdal incl. gdal-utilities (preferably version >=3.6) need to be installed in your environment, e.g. with conda

Expand Down
2 changes: 1 addition & 1 deletion src/thaw_slump_segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def run_inference(df, model, processing_dir, inference_dir, model_dir=Path('/isi
print('Empty dataframe')
else:
tiles = ' '.join(df.name.values)
run_string = f"CUDA_VISIBLE_DEVICES='{gpu}' python inference.py -n {model} --data_dir {processing_dir} --inference_dir {inference_dir} --patch_size {patch_size} --margin_size {margin_size} {model_dir/model} {tiles}"
run_string = f"CUDA_VISIBLE_DEVICES='{gpu}' inference -n {model} --data_dir {processing_dir} --inference_dir {inference_dir} --patch_size {patch_size} --margin_size {margin_size} {model_dir/model} {tiles}"
print(run_string)
if run:
os.system(run_string)
Expand Down
4 changes: 2 additions & 2 deletions src/thaw_slump_segmentation/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def flush_rio(filepath):
pass


def do_inference(tilename, args=None, log_path=None):
def do_inference(tilename, sources, model, dev, logger, args=None, log_path=None):
tile_logger = get_logger(f'inference.{tilename}')
# ===== PREPARE THE DATA =====
DATA_ROOT = args.data_dir
Expand Down Expand Up @@ -307,7 +307,7 @@ def main():
torch.set_grad_enabled(False)

for tilename in tqdm(args.tile_to_predict):
do_inference(tilename, args, log_path)
do_inference(tilename, sources, model, dev, logger, args, log_path)

if __name__ == "__main__":
main()
53 changes: 29 additions & 24 deletions src/thaw_slump_segmentation/scripts/process_02_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
help="Maximum number of images to process (optional)")
parser.add_argument("--skip_vrt", action="store_false",
help="set to skip DEM vrt creation")
parser.add_argument("--skip_vector_save", action="store_true",
help="set to skip output vector creation")

# TODO, make flag to skip vrt
args = parser.parse_args()
Expand Down Expand Up @@ -94,7 +96,7 @@ def main():
N_JOBS=40
print(f'Preprocessing {len(df_preprocess)} images') #fix this
if len(df_preprocess) > 0:
pp_string = f'python setup_raw_data.py --data_dir {args.processing_dir} --n_jobs {N_JOBS} --nolabel'
pp_string = f'setup_raw_data --data_dir {args.processing_dir} --n_jobs {N_JOBS} --nolabel'
os.system(pp_string)

# ## Processing/Inference
Expand Down Expand Up @@ -126,30 +128,33 @@ def main():
_ = Parallel(n_jobs=n_splits)(delayed(run_inference)(df_split[split], model=args.model, processing_dir=args.processing_dir, inference_dir=args.inference_dir, model_dir=args.model_dir, gpu=gpu_split[split], run=True) for split in range(n_splits))
# #### Merge output files

if not args.skip_vector_save:
# read all files which following the above defined threshold
flist = list((args.inference_dir / args.model).glob(f'*/*pred_binarized.shp'))
len(flist)
# TODO:uncomment here
if len(df_process_final) > 0:
# load them in parallel
out = Parallel(n_jobs=6)(delayed(load_and_parse_vector)(f) for f in tqdm(flist[:]))

# merge them and save to geopackage file
merged_gdf = gpd.pd.concat(out)
save_file = args.inference_dir / args.model / f'{args.model}_merged.gpkg'

# check if file already exists, create backup file if exists
if save_file.exists():
# Get the current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create the backup file name
save_file_bk = args.inference_dir / args.model / f"{args.model}_merged_bk_{timestamp}.gpkg"
print (f'Creating backup of file {save_file} to {save_file_bk}')
shutil.move(save_file, save_file_bk)

# save to files
print(f'Saving vectors to {save_file}')
merged_gdf.to_file(save_file)
flist = list((args.inference_dir / args.model).glob(f'*/*pred_binarized.shp'))
len(flist)
# TODO:uncomment here
if len(df_process_final) > 0:
# load them in parallel
out = Parallel(n_jobs=6)(delayed(load_and_parse_vector)(f) for f in tqdm(flist[:]))

# merge them and save to geopackage file
merged_gdf = gpd.pd.concat(out)
save_file = args.inference_dir / args.model / f'{args.model}_merged.gpkg'

# check if file already exists, create backup file if exists
if save_file.exists():
# Get the current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create the backup file name
save_file_bk = args.inference_dir / args.model / f"{args.model}_merged_bk_{timestamp}.gpkg"
print (f'Creating backup of file {save_file} to {save_file_bk}')
shutil.move(save_file, save_file_bk)

# save to files
print(f'Saving vectors to {save_file}')
merged_gdf.to_file(save_file)
else:
print('Skipping output vector creation!')

if __name__ == "__main__":
main()

0 comments on commit 4f27986

Please sign in to comment.