Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for closed source model for Generalized Knowledge Distillation Trainer #2179

Open
imrankh46 opened this issue Oct 5, 2024 · 3 comments
Assignees
Labels
✨ enhancement New feature or request 🏋 GKD Related to GKD ⏳ needs more info Additional information or clarification is required to proceed

Comments

@imrankh46
Copy link

Feature request

closed source model support for GKS, like openai gpt4-o and claude etc.

from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

Motivation

Your contribution

@imrankh46
Copy link
Author

@kashif @lewtun

@qgallouedec qgallouedec added the ✨ enhancement New feature or request label Oct 7, 2024
@kashif kashif added the 🏋 GKD Related to GKD label Oct 7, 2024
@kashif
Copy link
Collaborator

kashif commented Oct 7, 2024

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

@August-murr
Copy link
Contributor

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

Anthropic API doesn't output any logits or logprobs and they have no plans to, and OpenAI only allows a max of 20 logprobs. It seems like they really don't want you to distill.
OpenAI recently announced a distillation service, but it's only for their own models and not open source.

@qgallouedec qgallouedec added the ⏳ needs more info Additional information or clarification is required to proceed label Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ enhancement New feature or request 🏋 GKD Related to GKD ⏳ needs more info Additional information or clarification is required to proceed
Projects
None yet
Development

No branches or pull requests

4 participants