forked from sbrunk/storch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.sbt
150 lines (133 loc) · 4.68 KB
/
build.sbt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sbt._
import Keys._
import MdocPlugin.autoImport._
import LaikaPlugin.autoImport._
ThisBuild / tlBaseVersion := "0.0" // your current series x.y
ThisBuild / organization := "dev.storch"
ThisBuild / organizationName := "storch.dev"
ThisBuild / startYear := Some(2022)
ThisBuild / licenses := Seq(License.Apache2)
ThisBuild / developers := List(
// your GitHub handle and name
tlGitHubDev("sbrunk", "Sören Brunk")
)
// publish to s01.oss.sonatype.org (set to true to publish to oss.sonatype.org instead)
ThisBuild / tlSonatypeUseLegacyHost := false
// publish website from this branch
ThisBuild / tlSitePublishBranch := Some("main")
ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))
val scrImageVersion = "4.0.34"
val pytorchVersion = "2.1.2"
val cudaVersion = "12.3-8.9"
val openblasVersion = "0.3.25"
val mklVersion = "2024.0"
ThisBuild / scalaVersion := "3.3.1"
ThisBuild / javaCppVersion := "1.5.10-SNAPSHOT"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")
ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))
val enableGPU = settingKey[Boolean]("enable or disable GPU support")
ThisBuild / enableGPU := false
val hasMKL = {
val firstPlatform = org.bytedeco.sbt.javacpp.Platform.current.head
firstPlatform == "linux-x86_64" || firstPlatform == "windows-x86_64"
}
lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq()
// This is a hack to avoid depending on the native libs when publishing
// but conveniently have them on the classpath during development.
// There's probably a cleaner way to do this.
) ++ tlReplaceCommandAlias(
"tlReleaseLocal",
List(
"reload",
"project /",
"set core / javaCppPlatform := Seq()",
"set core / javaCppPresetLibs := Seq()",
"+publishLocal"
).mkString("; ", "; ", "")
) ++ tlReplaceCommandAlias(
"tlRelease",
List(
"reload",
"project /",
"set core / javaCppPlatform := Seq()",
"set core / javaCppPresetLibs := Seq()",
"+mimaReportBinaryIssues",
"+publish",
"tlSonatypeBundleReleaseIfRelevant"
).mkString("; ", "; ", "")
)
lazy val core = project
.in(file("core"))
.settings(commonSettings)
.settings(
javaCppPresetLibs ++= Seq(
(if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion,
"openblas" -> openblasVersion
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq())
++ (if (hasMKL) Seq("mkl" -> mklVersion) else Seq()),
javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current,
fork := true,
Test / fork := true,
libraryDependencies ++= Seq(
"org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}",
"org.typelevel" %% "spire" % "0.18.0",
"org.typelevel" %% "shapeless3-typeable" % "3.3.0",
"com.lihaoyi" %% "os-lib" % "0.9.1",
"com.lihaoyi" %% "sourcecode" % "0.3.0",
"dev.dirs" % "directories" % "26",
"org.scalameta" %% "munit" % "0.7.29" % Test,
"org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test
)
)
lazy val vision = project
.in(file("vision"))
.settings(commonSettings)
.settings(
libraryDependencies ++= Seq(
"com.sksamuel.scrimage" % "scrimage-core" % scrImageVersion,
"com.sksamuel.scrimage" % "scrimage-webp" % scrImageVersion,
"org.scalameta" %% "munit" % "0.7.29" % Test
)
)
.dependsOn(core)
lazy val examples = project
.in(file("examples"))
.enablePlugins(NoPublishPlugin)
.settings(commonSettings)
.settings(
fork := true,
libraryDependencies ++= Seq(
"me.tongfei" % "progressbar" % "0.9.5",
"com.github.alexarchambault" %% "case-app" % "2.1.0-M24",
"org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4"
)
)
.dependsOn(vision)
lazy val docs = project
.in(file("site"))
.enablePlugins(ScalaUnidocPlugin, TypelevelSitePlugin, StorchSitePlugin)
.settings(commonSettings)
.settings(
mdocVariables ++= Map(
"JAVACPP_VERSION" -> javaCppVersion.value,
"PYTORCH_VERSION" -> pytorchVersion,
"OPENBLAS_VERSION" -> openblasVersion,
"MKL_VERSION" -> mklVersion,
"CUDA_VERSION" -> cudaVersion
),
ScalaUnidoc / unidoc / unidocProjectFilter := inAnyProject -- inProjects(examples),
Laika / sourceDirectories ++= Seq(sourceDirectory.value),
laikaIncludeAPI := true,
laikaGenerateAPI / mappings := (ScalaUnidoc / packageDoc / mappings).value
)
.dependsOn(vision)
lazy val root = project
.enablePlugins(NoPublishPlugin)
.in(file("."))
.aggregate(core, vision, examples, docs)
.settings(
javaCppVersion := (ThisBuild / javaCppVersion).value
)