From c075870889c2888027d5eef1acccd02377c6ee1f Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Wed, 19 Oct 2022 14:17:38 -0700 Subject: [PATCH] Added support for label smoothing in CrossEntropyLoss. --- src/Native/LibTorchSharp/THSLoss.cpp | 3 ++- src/Native/LibTorchSharp/THSNN.h | 2 +- src/TorchSharp/NN/Losses.cs | 15 ++++++++++----- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/Native/LibTorchSharp/THSLoss.cpp b/src/Native/LibTorchSharp/THSLoss.cpp index 26e893c4e..2f948ca70 100644 --- a/src/Native/LibTorchSharp/THSLoss.cpp +++ b/src/Native/LibTorchSharp/THSLoss.cpp @@ -48,11 +48,12 @@ Tensor THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, con ) } -Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction) +Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing) { CATCH_RETURN_Tensor( auto opts = torch::nn::functional::CrossEntropyFuncOptions(); ApplyReduction(opts, reduction); + opts.label_smoothing(smoothing); if (has_ii) opts = opts.ignore_index(ignore_index); if (weight != NULL) diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index df0d7d0bf..ecfaed899 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -494,7 +494,7 @@ EXPORT_API(Tensor) THSNN_Sequential_forward(const NNModule module, const Tenso EXPORT_API(Tensor) THSNN_binary_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction); EXPORT_API(Tensor) THSNN_binary_cross_entropy_with_logits(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction, const Tensor pos_weights_); EXPORT_API(Tensor) THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, const Tensor target, const double margin, const int64_t reduction); -EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction); +EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing); EXPORT_API(Tensor) THSNN_ctc_loss(const Tensor log_probs, const Tensor targets, const Tensor input_lengths, const Tensor target_lengths, int64_t blank, bool zero_infinity, const int64_t reduction); EXPORT_API(Tensor) THSNN_hinge_embedding_loss(const Tensor input, const Tensor target, const double margin, const int64_t reduction); EXPORT_API(Tensor) THSNN_huber_loss(const Tensor input, const Tensor target, const double delta, const int64_t reduction); diff --git a/src/TorchSharp/NN/Losses.cs b/src/TorchSharp/NN/Losses.cs index b368cc029..e7a98ae65 100644 --- a/src/TorchSharp/NN/Losses.cs +++ b/src/TorchSharp/NN/Losses.cs @@ -396,10 +396,13 @@ public static Tensor binary_cross_entropy(Tensor input, Tensor target, Tensor? w /// Note that ignore_index is only applicable when the target contains class indices. /// /// Specifies the reduction to apply to the output + /// A float in [0.0, 1.0]. + /// Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. + /// The targets become a mixture of the original ground truth and a uniform distribution. /// - public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean) + public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0) { - var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction); + var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction, label_smoothing); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } @@ -724,7 +727,7 @@ public static Tensor triplet_margin_with_distance_loss(Tensor anchor, Tensor pos #region External functions [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction); + internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction, double smoothing); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_binary_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction); @@ -800,20 +803,22 @@ namespace Modules public sealed class CrossEntropyLoss : WeightedLoss { - public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) + public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0) : base(weight, reduction) { this.ignore_index = ignore_index; + this.label_smoothing = label_smoothing; } public override Tensor forward(Tensor input, Tensor target) { var ii = ignore_index.HasValue ? ignore_index.Value : -100; - var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction); + var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction, label_smoothing); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } public long? ignore_index { get; } + public double label_smoothing { get; } } public sealed class BCELoss : WeightedLoss