-
Notifications
You must be signed in to change notification settings - Fork 7
/
metrics.py
60 lines (48 loc) · 2.75 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from smplx import SMPL
import config
def compute_pve_neutral_pose_scale_corrected(predicted_smpl_shape, target_smpl_shape, gender):
"""
Given predicted and target SMPL shape parameters, computes neutral-pose per-vertex error
after scale-correction (to account for scale vs camera depth ambiguity).
:param predicted_smpl_parameters: predicted SMPL shape parameters tensor with shape (1, 10)
:param target_smpl_parameters: target SMPL shape parameters tensor with shape (1, 10)
:param gender: gender of target
"""
smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='male')
smpl_female = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='female')
# Get neutral pose vertices
if gender == 'm':
pred_smpl_neutral_pose_output = smpl_male(betas=predicted_smpl_shape)
target_smpl_neutral_pose_output = smpl_male(betas=target_smpl_shape)
elif gender == 'f':
pred_smpl_neutral_pose_output = smpl_female(betas=predicted_smpl_shape)
target_smpl_neutral_pose_output = smpl_female(betas=target_smpl_shape)
pred_smpl_neutral_pose_vertices = pred_smpl_neutral_pose_output.vertices
target_smpl_neutral_pose_vertices = target_smpl_neutral_pose_output.vertices
# Rescale such that RMSD of predicted vertex mesh is the same as RMSD of target mesh.
# This is done to combat scale vs camera depth ambiguity.
pred_smpl_neutral_pose_vertices_rescale = scale_and_translation_transform_batch(pred_smpl_neutral_pose_vertices,
target_smpl_neutral_pose_vertices)
# Compute PVE-T-SC
pve_neutral_pose_scale_corrected = np.linalg.norm(pred_smpl_neutral_pose_vertices_rescale
- target_smpl_neutral_pose_vertices,
axis=-1) # (1, 6890)
return pve_neutral_pose_scale_corrected
def scale_and_translation_transform_batch(P, T):
"""
First normalises batch of input 3D meshes P such that each mesh has mean (0, 0, 0) and
RMS distance from mean = 1.
Then transforms P such that it has the same mean and RMSD as T.
:param P: (batch_size, N, 3) batch of N 3D meshes to transform.
:param T: (batch_size, N, 3) batch of N reference 3D meshes.
:return: P transformed
"""
P_mean = np.mean(P, axis=1, keepdims=True)
P_trans = P - P_mean
P_scale = np.sqrt(np.sum(P_trans ** 2, axis=(1, 2), keepdims=True) / P.shape[1])
P_normalised = P_trans / P_scale
T_mean = np.mean(T, axis=1, keepdims=True)
T_scale = np.sqrt(np.sum((T - T_mean) ** 2, axis=(1, 2), keepdims=True) / T.shape[1])
P_transformed = P_normalised * T_scale + T_mean
return P_transformed