A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021).
Paper: https://arxiv.org/abs/2106.10777Objective for metric learning:
triplet_loss = triplet_(ml_real_out,ml_real_out_shuffle,ml_fake_out_shuffle)
Objective for manifold matching with learned metric:
g_loss = p_dist + c_dist
where
ml_real_out = netML(real_img) # real data
ml_fake_out = netML(fake_img) # generated data
# shuffle in batch
r1=torch.randperm(batch_size)
r2=torch.randperm(batch_size)
ml_real_out_shuffle = ml_real_out[r1[:, None]].view(ml_real_out.shape[0],ml_real_out.shape[-1])
ml_fake_out_shuffle = ml_fake_out[r2[:, None]].view(ml_fake_out.shape[0],ml_fake_out.shape[-1])
# pairwise distances
pd_r = pairwise_distances(ml_real_out, ml_real_out)
pd_f = pairwise_distances(ml_fake_out, ml_fake_out)
# matching terms
p_dist = torch.dist(pd_r,pd_f,2) # matching 2-diameters
c_dist = torch.dist(ml_real_out.mean(0),ml_fake_out.mean(0),2) # matching centroids
To train a model for unconditonal generation, run:
python train.py
We also tried our objective on generating higher resolution images using a StyleGAN2 data generator and a simple metric generator. Implemenation details can be found here. Below are randomly generated 512x512 samples on FFHQ dataset at ~150K iterations:
@InProceedings{Dai_2021_ICCV,
author = {Dai, Mengyu and Hang, Haibin},
title = {Manifold Matching via Deep Metric Learning for Generative Modeling},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {6587-6597}
}