From 028f9bb2bca67faa71abe2fb2e848a8478de8ff4 Mon Sep 17 00:00:00 2001 From: Haonan Yu <51248379+hnyu@users.noreply.github.com> Date: Tue, 28 Mar 2023 15:30:51 -0700 Subject: [PATCH] rename pprint and improve matching loss (#1455) * rename pprint and improve matching loss * address comments and add another test * docstring for return values * moved comment location * no adding max_cost --- alf/trainers/policy_trainer.py | 2 +- alf/utils/losses.py | 58 +++++----- alf/utils/losses_test.py | 128 +++++++++++++++++++++-- alf/utils/{pprint.py => pretty_print.py} | 0 4 files changed, 147 insertions(+), 41 deletions(-) rename alf/utils/{pprint.py => pretty_print.py} (100%) diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 83c4d6482..6c9d0f26e 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -42,7 +42,7 @@ from alf.utils import common from alf.utils import git_utils from alf.utils import math_ops -from alf.utils.pprint import pformat_pycolor +from alf.utils.pretty_print import pformat_pycolor from alf.utils.checkpoint_utils import Checkpointer import alf.utils.datagen as datagen from alf.utils.summary_utils import record_time diff --git a/alf/utils/losses.py b/alf/utils/losses.py index cc01e7008..9ec459e3d 100644 --- a/alf/utils/losses.py +++ b/alf/utils/losses.py @@ -714,56 +714,50 @@ class BipartiteMatchingLoss(object): """ def __init__(self, - pair_loss_fn: Callable = torch.cdist, reduction: str = 'mean', name: str = "BipartiteMatchingLoss"): """ Args: - pair_loss_fn: the pairwise matching loss function. It should take - two sets of length ``N`` and output a cost matrix of shape ``[N,N]``. reduction: 'sum', 'mean' or 'none'. This is how to reduce the matching loss. For the former two, the loss shape is ``[B]``, while for the 'none', the loss shape is ``[B,N]``. """ super().__init__() - self._pair_loss_fn = pair_loss_fn self._reduction = reduction assert reduction in ['mean', 'sum', 'none'] self._name = name def forward(self, - prediction: Tensor, - target: Tensor, - target_mask: Tensor = None): - """ + matching_cost_mat: torch.Tensor, + cost_mat: torch.Tensor = None): + """Compute the optimal matching loss. + Args: - prediction: the predicted set with a shape of ``[B,N,...]``. - target: the target set with a shape of ``[B,N,...]``. - target_mask: the valid mask for the target set. We assume that the - target and prediction sets each always contains ``N`` objects, but - some objects in the target set are just paddings whose mask values - are 0. These paddings will always have a matching loss of 0. The - shape should be ``[B,N]``. If None, then all target objects are - valid. + matching_cost_mat: the cost matrix used to determine the optimal + matching. It shape should be ``[B,N,N]``. + cost_mat: the cost matrix used to compute the optimal loss once the + optimal matching is found. According to the DETR paper, this + cost matrix might be different from the one used for matching. + If None, then it will be the same matrix for matching. + + Returns: + tuple: + - the optimal loss. If reduction is 'mean' or 'sum', its shape is + ``[B,N]``, otherwise its shape is ``[B,N,N]``. + - the optimal matching given the cost matrix. Its shape is ``[B,N]``, + where the value of n-th entry is its mapped index in the target set. """ - assert prediction.shape[:2] == target.shape[:2] - B, N = prediction.shape[:2] - cost_mat = self._pair_loss_fn(prediction, target) # [B,N,N] - assert cost_mat.shape == (B, N, N), ( - "The pairwise loss function must enumerate all pairs and output " - "a scalar loss for each pair!") - - # mask out any cost with mask values=0 - if target_mask is not None: - target_mask = target_mask.unsqueeze(1) # [B,1,N] - cost_mat = cost_mat * target_mask + if cost_mat is None: + cost_mat = matching_cost_mat with torch.no_grad(): + B, N = matching_cost_mat.shape[:2] + max_cost = matching_cost_mat.max() + 1. # [B*N, B*N] - max_cost = cost_mat.max() + 1. - big_cost_mat = torch.block_diag(*list(cost_mat - max_cost)) - # fill in all off-diag entries with a max cost - big_cost_mat = big_cost_mat + max_cost + # Subtract all diag entries by a max cost so that no off-diag matchings + # will be optimal. + big_cost_mat = torch.block_diag( + *list(matching_cost_mat - max_cost)) np_big_cost_mat = big_cost_mat.cpu().numpy() # col_ind: [B*N] row_ind, col_ind = linear_sum_assignment(np_big_cost_mat) @@ -777,7 +771,7 @@ def forward(self, optimal_loss = optimal_loss.mean(-1) elif self._reduction == 'sum': optimal_loss = optimal_loss.sum(-1) - return optimal_loss + return optimal_loss, col_ind.squeeze(-1) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/alf/utils/losses_test.py b/alf/utils/losses_test.py index 16c8ee4ca..dda339461 100644 --- a/alf/utils/losses_test.py +++ b/alf/utils/losses_test.py @@ -91,7 +91,8 @@ def test_loss_shape(self, reduction): target = torch.rand([2, 5, 4]) matcher = losses.BipartiteMatchingLoss(reduction=reduction) - loss = matcher(prediction, target) + cost_mat = torch.cdist(prediction, target, p=1) + loss, _ = matcher(cost_mat) if reduction == 'none': self.assertEqual(loss.shape, (2, 5)) else: @@ -104,9 +105,9 @@ def test_forward_loss(self): requires_grad=True) target = torch.tensor([[[0.9, 0.9, 0.9], [0, 0, 0.1]], [[0.1, 0.1, 0.1], [0.5, 0.6, 0.5]]]) - matcher = losses.BipartiteMatchingLoss( - pair_loss_fn=partial(torch.cdist, p=1), reduction='none') - loss = matcher(prediction, target) + matcher = losses.BipartiteMatchingLoss(reduction='none') + cost_mat = torch.cdist(prediction, target, p=1) + loss, _ = matcher(cost_mat) self.assertTrue(loss.requires_grad) self.assertTensorClose(loss, torch.tensor([[0.1, 0.3], [0.3, 1.4]])) @@ -162,8 +163,7 @@ def test_loss_training(self): optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-3) epochs = 10 batch_size = 100 - matcher = losses.BipartiteMatchingLoss( - pair_loss_fn=partial(torch.cdist, p=2), reduction='mean') + matcher = losses.BipartiteMatchingLoss(reduction='mean') for _ in range(epochs): idx = torch.randperm(tr_inputs.shape[0]) tr_inputs = tr_inputs[idx] @@ -175,7 +175,8 @@ def test_loss_training(self): b_target = tr_target[i:i + batch_size] b_pred = model(b_inputs) # [b,N+1,1] b_pred = b_pred[:, 1:, :] - loss = matcher(b_pred, b_target).mean() + cost_mat = torch.cdist(b_pred, b_target, p=2.) + loss = matcher(cost_mat)[0].mean() loss.backward() optimizer.step() l.append(loss) @@ -185,7 +186,8 @@ def test_loss_training(self): val_pred = model(val_inputs) val_pred = val_pred[:, 1:, :] - val_loss = matcher(val_pred, val_target) + cost_mat = torch.cdist(val_pred, val_target, p=2.) + val_loss = matcher(cost_mat)[0] print("Validation prediction - inputs") print(torch.round(val_pred[:3] - val_inputs[:3, :1, :], decimals=2)) @@ -193,6 +195,116 @@ def test_loss_training(self): print("Validation loss: ", val_loss.mean()) self.assertLess(float(val_loss.mean()), 0.15) + def test_loss_training_discrete_target(self): + """A simple toy task for testing bipartite matching loss on discrete + target variables. + + The inputs are sampled from some fixed random Gaussians, each Gaussian + representing a different object class. The target values are Gaussian ids, + but shuffled in a random order for each sample. + """ + samples_n = 10200 + M, N, D = 3, 5, 10 + mean = torch.randn((N, D)).unsqueeze(0) # [1,N,D] + std = torch.ones_like(mean) * 0.5 # [1,N,D] + + inputs = torch.normal( + mean.repeat(samples_n, 1, 1), std.repeat(samples_n, 1, + 1)) # [samples_n,N,D] + target = torch.arange(N).unsqueeze(0).repeat(samples_n, + 1) # [samples_n,N] + + # randomly shuffle + idx = torch.argsort(torch.randn(samples_n, N), dim=1) + inputs = torch.gather( + inputs, dim=1, index=idx.unsqueeze(-1).expand(-1, -1, D)) + target = torch.gather(target, dim=1, index=idx) + # Only take the first M objects for each sample + inputs = inputs[:, :M, :] + target = target[:, :M] + # shuffle the target again ... + idx = torch.argsort(torch.randn(samples_n, M), dim=1) + target = torch.gather(target, dim=1, index=idx) + + d_model = 64 + transform_layers = [] + for i in range(3): + transform_layers.append( + alf.layers.TransformerBlock( + d_model=d_model, + num_heads=3, + memory_size=M * 2, # input + queries + positional_encoding='abs' if i == 0 else 'none')) + model = torch.nn.Sequential(*transform_layers, alf.layers.FC( + d_model, N)) + input_fc = alf.layers.FC(D, d_model) + + queries = torch.nn.Parameter(torch.Tensor(M, d_model)) + torch.nn.init.normal_(queries) + + val_n = 200 + tr_inputs = inputs[:-val_n, ...] + val_inputs = inputs[-val_n:, ...] + tr_target = target[:-val_n, ...] + val_target = target[-val_n:, ...] + + def _compute_cost_mat(p, t): + p = torch.nn.functional.log_softmax(p, dim=-1) + oh_t = torch.nn.functional.one_hot( + t, num_classes=N).to(torch.float32) + return -torch.einsum('bnk,bmk->bnm', p, oh_t) + + optimizer = torch.optim.Adam( + list(model.parameters()) + list(input_fc.parameters()) + [queries], + lr=1e-3) + epochs = 5 + batch_size = 100 + matcher = losses.BipartiteMatchingLoss(reduction='mean') + for _ in range(epochs): + idx = torch.randperm(tr_inputs.shape[0]) + tr_inputs = tr_inputs[idx] + tr_target = tr_target[idx] + l = [] + for i in range(0, idx.shape[0], batch_size): + optimizer.zero_grad() + b_inputs = tr_inputs[i:i + batch_size] + b_target = tr_target[i:i + batch_size] + b_inputs = input_fc(b_inputs) + b_queries = queries.unsqueeze(0).expand( + b_inputs.shape[0], -1, -1) + b_inputs = torch.cat([b_inputs, b_queries], dim=1) # [b,2M,..] + b_pred = model(b_inputs) # [b,2M,N] + b_pred = b_pred[:, M:, :] + cost_mat = _compute_cost_mat(b_pred, b_target) + loss = matcher(cost_mat)[0].mean() + loss.backward() + optimizer.step() + l.append(loss) + print("Training loss: ", sum(l) / len(l)) + + val_fc_out = input_fc(val_inputs) + val_queries = queries.unsqueeze(0).expand(val_fc_out.shape[0], -1, -1) + val_fc_out = torch.cat([val_fc_out, val_queries], dim=1) # [b,2M,..] + val_pred = model(val_fc_out) + val_pred = val_pred[:, M:, :] + cost_mat = _compute_cost_mat(val_pred, val_target) + val_loss = matcher(cost_mat)[0] + + print("Cluster mean") + print(mean) + + print("Validation inputs") + print(val_inputs[:5]) + + print("Validation prediction") + print(torch.argmax(val_pred, dim=-1)[:5]) + + print("Validation target") + print(val_target[:5]) + + print("Validation loss: ", val_loss.mean()) + self.assertLess(float(val_loss.mean()), 0.01) + if __name__ == '__main__': logging.set_verbosity(logging.INFO) diff --git a/alf/utils/pprint.py b/alf/utils/pretty_print.py similarity index 100% rename from alf/utils/pprint.py rename to alf/utils/pretty_print.py