Skip to content

Commit

Permalink
Merge pull request sbrunk#27 from davoclavo/fix_randn
Browse files Browse the repository at this point in the history
Fix torch.randn to use proper native function
  • Loading branch information
sbrunk authored Jun 17, 2023
2 parents 40f607e + b064124 commit 45146d5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
3 changes: 3 additions & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
*/
def t: Tensor[D] = Tensor(native.t())

/** Calculates the variance of all elements of this tensor. */
def variance = Tensor(native.`var`())

/** Returns a new tensor with a dimension of size one inserted at the specified position.
*
* The returned tensor shares the same underlying data with this tensor.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/torch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def randn[D <: FloatNN](
requiresGrad: Boolean = false
): Tensor[D] =
Tensor(
torchNative.torch_rand(
torchNative.torch_randn(
size.toArray.map(_.toLong),
NativeConverters.tensorOptions(dtype, layout, device, requiresGrad)
)
Expand Down
16 changes: 16 additions & 0 deletions core/src/test/scala/torch/TensorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ class TensorSuite extends TensorCheckSuite {
assertEquals(tensor(---, -1), Tensor(Seq(3, 7, 11, 15)))
}

// Random sampling

test("randn.unit-test") {
val randnTensor = randn(Seq(100000))
val randnMean = randnTensor.mean
val expectedMean = Tensor(0.0).to(dtype = float32)
val randnVariance = randnTensor.variance
val expectedVariance = Tensor(1.0).to(dtype = float32)

assert(
allclose(randnMean, expectedMean, atol = 1e-2) &&
allclose(randnVariance, expectedVariance, atol = 1e-2)
)
}

// End Random sampling
testUnaryOp(
op = abs,
opName = "abs",
Expand Down

0 comments on commit 45146d5

Please sign in to comment.