Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Feb 5, 2024
1 parent 11eb92c commit 0be74e7
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 17 deletions.
3 changes: 1 addition & 2 deletions core/src/main/scala/torch/hub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ object hub:
if !os.exists(cachedFile) then
System.err.println(s"Downloading: $url to $cachedFile")
Using.resource(URL(url).openStream()) { inputStream =>
Files.copy(inputStream, cachedFile.toNIO)
()
val _ = Files.copy(inputStream, cachedFile.toNIO)
}
torch.pickleLoad(cachedFile.toNIO)
8 changes: 4 additions & 4 deletions core/src/test/scala/TrainingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class TraininSuite extends munit.FunSuite {

torch.manualSeed(1)

var weight = torch.randn(Seq(1), requiresGrad = true)
var bias = torch.zeros(Seq(1), requiresGrad = true)
val weight = torch.randn(Seq(1), requiresGrad = true)
val bias = torch.zeros(Seq(1), requiresGrad = true)

def model(xb: Tensor[Float32]): Tensor[Float32] = (xb matmul weight) + bias

Expand Down Expand Up @@ -57,11 +57,11 @@ class TraininSuite extends munit.FunSuite {
noGrad {
weight.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
bias.grad.foreach { grad =>
weight -= grad * learningRate
grad.zero()
grad.zero_()
}
}
loss
Expand Down
8 changes: 1 addition & 7 deletions core/src/test/scala/torch/DeviceSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@
package torch

import munit.ScalaCheckSuite
import torch.DeviceType.CUDA
import org.scalacheck.Prop.*
import org.bytedeco.pytorch.global.torch as torch_native
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck._
import Gen._
import Arbitrary.arbitrary
import DeviceType.CPU
import Generators.{*, given}
import Generators.given

class DeviceSuite extends ScalaCheckSuite {
test("device native roundtrip") {
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/torch/TensorCheckSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package torch
import munit.ScalaCheckSuite
import shapeless3.typeable.{TypeCase, Typeable}
import shapeless3.typeable.syntax.typeable.*
import Generators.{*, given}
import Generators.*
import org.scalacheck.Prop.*

import scala.util.Try
Expand Down
3 changes: 0 additions & 3 deletions core/src/test/scala/torch/TensorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package torch

import org.scalacheck.Prop.*
import Generators.given

class TensorSuite extends TensorCheckSuite {

test("tensor properties") {
Expand Down

0 comments on commit 0be74e7

Please sign in to comment.