Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fake DPO / KTO #599

Merged
merged 30 commits into from
Feb 6, 2024
Merged

Fake DPO / KTO #599

merged 30 commits into from
Feb 6, 2024

Conversation

psinger
Copy link
Collaborator

@psinger psinger commented Jan 31, 2024

This PR adds the simple KTO loss.

It currently requires to build pairs of accepted and random rejected samples outside of LLm Studio.

@psinger psinger marked this pull request as ready for review February 1, 2024 13:48
# merges the LoRa layers into the base model.
# This is needed if one wants to use the base model as a standalone model.
logger.info("Merging LORA layers with base model.")
if device == "cpu":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a side fix I did in this PR, some models cant be merged on cpu in float16.

@@ -95,6 +142,7 @@ class Losses:
"DPOLoss": DPOLoss,
"HingeLoss": HingeLoss,
"IPOLoss": IPOLoss,
"KTOPairLoss": KTOPairLoss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess KTOPairLoss needs to be added to LOSS_REDUCTION dict.

Mid-term it may sense to add get_batch_logps function directly to the loss calculation instead of using it in the model (and pass output dict with logits + labels to the loss functions). But not high priority atm.

@psinger
Copy link
Collaborator Author

psinger commented Feb 5, 2024

@maxjeblick any idea why this mypy error happens while it does not for DPOLoss
https://github.com/h2oai/h2o-llmstudio/actions/runs/7783291836/job/21221448386?pr=599#step:5:32

@maxjeblick
Copy link
Contributor

No, no idea, looks strange.

@psinger
Copy link
Collaborator Author

psinger commented Feb 5, 2024

No, no idea, looks strange.

Aftzer spending an hour on it without managing to solve the issues (apart from manually casting everything), I decided to remove the type annotation for return for both losses in order to not waste more time on it.

Copy link
Contributor

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, LGTM! Let's maybe add a note in the README that we added KTO loss (and how to use it currently).

@psinger psinger merged commit a7050b3 into main Feb 6, 2024
5 checks passed
@psinger psinger deleted the psi/dpofakepairs branch February 6, 2024 13:20
@psinger psinger mentioned this pull request Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants