Skip to content

Commit

Permalink
Merge pull request #28 from sbrunk/reduction-ops
Browse files Browse the repository at this point in the history
Add reduction ops
  • Loading branch information
sbrunk committed Jun 20, 2023
2 parents 4d8a82f + fc94bcc commit 89c4d5f
Show file tree
Hide file tree
Showing 9 changed files with 1,380 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
* limitations under the License.
*/

package object torch {}
/** @groupname pointwise_ops Pointwise Ops
* @groupname reduction_ops Reduction Ops
*/
package object torch
6 changes: 4 additions & 2 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ type NumericPromoted[D <: DType] <: DType = D match

/** Promoted type for tensor operations that always output floats (e.g. `sin`) */
type FloatPromoted[D <: DType] <: FloatNN = D match
case Float64 => Float64
case _ => Float32
case Float16 => Float16
case BFloat16 => BFloat16
case Float64 => Float64
case _ => Float32

/** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */
type ComplexToReal[D <: DType] <: DType = D match
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
/** Returns the tensor with elements logged. */
def log: Tensor[D] = Tensor(native.log())

def long: Tensor[Int64] = to(dtype = int64)

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
Tensor[Promoted[D, D2]](native.matmul(u.native))

Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool]

/* Evidence used in operations where at least one Float is required */
type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN

/* Evidence used in operations where at least one Float or Complex is required */
type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) |
B <:< (FloatNN | ComplexNN)
92 changes: 56 additions & 36 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,48 +34,68 @@ import org.bytedeco.pytorch.GenericDict
import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
import spire.math.UByte
import scala.annotation.targetName

private[torch] object NativeConverters:

inline def toOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)

def toOptional(l: Long | Option[Long]): pytorch.LongOptional =
toOptional(l, pytorch.LongOptional(_))
def toOptional(l: Double | Option[Double]): pytorch.DoubleOptional =
toOptional(l, pytorch.DoubleOptional(_))

def toOptional(l: Real | Option[Real]): pytorch.ScalarOptional =
toOptional(
l,
(r: Real) =>
val scalar = toScalar(r)
pytorch.ScalarOptional(scalar)
)

def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional =
toOptional(t, t => pytorch.TensorOptional(t.native))

def toArray(i: Long | (Long, Long)) = i match
case i: Long => Array(i)
case (i, j) => Array(i, j)

def toNative(input: Int | (Int, Int)) = input match
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

def toScalar(x: ScalaType): pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()
extension (l: Long | Option[Long])
def toOptional: pytorch.LongOptional = convertToOptional(l, pytorch.LongOptional(_))

extension (l: Double | Option[Double])
def toOptional: pytorch.DoubleOptional = convertToOptional(l, pytorch.DoubleOptional(_))

extension (l: Real | Option[Real])
def toOptional: pytorch.ScalarOptional =
convertToOptional(
l,
(r: Real) =>
val scalar = toScalar(r)
pytorch.ScalarOptional(scalar)
)

extension [D <: DType](t: Tensor[D] | Option[Tensor[D]])
def toOptional: TensorOptional =
convertToOptional(t, t => pytorch.TensorOptional(t.native))

extension (i: Long | (Long, Long))
def toArray = i match
case i: Long => Array(i)
case (i, j) => Array(i, j)

extension (i: Int | Seq[Int])
@targetName("intOrIntSeqToArray")
def toArray: Array[Long] = i match
case i: Int => Array(i.toLong)
case i: Seq[Int] => i.map(_.toLong).toArray

extension (i: Long | Seq[Long])
@targetName("longOrLongSeqToArray")
def toArray: Array[Long] = i match
case i: Long => Array(i)
case i: Seq[Long] => i.toArray

extension (input: Int | (Int, Int))
def toNative = input match
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()

def tensorOptions(
dtype: DType,
layout: Layout,
Expand Down
Loading

0 comments on commit 89c4d5f

Please sign in to comment.