Skip to content

Commit

Permalink
Merge pull request #17 from darrenjw/main
Browse files Browse the repository at this point in the history
Expose exp and log on a tensor
  • Loading branch information
sbrunk committed Mar 19, 2023
2 parents cd47abc + f79ce0c commit 22a7412
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
/** True if `other` has the same size and elements as this tensor, false otherwise. */
def equal(other: Tensor[D]): Boolean = native.equal(other.native)

/** Returns the tensor with elements exponentiated. */
def exp: Tensor[D] = Tensor(native.exp())

def flatten: Tensor[D] = Tensor(native.flatten())

def flatten(startDim: Int = 0, endDim: Int = -1): Tensor[D] = Tensor(
Expand Down Expand Up @@ -331,6 +334,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def layout: Layout = Layout.fromNative(native.layout())

/** Returns the tensor with elements logged. */
def log: Tensor[D] = Tensor(native.log())

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
Tensor[Promoted[D, D2]](native.matmul(u.native))

Expand Down
6 changes: 6 additions & 0 deletions core/src/test/scala/torch/TensorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ class TensorSuite extends ScalaCheckSuite {
assertEquals(t.toSeq, Seq.fill[Float](2 * 3)(1f))
}

test("exp and log") {
val t = Tensor(Seq(1.0, 2.0, 3.0))
assertEquals(t.log(0), Tensor(0.0))
assertEquals(t.log.exp, t)
}

test("toBuffer") {
val content = Seq(1, 2, 3, 4)
val t = Tensor(content)
Expand Down

0 comments on commit 22a7412

Please sign in to comment.