Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with use_fp16=True Leading to Type Conversion Error in unet.py #56

Open
DuoLi1999 opened this issue Dec 18, 2023 · 1 comment
Open

Comments

@DuoLi1999
Copy link

When setting use_fp16=False, the code functions correctly. However, an issue arises with use_fp16=True due to an unexpected type conversion in unet.py(line435).

The problem occurs at line 435, where the tensor a is converted from float16 to float32:

a = a.float()

Prior to this line, a is in float16, but after this line, it is converted to float32. If we remove or comment out this line, the code encounters an error. It seems that maintaining a in float16 is essential for the use_fp16=True setting to work correctly, but the current implementation inadvertently converts it to float32, leading to issues.

Additionally, I've noticed that the current code has been modified to prevent the utilization of flash attention. I also attempted to run the original version, but encountered similar errors.

@songtianhui
Copy link

Same question

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants