Skip to content

Commit

Permalink
wip add llama
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jul 26, 2023
1 parent 938c7bf commit 7f5a1c3
Show file tree
Hide file tree
Showing 2 changed files with 387 additions and 1 deletion.
9 changes: 8 additions & 1 deletion core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
* x.expand(-1, 4) // -1 means not changing the size of that dimension
* ```
*/
def expand(sizes: Int*) = Tensor(native.expand(sizes.map(_.toLong)*))
def expand(sizes: Int*): Tensor[D] = Tensor(native.expand(sizes.map(_.toLong)*))

def flatten: Tensor[D] = Tensor(native.flatten())

Expand Down Expand Up @@ -530,6 +530,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def size: Seq[Int] = ArraySeq.unsafeWrapArray(native.sizes.vec.get.map(_.toInt))

def sqrt: Tensor[D] = Tensor(native.sqrt())

def std: Tensor[D] = Tensor[D](native.std())

/** Returns a new tensor with the sine of the elements of this tensor. */
Expand All @@ -545,6 +547,11 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
*/
def t: Tensor[D] = Tensor(native.t())

/** Returns a tensor that is a transposed version of this tensor. The given dimensions dim0 and
* dim1 are swapped.
*/
def transpose(dim0: Int, dim1: Int): Tensor[D] = Tensor(native.transpose(dim0, dim1))

/** Calculates the variance of all elements of this tensor. */
def variance = Tensor(native.`var`())

Expand Down
Loading

0 comments on commit 7f5a1c3

Please sign in to comment.