Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add train-type parameter to otx train #1874

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ However, if you created a workspace with ``otx build``, the training process can
Comma-separated paths to unlabeled data folders
--unlabeled-file-list UNLABELED_FILE_LIST
Comma-separated paths to unlabeled file list
--train-type TRAIN_TYPE
The currently supported options: dict_keys(['INCREMENTAL', 'SEMISUPERVISED', 'SELFSUPERVISED']).
--load-weights LOAD_WEIGHTS
Load model weights from previously saved checkpoint.
--resume-from RESUME_FROM
Expand Down
2 changes: 1 addition & 1 deletion otx/cli/manager/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _get_train_type(self, ignore_args: bool = False) -> str:
if arg_algo_backend:
train_type = arg_algo_backend.get("train_type", {"value": "INCREMENTAL"}) # type: ignore
return train_type.get("value", "INCREMENTAL")
if self.mode in ("build") and self.args.train_type:
if hasattr(self.args, "train_type") and self.mode in ("build", "train") and self.args.train_type:
self.train_type = self.args.train_type.upper()
if self.train_type not in TASK_TYPE_TO_SUB_DIR_NAME:
raise ValueError(f"{self.train_type} is not currently supported by otx.")
Expand Down
7 changes: 7 additions & 0 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.adapters.model_adapter import ModelAdapter
from otx.cli.manager import ConfigManager
from otx.cli.manager.config_manager import TASK_TYPE_TO_SUB_DIR_NAME
from otx.cli.utils.hpo import run_hpo
from otx.cli.utils.importing import get_impl_class
from otx.cli.utils.io import read_binary, read_label_schema, save_model_data
Expand Down Expand Up @@ -60,6 +61,12 @@ def get_args():
"--unlabeled-file-list",
help="Comma-separated paths to unlabeled file list",
)
parser.add_argument(
"--train-type",
help=f"The currently supported options: {TASK_TYPE_TO_SUB_DIR_NAME.keys()}.",
type=str,
default="incremental",
)
parser.add_argument(
"--load-weights",
help="Load model weights from previously saved checkpoint.",
Expand Down