diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index afb2be0564..f1baa30f43 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -667,7 +667,7 @@ def __init__( :param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments) :param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable :param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important. - :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits) + :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. """ super().__init__( @@ -819,7 +819,7 @@ def __init__( :param disable_tqdm: Disable tqdm progress bar (helps to reduce verbosity in some environments) :param max_grad_norm: Max gradient norm for clipping, default 1.0, set to None to disable :param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important. - :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits) + :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. """ super().__init__(