Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add PoseFormer backbone #1215

Open
wants to merge 24 commits into
base: dev-0.26
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9b3baea
[Fix] Update mmcv installation CI and doc (#1205)
ly015 Feb 24, 2022
196debc
add derepcation message for deploy tools
QwQ2000 Feb 24, 2022
ee41593
change import warnings positions
QwQ2000 Feb 24, 2022
5142fba
do yapf
QwQ2000 Feb 24, 2022
6cadba0
do isort
QwQ2000 Feb 24, 2022
f64a14d
Add deprecation message for deploy tools (#1207)
QwQ2000 Feb 28, 2022
fd1e4fb
refactor dataset evaluation interface (#1209)
ly015 Feb 28, 2022
fcb75a6
[Feature] Add hrformer backbone (#1203)
zengwang430521 Mar 1, 2022
23d671e
[Fix] Fix data collate and scatter in inference (#1175)
ly015 Mar 1, 2022
33434b4
[Enhacemnet] api train support cpu training for mmcv<1.4.4 (#1161)
EasonQYS Mar 2, 2022
f198c80
[Feature] Switch to openmmlab pre-commit-hook for copyright check (#1…
ly015 Mar 2, 2022
1f449dd
PoseFormer backbone, head and config
QwQ2000 Mar 3, 2022
031e034
fix color channel order in visualization functions and docs (#1212)
ly015 Mar 3, 2022
1342818
[Feature] Add Windows CI (#1213)
ly015 Mar 4, 2022
40d5df7
Merge branch 'dev-0.24' into dev-0.24
QwQ2000 Mar 6, 2022
f8a2665
replace einops.rearrange
QwQ2000 Mar 8, 2022
ce2adbd
Merge branch 'dev-0.24' of github.com:QwQ2000/mmpose into dev-0.24
QwQ2000 Mar 8, 2022
82a2bb4
Merge branch 'dev-0.25' into dev-0.24
QwQ2000 Mar 8, 2022
73687ca
Add PoseFormer unit test.
QwQ2000 Mar 8, 2022
e035c23
Merge branch 'dev-0.24' of github.com:QwQ2000/mmpose into dev-0.24
QwQ2000 Mar 8, 2022
1d3cd3c
Match official weights & improve code clarity
QwQ2000 Mar 26, 2022
8302c2a
Update poseformer_h36m_81frame_cpn.py
QwQ2000 Apr 27, 2022
9878ac5
Update poseformer_h36m_81frame_cpn.py
QwQ2000 Apr 27, 2022
30fca23
Update poseformer_h36m_81frame_cpn.py
QwQ2000 Apr 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/h36m.py'
]
evaluation = dict(
interval=10, metric=['mpjpe', 'p-mpjpe'], key_indicator='MPJPE')

# optimizer settings
optimizer = dict(
type='Adam',
lr=2e-4,
weight_decay=0.1
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='exp',
by_epoch=True,
gamma=0.98,
)

total_epochs = 130

log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

# model settings
model = dict(
type='PoseLifter',
pretrained=None,
backbone=dict(
type='PoseFormer', num_frame=81, drop_path_rate=0.1),
keypoint_head=dict(
type='PoseFormerHead', loss_keypoint=dict(type='MPJPELoss')),
train_cfg=dict(),
test_cfg=dict(restore_global_position=True))

# data settings
data_root = 'data/h36m'
train_data_cfg = dict(
num_joints=17,
seq_len=81,
seq_frame_interval=1,
causal=False,
temporal_padding=False,
joint_2d_src='detection',
joint_2d_det_file=f'{data_root}/joint_2d_det_files/' +
'cpn_ft_h36m_dbb_train.npy',
need_camera_param=True,
camera_param_file=f'{data_root}/annotation_body3d/cameras.pkl',
)
test_data_cfg = dict(
num_joints=17,
seq_len=81,
seq_frame_interval=1,
causal=False,
temporal_padding=False,
joint_2d_src='detection',
joint_2d_det_file=f'{data_root}/joint_2d_det_files/' +
'cpn_ft_h36m_dbb_test.npy',
need_camera_param=True,
camera_param_file=f'{data_root}/annotation_body3d/cameras.pkl',
)

train_pipeline = [
dict(
type='GetRootCenteredPose',
item='target',
visible_item='target_visible',
root_index=0,
root_name='root_position',
remove_root=False),
dict(type='ImageCoordinateNormalization', item='input_2d'),
dict(
type='RelativeJointRandomFlip',
item=['input_2d', 'target'],
flip_cfg=[
dict(center_mode='static', center_x=0.),
dict(center_mode='root', center_index=0)
],
visible_item=['input_2d_visible', 'target_visible'],
flip_prob=0.5),
dict(type='PoseSequenceToTensor', item='input_2d', reshape=False),
dict(
type='Collect',
keys=[('input_2d', 'input'), 'target'],
meta_name='metas',
meta_keys=['target_image_path', 'flip_pairs', 'root_position'])
]

val_pipeline = [
dict(
type='GetRootCenteredPose',
item='target',
visible_item='target_visible',
root_index=0,
root_name='root_position',
remove_root=False),
dict(type='ImageCoordinateNormalization', item='input_2d'),
dict(type='PoseSequenceToTensor', item='input_2d', reshape=False),
dict(
type='Collect',
keys=[('input_2d', 'input'), 'target'],
meta_name='metas',
meta_keys=['target_image_path', 'flip_pairs', 'root_position'])
]

test_pipeline = val_pipeline

data = dict(
samples_per_gpu=128,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=128),
test_dataloader=dict(samples_per_gpu=128),
train=dict(
type='Body3DH36MDataset',
ann_file=f'{data_root}/annotation_body3d/fps50/h36m_train.npz',
img_prefix=f'{data_root}/images/',
data_cfg=train_data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='Body3DH36MDataset',
ann_file=f'{data_root}/annotation_body3d/fps50/h36m_test.npz',
img_prefix=f'{data_root}/images/',
data_cfg=test_data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='Body3DH36MDataset',
ann_file=f'{data_root}/annotation_body3d/fps50/h36m_test.npz',
img_prefix=f'{data_root}/images/',
data_cfg=test_data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
8 changes: 5 additions & 3 deletions mmpose/datasets/pipelines/pose3d_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,9 @@ class PoseSequenceToTensor:
item
"""

def __init__(self, item):
def __init__(self, item, reshape=True):
self.item = item
self.reshape = reshape

def __call__(self, results):
assert self.item in results
Expand All @@ -459,8 +460,9 @@ def __call__(self, results):
if seq.ndim == 2:
seq = seq[None, ...]

T = seq.shape[0]
seq = seq.transpose(1, 2, 0).reshape(-1, T)
if self.reshape:
T = seq.shape[0]
seq = seq.transpose(1, 2, 0).reshape(-1, T)
results[self.item] = torch.from_numpy(seq)

return results
Expand Down
3 changes: 2 additions & 1 deletion mmpose/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mspn import MSPN
from .poseformer import PoseFormer
from .regnet import RegNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1d
Expand All @@ -30,5 +31,5 @@
'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
'LiteHRNet', 'V2VNet', 'HRFormer'
'LiteHRNet', 'V2VNet', 'HRFormer', 'PoseFormer'
]
Loading