Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Test in the wild #2

Open
chuhang opened this issue Nov 30, 2018 · 42 comments
Open

Test in the wild #2

chuhang opened this issue Nov 30, 2018 · 42 comments

Comments

@chuhang
Copy link

chuhang commented Nov 30, 2018

Thanks for open-sourcing the code. Are you planning to release code for testing on an arbitrary input video, like the on shown in your "demo in the wild"?

@dariopavllo
Copy link
Contributor

Hi,

Since the process is a bit involved, we did this manually. We might release automated code at some point, although we don't have a precise timeline.

I can give you some pointers on how to do it, but you'll have to write some code:

  • First, you have to predicts the 2D keypoints from the video. You can use Detectron and the pretrained model on COCO (this is what we used). From an implementation perspective, you can either convert the video to a list of images (e.g. using ffmpeg) and run Detectron on the directory, or write custom code to run Detectorn on videos directly.

  • For each frame, Detectron will emit a set of bounding boxes and the corresponding keypoints. If you assume that in your video there is only one person, you can just take the bounding box with the highest probability, otherwise you need to link people between frames (e.g. by bipartite matching on the bounding box overlap). We sticked with the simple single-person scenario.

  • The next step is to predict the 2D keypoints using CPN, which requires the bounding boxes to be supplied externally. This is why we run Detectron first. CPN will give you slightly better poses, but you don't actually need to perform this step since it requires significant hacking of the CPN code. Detectron poses are already good enough.

  • Finally, you have to create a "fake" dataset (a NumPy archive with the same tree structure as the ones we provided for Human3.6M or HumanEva) and run our code using a model trained on COCO poses from Detectron. We haven't yet released this pretrained model to the public, but we're going to do it in a few days (so stay tuned).

Let me know if you need more help!

@Dene33
Copy link

Dene33 commented Dec 2, 2018

Really cool results, guys. Testing on an arbitrary input video/images is a must have for such a smooth implementation. :) Hope to see it soon!

Could you please describe the last step in more details? Where I can check the NumPy archive's structure? Is the pretrained model this one: https://s3.amazonaws.com/video-pose-3d/pretrained_h36m_cpn.bin ?

@dariopavllo
Copy link
Contributor

dariopavllo commented Dec 5, 2018

You can download the pretrained model here:
https://dl.fbaipublicfiles.com/video-pose-3d/d-pt-243.bin

And the corresponding 2D detections are here (just for completeness -- you don't really need them to generate demos in the wild):
https://dl.fbaipublicfiles.com/video-pose-3d/data_2d_h36m_detectron_pt_coco.npz

Specifically, this one (currently unlisted in the documentation) is trained on the poses generated by the pretrained Detectron model. On Detectron, you have to use the configuration 12_2017_baselines/e2e_keypoint_rcnn_R-101-FPN_s1x.yaml.

An important detail of this model is that it takes 3 coefficients per joint (instead of two), namely x, y, and the probability score from the heatmap. If you export the output tensors from Detectron, you will notice that there are 4 coefficients per joint (x, y, logit score before softmax, probability score after softmax), so you just need to take the indices 0, 1, and 3. You don't need to change anything in our code, since it already supports this kind of format.

@dariopavllo
Copy link
Contributor

The format of our datasets is straightforward, e.g.

data = np.load('data_2d_h36m_cpn_ft_h36m_dbb.npz')
actions = data['positions_2d'].item()
print(actions['S1']['Walking 1'][0].shape) # (3477, 17, 2)

The first index corresponds to the subject, the second to the action, and the third to the camera index.

@Dene33
Copy link

Dene33 commented Dec 6, 2018

So, summarizing:

  1. Convert desired video to images and process all the images with Detectron (with config provided);
  2. Pack all the Detectron estimations into the .npz file, for example our_detectron_2d_poses.npz;
  3. Adjust the code to define our .npz file:
    data = np.load('our_detectron_2d_poses.npz')

So at the end we'll have command for visualization like this (with adjusted model and path_to_video):
python run.py -k cpn_ft_h36m_dbb -arc 3,3,3,3,3 -c checkpoint --evaluate d-pt-243.bin --render --viz-subject S1 --viz-action Walking --viz-camera 0 --viz-video path_to_video --viz-output output.gif --viz-size 3 --viz-downsample 2 --viz-limit 60

Everything's right or I miss something?

P.S. Where do the names of actions come from (Detectron names actions, right?)? Where can I get full list and more information maybe? How should I define --viz-action at the end (how can I know how Detectron named the action, actually)? Thank you!

@dariopavllo
Copy link
Contributor

Yes, you got the general idea. Regarding points 2 and 3, the easiest (but hacky) solution would be to just modify our dataset (e.g. data_2d_h36m_detectron_pt_coco.npz) and replace one random action (e.g. Walking of subject S1) with the points provided by Detectron, but you are free to create a new one from scratch. You don't have to modify our code anyway. If you create a new one, make sure to specify it with the -k argument.

The action names are the original ones from Human3.6M, but they are arbitrary. You can name them as you wish.

@bucktoothsir
Copy link

so the only difference between https://s3.amazonaws.com/video-pose-3d/d-pt-243.bin and https://s3.amazonaws.com/video-pose-3d/pretrained_h36m_cpn.bin is the way to get 2d joints?

@dariopavllo
Copy link
Contributor

The difference is that one is trained on 2D poses from Detectron (pre-trained on COCO) and the other is trained on 2D poses from CPN (fine-tuned on Human3.6M). The architecture of the model is the same, the input on which it is trained is the only thing that changes.

@Dene33
Copy link

Dene33 commented Dec 7, 2018

I've managed to run Detectron in Google Colab. You can check the notebook here. As you can see Detectron returns .jpg or .pdf as output (didn't check other formats). So what format exactly do we need to have at the end to pack all of the output files to .npz? Is there some kinda argument in Detectron or should we adjust the code to extract bounding boxes and 2d joints?

I plan to create the .ipynb notebook for VideoPose3D to run in Google Colab as well, so help here is really appreciated. Thank you!

@Godatplay
Copy link

Godatplay commented Dec 7, 2018

@Dene33 FYI your Download Files cell is not working. First your endswith is missing a tuple, it should be:
if file.endswith(('.jpg', '.pdf')):
Second, once that is fixed, there's a file not found error because the files are output with .jpg.jpg for some reason.

@Godatplay
Copy link

There's one specific area where I'm slightly confused. It seems like infer_simple.py only visualizes inference, it doesn't emit any other data, correct? test_net.py could be another option, but it takes its input folder from the config file. The config file you suggested we use wants to pull images from coco_2014_minival.

Did you mean use test_net.py and start with the config file but customize it as needed? Or is there some other process we should use that I'm missing?

@dariopavllo
Copy link
Contributor

You are right about Detectron. The script infer_simple.py does not export raw data, but only a pdf, so you need to modify the code. The change is very easy: you just need to export cls_keyps and cls_boxes. You use cls_boxes to select the bounding box with the highest probability, and then you select the corresponding keypoints in cls_keyps.

I understand that the whole process is a bit involved. The goal of the demos in the wild was to show that our method generalizes well to unseen videos, and we didn't initially plan to release inference code on arbitrary videos since it requires ad-hoc scripts. Due to popular demand, I might release some scripts/diffs in a few days/weeks, so you don't have to reinvent the wheel. :)

The way I implemented it was to modify infer_simple.py to work directly on videos. The easiest and most efficient approach is to call ffmpeg to decode the video to a stream, and read it frame by frame. Then the output bounding boxes / keypoints are just concatenated and saved to a NumPy array.

@Godatplay
Copy link

Thanks for your reply! After going around in a couple circles, that's actually what I'm doing now, although I was just about to get to the point where I needed to test the base keypoint format vs the class format. Thanks for the extra info to save that trouble :)

@Godatplay
Copy link

Godatplay commented Dec 12, 2018

I may have an archive set up, but I'm running into problems actually running it. First of all, I had to transpose the array before saving it. The keypoints have all the joints together per value, whereas your archive has all the values together per joint. Second, I had to wrap my entire ndarray in a list to match the structure of the existing archive. Then, in the Interpreter while testing here, I replace the existing action and resave the npz, which is now ~840 MB instead of ~615 MB for some reason. The npz file for my own clip data is 86 KB, so is it just that I didn't save compressed and you did?

Anyway, the first issue I ran into when actually running my test was my array shape was greater than mocap_length, so I tried just commenting that part out since it seemed to be comparing it to positions_3d, which obviously isn't going to match. But that led to an AssertionError in model.py line 66:
assert x.shape[-1] == self.in_features
Any suggestions?

@Godatplay
Copy link

Godatplay commented Dec 12, 2018

Ah, reading the code more it looks like the list I had to wrap everything around is for the cameras. And for reference, the original data set:

>>> actions['S1']['Walking 1'][0].shape
(3477, 17, 3)

my replacement:

>>> myactions['S1']['Walking 1'][0].shape
(320, 17, 4)

[edit: expanded shape to show all 3 dimensions]

@Godatplay
Copy link

Godatplay commented Dec 12, 2018

I've been (clumsily heh) trying to debug my issue, and I am running into a weird problem that may be related to the AssertionError, I can't tell yet. Whenever I try to save my camera list - i.e. the keypoints numpy array simply wrapped in a list as the only item - that list gets converted as another dimension in the enclosed numpy array. In other words my [(320, 17, 4)] turns into (1, 320, 17, 4). And I have yet to find a way, with code at least, to prevent that from happening so that the list gets saved to the archive as-is instead of automatically getting converted.

In my reply above it shows the expected shape only because I used the Interpreter to manually build my data and replace the existing action. For some reason it worked in the Interpreter, but not in code? Here's my latest code attempt:

    list_with_camera = list()
    list_with_camera.append(keypoint_array) #keypoint_array is an ndarray of (320, 17, 4)
    dict = {'positions_2d': list_with_camera}
    np.savez("detections.npz", **dict)
    # results in (1, 320, 17, 4)

@Dene33
Copy link

Dene33 commented Dec 12, 2018

You can reshape ndarray of shape (1, 320, 17, 4) to shape of (320, 17, 4) with yourndarray.reshape(320, 17, 4)

You also use python reserved name dict as name for your variable what is not good.

Also, that's hard to struggle with code out of context. Could you please make a repo with everything you've done at this point? I could look into that.

@Godatplay
Copy link

Ok, I figured it out, whew. I misinterpreted the comment above about coefficients per joint to mean that having 4 would be fine. It's not fine, there can only be 3.

If you're just replacing an action in the existing dataset with another one of a different length, you do have to edit run.py to make sure ground truth is ignored.

@Godatplay
Copy link

videopose3d_archive_org_vid_test
Ha! So unfortunately this is not working in my test example. Throughout time, the poses bounce around the image. Detectron results look great for this frame, though, with "person 1.00" detected. Are the 3d positions of the action I replaced influencing the inference?

@bucktoothsir
Copy link

Ha! So unfortunately this is not working in my test example. Throughout time, the poses bounce around the image. Detectron results look great for this frame, though, with "person 1.00" detected. Are the 3d positions of the action I replaced influencing the inference?

what do you mean by 'the poses bounce around the image'?

@Godatplay
Copy link

Godatplay commented Dec 12, 2018

videopose3d-test
Great question, apologies that it wasn't more clear. Here's a gif of me scrubbing along the timeline. The poses generated by the keypoints change to a new arbitrary pose every frame.

@Godatplay
Copy link

Godatplay commented Dec 12, 2018

Ok I think at least part of my issue was that I was constructing the archive basically from scratch, instead of using the provided tools. So I've started over with hopefully a better pipeline:

  1. For each file inferred in Detectron, I add the keypoints and boxes to lists like so:
keypoints.append(cls_keyps)
boxes.append(cls_boxes)
  1. I save those out like so:
np.savez_compressed(os.path.join(args.output_dir, "detections.npz"), boxes=boxes, keypoints=keypoints)
  1. I rename the detections file to Walking.54138969.mp4.npz and place it in an S2 folder in my VideoPose3D data folder.
  2. I run python prepare_data_2d_h36m_generic.py -i . -o detectron_pt_coco_test which searches for preprocessed h36m npzs and finds the one I just placed.
  3. Looking at the resulting file data_2d_h36m_detectron_pt_coco_test.npz it seems to be missing the coefficients:
array({'S2': {'Walking': [array([], shape=(70, 17, 0), dtype=float32), None, None, None]}},
      dtype=object)
  1. If I ignore that and try to run I get an AssertionError:
AssertionError: Subject S1 is missing from the 2D detections dataset

And if I comment out lines 67-81 in run.py, another:

Traceback (most recent call last):
  File "run.py", line 88, in <module>
    kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
  File "D:\VideoPose3D\common\camera.py", line 15, in normalize_screen_coordinates
    assert X.shape[-1] == 2
AssertionError

Any ideas on why the coefficients would get removed?

@dariopavllo
Copy link
Contributor

As I said here, Detectron exports 4 features per joint (x, y, logit score before softmax, probability score after softmax). You need to select the indices 0, 1, and 3, so that you end up with tensor of shape (17, 3).

I also think you missed a step. cls_boxes and cls_keyps contain all region proposals, including those that are not thresholded (this means that some of them may be wrong). This explains why your 2D pose is jumping across frames. Assuming that you have only one subject in your video, you should take the bounding box with the highest probability.

See this sample implementation (uncomment the last line and comment the other return statement).

@Godatplay
Copy link

Godatplay commented Dec 13, 2018

Sorry if it wasn't clear, but before I streamlined everything, I was building my array more manually, choosing the highest probability, and it gave me a shape (n, 17, 3) where n was number of frames/images. My gif above was based on that result. I must have had some kind of problem constructing my data, though.

But then I saw your data utilities and switched to using those. Thanks for the suggestion! I didn't make the connection; switching that line definitely fixes my issue with not having data.

For some reason I'm getting all 4 cameras, but I should be able to figure that out before too long here:

>>> len(action['S2']['Walking'])
4

@Godatplay
Copy link

Godatplay commented Dec 13, 2018

Using the built-in tools, unfortunately the output looks the same as when I built the numpy archive myself. As was suggested above, I've swapped commenting out lines 80 and 81 in import_detectron_poses in data_utils.py. Just to be sure, I also built my own 3d dataset with an action that has the same shape as my 2d action (except with 32 joints so it's gets reduced back down).

I tried passing in my data plainly and it didn't work, so just to be extra sure, before I run prepare_data_2d_h36m_generic which processes my data and picks the best box on its own, I pick the highest probability box in my infer file myself first. Here's the meat of my infer code for building the 2D and 3D data, which you can see is heavily based on infer_simple:

    keypoints = []
    boxes = []
    
    for i, im_name in enumerate(im_list):
        out_name = os.path.join(
            args.output_dir, '{}'.format(os.path.splitext(os.path.basename(im_name))[0] + '.' + args.output_ext)
        )
        logger.info('Processing {} -> {}'.format(im_name, out_name))
        im = cv2.imread(im_name)
        timers = defaultdict(Timer)
        t = time.time()
        with c2_utils.NamedCudaScope(0):
            cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
                model, im, None, timers=timers
            )
        
        best_match = np.argmax(cls_boxes[1][:, 4])
        inferred_boxes = []
        inferred_boxes.append(cls_boxes[1][best_match])
        box_list = [[], np.array(inferred_boxes)]
        
        inferred_keypt_sets = []
        inferred_keypt_sets.append(np.array(cls_keyps[1][best_match]))
        keypt_set_list = [[], inferred_keypt_sets]
       
        keypoints.append(keypt_set_list)
        boxes.append(box_list)
    
    np.savez_compressed(os.path.join(args.output_dir, "Walking.54138969.mp4.npz"), boxes=boxes, keypoints=keypoints)
    
    output = {}
    output['S1'] = {}
    positions = []
    
    for i in range(len(keypoints)):
        pts = keypoints[i][1][0]
        pts = np.transpose(pts)
        pts = np.delete(pts, 3, 1)
        pts = np.append(pts, np.zeros((15, 3)), axis=0)
        positions.append(pts)
        
    positions = np.array(positions)
    output['S1']['Walking'] = positions.astype('float32')
    
    np.savez_compressed(os.path.join(args.output_dir, "data_3d_h36m.npz"), positions_3d=output)

Currently I'm assuming that the actual content of my 3D data is not important, I just started writing it this way trying to preserve formatting at first but in the end I'm guessing I could have just filled a whole ndarray of zeros.

@bucktoothsir
Copy link

Currently I'm assuming that the actual content of my 3D data is not important, I just started writing it this way trying to preserve formatting at first but in the end I'm guessing I could have just filled a whole ndarray of zeros.

"Currently I'm assuming that the actual content of my 3D data is not important, I just started writing it this way trying to preserve formatting at first but in the end I'm guessing I could have just filled a whole ndarray of zeros." Your thought is right, cuz you don't have ground-truth and you don't need to vis it.

@Godatplay
Copy link

Thank you, yeah that was my assumption, but I guess I got superstitious once I got errors when the 3D data didn't match the 2D. Making my own 3D data (like you mentioned in the other issue) means I have to change less code. But as I've gotten to know the codebase more, it seems like the errors just come down to a lack of separation of concerns like I initially hoped.

@zhengyuezhi
Copy link

Hello, I want to know where to find the video corresponding to human3.6M, or the video you detected.

@feiwangoooh
Copy link

the detectron's keypoints are not corresponding to the predicted 3D keypoints( no eyes, no ears), the upper figure make me confused, so the network change the keypints representation

@Godatplay
Copy link

Thanks for sharing this. Interesting, it looks like the ground truth's root motion is being used in the reconstruction. That would definitely mess up the projection to 3D.

@feiwangoooh
Copy link

Thanks for sharing this. Interesting, it looks like the ground truth's root motion is being used in the reconstruction. That would definitely mess up the projection to 3D.

Thanks, I fixed my problem, it was the 2d keypoints data problem.

@Godatplay
Copy link

I finally got in-the-wild inference working with a quality that resembles the posted results. I used AlphaPose with PoseFlow to generate COCO-style keypoints, then I used the cpn-pt-243.bin model mentioned here along with a custom script I wrote to translate keypoints from a PoseFlow-tracked person to VideoPose3D format. My corresponding 3D npz is just ones in the proper shape.

Over the last couple days, I've had a lot of problems trying to visualize the results properly. The major issue has been that, regardless of skipping and frame limits, the entire input video is loaded into memory in raw format, which quickly runs out of RAM for any reasonable length video. I tried to customize that with ffmpeg options and other techniques, but just couldn't get it working. I gave up and just edited my input video to a smaller clip. The remaining issues are improper joint link indexing and the timing of the output is not synced up with the visualization timing. But the reconstruction looks quite similar to the posted results, with good prediction for the smoother movements and an averaging out of "sharper" movements, presumably due to such a high temporal window? Anyway, this is what I end up with:
output

@abhikhanna30
Copy link

abhikhanna30 commented Jan 4, 2019

@Godatplay hey your results look good. Even i am using the cpn_pt model. i just wanted to know that when you get your output 2d keypoints (COCO keypoints), in which original dataset do you put it in. Because the author suggested we can use the original data_2d_h36m_gt.npz dataset and replace action Walking and subject S1 data with our data. Thats what i also tried to do by replacing the one action and one subject data with my own N,17,2 coco keypoints but my results are a little off. is that the correct approach?

@Godatplay
Copy link

Godatplay commented Jan 4, 2019

One of these days I'll make a fork with my setup...until then I can say that I'm saving the npz I converted from COCO-style keypoints as Walking.54138969.mp4.npz in the $ROOT/data/S1 folder. This is because I'm using the included prepare_data_2d_h36m_generic.py to convert my file to the final dataset testing file, and it'll crawl the data folder looking for subject folders and within them npz files with the structure [action].[cam_id].mp4.npz. Note I've changed that file such that the center and width and height match my video, and have removed the other camera entries from the cam_map list. My -o parameter is cpn_pt_coco_test so that the resulting file is data_2d_h36m_cpn_pt_coco_test.npz because I intentionally wanted to include cpn and coco in the name. That ndarray then only has the Walking action for only the S1 subject.

Since I've created the output you're referring to, I've switched to deepmatch-based PoseFlow instead of orb-based, which is supposed to be higher quality. But the results are jumbled up temporally. The pose of the first frame starts out correctly, but then it switches to posing the legs based on a different time of the clip, while the arms are referencing yet another time in the clip. Haven't had a chance to look into it yet, but it's certainly odd. I'm using a different part in this clip, so I'm not sure if that's related or not. (A visualization of the 2D detection looks great, FWIW)

@tobiascz
Copy link

tobiascz commented Jan 27, 2019

I made a fork with my version of the code that runs on in-the-wild/ your own video. Here is a link to my fork. Best to follow the Readme part I wrote but it is in an early stage so if you have any problems or remarks feel free to share!

I fixed the visualization problem of the 2D poses seen in the output created by @Godatplay .

I used the same in the wild video of the ice skating girl but my results are slightly worse than the one from the authors.

My output

myOutput

Authors output

authorsOutput

I am not sure why my results are worse... If you have any idea let me know

@bucktoothsir
Copy link

@

I made a fork with my version of the code that runs on in-the-wild/ your own video. Here is a link to my fork. Best to follow the Readme part I wrote but it is in an early stage so if you have any problems or remarks feel free to share!

I fixed the visualization problem of the 2D poses seen in the output created by @Godatplay .

I used the same in the wild video of the ice skating girl but my results are slightly worse than the one from the authors.

My output

myOutput

Authors output

authorsOutput

I am not sure why my results are worse... If you have any idea let me know

hi you could check my issues. I came out the same problem and I figured it out.
#6
#23

@lxy5513
Copy link

lxy5513 commented Jan 28, 2019

@tobiascz
your 2D keypoints accuracy is bad, Maybe you can try AlphaPose or simple baselines of human pose estimation to generate 2D keypoints , then feed them to the video pose.

BTW, could you please give me the origin ski video download adress ?

@tobiascz
Copy link

@lxy5513 Video Link

Thanks for your comment I will continue this discussion in Issue #6 !

@lxy5513
Copy link

lxy5513 commented Jan 28, 2019

@tobiascz, thanks.
I integrate AlphaPose code to Video pose, my videopose
Maybe you can reference it

@ihabkh
Copy link

ihabkh commented Aug 6, 2019

I finally got in-the-wild inference working with a quality that resembles the posted results. I used AlphaPose with PoseFlow to generate COCO-style keypoints, then I used the cpn-pt-243.bin model mentioned here along with a custom script I wrote to translate keypoints from a PoseFlow-tracked person to VideoPose3D format. My corresponding 3D npz is just ones in the proper shape.

Over the last couple days, I've had a lot of problems trying to visualize the results properly. The major issue has been that, regardless of skipping and frame limits, the entire input video is loaded into memory in raw format, which quickly runs out of RAM for any reasonable length video. I tried to customize that with ffmpeg options and other techniques, but just couldn't get it working. I gave up and just edited my input video to a smaller clip. The remaining issues are improper joint link indexing and the timing of the output is not synced up with the visualization timing. But the reconstruction looks quite similar to the posted results, with good prediction for the smoother movements and an averaging out of "sharper" movements, presumably due to such a high temporal window? Anyway, this is what I end up with:
output

I'm using AlphaPose for 2D keypoints detection and PoseFlow for tracking. My idea is to visualize the 3D skeleton for the X tracked person using VidePose3D. I tried to build a script to convert the keypoints from PoseFlow into VideoPose3D format but it is not working well. Could you please advise. Thanks

@Alex-JYJ
Copy link

Alex-JYJ commented May 4, 2020

As I said here, Detectron exports 4 features per joint (x, y, logit score before softmax, probability score after softmax). You need to select the indices 0, 1, and 3, so that you end up with tensor of shape (17, 3).

I also think you missed a step. cls_boxes and cls_keyps contain all region proposals, including those that are not thresholded (this means that some of them may be wrong). This explains why your 2D pose is jumping across frames. Assuming that you have only one subject in your video, you should take the bounding box with the highest probability.

See this sample implementation (uncomment the last line and comment the other return statement).

I followed the suggestion to select the indices 0, 1, and 3, but the models in the line

model_pos_train = TemporalModel(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], dataset.skeleton().num_joints(),
would accept the shape(17,2). Can I just use the kps[..., :2] to get it work?

@bobby20180331
Copy link

cool.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests