Skip to content

Commit

Permalink
rename pprint and improve matching loss (#1455)
Browse files Browse the repository at this point in the history
* rename pprint and improve matching loss

* address comments and add another test

* docstring for return values

* moved comment location

* no adding max_cost
  • Loading branch information
hnyu authored Mar 28, 2023
1 parent 997084a commit 028f9bb
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 41 deletions.
2 changes: 1 addition & 1 deletion alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from alf.utils import common
from alf.utils import git_utils
from alf.utils import math_ops
from alf.utils.pprint import pformat_pycolor
from alf.utils.pretty_print import pformat_pycolor
from alf.utils.checkpoint_utils import Checkpointer
import alf.utils.datagen as datagen
from alf.utils.summary_utils import record_time
Expand Down
58 changes: 26 additions & 32 deletions alf/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,56 +714,50 @@ class BipartiteMatchingLoss(object):
"""

def __init__(self,
pair_loss_fn: Callable = torch.cdist,
reduction: str = 'mean',
name: str = "BipartiteMatchingLoss"):
"""
Args:
pair_loss_fn: the pairwise matching loss function. It should take
two sets of length ``N`` and output a cost matrix of shape ``[N,N]``.
reduction: 'sum', 'mean' or 'none'. This is how to reduce the matching
loss. For the former two, the loss shape is ``[B]``, while for
the 'none', the loss shape is ``[B,N]``.
"""
super().__init__()
self._pair_loss_fn = pair_loss_fn
self._reduction = reduction
assert reduction in ['mean', 'sum', 'none']
self._name = name

def forward(self,
prediction: Tensor,
target: Tensor,
target_mask: Tensor = None):
"""
matching_cost_mat: torch.Tensor,
cost_mat: torch.Tensor = None):
"""Compute the optimal matching loss.
Args:
prediction: the predicted set with a shape of ``[B,N,...]``.
target: the target set with a shape of ``[B,N,...]``.
target_mask: the valid mask for the target set. We assume that the
target and prediction sets each always contains ``N`` objects, but
some objects in the target set are just paddings whose mask values
are 0. These paddings will always have a matching loss of 0. The
shape should be ``[B,N]``. If None, then all target objects are
valid.
matching_cost_mat: the cost matrix used to determine the optimal
matching. It shape should be ``[B,N,N]``.
cost_mat: the cost matrix used to compute the optimal loss once the
optimal matching is found. According to the DETR paper, this
cost matrix might be different from the one used for matching.
If None, then it will be the same matrix for matching.
Returns:
tuple:
- the optimal loss. If reduction is 'mean' or 'sum', its shape is
``[B,N]``, otherwise its shape is ``[B,N,N]``.
- the optimal matching given the cost matrix. Its shape is ``[B,N]``,
where the value of n-th entry is its mapped index in the target set.
"""
assert prediction.shape[:2] == target.shape[:2]
B, N = prediction.shape[:2]
cost_mat = self._pair_loss_fn(prediction, target) # [B,N,N]
assert cost_mat.shape == (B, N, N), (
"The pairwise loss function must enumerate all pairs and output "
"a scalar loss for each pair!")

# mask out any cost with mask values=0
if target_mask is not None:
target_mask = target_mask.unsqueeze(1) # [B,1,N]
cost_mat = cost_mat * target_mask
if cost_mat is None:
cost_mat = matching_cost_mat

with torch.no_grad():
B, N = matching_cost_mat.shape[:2]
max_cost = matching_cost_mat.max() + 1.
# [B*N, B*N]
max_cost = cost_mat.max() + 1.
big_cost_mat = torch.block_diag(*list(cost_mat - max_cost))
# fill in all off-diag entries with a max cost
big_cost_mat = big_cost_mat + max_cost
# Subtract all diag entries by a max cost so that no off-diag matchings
# will be optimal.
big_cost_mat = torch.block_diag(
*list(matching_cost_mat - max_cost))
np_big_cost_mat = big_cost_mat.cpu().numpy()
# col_ind: [B*N]
row_ind, col_ind = linear_sum_assignment(np_big_cost_mat)
Expand All @@ -777,7 +771,7 @@ def forward(self,
optimal_loss = optimal_loss.mean(-1)
elif self._reduction == 'sum':
optimal_loss = optimal_loss.sum(-1)
return optimal_loss
return optimal_loss, col_ind.squeeze(-1)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
128 changes: 120 additions & 8 deletions alf/utils/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_loss_shape(self, reduction):
target = torch.rand([2, 5, 4])

matcher = losses.BipartiteMatchingLoss(reduction=reduction)
loss = matcher(prediction, target)
cost_mat = torch.cdist(prediction, target, p=1)
loss, _ = matcher(cost_mat)
if reduction == 'none':
self.assertEqual(loss.shape, (2, 5))
else:
Expand All @@ -104,9 +105,9 @@ def test_forward_loss(self):
requires_grad=True)
target = torch.tensor([[[0.9, 0.9, 0.9], [0, 0, 0.1]],
[[0.1, 0.1, 0.1], [0.5, 0.6, 0.5]]])
matcher = losses.BipartiteMatchingLoss(
pair_loss_fn=partial(torch.cdist, p=1), reduction='none')
loss = matcher(prediction, target)
matcher = losses.BipartiteMatchingLoss(reduction='none')
cost_mat = torch.cdist(prediction, target, p=1)
loss, _ = matcher(cost_mat)
self.assertTrue(loss.requires_grad)
self.assertTensorClose(loss, torch.tensor([[0.1, 0.3], [0.3, 1.4]]))

Expand Down Expand Up @@ -162,8 +163,7 @@ def test_loss_training(self):
optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-3)
epochs = 10
batch_size = 100
matcher = losses.BipartiteMatchingLoss(
pair_loss_fn=partial(torch.cdist, p=2), reduction='mean')
matcher = losses.BipartiteMatchingLoss(reduction='mean')
for _ in range(epochs):
idx = torch.randperm(tr_inputs.shape[0])
tr_inputs = tr_inputs[idx]
Expand All @@ -175,7 +175,8 @@ def test_loss_training(self):
b_target = tr_target[i:i + batch_size]
b_pred = model(b_inputs) # [b,N+1,1]
b_pred = b_pred[:, 1:, :]
loss = matcher(b_pred, b_target).mean()
cost_mat = torch.cdist(b_pred, b_target, p=2.)
loss = matcher(cost_mat)[0].mean()
loss.backward()
optimizer.step()
l.append(loss)
Expand All @@ -185,14 +186,125 @@ def test_loss_training(self):

val_pred = model(val_inputs)
val_pred = val_pred[:, 1:, :]
val_loss = matcher(val_pred, val_target)
cost_mat = torch.cdist(val_pred, val_target, p=2.)
val_loss = matcher(cost_mat)[0]

print("Validation prediction - inputs")
print(torch.round(val_pred[:3] - val_inputs[:3, :1, :], decimals=2))

print("Validation loss: ", val_loss.mean())
self.assertLess(float(val_loss.mean()), 0.15)

def test_loss_training_discrete_target(self):
"""A simple toy task for testing bipartite matching loss on discrete
target variables.
The inputs are sampled from some fixed random Gaussians, each Gaussian
representing a different object class. The target values are Gaussian ids,
but shuffled in a random order for each sample.
"""
samples_n = 10200
M, N, D = 3, 5, 10
mean = torch.randn((N, D)).unsqueeze(0) # [1,N,D]
std = torch.ones_like(mean) * 0.5 # [1,N,D]

inputs = torch.normal(
mean.repeat(samples_n, 1, 1), std.repeat(samples_n, 1,
1)) # [samples_n,N,D]
target = torch.arange(N).unsqueeze(0).repeat(samples_n,
1) # [samples_n,N]

# randomly shuffle
idx = torch.argsort(torch.randn(samples_n, N), dim=1)
inputs = torch.gather(
inputs, dim=1, index=idx.unsqueeze(-1).expand(-1, -1, D))
target = torch.gather(target, dim=1, index=idx)
# Only take the first M objects for each sample
inputs = inputs[:, :M, :]
target = target[:, :M]
# shuffle the target again ...
idx = torch.argsort(torch.randn(samples_n, M), dim=1)
target = torch.gather(target, dim=1, index=idx)

d_model = 64
transform_layers = []
for i in range(3):
transform_layers.append(
alf.layers.TransformerBlock(
d_model=d_model,
num_heads=3,
memory_size=M * 2, # input + queries
positional_encoding='abs' if i == 0 else 'none'))
model = torch.nn.Sequential(*transform_layers, alf.layers.FC(
d_model, N))
input_fc = alf.layers.FC(D, d_model)

queries = torch.nn.Parameter(torch.Tensor(M, d_model))
torch.nn.init.normal_(queries)

val_n = 200
tr_inputs = inputs[:-val_n, ...]
val_inputs = inputs[-val_n:, ...]
tr_target = target[:-val_n, ...]
val_target = target[-val_n:, ...]

def _compute_cost_mat(p, t):
p = torch.nn.functional.log_softmax(p, dim=-1)
oh_t = torch.nn.functional.one_hot(
t, num_classes=N).to(torch.float32)
return -torch.einsum('bnk,bmk->bnm', p, oh_t)

optimizer = torch.optim.Adam(
list(model.parameters()) + list(input_fc.parameters()) + [queries],
lr=1e-3)
epochs = 5
batch_size = 100
matcher = losses.BipartiteMatchingLoss(reduction='mean')
for _ in range(epochs):
idx = torch.randperm(tr_inputs.shape[0])
tr_inputs = tr_inputs[idx]
tr_target = tr_target[idx]
l = []
for i in range(0, idx.shape[0], batch_size):
optimizer.zero_grad()
b_inputs = tr_inputs[i:i + batch_size]
b_target = tr_target[i:i + batch_size]
b_inputs = input_fc(b_inputs)
b_queries = queries.unsqueeze(0).expand(
b_inputs.shape[0], -1, -1)
b_inputs = torch.cat([b_inputs, b_queries], dim=1) # [b,2M,..]
b_pred = model(b_inputs) # [b,2M,N]
b_pred = b_pred[:, M:, :]
cost_mat = _compute_cost_mat(b_pred, b_target)
loss = matcher(cost_mat)[0].mean()
loss.backward()
optimizer.step()
l.append(loss)
print("Training loss: ", sum(l) / len(l))

val_fc_out = input_fc(val_inputs)
val_queries = queries.unsqueeze(0).expand(val_fc_out.shape[0], -1, -1)
val_fc_out = torch.cat([val_fc_out, val_queries], dim=1) # [b,2M,..]
val_pred = model(val_fc_out)
val_pred = val_pred[:, M:, :]
cost_mat = _compute_cost_mat(val_pred, val_target)
val_loss = matcher(cost_mat)[0]

print("Cluster mean")
print(mean)

print("Validation inputs")
print(val_inputs[:5])

print("Validation prediction")
print(torch.argmax(val_pred, dim=-1)[:5])

print("Validation target")
print(val_target[:5])

print("Validation loss: ", val_loss.mean())
self.assertLess(float(val_loss.mean()), 0.01)


if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
Expand Down
File renamed without changes.

0 comments on commit 028f9bb

Please sign in to comment.