-
Notifications
You must be signed in to change notification settings - Fork 8
/
cli.py
132 lines (117 loc) · 4.58 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
from pathlib import Path
import click
from paq2piq.common import set_up_seed
from paq2piq.inference_model import InferenceModel, RoIPoolModel
from paq2piq.trainer import Trainer, validate_and_test
def init_logging() -> None:
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@click.group()
def cli():
pass
"""
# %%
!ls paq2piq/*.py
!python /media/zq/Seagate/Git/paq2piq/cli.py --help
!python /media/zq/Seagate/Git/paq2piq/cli.py get-image-score --help
!python /media/zq/Seagate/Git/paq2piq/cli.py get-image-score --path_to_model_state paq2piq/RoIPoolModel.pth --path_to_image /media/zq/Seagate/Git/paq2piq/images/Picture1.jpg
#
#
pip install click
# %%
"""
@click.command("get-image-score", short_help="Get image scores")
@click.option("--path_to_model_state", help="path to model weight .pth file", required=True, type=str)
@click.option("--path_to_image", help="image ", required=True, type=str)
def get_image_score(path_to_model_state, path_to_image):
model = InferenceModel(RoIPoolModel(), path_to_model_state=path_to_model_state)
result = model.predict_from_file(path_to_image)
click.echo(result)
"""
# %%
!python --version
!/home/zq/.virtualenvs/gpu/bin/python /media/zq/Seagate/Git/paq2piq/cli.py validate-model \
--path_to_model_state paq2piq/RoIPoolModel.pth \
--path_to_save_csv "!data/FLIVE/release" \
--path_to_images "!data/FLIVE/release/images" \
--batch_size 64
# %%
"""
@click.command("validate-model", short_help="Validate model")
@click.option("--path_to_model_state", help="path to model weight .pth file", required=True, type=Path)
@click.option("--path_to_save_csv", help="where save train.csv|val.csv|test.csv", required=True, type=Path)
@click.option("--path_to_images", help="images directory", required=True, type=Path)
@click.option("--batch_size", help="batch size", default=128, type=int)
@click.option("--num_workers", help="number of reading workers", default=16, type=int)
# @click.option("--drop_out", help="drop out", default=0.0, type=float)
def validate_model(path_to_model_state, path_to_save_csv, path_to_images, batch_size, num_workers): # , drop_out
validate_and_test(
path_to_model_state=path_to_model_state,
path_to_save_csv=path_to_save_csv,
path_to_images=path_to_images,
batch_size=batch_size,
num_workers=num_workers,
# drop_out=drop_out,
)
click.echo("Done!")
"""
# %%
!/home/zq/.virtualenvs/gpu/bin/python
!python /media/zq/Seagate/Git/paq2piq/cli.py train-model \
--path_to_save_csv "!data/FLIVE/release" \
--path_to_images "!data/FLIVE/release/images" \
--experiment_dir "data/exp/t1-baseline" \
--num_epoch 100 \
--batch_size 64
# sh train.sh
# %%
"""
@click.command("train-model", short_help="Train model")
@click.option("--path_to_save_csv", help="where save train.csv|val.csv|test.csv", required=True, type=Path)
@click.option("--path_to_images", help="images directory", required=True, type=Path)
@click.option("--experiment_dir", help="directory name to save all logs and weight", required=True, type=Path)
@click.option("--model_type", help="res net model type", default="resnet18", type=str)
@click.option("--batch_size", help="batch size", default=128, type=int)
@click.option("--num_workers", help="number of reading workers", default=16, type=int)
@click.option("--num_epoch", help="number of epoch", default=32, type=int)
@click.option("--init_lr", help="initial learning rate", default=0.0001, type=float)
# @click.option("--drop_out", help="drop out", default=0.5, type=float)
@click.option("--optimizer_type", help="optimizer type", default="adam", type=str)
@click.option("--seed", help="random seed", default=42, type=int)
def train_model(
path_to_save_csv: Path,
path_to_images: Path,
experiment_dir: Path,
model_type: str,
batch_size: int,
num_workers: int,
num_epoch: int,
init_lr: float,
# drop_out: float,
optimizer_type: str,
seed: int,
):
click.echo("Train and validate model")
set_up_seed(seed)
trainer = Trainer(
path_to_save_csv=path_to_save_csv,
path_to_images=path_to_images,
experiment_dir=experiment_dir,
model_type=model_type,
batch_size=batch_size,
num_workers=num_workers,
num_epoch=num_epoch,
init_lr=init_lr,
# drop_out=drop_out,
optimizer_type=optimizer_type,
)
trainer.train_model()
click.echo("Done!")
def main():
init_logging()
cli.add_command(get_image_score)
cli.add_command(validate_model)
cli.add_command(train_model)
cli()
if __name__ == "__main__":
main()