Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Feb 5, 2024
1 parent 11eb92c commit a2bd00b
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 15 deletions.
8 changes: 4 additions & 4 deletions core/src/test/scala/TrainingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class TraininSuite extends munit.FunSuite {

torch.manualSeed(1)

var weight = torch.randn(Seq(1), requiresGrad = true)
var bias = torch.zeros(Seq(1), requiresGrad = true)
val weight = torch.randn(Seq(1), requiresGrad = true)
val bias = torch.zeros(Seq(1), requiresGrad = true)

def model(xb: Tensor[Float32]): Tensor[Float32] = (xb matmul weight) + bias

Expand Down Expand Up @@ -57,11 +57,11 @@ class TraininSuite extends munit.FunSuite {
noGrad {
weight.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
bias.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
}
loss
Expand Down
8 changes: 1 addition & 7 deletions core/src/test/scala/torch/DeviceSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@
package torch

import munit.ScalaCheckSuite
import torch.DeviceType.CUDA
import org.scalacheck.Prop.*
import org.bytedeco.pytorch.global.torch as torch_native
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck._
import Gen._
import Arbitrary.arbitrary
import DeviceType.CPU
import Generators.{*, given}
import Generators.given

class DeviceSuite extends ScalaCheckSuite {
test("device native roundtrip") {
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/torch/TensorCheckSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package torch
import munit.ScalaCheckSuite
import shapeless3.typeable.{TypeCase, Typeable}
import shapeless3.typeable.syntax.typeable.*
import Generators.{*, given}
import Generators.*
import org.scalacheck.Prop.*

import scala.util.Try
Expand Down
3 changes: 0 additions & 3 deletions core/src/test/scala/torch/TensorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package torch

import org.scalacheck.Prop.*
import Generators.given

class TensorSuite extends TensorCheckSuite {

test("tensor properties") {
Expand Down

0 comments on commit a2bd00b

Please sign in to comment.