Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyTorch 2.1.0 #64

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ ThisBuild / tlSitePublishBranch := Some("main")
ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))

val scrImageVersion = "4.0.34"
val pytorchVersion = "2.0.1"
val cudaVersion = "12.1-8.9"
val pytorchVersion = "2.1.0"
val cudaVersion = "12.3-8.9"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.1"
Expand Down
223 changes: 138 additions & 85 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,33 @@ import scala.compiletime.{erasedValue, summonFrom}
// format: on
sealed abstract class DType private[torch] ():
private[torch] def toScalarType: ScalarType = this match
case _: UInt8 => ScalarType.Byte
case _: Int8 => ScalarType.Char
case _: Int16 => ScalarType.Short
case _: Int32 => ScalarType.Int
case _: Int64 => ScalarType.Long
case _: Float32 => ScalarType.Float
case _: Float64 => ScalarType.Double
case _: Complex32 => ScalarType.ComplexHalf
case _: Complex64 => ScalarType.ComplexFloat
case _: Complex128 => ScalarType.ComplexDouble
case _: Bool => ScalarType.Bool
case _: QInt8 => ScalarType.QInt8
case _: QUInt8 => ScalarType.QUInt8
case _: QInt32 => ScalarType.QInt32
case _: BFloat16 => ScalarType.BFloat16
case _: QUInt4x2 => ScalarType.QUInt4x2
case _: QUInt2x4 => ScalarType.QUInt2x4
case _: Float16 => ScalarType.Half
case _: Undefined => ScalarType.Undefined
case _: NumOptions => ScalarType.NumOptions
case _: UInt8 => ScalarType.Byte
case _: Int8 => ScalarType.Char
case _: Int16 => ScalarType.Short
case _: Int32 => ScalarType.Int
case _: Int64 => ScalarType.Long
case _: Float16 => ScalarType.Half
case _: Float32 => ScalarType.Float
case _: Float64 => ScalarType.Double
case _: Complex32 => ScalarType.ComplexHalf
case _: Complex64 => ScalarType.ComplexFloat
case _: Complex128 => ScalarType.ComplexDouble
case _: Bool => ScalarType.Bool
case _: QInt8 => ScalarType.QInt8
case _: QUInt8 => ScalarType.QUInt8
case _: QInt32 => ScalarType.QInt32
case _: BFloat16 => ScalarType.BFloat16
case _: QUInt4x2 => ScalarType.QUInt4x2
case _: QUInt2x4 => ScalarType.QUInt2x4
case _: Bits1x8 => ScalarType.Bits1x8
case _: Bits2x4 => ScalarType.Bits2x4
case _: Bits4x2 => ScalarType.Bits4x2
case _: Bits8 => ScalarType.Bits8
case _: Bits16 => ScalarType.Bits16
case _: Float8_e5m2 => ScalarType.Float8_e5m2
case _: Float8_e4m3fn => ScalarType.Float8_e4m3fn
case _: Undefined => ScalarType.Undefined
case _: NumOptions => ScalarType.NumOptions

object DType:
private[torch] def fromScalarType(t: ScalarType): DType = t.intern() match
Expand All @@ -138,6 +145,13 @@ object DType:
case ScalarType.BFloat16 => bfloat16
case ScalarType.QUInt4x2 => quint4x2
case ScalarType.QUInt2x4 => quint2x4
case ScalarType.Bits1x8 => bits1x8
case ScalarType.Bits2x4 => bits2x4
case ScalarType.Bits4x2 => bits4x2
case ScalarType.Bits8 => bits8
case ScalarType.Bits16 => bits16
case ScalarType.Float8_e5m2 => float8_e5m2
case ScalarType.Float8_e4m3fn => float8_e4m3fn
case ScalarType.Half => float16
case ScalarType.Undefined => undefined
case ScalarType.NumOptions => numoptions
Expand All @@ -146,48 +160,63 @@ object DType:
case object int16 extends Int16 /* 2, */
case object int32 extends Int32 /* 3, */
case object int64 extends Int64 /* 4, */
case object float32 extends Float32 /* 5, */
case object float64 extends Float64 /* 6, */
case object complex32 extends Complex32 /* 7, */
case object complex64 extends Complex64 /* 8, */
case object complex128 extends Complex128 /* 9, */
case object bool extends Bool /* 10 */
case object qint8 extends QInt8 /* 11 */
case object quint8 extends QUInt8 /* 12 */
case object qint32 extends QInt32 /* 13 */
case object bfloat16 extends BFloat16 /* 14 */
case object quint4x2 extends QUInt4x2 /* 15 */
case object quint2x4 extends QUInt2x4 /* 15 */
case object float16 extends Float16 /* 16, */
case object undefined extends Undefined /* 17 */
case object numoptions extends NumOptions /* 18 */
case object float16 extends Float16 /* 5*/
case object float32 extends Float32 /* 6 */
case object float64 extends Float64 /* 7 */
case object complex32 extends Complex32 /* 8 */
case object complex64 extends Complex64 /* 9 */
case object complex128 extends Complex128 /* 10 */
case object bool extends Bool /* 11 */
case object qint8 extends QInt8 /* 12 */
case object quint8 extends QUInt8 /* 13 */
case object qint32 extends QInt32 /* 14 */
case object bfloat16 extends BFloat16 /* 15 */
case object quint4x2 extends QUInt4x2 /* 16 */
case object quint2x4 extends QUInt2x4 /* 17 */
case object bits1x8 extends Bits1x8 /* 18 */
case object bits2x4 extends Bits2x4 /* 19 */
case object bits4x2 extends Bits4x2 /* 20 */
case object bits8 extends Bits8 /* 21 */
case object bits16 extends Bits16 /* 22 */
case object float8_e5m2 extends Float8_e5m2 /* 23 */
case object float8_e4m3fn extends Float8_e4m3fn /* 24 */
case object undefined extends Undefined /* 25 */
case object numoptions extends NumOptions /* 26 */

sealed abstract class UInt8 extends DType /* 0, Byte */
sealed abstract class Int8 extends DType /* 1, Char */
sealed abstract class Int16 extends DType /* 2, Short */
sealed abstract class Int32 extends DType /* 3, Int */
sealed abstract class Int64 extends DType /* 4, Long */
sealed abstract class Float32 extends DType /* 5, Float */
sealed abstract class Float64 extends DType /* 6, Double */
sealed abstract class Complex32 extends DType /* 7, ComplexHalf */
sealed abstract class Complex64 extends DType /* 8, ComplexFloat */
sealed abstract class Complex128 extends DType /* 9, ComplexDouble */
sealed abstract class Float16 extends DType /* 5, Half */
sealed abstract class Float32 extends DType /* 6, Float */
sealed abstract class Float64 extends DType /* 7, Double */
sealed abstract class Complex32 extends DType /* 8, ComplexHalf */
sealed abstract class Complex64 extends DType /* 9, ComplexFloat */
sealed abstract class Complex128 extends DType /* 10, ComplexDouble */
sealed abstract class Bool extends DType /* 10 */
sealed abstract class QInt8 extends DType /* 11 */
sealed abstract class QUInt8 extends DType /* 12 */
sealed abstract class QInt32 extends DType /* 13 */
sealed abstract class BFloat16 extends DType /* 14 */
sealed abstract class QUInt4x2 extends DType /* 15 */
sealed abstract class QUInt2x4 extends DType /* 16 */
sealed abstract class Float16 extends DType /* 17, Half */
sealed abstract class Undefined extends DType /* 18 */
sealed abstract class NumOptions extends DType /* 18 */
sealed abstract class Bits1x8 extends DType
sealed abstract class Bits2x4 extends DType
sealed abstract class Bits4x2 extends DType
sealed abstract class Bits8 extends DType
sealed abstract class Bits16 extends DType
sealed abstract class Float8_e5m2 extends DType
sealed abstract class Float8_e4m3fn extends DType
sealed abstract class Undefined extends DType
sealed abstract class NumOptions extends DType

val uint8: UInt8 = DType.uint8
val int8: Int8 = DType.int8
val int16: Int16 = DType.int16
val int32: Int32 = DType.int32
val int64: Int64 = DType.int64
val float16: Float16 = DType.float16
val float32: Float32 = DType.float32
val float64: Float64 = DType.float64
val complex32: Complex32 = DType.complex32
Expand All @@ -200,7 +229,13 @@ val qint32: QInt32 = DType.qint32
val bfloat16: BFloat16 = DType.bfloat16
val quint4x2: QUInt4x2 = DType.quint4x2
val quint2x4: QUInt2x4 = DType.quint2x4
val float16: Float16 = DType.float16
val bits1x8: Bits1x8 = DType.bits1x8
val bits2x4: Bits2x4 = DType.bits2x4
val bits4x2: Bits4x2 = DType.bits4x2
val bits8: Bits8 = DType.bits8
val bits16: Bits16 = DType.bits16
val float8_e5m2: Float8_e5m2 = DType.float8_e5m2
val float8_e4m3fn: Float8_e4m3fn = DType.float8_e4m3fn
val undefined: Undefined = DType.undefined
val numoptions: NumOptions = DType.numoptions

Expand Down Expand Up @@ -308,27 +343,34 @@ def scalaToDType[S <: ScalaType](s: S): DType = s match
case Complex(_: Float, _: Float) => complex64

type TensorType[T] <: DType = T match
case UInt8 => UInt8
case Int8 => Int8
case Int16 => Int16
case Int32 => Int32
case Int64 => Int64
case Float32 => Float32
case Float64 => Float64
case Complex32 => Complex32
case Complex64 => Complex64
case Complex128 => Complex128
case Bool => Bool
case QInt8 => QInt8
case QUInt8 => QUInt8
case QInt32 => QInt32
case BFloat16 => BFloat16
case QUInt4x2 => QUInt4x2
case QUInt2x4 => QUInt2x4
case Float16 => Float16
case Undefined => Undefined
case NumOptions => NumOptions
case DType => DType
case UInt8 => UInt8
case Int8 => Int8
case Int16 => Int16
case Int32 => Int32
case Int64 => Int64
case Float16 => Float16
case Float32 => Float32
case Float64 => Float64
case Complex32 => Complex32
case Complex64 => Complex64
case Complex128 => Complex128
case Bool => Bool
case QInt8 => QInt8
case QUInt8 => QUInt8
case QInt32 => QInt32
case BFloat16 => BFloat16
case QUInt4x2 => QUInt4x2
case QUInt2x4 => QUInt2x4
case Bits1x8 => Bits1x8
case Bits2x4 => Bits2x4
case Bits4x2 => Bits4x2
case Bits8 => Bits8
case Bits16 => Bits16
case Float8_e5m2 => Float8_e5m2
case Float8_e4m3fn => Float8_e4m3fn
case Undefined => Undefined
case NumOptions => NumOptions
case DType => DType

type DTypeOrDeriveFromScalar[T <: DType | Derive, U <: ScalaType] <: DType = T match
case Derive => ScalaToDType[U]
Expand Down Expand Up @@ -400,6 +442,10 @@ type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
case (T, Int32) => T
case (Int64, U) => U
case (T, Int64) => T
case (Float8_e5m2, U) => U
case (T, Float8_e5m2) => T
case (Float8_e4m3fn, U) => U
case (T, Float8_e5m2) => T
case (Float16, BFloat16) | (BFloat16, Float16) => Float32
case (Float16, U) => U
case (T, Float16) => T
Expand Down Expand Up @@ -467,23 +513,30 @@ private[torch] type TypedBuffer[T <: ScalaType] <: Buffer = T match

transparent inline def deriveDType[T <: DType]: DType =
inline erasedValue[T] match
case _: UInt8 => uint8
case _: Int8 => int8
case _: Int16 => int16
case _: Int32 => int32
case _: Int64 => int64
case _: Float32 => float32
case _: Float64 => float64
case _: Complex32 => complex32
case _: Complex64 => complex64
case _: Complex128 => complex128
case _: Bool => bool
case _: QInt8 => qint8
case _: QUInt8 => quint8
case _: QInt32 => qint32
case _: BFloat16 => bfloat16
case _: QUInt4x2 => quint4x2
case _: QUInt2x4 => quint2x4
case _: Float16 => float16
case _: Undefined => undefined
case _: NumOptions => numoptions
case _: UInt8 => uint8
case _: Int8 => int8
case _: Int16 => int16
case _: Int32 => int32
case _: Int64 => int64
case _: Float16 => float16
case _: Float32 => float32
case _: Float64 => float64
case _: Complex32 => complex32
case _: Complex64 => complex64
case _: Complex128 => complex128
case _: Bool => bool
case _: QInt8 => qint8
case _: QUInt8 => quint8
case _: QInt32 => qint32
case _: BFloat16 => bfloat16
case _: QUInt4x2 => quint4x2
case _: QUInt2x4 => quint2x4
case _: Bits1x8 => bits1x8
case _: Bits2x4 => bits2x4
case _: Bits4x2 => bits4x2
case _: Bits8 => bits8
case _: Bits16 => bits16
case _: Float8_e5m2 => float8_e5m2
case _: Float8_e4m3fn => float8_e4m3fn
case _: Undefined => undefined
case _: NumOptions => numoptions
Loading