diff --git a/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py b/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py new file mode 100644 index 00000000000..88c85767a92 --- /dev/null +++ b/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_wandb_coco.py @@ -0,0 +1,26 @@ +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +# Set evaluation interval +evaluation = dict(interval=2) +# Set checkpoint interval +checkpoint_config = dict(interval=4) + +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='MMDetWandbHook', + init_kwargs={ + 'project': 'mmdetection', + 'group': 'maskrcnn-r50-fpn-1x-coco' + }, + interval=50, + log_checkpoint=True, + log_checkpoint_metadata=True, + num_eval_images=100) + ]) diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index 7c1fbe968d2..98856c18ce6 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -25,6 +25,7 @@ class EvalHook(BaseEvalHook): def __init__(self, *args, dynamic_intervals=None, **kwargs): super(EvalHook, self).__init__(*args, **kwargs) + self.latest_results = None self.use_dynamic_intervals = dynamic_intervals is not None if self.use_dynamic_intervals: @@ -53,7 +54,11 @@ def _do_evaluate(self, runner): return from mmdet.apis import single_gpu_test + + # Changed results to self.results so that MMDetWandbHook can access + # the evaluation results and log them to wandb. results = single_gpu_test(runner.model, self.dataloader, show=False) + self.latest_results = results runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) key_score = self.evaluate(runner, results) # the key_score may be `None` so it needs to skip the action to save @@ -69,6 +74,7 @@ class DistEvalHook(BaseDistEvalHook): def __init__(self, *args, dynamic_intervals=None, **kwargs): super(DistEvalHook, self).__init__(*args, **kwargs) + self.latest_results = None self.use_dynamic_intervals = dynamic_intervals is not None if self.use_dynamic_intervals: @@ -114,11 +120,15 @@ def _do_evaluate(self, runner): tmpdir = osp.join(runner.work_dir, '.eval_hook') from mmdet.apis import multi_gpu_test + + # Changed results to self.results so that MMDetWandbHook can access + # the evaluation results and log them to wandb. results = multi_gpu_test( runner.model, self.dataloader, tmpdir=tmpdir, gpu_collect=self.gpu_collect) + self.latest_results = results if runner.rank == 0: print('\n') runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) diff --git a/mmdet/core/hook/__init__.py b/mmdet/core/hook/__init__.py index 788ab494cbd..7b9ac9ff3ef 100644 --- a/mmdet/core/hook/__init__.py +++ b/mmdet/core/hook/__init__.py @@ -5,11 +5,13 @@ from .set_epoch_info_hook import SetEpochInfoHook from .sync_norm_hook import SyncNormHook from .sync_random_size_hook import SyncRandomSizeHook +from .wandblogger_hook import MMDetWandbHook from .yolox_lrupdater_hook import YOLOXLrUpdaterHook from .yolox_mode_switch_hook import YOLOXModeSwitchHook __all__ = [ 'SyncRandomSizeHook', 'YOLOXModeSwitchHook', 'SyncNormHook', 'ExpMomentumEMAHook', 'LinearMomentumEMAHook', 'YOLOXLrUpdaterHook', - 'CheckInvalidLossHook', 'SetEpochInfoHook', 'MemoryProfilerHook' + 'CheckInvalidLossHook', 'SetEpochInfoHook', 'MemoryProfilerHook', + 'MMDetWandbHook' ] diff --git a/mmdet/core/hook/wandblogger_hook.py b/mmdet/core/hook/wandblogger_hook.py new file mode 100644 index 00000000000..e4aef88502a --- /dev/null +++ b/mmdet/core/hook/wandblogger_hook.py @@ -0,0 +1,565 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os.path as osp +import sys + +import mmcv +import numpy as np +import pycocotools.mask as mask_util +from mmcv.runner import HOOKS +from mmcv.runner.dist_utils import master_only +from mmcv.runner.hooks.checkpoint import CheckpointHook +from mmcv.runner.hooks.logger.wandb import WandbLoggerHook + +from mmdet.core import DistEvalHook, EvalHook +from mmdet.core.mask.structures import polygon_to_bitmap + + +@HOOKS.register_module() +class MMDetWandbHook(WandbLoggerHook): + """Enhanced Wandb logger hook for MMDetection. + + Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not + only automatically log all the metrics but also log the following extra + information - saves model checkpoints as W&B Artifact, and + logs model prediction as interactive W&B Tables. + + - Metrics: The MMDetWandbHook will automatically log training + and validation metrics along with system metrics (CPU/GPU). + + - Checkpointing: If `log_checkpoint` is True, the checkpoint saved at + every checkpoint interval will be saved as W&B Artifacts. + This depends on the : class:`mmcv.runner.CheckpointHook` whose priority + is higher than this hook. Please refer to + https://docs.wandb.ai/guides/artifacts/model-versioning + to learn more about model versioning with W&B Artifacts. + + - Checkpoint Metadata: If evaluation results are available for a given + checkpoint artifact, it will have a metadata associated with it. + The metadata contains the evaluation metrics computed on validation + data with that checkpoint along with the current epoch. It depends + on `EvalHook` whose priority is more than MMDetWandbHook. + + - Evaluation: At every evaluation interval, the `MMDetWandbHook` logs the + model prediction as interactive W&B Tables. The number of samples + logged is given by `num_eval_images`. Currently, the `MMDetWandbHook` + logs the predicted bounding boxes along with the ground truth at every + evaluation interval. This depends on the `EvalHook` whose priority is + more than `MMDetWandbHook`. Also note that the data is just logged once + and subsequent evaluation tables uses reference to the logged data + to save memory usage. Please refer to + https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables. + + For more details check out W&B's MMDetection docs: + https://docs.wandb.ai/guides/integrations/mmdetection + + ``` + Example: + log_config = dict( + ... + hooks=[ + ..., + dict(type='MMDetWandbHook', + init_kwargs={ + 'entity': "YOUR_ENTITY", + 'project': "YOUR_PROJECT_NAME" + }, + interval=50, + log_checkpoint=True, + log_checkpoint_metadata=True, + num_eval_images=100, + bbox_score_thr=0.3) + ]) + ``` + + Args: + init_kwargs (dict): A dict passed to wandb.init to initialize + a W&B run. Please refer to https://docs.wandb.ai/ref/python/init + for possible key-value pairs. + interval (int): Logging interval (every k iterations). Defaults to 50. + log_checkpoint (bool): Save the checkpoint at every checkpoint interval + as W&B Artifacts. Use this for model versioning where each version + is a checkpoint. Defaults to False. + log_checkpoint_metadata (bool): Log the evaluation metrics computed + on the validation data with the checkpoint, along with current + epoch as a metadata to that checkpoint. + Defaults to True. + num_eval_images (int): The number of validation images to be logged. + If zero, the evaluation won't be logged. Defaults to 100. + bbox_score_thr (float): Threshold for bounding box scores. + Defaults to 0.3. + """ + + def __init__(self, + init_kwargs=None, + interval=50, + log_checkpoint=False, + log_checkpoint_metadata=False, + num_eval_images=100, + bbox_score_thr=0.3, + **kwargs): + super(MMDetWandbHook, self).__init__(init_kwargs, interval, **kwargs) + + self.log_checkpoint = log_checkpoint + self.log_checkpoint_metadata = ( + log_checkpoint and log_checkpoint_metadata) + self.num_eval_images = num_eval_images + self.bbox_score_thr = bbox_score_thr + self.log_evaluation = (num_eval_images > 0) + self.ckpt_hook: CheckpointHook = None + self.eval_hook: EvalHook = None + + @master_only + def before_run(self, runner): + super(MMDetWandbHook, self).before_run(runner) + + # Save and Log config. + if runner.meta is not None: + src_cfg_path = osp.join(runner.work_dir, + runner.meta.get('exp_name', None)) + if osp.exists(src_cfg_path): + self.wandb.save(src_cfg_path, base_path=runner.work_dir) + self._update_wandb_config(runner) + else: + runner.logger.warning('No meta information found in the runner. ') + + # Inspect CheckpointHook and EvalHook + for hook in runner.hooks: + if isinstance(hook, CheckpointHook): + self.ckpt_hook = hook + if isinstance(hook, (EvalHook, DistEvalHook)): + self.eval_hook = hook + + # Check conditions to log checkpoint + if self.log_checkpoint: + if self.ckpt_hook is None: + self.log_checkpoint = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log checkpoint in MMDetWandbHook, `CheckpointHook` is' + 'required, please check hooks in the runner.') + else: + self.ckpt_interval = self.ckpt_hook.interval + + # Check conditions to log evaluation + if self.log_evaluation or self.log_checkpoint_metadata: + if self.eval_hook is None: + self.log_evaluation = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log evaluation or checkpoint metadata in ' + 'MMDetWandbHook, `EvalHook` or `DistEvalHook` in mmdet ' + 'is required, please check whether the validation ' + 'is enabled.') + else: + self.eval_interval = self.eval_hook.interval + self.val_dataset = self.eval_hook.dataloader.dataset + # Determine the number of samples to be logged. + if self.num_eval_images > len(self.val_dataset): + self.num_eval_images = len(self.val_dataset) + runner.logger.warning( + f'The num_eval_images ({self.num_eval_images}) is ' + 'greater than the total number of validation samples ' + f'({len(self.val_dataset)}). The complete validation ' + 'dataset will be logged.') + + # Check conditions to log checkpoint metadata + if self.log_checkpoint_metadata: + assert self.ckpt_interval % self.eval_interval == 0, \ + 'To log checkpoint metadata in MMDetWandbHook, the interval ' \ + f'of checkpoint saving ({self.ckpt_interval}) should be ' \ + 'divisible by the interval of evaluation ' \ + f'({self.eval_interval}).' + + # Initialize evaluation table + if self.log_evaluation: + # Initialize data table + self._init_data_table() + # Add data to the data table + self._add_ground_truth(runner) + # Log ground truth data + self._log_data_table() + + @master_only + def after_train_epoch(self, runner): + super(MMDetWandbHook, self).after_train_epoch(runner) + + if not self.by_epoch: + return + + # Log checkpoint and metadata. + if (self.log_checkpoint + and self.every_n_epochs(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_epoch(runner))): + if self.log_checkpoint_metadata and self.eval_hook: + metadata = { + 'epoch': runner.epoch + 1, + **self._get_eval_results() + } + else: + metadata = None + aliases = [f'epoch_{runner.epoch+1}', 'latest'] + model_path = osp.join(self.ckpt_hook.out_dir, + f'epoch_{runner.epoch+1}.pth') + self._log_ckpt_as_artifact(model_path, aliases, metadata) + + # Save prediction table + if self.log_evaluation and self.eval_hook._should_evaluate(runner): + results = self.eval_hook.latest_results + # Initialize evaluation table + self._init_pred_table() + # Log predictions + self._log_predictions(results) + # Log the table + self._log_eval_table(runner.epoch + 1) + + @master_only + def after_train_iter(self, runner): + if self.get_mode(runner) == 'train': + # An ugly patch. The iter-based eval hook will call the + # `after_train_iter` method of all logger hooks before evaluation. + # Use this trick to skip that call. + # Don't call super method at first, it will clear the log_buffer + return super(MMDetWandbHook, self).after_train_iter(runner) + else: + super(MMDetWandbHook, self).after_train_iter(runner) + + if self.by_epoch: + return + + # Save checkpoint and metadata + if (self.log_checkpoint + and self.every_n_iters(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_iter(runner))): + if self.log_checkpoint_metadata and self.eval_hook: + metadata = { + 'iter': runner.iter + 1, + **self._get_eval_results() + } + else: + metadata = None + aliases = [f'iter_{runner.iter+1}', 'latest'] + model_path = osp.join(self.ckpt_hook.out_dir, + f'iter_{runner.iter+1}.pth') + self._log_ckpt_as_artifact(model_path, aliases, metadata) + + # Save prediction table + if self.log_evaluation and self.eval_hook._should_evaluate(runner): + results = self.eval_hook.latest_results + # Initialize evaluation table + self._init_pred_table() + # Log predictions + self._log_predictions(results) + # Log the table + self._log_eval_table(runner.iter + 1) + + @master_only + def after_run(self, runner): + self.wandb.finish() + + def _update_wandb_config(self, runner): + """Update wandb config.""" + # Import the config file. + sys.path.append(runner.work_dir) + config_filename = runner.meta['exp_name'][:-3] + configs = importlib.import_module(config_filename) + # Prepare a nested dict of config variables. + config_keys = [key for key in dir(configs) if not key.startswith('__')] + config_dict = {key: getattr(configs, key) for key in config_keys} + # Update the W&B config. + self.wandb.config.update(config_dict) + + def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None): + """Log model checkpoint as W&B Artifact. + + Args: + model_path (str): Path of the checkpoint to log. + aliases (list): List of the aliases associated with this artifact. + metadata (dict, optional): Metadata associated with this artifact. + """ + model_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_model', type='model', metadata=metadata) + model_artifact.add_file(model_path) + self.wandb.log_artifact(model_artifact, aliases=aliases) + + def _get_eval_results(self): + """Get model evaluation results.""" + results = self.eval_hook.latest_results + eval_results = self.val_dataset.evaluate( + results, logger='silent', **self.eval_hook.eval_kwargs) + return eval_results + + def _init_data_table(self): + """Initialize the W&B Tables for validation data.""" + columns = ['image_name', 'image'] + self.data_table = self.wandb.Table(columns=columns) + + def _init_pred_table(self): + """Initialize the W&B Tables for model evaluation.""" + columns = ['image_name', 'ground_truth', 'prediction'] + self.eval_table = self.wandb.Table(columns=columns) + + def _add_ground_truth(self, runner): + # Get image loading pipeline + from mmdet.datasets.pipelines import LoadImageFromFile + img_loader = None + for t in self.val_dataset.pipeline.transforms: + if isinstance(t, LoadImageFromFile): + img_loader = t + + if img_loader is None: + self.log_evaluation = False + runner.logger.warning( + 'LoadImageFromFile is required to add images ' + 'to W&B Tables.') + return + + # Select the images to be logged. + self.eval_image_indexs = np.arange(len(self.val_dataset)) + # Set seed so that same validation set is logged each time. + np.random.seed(42) + np.random.shuffle(self.eval_image_indexs) + self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images] + + CLASSES = self.val_dataset.CLASSES + self.class_id_to_label = { + id + 1: name + for id, name in enumerate(CLASSES) + } + self.class_set = self.wandb.Classes([{ + 'id': id, + 'name': name + } for id, name in self.class_id_to_label.items()]) + + img_prefix = self.val_dataset.img_prefix + + for idx in self.eval_image_indexs: + img_info = self.val_dataset.data_infos[idx] + image_name = img_info.get('filename', f'img_{idx}') + img_height, img_width = img_info['height'], img_info['width'] + + img_meta = img_loader( + dict(img_info=img_info, img_prefix=img_prefix)) + + # Get image and convert from BGR to RGB + image = mmcv.bgr2rgb(img_meta['img']) + + data_ann = self.val_dataset.get_ann_info(idx) + bboxes = data_ann['bboxes'] + labels = data_ann['labels'] + masks = data_ann.get('masks', None) + + # Get dict of bounding boxes to be logged. + assert len(bboxes) == len(labels) + wandb_boxes = self._get_wandb_bboxes(bboxes, labels) + + # Get dict of masks to be logged. + if masks is not None: + wandb_masks = self._get_wandb_masks( + masks, + labels, + is_poly_mask=True, + height=img_height, + width=img_width) + else: + wandb_masks = None + # TODO: Panoramic segmentation visualization. + + # Log a row to the data table. + self.data_table.add_data( + image_name, + self.wandb.Image( + image, + boxes=wandb_boxes, + masks=wandb_masks, + classes=self.class_set)) + + def _log_predictions(self, results): + table_idxs = self.data_table_ref.get_index() + assert len(table_idxs) == len(self.eval_image_indexs) + + for ndx, eval_image_index in enumerate(self.eval_image_indexs): + # Get the result + result = results[eval_image_index] + if isinstance(result, tuple): + bbox_result, segm_result = result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = result, None + assert len(bbox_result) == len(self.class_id_to_label) + + # Get labels + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + # Get segmentation mask if available. + segms = None + if segm_result is not None and len(labels) > 0: + segms = mmcv.concat_list(segm_result) + segms = mask_util.decode(segms) + segms = segms.transpose(2, 0, 1) + assert len(segms) == len(labels) + # TODO: Panoramic segmentation visualization. + + # Remove bounding boxes and masks with score lower than threshold. + if self.bbox_score_thr > 0: + assert bboxes is not None and bboxes.shape[1] == 5 + scores = bboxes[:, -1] + inds = scores > self.bbox_score_thr + bboxes = bboxes[inds, :] + labels = labels[inds] + if segms is not None: + segms = segms[inds, ...] + + # Get dict of bounding boxes to be logged. + wandb_boxes = self._get_wandb_bboxes(bboxes, labels, log_gt=False) + # Get dict of masks to be logged. + if segms is not None: + wandb_masks = self._get_wandb_masks(segms, labels) + else: + wandb_masks = None + + # Log a row to the eval table. + self.eval_table.add_data( + self.data_table_ref.data[ndx][0], + self.data_table_ref.data[ndx][1], + self.wandb.Image( + self.data_table_ref.data[ndx][1], + boxes=wandb_boxes, + masks=wandb_masks, + classes=self.class_set)) + + def _get_wandb_bboxes(self, bboxes, labels, log_gt=True): + """Get list of structured dict for logging bounding boxes to W&B. + + Args: + bboxes (list): List of bounding box coordinates in + (minX, minY, maxX, maxY) format. + labels (int): List of label ids. + log_gt (bool): Whether to log ground truth or prediction boxes. + + Returns: + Dictionary of bounding boxes to be logged. + """ + wandb_boxes = {} + + box_data = [] + for bbox, label in zip(bboxes, labels): + if not isinstance(label, int): + label = int(label) + label = label + 1 + + if len(bbox) == 5: + confidence = float(bbox[4]) + class_name = self.class_id_to_label[label] + box_caption = f'{class_name} {confidence:.2f}' + else: + box_caption = str(self.class_id_to_label[label]) + + position = dict( + minX=int(bbox[0]), + minY=int(bbox[1]), + maxX=int(bbox[2]), + maxY=int(bbox[3])) + + box_data.append({ + 'position': position, + 'class_id': label, + 'box_caption': box_caption, + 'domain': 'pixel' + }) + + wandb_bbox_dict = { + 'box_data': box_data, + 'class_labels': self.class_id_to_label + } + + if log_gt: + wandb_boxes['ground_truth'] = wandb_bbox_dict + else: + wandb_boxes['predictions'] = wandb_bbox_dict + + return wandb_boxes + + def _get_wandb_masks(self, + masks, + labels, + is_poly_mask=False, + height=None, + width=None): + """Get list of structured dict for logging masks to W&B. + + Args: + masks (list): List of masks. + labels (int): List of label ids. + is_poly_mask (bool): Whether the mask is polygonal or not. + This is true for CocoDataset. + height (int): Height of the image. + width (int): Width of the image. + + Returns: + Dictionary of masks to be logged. + """ + mask_label_dict = dict() + for mask, label in zip(masks, labels): + label = label + 1 + # Get bitmap mask from polygon. + if is_poly_mask: + if height is not None and width is not None: + mask = polygon_to_bitmap(mask, height, width) + # Create composite masks for each class. + if label not in mask_label_dict.keys(): + mask_label_dict[label] = mask + else: + mask_label_dict[label] = np.logical_or(mask_label_dict[label], + mask) + + wandb_masks = dict() + for key, value in mask_label_dict.items(): + # Create mask for that class. + value = value.astype(np.uint8) + value[value > 0] = key + + # Create dict of masks for logging. + class_name = self.class_id_to_label[key] + wandb_masks[class_name] = { + 'mask_data': value, + 'class_labels': self.class_id_to_label + } + + return wandb_masks + + def _log_data_table(self): + """Log the W&B Tables for validation data as artifact and calls + `use_artifact` on it so that the evaluation table can use the reference + of already uploaded images. + + This allows the data to be uploaded just once. + """ + data_artifact = self.wandb.Artifact('val', type='dataset') + data_artifact.add(self.data_table, 'val_data') + + self.wandb.run.use_artifact(data_artifact) + data_artifact.wait() + + self.data_table_ref = data_artifact.get('val_data') + + def _log_eval_table(self, idx): + """Log the W&B Tables for model evaluation. + + The table will be logged multiple times creating new version. Use this + to compare models at different intervals interactively. + """ + pred_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_pred', type='evaluation') + pred_artifact.add(self.eval_table, 'eval_data') + if self.by_epoch: + aliases = ['latest', f'epoch_{idx}'] + else: + aliases = ['latest', f'iter_{idx}'] + self.wandb.run.log_artifact(pred_artifact, aliases=aliases)