diff --git a/mpa/cls/stage.py b/mpa/cls/stage.py index be06e2b5..b0034c51 100644 --- a/mpa/cls/stage.py +++ b/mpa/cls/stage.py @@ -195,7 +195,7 @@ def configure_task(cfg, training, model_meta=None, **kwargs): # model configuration update cfg.model.head.num_classes = len(dst_classes) - gamma = 2 if cfg['task_adapt'].get('efficient_mode', False) else 3 + gamma = 2 if cfg['task_adapt'].get('efficient_mode', True) else 3 cfg.model.head.loss = ConfigDict( type='SoftmaxFocalLoss', loss_weight=1.0, @@ -212,7 +212,8 @@ def configure_task(cfg, training, model_meta=None, **kwargs): dst_classes=dst_classes, model_type=cfg.model.type, sampler_flag=sampler_flag, - efficient_mode=cfg['task_adapt'].get('efficient_mode', False) + sampler_type='balanced', + efficient_mode=cfg['task_adapt'].get('efficient_mode', True) ) update_or_add_custom_hook(cfg, task_adapt_hook) diff --git a/mpa/modules/datasets/samplers/balanced_sampler.py b/mpa/modules/datasets/samplers/balanced_sampler.py new file mode 100644 index 00000000..668c80f5 --- /dev/null +++ b/mpa/modules/datasets/samplers/balanced_sampler.py @@ -0,0 +1,65 @@ +import numpy as np +from torch.utils.data.sampler import Sampler +import math +from mpa.utils.logger import get_logger + +logger = get_logger() + + +class BalancedSampler(Sampler): + """Sampler for Class-Incremental Task + This sampler is a sampler that creates an effective batch + In reduce mode, + reduce the iteration size by estimating the trials + that all samples in the tail class are selected more than once with probability 0.999 + + Args: + dataset (Dataset): A built-up dataset + samples_per_gpu (int): batch size of Sampling + efficient_mode (bool): Flag about using efficient mode + """ + def __init__(self, dataset, batch_size, efficient_mode=True): + self.batch_size = batch_size + self.repeat = 1 + if hasattr(dataset, 'times'): + self.repeat = dataset.times + if hasattr(dataset, 'dataset'): + self.dataset = dataset.dataset + else: + self.dataset = dataset + self.img_indices = self.dataset.img_indices + self.num_cls = len(self.img_indices.keys()) + self.data_length = len(self.dataset) + + if efficient_mode: + # Reduce the # of sampling (sampling data for a single epoch) + self.num_tail = min([len(cls_indices) for cls_indices in self.img_indices.values()]) + base = 1 - (1/self.num_tail) + if base == 0: + raise ValueError('Required more than one sample per class') + self.num_trials = int(math.log(0.001, base)) + if int(self.data_length / self.num_cls) < self.num_trials: + self.num_trials = int(self.data_length / self.num_cls) + else: + self.num_trials = int(self.data_length / self.num_cls) + self.compute_sampler_length() + logger.info(f"This sampler will select balanced samples {self.num_trials} times") + + def compute_sampler_length(self): + self.sampler_length = self.num_trials * self.num_cls * self.repeat + + def __iter__(self): + indices = [] + for _ in range(self.repeat): + for i in range(self.num_trials): + indice = np.concatenate( + [np.random.choice(self.img_indices[cls_indices], 1) for cls_indices in self.img_indices.keys()]) + indices.append(indice) + + indices = np.concatenate(indices) + indices = indices.astype(np.int64).tolist() + + return iter(indices) + + def __len__(self): + return self.sampler_length diff --git a/mpa/modules/hooks/task_adapt_hook.py b/mpa/modules/hooks/task_adapt_hook.py index fd07b653..840e374c 100644 --- a/mpa/modules/hooks/task_adapt_hook.py +++ b/mpa/modules/hooks/task_adapt_hook.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader from mpa.modules.datasets.samplers.cls_incr_sampler import ClsIncrSampler +from mpa.modules.datasets.samplers.balanced_sampler import BalancedSampler from mpa.utils.logger import get_logger logger = get_logger() @@ -28,12 +29,14 @@ def __init__(self, dst_classes, model_type='FasterRCNN', sampler_flag=False, + sampler_type='cls_incr', efficient_mode=False): self.src_classes = src_classes self.dst_classes = dst_classes self.model_type = model_type - self.efficient_mode = efficient_mode self.sampler_flag = sampler_flag + self.sampler_type = sampler_type + self.efficient_mode = efficient_mode logger.info(f'Task Adaptation: {self.src_classes} => {self.dst_classes}') logger.info(f'- Efficient Mode: {self.efficient_mode}') @@ -47,7 +50,10 @@ def before_epoch(self, runner): num_workers = runner.data_loader.num_workers collate_fn = runner.data_loader.collate_fn worker_init_fn = runner.data_loader.worker_init_fn - sampler = ClsIncrSampler(dataset, batch_size, efficient_mode=self.efficient_mode) + if self.sampler_type == 'balanced': + sampler = BalancedSampler(dataset, batch_size, efficient_mode=self.efficient_mode) + else: + sampler = ClsIncrSampler(dataset, batch_size, efficient_mode=self.efficient_mode) runner.data_loader = DataLoader( dataset, batch_size=batch_size,