Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to use FP16 in pytorch-pretrained-BERT. Getting error **Runtime error: Expected scalar type object Half but got scalar type Float for argument #2 target** #140

Closed
Ashish-Gupta03 opened this issue Dec 20, 2018 · 3 comments

Comments

@Ashish-Gupta03
Copy link

Ashish-Gupta03 commented Dec 20, 2018

I'm not able to work with FP16 for pytorch BERT code. Particularly for BertForSequenceClassification, which I tried and got the issue
Runtime error: Expected scalar type object Half but got scalar type Float for argument #2 target
when I enabled fp16.
Also when using
logits = logits.half() labels = labels.half()
then the epoch time also increased.

The training time without fp16 was 2.5 hrs per epoch after doing logits.half() and labels.half() the runtime per epoch shot up to 8hrs.

@Ashish-Gupta03 Ashish-Gupta03 changed the title Not able to use FP16 in pytorch-pretrained-BERT. Getting error Not able to use FP16 in pytorch-pretrained-BERT. Getting error **Runtime error: Expected scalar type object Half but got scalar type Float for argument #2 target** Dec 20, 2018
@thomwolf
Copy link
Member

Which kind of GPU are you using? fp16 only works on recent GPU (better with Tesla and Volta series).

@tholor
Copy link
Contributor

tholor commented Dec 28, 2018

I experienced a similar issue with CUDA 9.1. Using 9.2 solved this for me.

@thomwolf
Copy link
Member

thomwolf commented Jan 7, 2019

Yes, CUDA 10 is recommended for using fp16 with good performances.

@thomwolf thomwolf closed this as completed Jan 7, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants