From 56e262060b301d3fdf16a7dc5c4ef67ac17453eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Andr=C3=A9s=20Margffoy=20Tuay?= Date: Fri, 1 Dec 2023 21:01:56 -0500 Subject: [PATCH] Add take --- lib/extorch/native/tensor/ops/manipulation.ex | 27 +++++++++++++++++++ native/extorch/include/manipulation.h | 4 +++ native/extorch/src/csrc/manipulation.cc | 11 ++++++++ native/extorch/src/lib.rs | 1 + native/extorch/src/native/tensor/ops.rs.in | 6 +++++ native/extorch/src/nifs/tensor_ops.rs | 7 +++++ test/tensor/manipulation_test.exs | 7 +++++ 7 files changed, 63 insertions(+) diff --git a/lib/extorch/native/tensor/ops/manipulation.ex b/lib/extorch/native/tensor/ops/manipulation.ex index e3e046e..ec9d787 100644 --- a/lib/extorch/native/tensor/ops/manipulation.ex +++ b/lib/extorch/native/tensor/ops/manipulation.ex @@ -1646,5 +1646,32 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do @spec t(ExTorch.Tensor.t()) :: ExTorch.Tensor.t() defbinding(t(input)) + @doc """ + Returns a new tensor with the elements of `input` at the given `indices`. + + The `input` tensor is treated as if it were viewed as a 1-D tensor. + The result takes the same shape as the `indices`. + + ## Arguments + - `input` (`ExTorch.Tensor`) - the input tensor. + - `indices` (`ExTorch.Tensor`) - the indices into tensor. + It must be of `:int64` or `:long` dtype. + + ## Examples + iex> a = ExTorch.rand({3, 3}) + #Tensor< + [[0.0860, 0.9378, 0.3475], + [0.3576, 0.7145, 0.1036], + [0.7352, 0.4285, 0.2933]] + [size: {3, 3}, dtype: :float, device: :cpu, requires_grad: false]> + + iex> ExTorch.take(a, ExTorch.tensor([1, 5, 6], dtype: :int64)) + #Tensor< + [0.9378, 0.1036, 0.7352] + [size: {3}, dtype: :float, device: :cpu, requires_grad: false]> + + """ + @spec take(ExTorch.Tensor.t(), ExTorch.Tensor.t()) :: ExTorch.Tensor.t() + defbinding(take(input, indices)) end end diff --git a/native/extorch/include/manipulation.h b/native/extorch/include/manipulation.h index 0c14702..47bdc35 100644 --- a/native/extorch/include/manipulation.h +++ b/native/extorch/include/manipulation.h @@ -178,3 +178,7 @@ std::shared_ptr squeeze( std::shared_ptr stack(TensorList seq, int64_t dim, TensorOut opt_out); std::shared_ptr t(const std::shared_ptr &input); + +std::shared_ptr take( + const std::shared_ptr &input, + const std::shared_ptr &indices); diff --git a/native/extorch/src/csrc/manipulation.cc b/native/extorch/src/csrc/manipulation.cc index 2949832..9df6cd4 100644 --- a/native/extorch/src/csrc/manipulation.cc +++ b/native/extorch/src/csrc/manipulation.cc @@ -742,3 +742,14 @@ std::shared_ptr t(const std::shared_ptr &input) { out_tensor = torch::t(in_tensor); return std::make_shared(std::move(out_tensor)); } + +std::shared_ptr take( + const std::shared_ptr &input, + const std::shared_ptr &indices) { + CrossTensor out_tensor; + CrossTensor in_tensor = *input.get(); + CrossTensor indices_tensor = *indices.get(); + + out_tensor = torch::take(in_tensor, indices_tensor); + return std::make_shared(std::move(out_tensor)); +} diff --git a/native/extorch/src/lib.rs b/native/extorch/src/lib.rs index 3b46a3b..6f65053 100644 --- a/native/extorch/src/lib.rs +++ b/native/extorch/src/lib.rs @@ -126,6 +126,7 @@ rustler::init!( squeeze, stack, t, + take, // Tensor comparing operations allclose, diff --git a/native/extorch/src/native/tensor/ops.rs.in b/native/extorch/src/native/tensor/ops.rs.in index b3658a8..ee6fe39 100644 --- a/native/extorch/src/native/tensor/ops.rs.in +++ b/native/extorch/src/native/tensor/ops.rs.in @@ -233,3 +233,9 @@ fn stack(seq: TensorList, dim: i64, out: TensorOut) -> Result) -> Result>; + +/// Index a tensor as if it were a 1D one. +fn take( + input: &SharedPtr, + indices: &SharedPtr, +) -> Result>; diff --git a/native/extorch/src/nifs/tensor_ops.rs b/native/extorch/src/nifs/tensor_ops.rs index 2c7b164..ca330ee 100644 --- a/native/extorch/src/nifs/tensor_ops.rs +++ b/native/extorch/src/nifs/tensor_ops.rs @@ -289,3 +289,10 @@ nif_impl!( TensorStruct<'a>, input: TensorStruct<'a> ); + +nif_impl!( + take, + TensorStruct<'a>, + input: TensorStruct<'a>, + indices: TensorStruct<'a> +); diff --git a/test/tensor/manipulation_test.exs b/test/tensor/manipulation_test.exs index 4a156c4..553b58d 100644 --- a/test/tensor/manipulation_test.exs +++ b/test/tensor/manipulation_test.exs @@ -847,4 +847,11 @@ defmodule ExTorchTest.Tensor.ManipulationTest do out = ExTorch.t(input) assert ExTorch.allclose(out, expected) end + + test "take/2" do + input = ExTorch.rand({3, 3}) + expected = input[{[0, 1, 2], [1, 2, 0]}] + out = ExTorch.take(input, ExTorch.tensor([1, 5, 6], dtype: :int64)) + assert ExTorch.allclose(out, expected) + end end