diff --git a/RELEASENOTES.md b/RELEASENOTES.md index ac759131b..984eed982 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -8,6 +8,9 @@ __Breaking Changes__: __API Changes__: +- #1266 An implicit operator from `Tensor` to `TensorIndex` has been added. + - The implicit operator from `TensorIndex` to `Tensor` has been removed. + __Bug Fixes__: diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 8fe5fe013..36b1f56e6 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -7143,19 +7143,16 @@ public static implicit operator TensorIndex(long value) return TensorIndex.Single(value); } - public static implicit operator Tensor(TensorIndex value) + public static implicit operator TensorIndex(Tensor tensor) { - _throw(); - return new Tensor(IntPtr.Zero); + return TensorIndex.Tensor(tensor); } - private static void _throw() + public static implicit operator TensorIndex((int? start, int? end) range) { - throw new InvalidOperationException("Should not be called."); + return TensorIndex.Slice(range.start, range.end); } - public static implicit operator TensorIndex((int? start, int? end) range) => TensorIndex.Slice((long?)range.start, (long?)range.end); - #if !NETSTANDARD2_0_OR_GREATER public static implicit operator TensorIndex(System.Range range) { diff --git a/src/TorchVision/AutoAugment.cs b/src/TorchVision/AutoAugment.cs index 91f49ad0b..2ed41042b 100644 --- a/src/TorchVision/AutoAugment.cs +++ b/src/TorchVision/AutoAugment.cs @@ -141,7 +141,7 @@ public Tensor call(Tensor img) { if (probs[i].ToDouble() <= p) { var (magnitudes, signed) = op_meta[op_name]; - var magnitude = magnitude_id != null ? magnitudes[magnitude_id].ToDouble() : 0.0; + var magnitude = magnitude_id.HasValue ? magnitudes[magnitude_id.Value].ToDouble() : 0.0; if (signed && signs[i].ToBoolean()) magnitude *= -1.0; img = apply_op(img, op_name, magnitude, interpolation: this.interpolation, fill: this.fill); diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index 60f06d812..e8d617c6a 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -3183,11 +3183,11 @@ public void ScalarToTensor3() public void NegativeScalarToTensor() { Scalar s = 10; - TensorIndex ti = 10; + // TensorIndex ti = 10; Tensor t; Assert.Throws(() => { t = s; }); - Assert.Throws(() => { t = ti; }); + // Assert.Throws(() => { t = ti; }); } [Fact]