Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tune with a new object point cloud dataset #68

Open
noahcao opened this issue Mar 25, 2024 · 5 comments
Open

Fine-tune with a new object point cloud dataset #68

noahcao opened this issue Mar 25, 2024 · 5 comments

Comments

@noahcao
Copy link

noahcao commented Mar 25, 2024

Hi @ZENGXH ,

I was trying to fine tune the LION from the weights of unconditional/all55/checkpoints/epoch_10999_iters_2100999.pt by the config file unconditional/all55/cfg.yml you provide.

My basic idea is to freeze the weights of VAE encoder and decoder and only fine tune the two priors by imitating the behavior in train_2prior.py. I did the necessary preprocessing of the data points that I have used the pre-trained VAE to make sure that the input point clouds can be reconstructed.

However, the training does not goes well. and the final generated results by demo.py is like:

Screenshot 2024-03-25 at 19 05 40

I attach the key components of the code I write here:

timestep # 1->1000

def gain_x_t(timesteps, noise, x0):
  t_p, var_t_p, m_t_p = self.iw_quantities(timestep)  # as in utils/diffusion_continuous.py
  x_t = m_t_p * x0 + torch.sqrt(var_t_p) * noise
  return t_p, x_t

x_start_obj_g, x_start_obj_l = LION.VAE.encode_obj(obj_points) # VAE is the pre-trained LION VAE
x_start_obj_g, x_start_obj_l = x_start_obj_g.detach(), x_start_obj_l.detach()
noise['obj_g'] = torch.rand_like(x_start_obj_g)
noise['obj_l'] = torch.rand_like(x_start_obj_l)

t_p, x_t_obj_g = gain_x_t(timestep, noise['obj_g'], x_start_obj_g)
t_p, x_t_obj_l = gain_x_t(timestep, noise['obj_l'], x_start_obj_l)

global_cond = LION.VAE.global2style(x_start_obj_g).detach()
pred_noise_g = LION.priors[0](x_t_obj_g, t_p, x0=None, clip_feat=None)
pred_noise_l = LION.priors[1](x_t_obj_l, t_p, x0=None, condition_input=global_cond, clip_feat=None)

loss_g = F.mse_loss(pred_noise_g.view(B,-1), noise['obj_g'].view(B,-1), reduction='mean')
loss_l = F.mse_loss(pred_noise_l.view(B,-1), noise['obj_l'].view(B,-1), reduction='mean')

I can't find an obvious error in on my side and the training losses seem good to me. However, as shown above, the fine-tuned model can't generate valid point clouds....

Screenshot 2024-03-25 at 19 17 33

Also, the dataset I am using contains different object categories and I use no clip feature as the condition. I assumed this should be fine. But can you also confirm this? It would be great if you can share any idea! Thanks

@ZENGXH
Copy link
Collaborator

ZENGXH commented Mar 26, 2024 via email

@noahcao
Copy link
Author

noahcao commented Mar 26, 2024

Thanks for the prompt response!

Sure, this is a point cloud I sample from my dataset:

This is the reconstructed by using the function VAE.recont(points) link

Screenshot 2024-03-26 at 18 21 44

Per my opinion, the reconstruction looks pretty good.

For your question

For the sampling, is it failed from the early iteration or failed after some iterations? The sampled image from different iterations are logged in the training code as well.

The generation quality becomes worse and worse very quickly. Before the training, I can sample some point clouds by the pre-trained weights of aall55 your provided:

Screenshot 2024-03-26 at 18 24 53

@noahcao
Copy link
Author

noahcao commented Mar 26, 2024

Also, I didn't use the mixing_component function for my training. Will this matter that lot?

@noahcao
Copy link
Author

noahcao commented Mar 26, 2024

Maybe this is the reason? #66 (comment).

I used the pre-trained all55 weights for priors and VAE and then fine-tune over it with my own datasets. But I use the dataset definition as pointflow_datasets for the training.

@ZENGXH
Copy link
Collaborator

ZENGXH commented Mar 26, 2024

The same dataloader should work. But would be good to check if the loaded shape is aligned witht the validation point cloud of the all55 model.

mixing_component might matter. It seems to be the missing part when comparing to the train_2prior.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants