-
Notifications
You must be signed in to change notification settings - Fork 408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
procrustes alignment #2691
Comments
Hi! thanks for your contribution!, great first issue! |
this sounds good, would you be interested in adding it to TM? creating draft PR and then we can help you finish it... 👼 |
Hi @heth27, I took a stab at implementing a batched version of your implementation: import torch
def procrustus_batch(data1, data2):
if data1.shape != data2.shape:
raise ValueError("data1 and data2 must have the same shape")
if data1.ndim == 2:
data1 = data1[None, :, :]
data2 = data2[None, :, :]
data1 -= data1.mean(dim=1, keepdim=True)
data2 -= data2.mean(dim=1, keepdim=True)
data1 /= torch.linalg.norm(data1, dim=[1,2], keepdim=True)
data2 /= torch.linalg.norm(data2, dim=[1,2], keepdim=True)
try:
u, w, v = torch.linalg.svd(torch.matmul(data2.transpose(1, 2), data1).transpose(1,2), full_matrices=False)
except:
raise ValueError("SVD did not converge")
rotation = torch.matmul(u, v)
scale = w.sum(1, keepdim=True)
data2 = scale[:,None] * torch.matmul(data2, rotation.transpose(1,2))
disparity = (data1 - data2).square().sum(dim=[1,2])
return disparity
coords_3d_true = torch.rand(2, 10, 3)
coords_3d_prediction = torch.rand(2, 10, 3)
p2 = procrustus_batch(coords_3d_true.clone(), coords_3d_prediction.clone())
print(p2)
from scipy.spatial import procrustes as procrustes_scipy
for i in range(2):
mtx1, mtx2, disparity = procrustes_scipy(coords_3d_true[i].clone(), coords_3d_prediction[i].clone())
print(disparity) for random inputs it seems to work when comparing against scipy. |
Hi @SkafteNicki thank you, please feel free to create a PR. How do you feel about returning the rotation matrix, or the transformed coordinates as well? They are used for downstream calculation of procrustes-aligned mean per joint position error in a lot of human pose estimation tasks. |
@heth27 I would be fine with that. Maybe it makes sense to add an additional argument like This metric does not fit under any of our current subdomains, do you have a recommendation for what new domain this metric fits under? |
Wikipedia suggests shape analysis. I plan on adding things like Mean-Per-Joint-Position-Error (MPJPE) and Percent-of-correctly-classified-keypoints (PCK) as well when I have more time. Those would also fit the domain. |
🚀 Feature
spatial procrustes alignment, a similarity test for two data sets
Motivation
Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.
Pitch
There is a variant in scipy
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html
Alternatives
Additional context
The implementation I'm using, don't know if it is any good.
example usage:
The problem with this version is that it does not work on batches.
The text was updated successfully, but these errors were encountered: