From 2808124950920a40cf2efb92fa0d864c5e08b01a Mon Sep 17 00:00:00 2001 From: Darren Wilkinson Date: Sat, 18 Mar 2023 18:46:34 +0000 Subject: [PATCH 1/2] expose exp and log on a tensor --- core/src/main/scala/torch/Tensor.scala | 6 ++++++ core/src/test/scala/torch/TensorSuite.scala | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 8156789e..064038e2 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -282,6 +282,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto /** True if `other` has the same size and elements as this tensor, false otherwise. */ def equal(other: Tensor[D]): Boolean = native.equal(other.native) + /** Returns the tensor with element exponentiated. */ + def exp: Tensor[D] = Tensor(native.exp()) + def flatten: Tensor[D] = Tensor(native.flatten()) def flatten(startDim: Int = 0, endDim: Int = -1): Tensor[D] = Tensor( @@ -331,6 +334,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def layout: Layout = Layout.fromNative(native.layout()) + /** Returns the tensor with elements logged. */ + def log: Tensor[D] = Tensor(native.log()) + def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor[Promoted[D, D2]](native.matmul(u.native)) diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index 272bc0ac..caabd17c 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -78,6 +78,12 @@ class TensorSuite extends ScalaCheckSuite { assertEquals(t.toSeq, Seq.fill[Float](2 * 3)(1f)) } + test("exp and log") { + val t = Tensor(Seq(1.0, 2.0, 3.0)) + assertEquals(t.log(0), Tensor(0.0)) + assertEquals(t.log.exp, t) + } + test("toBuffer") { val content = Seq(1, 2, 3, 4) val t = Tensor(content) From f79ce0c726ceb3eead99f07fb548f7a70685726b Mon Sep 17 00:00:00 2001 From: Darren Wilkinson Date: Sat, 18 Mar 2023 18:48:03 +0000 Subject: [PATCH 2/2] expose exp and log on a tensor --- core/src/main/scala/torch/Tensor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 064038e2..7a9c346d 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -282,7 +282,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto /** True if `other` has the same size and elements as this tensor, false otherwise. */ def equal(other: Tensor[D]): Boolean = native.equal(other.native) - /** Returns the tensor with element exponentiated. */ + /** Returns the tensor with elements exponentiated. */ def exp: Tensor[D] = Tensor(native.exp()) def flatten: Tensor[D] = Tensor(native.flatten())