Skip to content

Commit

Permalink
fix prune.py (PaddlePaddle#212)
Browse files Browse the repository at this point in the history
* fix prune.py
  • Loading branch information
heavengate authored Feb 6, 2020
1 parent 791b8f4 commit 677dfad
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions slim/prune/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu, check_version
import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feed

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
Expand Down Expand Up @@ -142,9 +141,9 @@ def main():
if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult']
extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
if cfg.metric == 'WIDERFACE':
extra_keys = ['im_id', 'im_shape', 'gt_box']
extra_keys = ['im_id', 'im_shape', 'gt_bbox']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)

Expand Down Expand Up @@ -306,8 +305,14 @@ def main():
if 'mask' in results[0]:
resolution = model.mask_head.resolution
box_ap_stats = eval_results(
results, eval_feed, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval, map_type)
results,
cfg.metric,
cfg.num_classes,
resolution,
is_bbox_normalized,
FLAGS.output_eval,
map_type,
dataset=dataset)

# use tb_paddle to log mAP
if FLAGS.use_tb:
Expand Down

0 comments on commit 677dfad

Please sign in to comment.