Skip to content

Commit

Permalink
Support glob expression for predict.src_las (#15)
Browse files Browse the repository at this point in the history
* Support glob expression for predict.src_las

* Add condition to save docke rimage on push only
  • Loading branch information
CharlesGaydon authored May 11, 2022
1 parent d4bd918 commit bab20e7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ jobs:
--ignore=actions-runner
--ignore="notebooks"
# Everything ran so we tag the valid docker image to keep it
# This happens for push events, which are in particular
# triggered when a pull request is merged.
Expand All @@ -87,6 +85,7 @@ jobs:

# This needs writing rights to the mounted path
- name: Save the docker image as lidar_deep_im_${BRANCH_NAME}.tar
if: github.event_name == 'push'
run: docker save lidar_deep_im:${BRANCH_NAME} -o /var/data/cicd/CICD_github_assets/CICD_docker_images/lidar_deep_im_${BRANCH_NAME}.tar

- name: Clean dangling docker images
Expand Down
21 changes: 15 additions & 6 deletions configs/predict/default.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
src_las: "/path/to/input.las"
output_dir: "/path/to/output_dir/"
resume_from_checkpoint: "/path/to/lightning_model.ckpt"
src_las: "/path/to/input.las" # Any glob pattern can be used to predict on multiple files.
output_dir: "/path/to/output_dir/" # Predictions are saved in a new file which shares src_las basename.
resume_from_checkpoint: "/path/to/lightning_model.ckpt" # Checkpoint of trained model.
gpus: 0 # 0 for none, 1 for one, [gpu_id] to specify which gpu to use e.g [1]

probas_to_save: "all" # override with a list of string matching class names to select specific probas to save
# Speifying the ouptut:
# A list of string matching class names to select specific probas to save
# OR keyword "all" to save all probabilities.
# In addition, these dimensions are always created:
# - `PredictedClassification`: predicted classification based on argmax, with classes
# specified by datamodule.dataset_description class dictionary.
# - entropy`: Shannon entropy of predicted probabilities
probas_to_save: "all"

# Relative to how probas are interpolated
# e.g. subtile_overlap=25 to use a sliding window of inference of whihc predictions will be merged.
# Probas interpolation parameters
# subtile_overlap=25 to use a sliding window of inference of which predictions will be merged.
# This comes with a computing cost as the effective predicted area is multiplied.
subtile_overlap: 0
# Higher interpolation_k to use more context
interpolation_k: 10
6 changes: 4 additions & 2 deletions docs/source/tutorials/make_predictions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

Refer to [this tutorial](./setup_install.md) for how to setup a virtual environment and install the library.

## Run inference from installed package

To run inference, you will need:
- A source cloud point in LAS format on which to infer classes and probabilites.
- A checkpoint of a trained lightning module implementing model logic (class `lidar_multiclass.models.model.Model`)
- A minimal yaml configuration specifying parameters. We use [hydra](https://hydra.cc/) to manage configurations, and this yaml results from the model training. The `datamodule` and `model` parameters groups must match datset characteristics and model training settings. The `predict` parameters group specifies path to models and data as well as batch size (N=50 works well, the larger the faster) and use of gpu (optionnal).

## Run inference from installed package

Fill out the {missing parameters} below and run:

```bash
Expand All @@ -29,6 +29,8 @@ To show you current inference config, simply add a `--help` flag:
python -m lidar_multiclass.predict --config-path {/path/to/.hydra} --config-name {config.yaml} --help
```

Note that `predict.src_las` may be any valid glob pattern (e.g. `/path/to/multiple_files/*.las`), in order to **predict on multiple files successively**.

## Run inference from sources

From the line for package-based inference above, simply change `python -m lidar_multiclass.predict` to `python run.py` to run directly from sources.
Expand Down
28 changes: 23 additions & 5 deletions lidar_multiclass/predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from glob import glob
import os
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule, LightningModule
from tqdm import tqdm

Expand Down Expand Up @@ -62,9 +63,15 @@ def predict(config: DictConfig) -> str:

@hydra.main(config_path="../configs/", config_name="config.yaml")
def main(config: DictConfig):
"""See function predict
"""This wrapper allows to specify a hydra configuration from command line.
:meta private:
The config file logged during training should be used for prediction. It should
be edited so that it does not rely on unnecessary environment variables (`oc.env:` prefix).
Parameters in configuration group `predict` can be specified directly in the config file
or overriden via CLI at runtime.
This wrapper supports running predictions for all files specified
by a glob pattern as specified via config parameter predict.src_las.
"""
# Imports should be nested inside @hydra.main to optimize tab completion
Expand All @@ -78,10 +85,21 @@ def main(config: DictConfig):
if config.get("print_config"):
utils.print_config(config, resolve=False)

return predict(config)
# Parameter predict.src_las can be a path or a glob pattern
# e.g. /path/to/files_*.las
src_las_iterable = glob(config.predict.src_las)

if not src_las_iterable:
raise FileNotFoundError(
f"Globing pattern {config.predict.src_las} (param predict.src_las) did not return any file."
)

# Iterate over the files and predict.
for config.predict.src_las in tqdm(src_las_iterable):
predict(config)


if __name__ == "__main__":
# cf. https://github.com/facebookresearch/hydra/issues/1283
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method, replace=True)
# OmegaConf.register_new_resolver("get_method", hydra.utils.get_method, replace=True)
main()

0 comments on commit bab20e7

Please sign in to comment.