Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few questions about the training pipeline #5

Closed
avipartho opened this issue Jul 3, 2023 · 15 comments
Closed

A few questions about the training pipeline #5

avipartho opened this issue Jul 3, 2023 · 15 comments

Comments

@avipartho
Copy link

Hi,

I read your paper and it was a great work! Thanks for sharing your codebase with the community. As I was going through your codes, I came across a few places, where I would greatly appreciate your explanations/suggestions. Here are my questions -

  • What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?
  • Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?
  • From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?
@kohjingyu
Copy link
Owner

Thanks for your kind words!

What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?

Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?

This and the loss defined on line 508 make up $l_p$ in the paper (equation 2). This is the loss for training the model to produce the [IMG] tokens at the end of "caption-like" text. Both L506 and L508 are actually the same loss, since the same caption (e.g., "A picture of a dog [IMG0]...[IMG7]") are used for retrieval and generation. This is why they have a 0.5 multiplier, so that they sum to be $l_p$.

From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?

That's right, and the reason for this is that we force the generation of the r={2,3,...,8} tokens as the next 7 tokens whenever the model produces [IMG0] (since we always need all 8 tokens for generation/retrieval, so it doesn't make sense to have a partial set of the [IMG] tokens). The embeddings of [IMG2]...[IMG8] tokens are therefore only learnt through the other losses (in particular the generation loss $l_g$), when their embeddings/hidden states are used for computing the generation/retrieval objectives. $l_p$ doesn't affect [IMG2]...[IMG8] tokens. So the model will never produce [IMG2]...[IMG8] organically, but their representations are still helpful for feeding into the GILLMapper module for image generation.

Hope that makes sense!

@avipartho
Copy link
Author

Thanks for your quick response. Reopening this issue for another query regarding the pipeline (didn't want to unnecessarily create new issue).

If I am not wrong, this line makes the entire OPT embedding layer trainable. It is also evident from the param_count.txt file your scripts generate. However, according to the paper only the [IMG] embedding matrix Eimg was supposed to be trainable. Did I miss anything here?

First few lines from param_count.txt :

Module | Trainable | Shape | Param Count |

| model.logit_scale | True | () | 1 |
| model.lm.model.decoder.embed_tokens.weight | True | (50274, 4096) | 205,922,304 |

@avipartho avipartho reopened this Jul 4, 2023
@kohjingyu
Copy link
Owner

You're right, they become trainable, which is why we zero out the gradients of the non-[IMG] embeddings here:

gill/main.py

Lines 578 to 587 in 53fdcf2

# Zero out gradients of the embedding matrix outside of [IMG].
for param in model.module.model.input_embeddings.parameters():
assert param.grad.shape[0] == len(tokenizer)
# Keep other embeddings frozen.
mask = torch.zeros((param.grad.shape[0], 1)).to(param.grad)
for ret_idx in args.retrieval_token_idx:
mask[ret_idx] = 1
for gen_idx in args.gen_token_idx:
mask[gen_idx] = 1
param.grad = param.grad * mask

This is not super ideal, but I think it is overall cleaner than concatenating a trainable embedding matrix with a frozen one.

@avipartho
Copy link
Author

Thanks again. Unfortunately, I missed this section of the script.

Is it also correct to say that for the lp loss, you are considering the loss for generating each token of the input text (caption) i.e. the negative log likelihood of generating token st conditioned on s1,...,st-1 where t={1,...,T}?

@kohjingyu
Copy link
Owner

Yes that's right!

@avipartho
Copy link
Author

In that case, I believe equation 2 is slightly misleading, as the summation goes over i from 1 to r there. This practically says that we are considering loss for generating all 8 [IMG] tokens.

@kohjingyu
Copy link
Owner

You're absolutely right, thanks for pointing this out! We'll fix this in the paper soon. The correct scheme should be that the loss is only considered for the first [IMG0] token. The part about forcing generation of the remaining tokens during inference is still true.

@avipartho
Copy link
Author

As I was trying to train my own model and run inference using the saved checkpoint, I noticed a few things, please verify (might be helpful for other users).

  • Because GILLArgs() only has local variables and no attributes, all of them don't get saved in the model_args.json file unless specifically set after instantiating. One such attribute is text_emb_layers. Turning all local variables into class attributes can solve this.
  • Currently the main.py script also saves the scheduler state, which pretty much saves the entire model (probable reason) and therefore results in a large checkpoint.
  • The pretrained checkpoint provided in this codebase has a shape of (8,4096) for input_embeddings.weight whereas running the main.py will produce a checkpoint with input_embeddings.weight of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()[-model_kwargs['num_tokens']:, :]

@kohjingyu
Copy link
Owner

Thanks for sharing this!

  • The pretrained checkpoint provided in this codebase has a shape of (8,4096) for input_embeddings.weight whereas running the main.py will produce a checkpoint with input_embeddings.weight of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,

You're right, and I also realized that I hadn't uploaded the script used to prune the checkpoints (keeping just the trained weights, and discarding the pretrained model weights). I just did that here: https://github.com/kohjingyu/gill/blob/main/scripts/prune_model_ckpt.py

I think this is essentially the same as the changes you probably made locally, though I haven't tested this script in a while.

@avipartho
Copy link
Author

Thanks for sharing the script! Just noticed a few things -

  • These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.
  • This line gives error, as it is trying to mutate an ordered dict during the for loop. This can be avoided by making an empty dict first and then copying everything there (just like it's done here in models.py)
  • I believe the example usage should be python scripts/prune_model_ckpt.py runs/gill_exp, given the location of the script.
  • What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.

@kohjingyu
Copy link
Owner

kohjingyu commented Jul 7, 2023

Thanks for the notes! Sorry about this, it's what happens when you don't test before you upload...

These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.

That's right.

What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.

share_ret_gen doesn't exist anymore, I think it was something used during debugging previously. I've updated the script as such, hopefully it works as expected now. Thanks for your help in debugging this!

@avipartho
Copy link
Author

Another small update. I could not find warmup-scheduler==0.3.2 (as mentioned in the requirements.txt file), the current available version is probably 0.3. Will it be compatible with your scripts? (I can verify that the training continues with this version)

@kohjingyu
Copy link
Owner

Ah, looks like it should be pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git instead. The link you provided should still work though.

@avipartho
Copy link
Author

avipartho commented Jul 14, 2023

I have another question. As mentioned above, the lp loss includes the negative log likelihood (NLL) of generating each token of the input text (caption). Did you find this helpful for the overall model performance? I am asking this because from the name and purpose of this loss, I would assume that it was intended to only consider the NLL of generating [IMG] tokens.

@kohjingyu
Copy link
Owner

I have not run this particular ablation, sorry. I would guess that it does not have a significant effect on performance on the tasks we evaluated on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants