Skip to content

Commit

Permalink
wip add llama
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jul 25, 2023
1 parent 938c7bf commit 91452c1
Show file tree
Hide file tree
Showing 2 changed files with 385 additions and 1 deletion.
7 changes: 6 additions & 1 deletion core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
* x.expand(-1, 4) // -1 means not changing the size of that dimension
* ```
*/
def expand(sizes: Int*) = Tensor(native.expand(sizes.map(_.toLong)*))
def expand(sizes: Int*): Tensor[D] = Tensor(native.expand(sizes.map(_.toLong)*))

def flatten: Tensor[D] = Tensor(native.flatten())

Expand Down Expand Up @@ -545,6 +545,11 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
*/
def t: Tensor[D] = Tensor(native.t())

/** Returns a tensor that is a transposed version of this tensor. The given dimensions dim0 and
* dim1 are swapped.
*/
def transpose(dim0: Int, dim1: Int): Tensor[D] = Tensor(native.transpose(dim0, dim1))

/** Calculates the variance of all elements of this tensor. */
def variance = Tensor(native.`var`())

Expand Down
Loading

0 comments on commit 91452c1

Please sign in to comment.