Skip to content

Commit

Permalink
Add train-type parameter to otx train
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Mar 9, 2023
1 parent 0e4a8b7 commit 6a8a893
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ However, if you created a workspace with ``otx build``, the training process can
Comma-separated paths to unlabeled data folders
--unlabeled-file-list UNLABELED_FILE_LIST
Comma-separated paths to unlabeled file list
--train-type TRAIN_TYPE
The currently supported options: dict_keys(['INCREMENTAL', 'SEMISUPERVISED', 'SELFSUPERVISED']).
--load-weights LOAD_WEIGHTS
Load model weights from previously saved checkpoint.
--resume-from RESUME_FROM
Expand Down
2 changes: 1 addition & 1 deletion otx/cli/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _get_train_type(self, ignore_args: bool = False) -> str:
if arg_algo_backend:
train_type = arg_algo_backend.get("train_type", {"value": "INCREMENTAL"}) # type: ignore
return train_type.get("value", "INCREMENTAL")
if self.mode in ("build") and self.args.train_type:
if self.args.train_type:
self.train_type = self.args.train_type.upper()
if self.train_type not in TASK_TYPE_TO_SUB_DIR_NAME:
raise ValueError(f"{self.train_type} is not currently supported by otx.")
Expand Down
7 changes: 7 additions & 0 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.adapters.model_adapter import ModelAdapter
from otx.cli.manager import ConfigManager
from otx.cli.manager.config_manager import TASK_TYPE_TO_SUB_DIR_NAME
from otx.cli.utils.hpo import run_hpo
from otx.cli.utils.importing import get_impl_class
from otx.cli.utils.io import read_binary, read_label_schema, save_model_data
Expand Down Expand Up @@ -60,6 +61,12 @@ def get_args():
"--unlabeled-file-list",
help="Comma-separated paths to unlabeled file list",
)
parser.add_argument(
"--train-type",
help=f"The currently supported options: {TASK_TYPE_TO_SUB_DIR_NAME.keys()}.",
type=str,
default="incremental",
)
parser.add_argument(
"--load-weights",
help="Load model weights from previously saved checkpoint.",
Expand Down

0 comments on commit 6a8a893

Please sign in to comment.