Skip to content

Commit

Permalink
Merge pull request #806 from NiklasGustafsson/main
Browse files Browse the repository at this point in the history
Added support for label smoothing in CrossEntropyLoss.
  • Loading branch information
NiklasGustafsson committed Oct 19, 2022
2 parents 0473cb0 + c075870 commit 8a08803
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/Native/LibTorchSharp/THSLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ Tensor THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, con
)
}

Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction)
Tensor THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing)
{
CATCH_RETURN_Tensor(
auto opts = torch::nn::functional::CrossEntropyFuncOptions();
ApplyReduction(opts, reduction);
opts.label_smoothing(smoothing);
if (has_ii)
opts = opts.ignore_index(ignore_index);
if (weight != NULL)
Expand Down
2 changes: 1 addition & 1 deletion src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ EXPORT_API(Tensor) THSNN_Sequential_forward(const NNModule module, const Tenso
EXPORT_API(Tensor) THSNN_binary_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction);
EXPORT_API(Tensor) THSNN_binary_cross_entropy_with_logits(const Tensor input, const Tensor target, const Tensor weight, const int64_t reduction, const Tensor pos_weights_);
EXPORT_API(Tensor) THSNN_cosine_embedding_loss(const Tensor input1, const Tensor input2, const Tensor target, const double margin, const int64_t reduction);
EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction);
EXPORT_API(Tensor) THSNN_cross_entropy(const Tensor input, const Tensor target, const Tensor weight, const int64_t ignore_index, const bool has_ii, const int64_t reduction, const double smoothing);
EXPORT_API(Tensor) THSNN_ctc_loss(const Tensor log_probs, const Tensor targets, const Tensor input_lengths, const Tensor target_lengths, int64_t blank, bool zero_infinity, const int64_t reduction);
EXPORT_API(Tensor) THSNN_hinge_embedding_loss(const Tensor input, const Tensor target, const double margin, const int64_t reduction);
EXPORT_API(Tensor) THSNN_huber_loss(const Tensor input, const Tensor target, const double delta, const int64_t reduction);
Expand Down
15 changes: 10 additions & 5 deletions src/TorchSharp/NN/Losses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,13 @@ public static Tensor binary_cross_entropy(Tensor input, Tensor target, Tensor? w
/// Note that ignore_index is only applicable when the target contains class indices.
/// </param>
/// <param name="reduction">Specifies the reduction to apply to the output</param>
/// <param name="label_smoothing">A float in [0.0, 1.0].
/// Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing.
/// The targets become a mixture of the original ground truth and a uniform distribution.</param>
/// <returns></returns>
public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean)
public static Tensor cross_entropy(Tensor input, Tensor target, Tensor? weight = null, long ignore_index = -100, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0)
{
var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction);
var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ignore_index, true, (long)reduction, label_smoothing);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
Expand Down Expand Up @@ -724,7 +727,7 @@ public static Tensor triplet_margin_with_distance_loss(Tensor anchor, Tensor pos

#region External functions
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction);
internal static extern IntPtr THSNN_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long ignore_index, bool hasII, long reduction, double smoothing);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_binary_cross_entropy(IntPtr srct, IntPtr trgt, IntPtr wgt, long reduction);
Expand Down Expand Up @@ -800,20 +803,22 @@ namespace Modules

public sealed class CrossEntropyLoss : WeightedLoss<Tensor, Tensor, Tensor>
{
public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean) : base(weight, reduction)
public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean, double label_smoothing = 0.0) : base(weight, reduction)
{
this.ignore_index = ignore_index;
this.label_smoothing = label_smoothing;
}

public override Tensor forward(Tensor input, Tensor target)
{
var ii = ignore_index.HasValue ? ignore_index.Value : -100;
var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction);
var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction, label_smoothing);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}

public long? ignore_index { get; }
public double label_smoothing { get; }
}

public sealed class BCELoss : WeightedLoss<Tensor, Tensor, Tensor>
Expand Down

0 comments on commit 8a08803

Please sign in to comment.