diff --git a/src/Native/LibTorchSharp/THSLoss.cpp b/src/Native/LibTorchSharp/THSLoss.cpp
index 26e893c4e..2f948ca70 100644
--- a/src/Native/LibTorchSharp/THSLoss.cpp
+++ b/src/Native/LibTorchSharp/THSLoss.cpp
@@ -48,11 +48,12 @@ Tensor THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, con
)
}
-Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction)
+Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing)
{
CATCH_RETURN_Tensor(
auto opts = torch::nn::functional::CrossEntropyFuncOptions();
ApplyReduction(opts, reduction);
+ opts.label_smoothing(smoothing);
if (has_ii)
opts = opts.ignore_index(ignore_index);
if (weight != NULL)
diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h
index df0d7d0bf..ecfaed899 100644
--- a/src/Native/LibTorchSharp/THSNN.h
+++ b/src/Native/LibTorchSharp/THSNN.h
@@ -494,7 +494,7 @@ EXPORT_API(Tensor) THSNN_Sequential_forward(const NNModule module, const Tenso
EXPORT_API(Tensor) THSNN_binary_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction);
EXPORT_API(Tensor) THSNN_binary_cross_entropy_with_logits(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction, const Tensor pos_weights_);
EXPORT_API(Tensor) THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, const Tensor target, const double margin, const int64_t reduction);
-EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction);
+EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing);
EXPORT_API(Tensor) THSNN_ctc_loss(const Tensor log_probs, const Tensor targets, const Tensor input_lengths, const Tensor target_lengths, int64_t blank, bool zero_infinity, const int64_t reduction);
EXPORT_API(Tensor) THSNN_hinge_embedding_loss(const Tensor input, const Tensor target, const double margin, const int64_t reduction);
EXPORT_API(Tensor) THSNN_huber_loss(const Tensor input, const Tensor target, const double delta, const int64_t reduction);
diff --git a/src/TorchSharp/NN/Losses.cs b/src/TorchSharp/NN/Losses.cs
index b368cc029..e7a98ae65 100644
--- a/src/TorchSharp/NN/Losses.cs
+++ b/src/TorchSharp/NN/Losses.cs
@@ -396,10 +396,13 @@ public static Tensor binary_cross_entropy(Tensor input, Tensor target, Tensor? w
/// Note that ignore_index is only applicable when the target contains class indices.
///
/// Specifies the reduction to apply to the output
+ /// A float in [0.0, 1.0].
+ /// Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing.
+ /// The targets become a mixture of the original ground truth and a uniform distribution.
///
- public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean)
+ public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0)
{
- var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction);
+ var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction, label_smoothing);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
@@ -724,7 +727,7 @@ public static Tensor triplet_margin_with_distance_loss(Tensor anchor, Tensor pos
#region External functions
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction);
+ internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction, double smoothing);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_binary_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
@@ -800,20 +803,22 @@ namespace Modules
public sealed class CrossEntropyLoss : WeightedLoss
{
- public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean) : base(weight, reduction)
+ public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0) : base(weight, reduction)
{
this.ignore_index = ignore_index;
+ this.label_smoothing = label_smoothing;
}
public override Tensor forward(Tensor input, Tensor target)
{
var ii = ignore_index.HasValue ? ignore_index.Value : -100;
- var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction);
+ var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction, label_smoothing);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
public long? ignore_index { get; }
+ public double label_smoothing { get; }
}
public sealed class BCELoss : WeightedLoss