Skip to content

Commit

Permalink
Merge pull request #1197 from isi-vista/pipeline-fixes
Browse files Browse the repository at this point in the history
Pipeline updates
  • Loading branch information
spigo900 authored Oct 12, 2022
2 parents 47428f2 + 2880013 commit 92a7ff9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
44 changes: 27 additions & 17 deletions adam/experiment/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shlex
from shutil import copytree
from subprocess import run, CalledProcessError
from typing import List, Mapping, NewType, Optional, Sequence
from typing import Callable, List, Mapping, NewType, Optional, Sequence

import yaml

Expand Down Expand Up @@ -247,21 +247,26 @@ def is_empty_dir(path: Path) -> bool:
return path.is_dir() and child is None


def ignore_from_base_curriculum(_parent: str, children: Sequence[str]) -> Sequence[str]:
def is_raw_input_file(path_str: str) -> bool:
path = Path(path_str)
return (
path.name.startswith("semantic")
or path.name.startswith("original_colors_color_segmentation_")
or path.name.startswith("color_segmentation_")
or path.name.startswith("color_refined_semantic_")
or path.name.startswith("combined_color_refined_semantic_")
or path.name.startswith("stroke_")
or path.name.startswith("feature")
or path.name.startswith("post_decode")
)
def ignore_from_base_curriculum(
*, copy_segmentations: bool
) -> Callable[[str, Sequence[str]], Sequence[str]]:
def ignore_fn(_parent: str, children: Sequence[str]) -> Sequence[str]:
def is_raw_input_file(path_str: str) -> bool:
path = Path(path_str)
return (
(path.name.startswith("semantic") and not copy_segmentations)
or path.name.startswith("original_colors_color_segmentation_")
or path.name.startswith("color_segmentation_")
or path.name.startswith("color_refined_semantic_")
or path.name.startswith("combined_color_refined_semantic_")
or path.name.startswith("stroke_")
or path.name.startswith("feature")
or path.name.startswith("post_decode")
)

return [child for child in children if is_raw_input_file(child)]

return [child for child in children if is_raw_input_file(child)]
return ignore_fn


def make_run_identifier(run_start: datetime) -> str:
Expand All @@ -273,6 +278,9 @@ def pipeline_entrypoint(params: Parameters) -> None:

pipeline_params = params.namespace("pipeline")
use_sbatch = parse_bool_param(pipeline_params, "use_sbatch")
copy_segmentations_from_base = parse_bool_param(
pipeline_params, "copy_segmentations_from_base"
)
do_object_segmentation = parse_bool_param(pipeline_params, "do_object_segmentation")
segmentation_model = pipeline_params.string("segmentation_model")
segmentation_api_port = params.integer("segmentation_api_port")
Expand All @@ -288,7 +296,7 @@ def pipeline_entrypoint(params: Parameters) -> None:
gnn_decode = parse_bool_param(pipeline_params, "gnn_decode")
email = Email(pipeline_params.string("email")) if "email" in pipeline_params else None
submission_details_path = pipeline_params.creatable_file("submission_details_path")
job_logs_path = pipeline_params.creatable_file("job_logs_path")
job_logs_path = pipeline_params.creatable_directory("job_logs_path")
if train_gnn:
model_path = pipeline_params.creatable_file("stroke_model_path")
elif gnn_decode:
Expand Down Expand Up @@ -362,7 +370,9 @@ def pipeline_entrypoint(params: Parameters) -> None:
copytree(
split_to_base_curriculum_path[split],
path,
ignore=ignore_from_base_curriculum,
ignore=ignore_from_base_curriculum(
copy_segmentations=copy_segmentations_from_base
),
)
elif not path.is_dir():
raise RuntimeError(f"Path {path} for split {split} is not a dir.")
Expand Down
1 change: 1 addition & 0 deletions parameters/experiments/p3/m6_pipeline_debug.params
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _includes:
pipeline:
use_sbatch: false
do_object_segmentation: true
copy_segmentations_from_base: false
segmentation_model: "stego"
segment_colors: false
refine_colors: false
Expand Down
2 changes: 1 addition & 1 deletion slurm/segment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#SBATCH --account=borrowed
#SBATCH --partition=ephemeral
#SBATCH --qos=ephemeral
#SBATCH --time=1:00:00 # Number of hours required per node, max 24 on SAGA
#SBATCH --time=6:00:00 # Number of hours required per node, max 24 on SAGA
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=32g
Expand Down

0 comments on commit 92a7ff9

Please sign in to comment.