You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to train several BERT classifier models in one program but i am running out of GPU RAM by loading too many BERT models with const _bert_model, wordpiece, tokenizer = pretrain"Bert-uncased_L-12_H-768_A-12"
I am wondering if the train!() function found in the example trains all of the parameters shown in Flux.params(bert_model) or if it only trains those found in Flux.params(_bert_model.classifier). The reason why this is important is, if only the classifier parameters are modified instead of all bert model parameters, then I can load one pretrain"Bert-uncased_L-12_H-768_A-12" into RAM, instead of many, and then just train new classifiers ( _bert_model.classifier) for each bert classifier I need. This saves a lot of RAM of not loading in a new full BERT model for each bert classifier needed.
Please let me know if the whole bert model is trained with the train!() function or just the classifier parameters.
Thank you,
Jack
The text was updated successfully, but these errors were encountered:
The whole model is trained, according to const ps = params(bert_model). You can switch to params(bert_model.classifier) if you only want to train the classifier. The different between _bert_model and bert_model is that _bert_model is on cpu while bert_model is on GPU. There're only 1 bert model been loaded to the GPU. The problem are probably because your GPU RAM is not big enough for a batch with size 4.
Then whether you can train the classifier layer only is a design problem of your own model. That is totally doable, but might result in different model performance (for better or worse). So that's up to you.
I am trying to train several BERT classifier models in one program but i am running out of GPU RAM by loading too many BERT models with
const _bert_model, wordpiece, tokenizer = pretrain"Bert-uncased_L-12_H-768_A-12"
I am following the CoLA exmaple found here https://github.com/chengchingwen/Transformers.jl/blob/master/example/BERT/cola/train.jl
I am wondering if the train!() function found in the example trains all of the parameters shown in
Flux.params(bert_model)
or if it only trains those found inFlux.params(_bert_model.classifier)
. The reason why this is important is, if only the classifier parameters are modified instead of all bert model parameters, then I can load onepretrain"Bert-uncased_L-12_H-768_A-12"
into RAM, instead of many, and then just train new classifiers (_bert_model.classifier
) for each bert classifier I need. This saves a lot of RAM of not loading in a new full BERT model for each bert classifier needed.Please let me know if the whole bert model is trained with the train!() function or just the classifier parameters.
Thank you,
Jack
The text was updated successfully, but these errors were encountered: