-
Notifications
You must be signed in to change notification settings - Fork 843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) #1057
Conversation
Amazing work. Have you compared results with full BF16 training? |
Have done some comparisons few months ago( They do have some subtle difference but hard to say it is quality difference or performance differences It is just, difference |
thanks for reply. so it will train fp8 and then save as fp16 as usual? |
Trainable part will not be converted to fp8 |
can you elaborate more? for example when training with DreamBooth of SDXL we train both network, UNET and Text Encoder I think all parts? or I am missing something. Thank you |
Thank you for this PR! The changes are less than expected. I will check as soon as possible. |
I tested it a lot and basically what we did in the past is as same as FP8. The problem is also similar: you need autocast. So some part of computing which doesn't use autocast may have problem but can be solved easily. |
This PR is for lora/lycoris/hypernetwork(losalina) training. |
Hi, I have a question about this PR. |
Yes I choose e4m3 based on my experiments If we have better scaling method on it, maybe we can consider e5m2, but since we don't use fp8 for computing in here, i think the better precision is more important |
Thanks! |
good job |
Thank you again for the great work! |
…rain_network (or sdxl_train_network) (kohya-ss#1057) * Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision
Update README
Based on the PR for sd-webui on utilizing FP8, we can assume that we can also apply FP8 on the base model of train_network.
Since we don't need to update the weight of it, just need to compute things.
So I implement the first version of fp8 support in your framework and it works well!!
Actually I uploaded a experimental model very early for fp8 training, which only comsume 6.x GB vram when training SDXL with LyCORIS/LoRA.
If we cache the latent and TE, we can even use 4.4 GB vram to train all the things which is incredible.
(All the above experiments are done in 1024x1024 bs1 setup)
I think this is good for the community of SDXL.
BTW, my implementation is rely on the autocast right now which may be a good news for old GPU user or IPEX user. (But I think IPEX actually have autocast support, just slower then manual cast)
If you think it is also good idea I can try to make PR for manual cast. I already tried it can be used for training but may need some modification.