diff --git a/configs/_base_/datasets/kitti2015_raft_test.py b/configs/_base_/datasets/kitti2015_raft_test.py index 2a67368c..9c5a5dfc 100644 --- a/configs/_base_/datasets/kitti2015_raft_test.py +++ b/configs/_base_/datasets/kitti2015_raft_test.py @@ -1,4 +1,5 @@ -img_norm_cfg = dict(mean=[0., 0., 0.], std=[255., 255., 255.], to_rgb=False) +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) kitti_test_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/mmflow/utils/set_env.py b/mmflow/utils/set_env.py index 1f38b92a..a99a161f 100644 --- a/mmflow/utils/set_env.py +++ b/mmflow/utils/set_env.py @@ -11,7 +11,6 @@ def setup_multi_processes(cfg): """Setup multi-processing environment variables.""" logger = get_root_logger() - # set multi-process start method if platform.system() != 'Windows': mp_start_method = cfg.get('mp_start_method', None) @@ -31,7 +30,13 @@ def setup_multi_processes(cfg): else: logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') - if cfg.data.train_dataloader.workers_per_gpu > 1: + if cfg.data.get('train_dataloader') is not None: + workers_per_gpu = cfg.data.train_dataloader.get('workers_per_gpu', 0) + elif cfg.data.get('test_dataloader') is not None: + workers_per_gpu = cfg.data.test_dataloader.get('workers_per_gpu', 0) + else: + workers_per_gpu = 0 + if workers_per_gpu > 1: # setup OMP threads # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa omp_num_threads = cfg.get('omp_num_threads', None)