Skip to content

Commit

Permalink
[Refactor] Unify the --out and --dump in tools/test.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 10, 2023
1 parent aa53f77 commit 45e5200
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 28 deletions.
4 changes: 2 additions & 2 deletions docs/en/user_guides/train_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [
| `CONFIG_FILE` | The path to the config file. |
| `CHECKPOINT_FILE` | The path to the checkpoint file (It can be a http link, and you can find checkpoints [here](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html)). |
| `--work-dir WORK_DIR` | The directory to save the file containing evaluation metrics. |
| `--out OUT` | The path to save the file containing evaluation metrics. |
| `--dump DUMP` | The path to dump all outputs of the model for offline evaluation. |
| `--out OUT` | The path to save the file containing test results. |
| `--out-item OUT_ITEM` | To specify the content of the test results file, and it can be "pred" or "metrics". If "pred", save the outputs of the model for offline evaluation. If "metrics", save the evaluation metrics. Defaults to "pred". |
| `--cfg-options CFG_OPTIONS` | Override some settings in the used config, the key-value pair in xxx=yyy format will be merged into the config file. If the value to be overwritten is a list, it should be of the form of either `key="[a,b]"` or `key=a,b`. The argument also allows nested list/tuple values, e.g. `key="[(a,b),(c,d)]"`. Note that the quotation marks are necessary and that no white space is allowed. |
| `--show-dir SHOW_DIR` | The directory to save the result visualization images. |
| `--show` | Visualize the prediction result in a window. |
Expand Down
4 changes: 2 additions & 2 deletions docs/zh_CN/user_guides/train_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [
| `CONFIG_FILE` | 配置文件的路径。 |
| `CHECKPOINT_FILE` | 权重文件路径(支持 http 链接,你可以在[这里](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html)寻找需要的权重文件)。 |
| `--work-dir WORK_DIR` | 用来保存测试指标结果的文件夹。 |
| `--out OUT` | 用来保存测试指标结果的文件。 |
| `--dump DUMP` | 用来保存所有模型输出的文件,这些数据可以用于离线测评 |
| `--out OUT` | 用来保存测试输出的文件。 |
| `--out-item OUT_ITEM` | 指定测试输出文件的内容,可以为 "pred" 或 "metrics",其中 "pred" 表示保存所有模型输出,这些数据可以用于离线测评;"metrics" 表示输出测试指标。默认为 "pred"。 |
| `--cfg-options CFG_OPTIONS` | 重载配置文件中的一些设置。使用类似 `xxx=yyy` 的键值对形式指定,这些设置会被融合入从配置文件读取的配置。你可以使用 `key="[a,b]"` 或者 `key=a,b` 的格式来指定列表格式的值,且支持嵌套,例如 \`key="[(a,b),(c,d)]",这里的引号是不可省略的。另外每个重载项内部不可出现空格。 |
| `--show-dir SHOW_DIR` | 用于保存可视化预测结果图像的文件夹。 |
| `--show` | 在窗口中显示预测结果图像。 |
Expand Down
37 changes: 13 additions & 24 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.hooks import Hook
from mmengine.evaluator import DumpResults
from mmengine.runner import Runner

from mmcls.utils import register_all_modules
Expand All @@ -22,9 +22,11 @@ def parse_args():
help='the directory to save the file containing evaluation metrics')
parser.add_argument('--out', help='the file to save metric results.')
parser.add_argument(
'--dump',
type=str,
help='dump predictions to a pickle file for offline evaluation')
'--out-item',
default='pred',
choices=['metrics', 'pred'],
help='To output whether metrics or predictions. '
'Defaults to output predictions.')
parser.add_argument(
'--cfg-options',
nargs='+',
Expand Down Expand Up @@ -102,16 +104,6 @@ def merge_args(cfg, args):
cfg.default_hooks.visualization.out_dir = args.show_dir
cfg.default_hooks.visualization.interval = args.interval

# -------------------- Dump predictions --------------------
if args.dump is not None:
assert args.dump.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
dump_metric = dict(type='DumpResults', out_file_path=args.dump)
if isinstance(cfg.test_evaluator, (list, tuple)):
cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric)
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]

# -------------------- TTA related args --------------------
if args.tta:
if 'tta_model' not in cfg:
Expand Down Expand Up @@ -169,18 +161,15 @@ def main():
# build the runner from config
runner = Runner.from_cfg(cfg)

if args.out:

class SaveMetricHook(Hook):

def after_test_epoch(self, _, metrics=None):
if metrics is not None:
mmengine.dump(metrics, args.out)

runner.register_hook(SaveMetricHook(), 'LOWEST')
if args.out and args.out_item == 'pred':
runner.test_evaluator.metrics.append(
DumpResults(out_file_path=args.out))

# start testing
runner.test()
metrics = runner.test()

if args.out and args.out_item == 'metrics':
mmengine.dump(metrics, args.out)


if __name__ == '__main__':
Expand Down

0 comments on commit 45e5200

Please sign in to comment.