diff --git a/build.sbt b/build.sbt index 7cd092d6..9f8f388d 100644 --- a/build.sbt +++ b/build.sbt @@ -24,10 +24,10 @@ ThisBuild / tlSitePublishBranch := Some("main") ThisBuild / apiURL := Some(new URL("https://storch.dev/api/")) val scrImageVersion = "4.0.34" -val pytorchVersion = "2.1.1" +val pytorchVersion = "2.1.2" val cudaVersion = "12.3-8.9" -val openblasVersion = "0.3.23" -val mklVersion = "2023.1" +val openblasVersion = "0.3.25" +val mklVersion = "2024.0" ThisBuild / scalaVersion := "3.3.1" ThisBuild / javaCppVersion := "1.5.10-SNAPSHOT" ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots") @@ -38,6 +38,11 @@ val enableGPU = settingKey[Boolean]("enable or disable GPU support") ThisBuild / enableGPU := false +val hasMKL = { + val firstPlatform = org.bytedeco.sbt.javacpp.Platform.current.head + firstPlatform == "linux-x86_64" || firstPlatform == "windows-x86_64" +} + lazy val commonSettings = Seq( Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"), javaCppVersion := (ThisBuild / javaCppVersion).value, @@ -73,9 +78,9 @@ lazy val core = project .settings( javaCppPresetLibs ++= Seq( (if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion, - "mkl" -> mklVersion, "openblas" -> openblasVersion - ) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()), + ) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()) + ++ (if (hasMKL) Seq("mkl" -> mklVersion) else Seq()), javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current, fork := true, Test / fork := true, diff --git a/project/build.properties b/project/build.properties index fd5b1576..8cf07b7c 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.9.0 \ No newline at end of file +sbt.version=1.9.8 \ No newline at end of file