Skip to content

Commit

Permalink
transformers_as_arg
Browse files Browse the repository at this point in the history
  • Loading branch information
marjan.asgari committed Sep 18, 2024
1 parent ab02c3e commit c331750
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions geo_inference/config/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ arguments:
vec: False # Vector Conversion: bool
yolo: False # YOLO Conversion: bool
coco: False # COCO Conversion: bool
transformers : True
device: "gpu" # cpu or gpu: str
gpu_id: 0
mgpu: False
Expand Down
6 changes: 5 additions & 1 deletion geo_inference/geo_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
gpu_id: int = 0,
num_classes: int = 5,
prediction_threshold : float = 0.3,
transformers : bool = False,
):
self.work_dir: Path = get_directory(work_dir)
self.device = (
Expand All @@ -95,6 +96,8 @@ def __init__(
),
map_location=self.device,
)
if transformers:
self.model = tta.SegmentationTTAWrapper(self.model, tta.aliases.d4_transform(), merge_mode='mean')
self.mask_to_vec = mask_to_vec
self.mask_to_coco = mask_to_coco
self.mask_to_yolo = mask_to_yolo
Expand Down Expand Up @@ -363,7 +366,8 @@ def main() -> None:
device=arguments["device"],
gpu_id=arguments["gpu_id"],
num_classes=arguments["classes"],
prediction_threshold=arguments["prediction_threshold"]
prediction_threshold=arguments["prediction_threshold"],
transformers=arguments["transformers"],
)
inference_mask_layer_name = geo_inference(
inference_input=arguments["image"],
Expand Down
7 changes: 7 additions & 0 deletions geo_inference/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def cmd_interface(argv=None):

parser.add_argument("-pr", "--prediction_thr", type=float, nargs=1, help="Prediction Threshold")

parser.add_argument("-tr", "--transformers", nargs=1, help="Transformers Addition")

args = parser.parse_args()

if args.args:
Expand All @@ -452,6 +454,8 @@ def cmd_interface(argv=None):
classes = config["arguments"]["classes"]
patch_size = config["arguments"]["patch_size"]
prediction_threshold = config["arguments"]["prediction_thr"]
transformers = config["arguments"]["transformers"]

elif args.image:
image =args.image[0]
model = args.model[0] if args.model else None
Expand All @@ -468,6 +472,8 @@ def cmd_interface(argv=None):
classes = args.classes[0] if args.classes else 5
patch_size = args.patch_size[0] if args.patch_size else 1024
prediction_threshold = args.prediction_thr[0] if args.prediction_thr else 0.3
transformers = args.transformers[0] if args.transformers else False

else:
print("use the help [-h] option for correct usage")
raise SystemExit
Expand All @@ -487,6 +493,7 @@ def cmd_interface(argv=None):
"gpu_id": gpu_id,
"patch_size": patch_size,
"prediction_threshold": prediction_threshold,
"transformers": transformers,
}
return arguments

Expand Down
1 change: 1 addition & 0 deletions tests/data/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ arguments:
classes : 5
n_workers: 20
prediction_thr : 0.3
transformers: False
patch_size: 1024
3 changes: 3 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_read_yaml(test_data_dir):
"mgpu": False,
"classes": 5,
"prediction_thr": 0.3,
"transformers": False,
"n_workers": 20
}

Expand Down Expand Up @@ -145,6 +146,7 @@ def test_cmd_interface_with_args(monkeypatch, test_data_dir):
"classes": 5,
"multi_gpu": False,
"prediction_threshold": 0.3,
"transformers": False,
"patch_size": 1024
}

Expand All @@ -169,6 +171,7 @@ def test_cmd_interface_with_image(monkeypatch):
"gpu_id": 0,
"classes": 5,
"prediction_threshold": 0.3,
"transformers": False,
"multi_gpu": False,
}

Expand Down

0 comments on commit c331750

Please sign in to comment.