Skip to content

Commit

Permalink
implicit cast from Tensor to TensorIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
yueyinqiu committed Apr 26, 2024
1 parent 737c5cb commit 8da9a66
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 24 deletions.
4 changes: 4 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ __Breaking Changes__:

__API Changes__:

- #1266 Implicit operator from `Tensor` to `TensorIndex` has been added.
- The implicit operator from `TensorIndex` to `Tensor` has been removed.
- The indexer that accepts tensors as the indices has been removed.

__Bug Fixes__:


Expand Down
25 changes: 4 additions & 21 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1415,12 +1415,6 @@ public Tensor this[params TensorIndex[] indices] {
set { index_put_(value, indices); }
}

[IndexerName("TensorItems")]
public Tensor this[params Tensor[] indices] {
get { return index(indices); }
set { index_put_(value, indices); }
}

/// <summary>
/// Tensor indexer.
/// </summary>
Expand Down Expand Up @@ -1574,14 +1568,6 @@ public Tensor index(params TensorIndex[] indices)
}
}

/// <summary>
/// Index into the tensor using Python-like indexing expressions.
/// </summary>
public Tensor index(params Tensor[] indices)
{
return index(indices.Select(t => TensorIndex.Tensor(t)).ToArray());
}

/// <summary>
/// Index into the tensor using Python-like indexing expressions and place a tensor at the index.
/// </summary>
Expand Down Expand Up @@ -7143,19 +7129,16 @@ public static implicit operator TensorIndex(long value)
return TensorIndex.Single(value);
}

public static implicit operator Tensor(TensorIndex value)
public static implicit operator TensorIndex(Tensor tensor)
{
_throw();
return new Tensor(IntPtr.Zero);
return TensorIndex.Tensor(tensor);
}

private static void _throw()
public static implicit operator TensorIndex((int? start, int? end) range)
{
throw new InvalidOperationException("Should not be called.");
return TensorIndex.Slice(range.start, range.end);
}

public static implicit operator TensorIndex((int? start, int? end) range) => TensorIndex.Slice((long?)range.start, (long?)range.end);

#if !NETSTANDARD2_0_OR_GREATER
public static implicit operator TensorIndex(System.Range range)
{
Expand Down
2 changes: 1 addition & 1 deletion src/TorchVision/AutoAugment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public Tensor call(Tensor img)
{
if (probs[i].ToDouble() <= p) {
var (magnitudes, signed) = op_meta[op_name];
var magnitude = magnitude_id != null ? magnitudes[magnitude_id].ToDouble() : 0.0;
var magnitude = magnitude_id.HasValue ? magnitudes[magnitude_id.Value].ToDouble() : 0.0;
if (signed && signs[i].ToBoolean())
magnitude *= -1.0;
img = apply_op(img, op_name, magnitude, interpolation: this.interpolation, fill: this.fill);
Expand Down
4 changes: 2 additions & 2 deletions test/TorchSharpTest/TestTorchTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3183,11 +3183,11 @@ public void ScalarToTensor3()
public void NegativeScalarToTensor()
{
Scalar s = 10;
TensorIndex ti = 10;
// TensorIndex ti = 10;
Tensor t;

Assert.Throws<InvalidOperationException>(() => { t = s; });
Assert.Throws<InvalidOperationException>(() => { t = ti; });
// Assert.Throws<InvalidOperationException>(() => { t = ti; });
}

[Fact]
Expand Down

0 comments on commit 8da9a66

Please sign in to comment.