-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fake DPO / KTO #599
Fake DPO / KTO #599
Conversation
…upgrade_python_deps
# merges the LoRa layers into the base model. | ||
# This is needed if one wants to use the base model as a standalone model. | ||
logger.info("Merging LORA layers with base model.") | ||
if device == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a side fix I did in this PR, some models cant be merged on cpu in float16.
@@ -95,6 +142,7 @@ class Losses: | |||
"DPOLoss": DPOLoss, | |||
"HingeLoss": HingeLoss, | |||
"IPOLoss": IPOLoss, | |||
"KTOPairLoss": KTOPairLoss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess KTOPairLoss
needs to be added to LOSS_REDUCTION
dict.
Mid-term it may sense to add get_batch_logps
function directly to the loss calculation instead of using it in the model (and pass output dict with logits + labels to the loss functions). But not high priority atm.
…o psi/dpofakepairs
@maxjeblick any idea why this mypy error happens while it does not for |
No, no idea, looks strange. |
Aftzer spending an hour on it without managing to solve the issues (apart from manually casting everything), I decided to remove the type annotation for return for both losses in order to not waste more time on it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, LGTM! Let's maybe add a note in the README that we added KTO loss (and how to use it currently).
This PR adds the simple KTO loss.
It currently requires to build pairs of accepted and random rejected samples outside of LLm Studio.