Skip to content

Commit

Permalink
Add take
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Dec 2, 2023
1 parent cc73121 commit 56e2620
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1646,5 +1646,32 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
@spec t(ExTorch.Tensor.t()) :: ExTorch.Tensor.t()
defbinding(t(input))

@doc """
Returns a new tensor with the elements of `input` at the given `indices`.
The `input` tensor is treated as if it were viewed as a 1-D tensor.
The result takes the same shape as the `indices`.
## Arguments
- `input` (`ExTorch.Tensor`) - the input tensor.
- `indices` (`ExTorch.Tensor`) - the indices into tensor.
It must be of `:int64` or `:long` dtype.
## Examples
iex> a = ExTorch.rand({3, 3})
#Tensor<
[[0.0860, 0.9378, 0.3475],
[0.3576, 0.7145, 0.1036],
[0.7352, 0.4285, 0.2933]]
[size: {3, 3}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.take(a, ExTorch.tensor([1, 5, 6], dtype: :int64))
#Tensor<
[0.9378, 0.1036, 0.7352]
[size: {3}, dtype: :float, device: :cpu, requires_grad: false]>
"""
@spec take(ExTorch.Tensor.t(), ExTorch.Tensor.t()) :: ExTorch.Tensor.t()
defbinding(take(input, indices))
end
end
4 changes: 4 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,7 @@ std::shared_ptr<CrossTensor> squeeze(
std::shared_ptr<CrossTensor> stack(TensorList seq, int64_t dim, TensorOut opt_out);

std::shared_ptr<CrossTensor> t(const std::shared_ptr<CrossTensor> &input);

std::shared_ptr<CrossTensor> take(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &indices);
11 changes: 11 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,14 @@ std::shared_ptr<CrossTensor> t(const std::shared_ptr<CrossTensor> &input) {
out_tensor = torch::t(in_tensor);
return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> take(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &indices) {
CrossTensor out_tensor;
CrossTensor in_tensor = *input.get();
CrossTensor indices_tensor = *indices.get();

out_tensor = torch::take(in_tensor, indices_tensor);
return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ rustler::init!(
squeeze,
stack,
t,
take,

// Tensor comparing operations
allclose,
Expand Down
6 changes: 6 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,9 @@ fn stack(seq: TensorList, dim: i64, out: TensorOut) -> Result<SharedPtr<CrossTen

/// Transpose a 2D tensor.
fn t(input: &SharedPtr<CrossTensor>) -> Result<SharedPtr<CrossTensor>>;

/// Index a tensor as if it were a 1D one.
fn take(
input: &SharedPtr<CrossTensor>,
indices: &SharedPtr<CrossTensor>,
) -> Result<SharedPtr<CrossTensor>>;
7 changes: 7 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,10 @@ nif_impl!(
TensorStruct<'a>,
input: TensorStruct<'a>
);

nif_impl!(
take,
TensorStruct<'a>,
input: TensorStruct<'a>,
indices: TensorStruct<'a>
);
7 changes: 7 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -847,4 +847,11 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
out = ExTorch.t(input)
assert ExTorch.allclose(out, expected)
end

test "take/2" do
input = ExTorch.rand({3, 3})
expected = input[{[0, 1, 2], [1, 2, 0]}]
out = ExTorch.take(input, ExTorch.tensor([1, 5, 6], dtype: :int64))
assert ExTorch.allclose(out, expected)
end
end

0 comments on commit 56e2620

Please sign in to comment.