From 6ae3bb86368272a181fb74971d9375819756fc19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 18 Jun 2023 14:13:19 +0200 Subject: [PATCH 1/3] Implement reduction ops --- core/src/main/scala/{torch => }/package.scala | 5 +- core/src/main/scala/torch/Types.scala | 3 + .../torch/internal/NativeConverters.scala | 92 +- .../main/scala/torch/ops/PointwiseOps.scala | 277 +++++- .../main/scala/torch/ops/ReductionOps.scala | 877 +++++++++++++++++- .../scala/torch/ops/ReductionOpsSuite.scala | 2 +- 6 files changed, 1139 insertions(+), 117 deletions(-) rename core/src/main/scala/{torch => }/package.scala (84%) diff --git a/core/src/main/scala/torch/package.scala b/core/src/main/scala/package.scala similarity index 84% rename from core/src/main/scala/torch/package.scala rename to core/src/main/scala/package.scala index 391ab742..e4169585 100644 --- a/core/src/main/scala/torch/package.scala +++ b/core/src/main/scala/package.scala @@ -14,4 +14,7 @@ * limitations under the License. */ -package object torch {} +/** @groupname pointwise_ops Pointwise Ops + * @groupname reduction_ops Reduction Ops + */ +package object torch diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala index c32a9823..9e6fc279 100644 --- a/core/src/main/scala/torch/Types.scala +++ b/core/src/main/scala/torch/Types.scala @@ -45,3 +45,6 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] /* Evidence used in operations where at least one Float is required */ type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN + +/* Evidence used in operations where at least one Float or Complex is required */ +type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 84f13015..daafd300 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -34,48 +34,68 @@ import org.bytedeco.pytorch.GenericDict import org.bytedeco.pytorch.GenericDictIterator import spire.math.Complex import spire.math.UByte +import scala.annotation.targetName private[torch] object NativeConverters: - inline def toOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match + inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match case i: Option[T] => i.map(f(_)).orNull case i: T => f(i) - def toOptional(l: Long | Option[Long]): pytorch.LongOptional = - toOptional(l, pytorch.LongOptional(_)) - def toOptional(l: Double | Option[Double]): pytorch.DoubleOptional = - toOptional(l, pytorch.DoubleOptional(_)) - - def toOptional(l: Real | Option[Real]): pytorch.ScalarOptional = - toOptional( - l, - (r: Real) => - val scalar = toScalar(r) - pytorch.ScalarOptional(scalar) - ) - - def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional = - toOptional(t, t => pytorch.TensorOptional(t.native)) - - def toArray(i: Long | (Long, Long)) = i match - case i: Long => Array(i) - case (i, j) => Array(i, j) - - def toNative(input: Int | (Int, Int)) = input match - case (h, w) => LongPointer(Array(h.toLong, w.toLong)*) - case x: Int => LongPointer(Array(x.toLong, x.toLong)*) - - def toScalar(x: ScalaType): pytorch.Scalar = x match - case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte) - case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item() - case x: Byte => pytorch.Scalar(x) - case x: Short => pytorch.Scalar(x) - case x: Int => pytorch.Scalar(x) - case x: Long => pytorch.Scalar(x) - case x: Float => pytorch.Scalar(x) - case x: Double => pytorch.Scalar(x) - case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item() - case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item() + extension (l: Long | Option[Long]) + def toOptional: pytorch.LongOptional = convertToOptional(l, pytorch.LongOptional(_)) + + extension (l: Double | Option[Double]) + def toOptional: pytorch.DoubleOptional = convertToOptional(l, pytorch.DoubleOptional(_)) + + extension (l: Real | Option[Real]) + def toOptional: pytorch.ScalarOptional = + convertToOptional( + l, + (r: Real) => + val scalar = toScalar(r) + pytorch.ScalarOptional(scalar) + ) + + extension [D <: DType](t: Tensor[D] | Option[Tensor[D]]) + def toOptional: TensorOptional = + convertToOptional(t, t => pytorch.TensorOptional(t.native)) + + extension (i: Long | (Long, Long)) + def toArray = i match + case i: Long => Array(i) + case (i, j) => Array(i, j) + + extension (i: Int | Seq[Int]) + @targetName("intOrIntSeqToArray") + def toArray: Array[Long] = i match + case i: Int => Array(i.toLong) + case i: Seq[Int] => i.map(_.toLong).toArray + + extension (i: Long | Seq[Long]) + @targetName("longOrLongSeqToArray") + def toArray: Array[Long] = i match + case i: Long => Array(i) + case i: Seq[Long] => i.toArray + + extension (input: Int | (Int, Int)) + def toNative = input match + case (h, w) => LongPointer(Array(h.toLong, w.toLong)*) + case x: Int => LongPointer(Array(x.toLong, x.toLong)*) + + extension (x: ScalaType) + def toScalar: pytorch.Scalar = x match + case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte) + case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item() + case x: Byte => pytorch.Scalar(x) + case x: Short => pytorch.Scalar(x) + case x: Int => pytorch.Scalar(x) + case x: Long => pytorch.Scalar(x) + case x: Float => pytorch.Scalar(x) + case x: Double => pytorch.Scalar(x) + case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item() + case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item() + def tensorOptions( dtype: DType, layout: Layout, diff --git a/core/src/main/scala/torch/ops/PointwiseOps.scala b/core/src/main/scala/torch/ops/PointwiseOps.scala index 6a1cc4fd..86983224 100644 --- a/core/src/main/scala/torch/ops/PointwiseOps.scala +++ b/core/src/main/scala/torch/ops/PointwiseOps.scala @@ -24,23 +24,38 @@ import org.bytedeco.pytorch.global.torch as torchNative * https://pytorch.org/docs/stable/torch.html#pointwise-ops */ -/** Computes the absolute value of each element in `input`. */ +/** Computes the absolute value of each element in `input`. + * + * @group pointwise_ops + */ def abs[D <: NumericNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.abs(input.native)) -/** Computes the inverse cosine of each element in `input`. */ +/** Computes the inverse cosine of each element in `input`. + * + * @group pointwise_ops + */ def acos[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.acos(input.native)) -/** Returns a new tensor with the inverse hyperbolic cosine of the elements of `input` . */ +/** Returns a new tensor with the inverse hyperbolic cosine of the elements of `input` . + * + * @group pointwise_ops + */ def acosh[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.acosh(input.native)) -/** Adds `other` to `input`. */ +/** Adds `other` to `input`. + * + * @group pointwise_ops + */ def add[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.add(input.native, other.native)) -/** Adds `other` to `input`. */ +/** Adds `other` to `input`. + * + * @group pointwise_ops + */ def add[D <: DType, S <: ScalaType]( input: Tensor[D], other: S @@ -49,6 +64,8 @@ def add[D <: DType, S <: ScalaType]( /** Performs the element-wise division of tensor1 by tensor2, multiplies the result by the scalar * value and adds it to input. + * + * @group pointwise_ops */ def addcdiv[D <: DType, D2 <: DType, D3 <: DType]( input: Tensor[D], @@ -60,6 +77,8 @@ def addcdiv[D <: DType, D2 <: DType, D3 <: DType]( /** Performs the element-wise multiplication of tensor1 by tensor2, multiplies the result by the * scalar value and adds it to input. + * + * @group pointwise_ops */ def addcmul[D <: DType, D2 <: DType, D3 <: DType]( input: Tensor[D], @@ -69,23 +88,38 @@ def addcmul[D <: DType, D2 <: DType, D3 <: DType]( ): Tensor[Promoted[D, Promoted[D2, D3]]] = Tensor(torchNative.addcmul(input.native, tensor1.native, tensor2.native, toScalar(value))) -/** Computes the element-wise angle (in radians) of the given `input` tensor. */ +/** Computes the element-wise angle (in radians) of the given `input` tensor. + * + * @group pointwise_ops + */ def angle[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[ComplexToReal[D]]] = Tensor(torchNative.angle(input.native)) -/** Returns a new tensor with the arcsine of the elements of `input`. */ +/** Returns a new tensor with the arcsine of the elements of `input`. + * + * @group pointwise_ops + */ def asin[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.asin(input.native)) -/** Returns a new tensor with the inverse hyperbolic sine of the elements of `input`. */ +/** Returns a new tensor with the inverse hyperbolic sine of the elements of `input`. + * + * @group pointwise_ops + */ def asinh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.asinh(input.native)) -/** Returns a new tensor with the arctangent of the elements of `input`. */ +/** Returns a new tensor with the arctangent of the elements of `input`. + * + * @group pointwise_ops + */ def atan[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.atan(input.native)) -/** Returns a new tensor with the inverse hyperbolic tangent of the elements of `input`. */ +/** Returns a new tensor with the inverse hyperbolic tangent of the elements of `input`. + * + * @group pointwise_ops + */ def atanh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.atanh(input.native)) @@ -93,6 +127,8 @@ def atanh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = * tensor with the signed angles in radians between vector (other, input) and vector (1, 0). (Note * that other, the second parameter, is the x-coordinate, while input, the first parameter, is the * y-coordinate.) + * + * @group pointwise_ops */ def atan2[D <: RealNN, D2 <: RealNN]( input: Tensor[D], @@ -102,11 +138,15 @@ def atan2[D <: RealNN, D2 <: RealNN]( /** Computes the bitwise NOT of the given `input` tensor. The `input` tensor must be of integral or * Boolean types. For bool tensors, it computes the logical NOT. + * + * @group pointwise_ops */ def bitwiseNot[D <: BitwiseNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.bitwise_not(input.native)) /** Computes the bitwise AND of `input` and `other`. For bool tensors, it computes the logical AND. + * + * @group pointwise_ops */ def bitwiseAnd[D <: BitwiseNN, D2 <: BitwiseNN]( input: Tensor[D], @@ -115,6 +155,8 @@ def bitwiseAnd[D <: BitwiseNN, D2 <: BitwiseNN]( Tensor(torchNative.bitwise_and(input.native, other.native)) /** Computes the bitwise OR of `input` and `other`. For bool tensors, it computes the logical OR. + * + * @group pointwise_ops */ def bitwiseOr[D <: BitwiseNN, D2 <: BitwiseNN]( input: Tensor[D], @@ -123,6 +165,8 @@ def bitwiseOr[D <: BitwiseNN, D2 <: BitwiseNN]( Tensor(torchNative.bitwise_or(input.native, other.native)) /** Computes the bitwise XOR of `input` and `other`. For bool tensors, it computes the logical XOR. + * + * @group pointwise_ops */ def bitwiseXor[D <: BitwiseNN, D2 <: BitwiseNN]( input: Tensor[D], @@ -130,7 +174,10 @@ def bitwiseXor[D <: BitwiseNN, D2 <: BitwiseNN]( ): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_xor(input.native, other.native)) -/** Computes the left arithmetic shift of `input` by `other` bits. */ +/** Computes the left arithmetic shift of `input` by `other` bits. + * + * @group pointwise_ops + */ def bitwiseLeftShift[D <: BitwiseNN, D2 <: BitwiseNN]( input: Tensor[D], @@ -138,7 +185,10 @@ def bitwiseLeftShift[D <: BitwiseNN, D2 <: BitwiseNN]( )(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_left_shift(input.native, other.native)) -/** Computes the right arithmetic s\hift of `input` by `other` bits. */ +/** Computes the right arithmetic s\hift of `input` by `other` bits. + * + * @group pointwise_ops + */ def bitwiseRightShift[D <: BitwiseNN, D2 <: BitwiseNN]( input: Tensor[D], other: Tensor[D2] @@ -147,6 +197,8 @@ def bitwiseRightShift[D <: BitwiseNN, D2 <: BitwiseNN]( /** Returns a new tensor with the ceil of the elements of `input`, the smallest integer greater than * or equal to each element. + * + * @group pointwise_ops */ def ceil[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.ceil(input.native)) @@ -154,6 +206,8 @@ def ceil[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = /** Clamps all elements in `input` into the range [ min, max ]. Letting min_value and max_value be * min and max, respectively, this returns: `min(max(input, min_value), max_value)` If min is None, * there is no lower bound. Or, if max is None there is no upper bound. + * + * @group pointwise_ops */ // TODO Support Tensor for min and max def clamp[D <: RealNN]( @@ -165,12 +219,16 @@ def clamp[D <: RealNN]( /** Computes the element-wise conjugate of the given `input` tensor. If input has a non-complex * dtype, this function just returns input. + * + * @group pointwise_ops */ def conjPhysical[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.conj_physical(input.native)) /** Create a new floating-point tensor with the magnitude of input and the sign of other, * elementwise. + * + * @group pointwise_ops */ def copysign[D <: RealNN, D2 <: RealNN]( input: Tensor[D], @@ -184,21 +242,32 @@ def copysign[D <: RealNN, D2 <: RealNN]( torchNative.copysign(input.native, toScalar(other)) ) -/** Returns a new tensor with the cosine of the elements of `input`. */ +/** Returns a new tensor with the cosine of the elements of `input`. + * + * @group pointwise_ops + */ def cos[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.cos(input.native)) -/** Returns a new tensor with the hyperbolic cosine of the elements of `input`. */ +/** Returns a new tensor with the hyperbolic cosine of the elements of `input`. + * + * @group pointwise_ops + */ def cosh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.cosh(input.native)) /** Returns a new tensor with each of the elements of `input` converted from angles in degrees to * radians. + * + * @group pointwise_ops */ def deg2rad[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.deg2rad(input.native)) -/** Divides each element of the input `input` by the corresponding element of `other`. */ +/** Divides each element of the input `input` by the corresponding element of `other`. + * + * @group pointwise_ops + */ // TODO handle roundingMode def div[D <: DType, D2 <: DType]( input: Tensor[D], @@ -217,7 +286,10 @@ export torch.special.erf export torch.special.erfc export torch.special.erfinv -/** Returns a new tensor with the exponential of the elements of the `input` tensor `input`. */ +/** Returns a new tensor with the exponential of the elements of the `input` tensor `input`. + * + * @group pointwise_ops + */ def exp[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.exp(input.native)) @@ -226,6 +298,8 @@ export torch.special.expm1 /** Returns a new tensor with the data in `input` fake quantized per channel using `scale`, * `zero_point`, `quant_min` and `quant_max`, across the channel specified by `axis`. + * + * @group pointwise_ops */ def fakeQuantizePerChannelAffine( input: Tensor[Float32], @@ -248,6 +322,8 @@ def fakeQuantizePerChannelAffine( /** Returns a new tensor with the data in `input` fake quantized using `scale`, `zero_point`, * `quant_min` and `quant_max`. + * + * @group pointwise_ops */ def fakeQuantizePerTensorAffine( input: Tensor[Float32], @@ -279,6 +355,8 @@ def fakeQuantizePerTensorAffine( /** Returns a new tensor with the truncated integer values of the elements of `input`. Alias for * torch.trunc + * + * @group pointwise_ops */ def fix[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.fix(input.native)) @@ -286,6 +364,8 @@ def fix[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = /** Raises `input` to the power of `exponent`, elementwise, in double precision. If neither input is * complex returns a `torch.float64` tensor, and if one or more inputs is complex returns a * `torch.complex128` tensor. + * + * @group pointwise_ops */ def floatPower[D <: DType, D2 <: DType]( input: Tensor[D], @@ -307,11 +387,16 @@ def floatPower[D <: DType, S <: ScalaType]( /** Returns a new tensor with the floor of the elements of `input`, the largest integer less than or * equal to each element. + * + * @group pointwise_ops */ def floor[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.floor(input.native)) -/** Computes `input` divided by `other`, elementwise, and floors the result. */ +/** Computes `input` divided by `other`, elementwise, and floors the result. + * + * @group pointwise_ops + */ def floorDivide[D <: RealNN, D2 <: RealNN]( input: Tensor[D], other: Tensor[D2] @@ -326,6 +411,8 @@ def floorDivide[D <: RealNN, R <: Real]( /** Applies C++’s `std::fmod` entrywise. The result has the same sign as the dividend `input` and * its absolute value is less than that of `other`. + * + * @group pointwise_ops */ // NOTE: When the divisor is zero, returns NaN for floating point dtypes on both CPU and GPU; raises RuntimeError for integer division by zero on CPU; Integer division by zero on GPU may return any value. def fmod[D <: RealNN, D2 <: RealNN]( @@ -340,12 +427,17 @@ def fmod[D <: RealNN, S <: ScalaType]( )(using OnlyOneBool[D, ScalaToDType[S]]): Tensor[Promoted[D, ScalaToDType[S]]] = Tensor(torchNative.fmod(input.native, toScalar(other))) -/** Computes the fractional portion of each element in `input`. */ +/** Computes the fractional portion of each element in `input`. + * + * @group pointwise_ops + */ def frac[D <: FloatNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.frac(input.native)) /** Decomposes `input` into `mantissa` and `exponent` tensors such that `input = mantissa * (2 ** * exponent)` The range of mantissa is the open interval (-1, 1). + * + * @group pointwise_ops */ def frexp[D <: FloatNN](input: Tensor[D]): (Tensor[FloatPromoted[D]], Tensor[Int32]) = val nativeTuple = torchNative.frexp(input.native) @@ -353,6 +445,8 @@ def frexp[D <: FloatNN](input: Tensor[D]): (Tensor[FloatPromoted[D]], Tensor[Int /** Estimates the gradient of a function g:Rn → R in one or more dimensions using the second-order * accurate central differences method. + * + * @group pointwise_ops */ def gradient[D <: Int8 | Int16 | Int32 | Int64 | FloatNN | ComplexNN]( input: Tensor[D], @@ -367,17 +461,24 @@ def gradient[D <: Int8 | Int16 | Int32 | Int64 | FloatNN | ComplexNN]( /** Returns a new tensor containing imaginary values of the `input` tensor. The returned tensor and * `input` share the same underlying storage. + * + * @group pointwise_ops */ def imag[D <: ComplexNN](input: Tensor[D]): Tensor[ComplexToReal[D]] = Tensor(torchNative.imag(input.native)) -/** Multiplies `input` by 2 ** `other`. */ +/** Multiplies `input` by 2 ** `other`. + * + * @group pointwise_ops + */ def ldexp[D <: DType](input: Tensor[D], other: Tensor[D]): Tensor[D] = Tensor(torchNative.ldexp(input.native, other.native)) /** Does a linear interpolation of two tensors `start` (given by `input`) and `end` (given by * `other`) based on a scalar or tensor weight and returns the resulting out tensor. out = start + * weight × (end − start) + * + * @group pointwise_ops */ def lerp[D <: DType]( input: Tensor[D], @@ -391,23 +492,38 @@ def lerp[D <: DType]( case weight: Double => torchNative.lerp(input.native, other.native, toScalar(weight)) ) -/** Computes the natural logarithm of the absolute value of the gamma function on `input`. */ +/** Computes the natural logarithm of the absolute value of the gamma function on `input`. + * + * @group pointwise_ops + */ def lgamma[D <: RealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.lgamma(input.native)) -/** Returns a new tensor with the natural logarithm of the elements of `input`. */ +/** Returns a new tensor with the natural logarithm of the elements of `input`. + * + * @group pointwise_ops + */ def log[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log(input.native)) -/** Returns a new tensor with the logarithm to the base 10 of the elements of `input`. */ +/** Returns a new tensor with the logarithm to the base 10 of the elements of `input`. + * + * @group pointwise_ops + */ def log10[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log10(input.native)) -/** Returns a new tensor with the natural logarithm of (1 + input). */ +/** Returns a new tensor with the natural logarithm of (1 + input). + * + * @group pointwise_ops + */ def log1p[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log1p(input.native)) -/** Returns a new tensor with the logarithm to the base 2 of the elements of `input`. */ +/** Returns a new tensor with the logarithm to the base 2 of the elements of `input`. + * + * @group pointwise_ops + */ def log2[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log2(input.native)) @@ -417,6 +533,8 @@ def log2[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = * of the calculated probability is stored. This function allows adding probabilities stored in * such a fashion. This op should be disambiguated with `torch.logsumexp()` which performs a * reduction on a single tensor. + * + * @group pointwise_ops */ def logaddexp[D <: RealNN, D2 <: RealNN]( input: Tensor[D], @@ -426,6 +544,8 @@ def logaddexp[D <: RealNN, D2 <: RealNN]( /** Logarithm of the sum of exponentiations of the inputs in base-2. Calculates pointwise `log2(2**x * + 2**y)`. See torch.logaddexp() for more details. + * + * @group pointwise_ops */ def logaddexp2[D <: RealNN, D2 <: RealNN]( input: Tensor[D], @@ -435,6 +555,8 @@ def logaddexp2[D <: RealNN, D2 <: RealNN]( /** Computes the element-wise logical AND of the given `input` tensors. Zeros are treated as False * and nonzeros are treated as True. + * + * @group pointwise_ops */ def logicalAnd[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = Tensor(torchNative.logical_and(input.native, other.native)) @@ -443,25 +565,34 @@ def logicalAnd[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Te * a bool tensor, zeros are treated as False and non-zeros are treated as True. * * TODO If not specified, the output tensor will have the bool dtype. + * + * @group pointwise_ops */ def logicalNot[D <: DType](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.logical_not(input.native)) /** Computes the element-wise logical OR of the given `input` tensors. Zeros are treated as False * and nonzeros are treated as True. + * + * @group pointwise_ops */ def logicalOr[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = Tensor(torchNative.logical_or(input.native, other.native)) /** Computes the element-wise logical XOR of the given `input` tensors. Zeros are treated as False * and nonzeros are treated as True. + * + * @group pointwise_ops */ def logicalXor[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = Tensor(torchNative.logical_xor(input.native, other.native)) export torch.special.logit -/** Given the legs of a right triangle, return its hypotenuse. */ +/** Given the legs of a right triangle, return its hypotenuse. + * + * @group pointwise_ops + */ // TODO Change `D2 <: RealNN` once we fix property testing compilation def hypot[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], @@ -473,7 +604,10 @@ export torch.special.i0 export torch.special.igamma export torch.special.igammac -/** Multiplies input by other. */ +/** Multiplies input by other. + * + * @group pointwise_ops + */ def mul[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.mul(input.native, other.native)) @@ -483,6 +617,8 @@ export torch.special.mvlgamma * specified by nan, posinf, and neginf, respectively. By default, NaNs are replaced with zero, * positive infinity is replaced with the greatest finite value representable by input’s dtype, and * negative infinity is replaced with the least finite value representable by input’s dtype. + * + * @group pointwise_ops */ def nanToNum[D <: RealNN]( input: Tensor[D], @@ -494,11 +630,17 @@ def nanToNum[D <: RealNN]( torchNative.nan_to_num(input.native, toOptional(nan), toOptional(posinf), toOptional(neginf)) ) -/** Returns a new tensor with the negative of the elements of `input`. */ +/** Returns a new tensor with the negative of the elements of `input`. + * + * @group pointwise_ops + */ def neg[D <: NumericNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.neg(input.native)) -/** Return the next floating-point value after `input` towards `other`, elementwise. */ +/** Return the next floating-point value after `input` towards `other`, elementwise. + * + * @group pointwise_ops + */ // TODO Change `D2 <: RealNN` once we fix property testing compilation def nextafter[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], @@ -508,13 +650,18 @@ def nextafter[D <: RealNN, D2 <: FloatNN]( export torch.special.polygamma -/** Returns input. Normally throws a runtime error if input is a bool tensor in pytorch. */ +/** Returns input. Normally throws a runtime error if input is a bool tensor in pytorch. + * + * @group pointwise_ops + */ def positive[D <: NumericNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.positive(input.native)) /** Takes the power of each element in `input` with exponent and returns a tensor with the result. * `exponent` can be either a single float number or a Tensor with the same number of elements as * input. + * + * @group pointwise_ops */ def pow[D <: DType, D2 <: DType]( input: Tensor[D], @@ -541,22 +688,31 @@ def pow[S <: ScalaType, D <: DType]( /** Returns a new tensor with each of the elements of `input` converted from angles in radians to * degrees. + * + * @group pointwise_ops */ def rad2Deg[D <: RealNN | Bool](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.rad2deg(input.native)) /** Returns a new tensor containing real values of the self tensor. The returned tensor and self * share the same underlying storage. + * + * @group pointwise_ops */ def real[D <: DType](input: Tensor[D]): Tensor[ComplexToReal[D]] = Tensor(torchNative.real(input.native)) -/** Returns a new tensor with the reciprocal of the elements of `input` */ +/** Returns a new tensor with the reciprocal of the elements of `input` + * + * @group pointwise_ops + */ def reciprocal[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.reciprocal(input.native)) /** Computes Python’s modulus operation entrywise. The result has the same sign as the divisor * `other` and its absolute value is less than that of `other`. + * + * @group pointwise_ops */ def remainder[D <: RealNN, D2 <: RealNN]( input: Tensor[D], @@ -578,18 +734,25 @@ def remainder[D <: DType, R <: Real]( /** Rounds elements of `input` to the nearest integer. If decimals is negative, it specifies the * number of positions to the left of the decimal point. + * + * @group pointwise_ops */ def round[D <: FloatNN](input: Tensor[D], decimals: Long = 0): Tensor[D] = Tensor(torchNative.round(input.native, decimals)) /** Returns a new tensor with the reciprocal of the square-root of each of the elements of `input`. + * + * @group pointwise_ops */ def rsqrt[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.rsqrt(input.native)) export torch.special.sigmoid -/** Returns a new tensor with the signs of the elements of `input`. */ +/** Returns a new tensor with the signs of the elements of `input`. + * + * @group pointwise_ops + */ def sign[D <: RealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.sign(input.native)) @@ -597,35 +760,55 @@ def sign[D <: RealNN](input: Tensor[D]): Tensor[D] = * whose elements have the same angles as the corresponding elements of `input` and absolute values * (i.e. magnitudes) of one for complex tensors and is equivalent to torch.sign() for non-complex * tensors. + * + * @group pointwise_ops */ def sgn[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.sgn(input.native)) -/** Tests if each element of `input`` has its sign bit set or not. */ +/** Tests if each element of `input`` has its sign bit set or not. + * + * @group pointwise_ops + */ def signbit[D <: RealNN](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.signbit(input.native)) -/** Returns a new tensor with the sine of the elements of `input`. */ +/** Returns a new tensor with the sine of the elements of `input`. + * + * @group pointwise_ops + */ def sin[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sin(input.native)) export torch.special.sinc -/** Returns a new tensor with the hyperbolic sine of the elements of `input`. */ +/** Returns a new tensor with the hyperbolic sine of the elements of `input`. + * + * @group pointwise_ops + */ def sinh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sinh(input.native)) export torch.nn.functional.softmax -/** Returns a new tensor with the square-root of the elements of `input`. */ +/** Returns a new tensor with the square-root of the elements of `input`. + * + * @group pointwise_ops + */ def sqrt[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sqrt(input.native)) -/** Returns a new tensor with the square of the elements of `input`. */ +/** Returns a new tensor with the square of the elements of `input`. + * + * @group pointwise_ops + */ def square[D <: DType](input: Tensor[D]): Tensor[NumericPromoted[D]] = Tensor(torchNative.square(input.native)) -/** Subtracts `other`, scaled by `alpha`, from `input`. */ +/** Subtracts `other`, scaled by `alpha`, from `input`. + * + * @group pointwise_ops + */ def sub[D <: NumericNN, D2 <: NumericNN]( input: Tensor[D], other: Tensor[D2] @@ -646,15 +829,24 @@ def sub[D <: NumericNN, D2 <: NumericNN]( ): Tensor[Promoted[D, D2]] = Tensor(torchNative.sub(input.native, toScalar(other), toScalar(alpha))) -/** Returns a new tensor with the tangent of the elements of `input`. */ +/** Returns a new tensor with the tangent of the elements of `input`. + * + * @group pointwise_ops + */ def tan[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.tan(input.native)) -/** Returns a new tensor with the hyperbolic tangent of the elements of `input`. */ +/** Returns a new tensor with the hyperbolic tangent of the elements of `input`. + * + * @group pointwise_ops + */ def tanh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.tanh(input.native)) -/** Alias for `torch.div()` with `rounding_mode=None` */ +/** Alias for `torch.div()` with `rounding_mode=None` + * + * @group pointwise_ops + */ def trueDivide[D <: DType, D2 <: DType]( input: Tensor[D], other: Tensor[D2] @@ -667,7 +859,10 @@ def trueDivide[D <: DType, S <: ScalaType]( ): Tensor[FloatPromoted[Promoted[D, ScalaToDType[S]]]] = Tensor(torchNative.true_divide(input.native, toScalar(other))) -/** Returns a new tensor with the truncated integer values of the elements of `input`. */ +/** Returns a new tensor with the truncated integer values of the elements of `input` + * + * @group pointwise_ops + */ def trunc[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.trunc(input.native)) diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index 5f9fac14..7bb4f97f 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -14,60 +14,861 @@ * limitations under the License. */ +/** TODO figure out how to get these defines working. They work if defined on a class or object, but + * for some reason not on a package directly (they need to be in the same source file though). + * + * @define single_keepdim_details + * If `keepdim` is `true`, the output tensor is of the same size as `input` except in the + * dimension `dim` where it is of size 1. Otherwise, `dim` is squeezed (see `torch.squeeze`), + * resulting in the output tensor having 1 fewer dimension than `input`. + * + * @define multi_keepdim_details + * If `keepdim` is `true`, the output tensor is of the same size as `input` except in the + * dimension(s) `dim` where it is of size 1. Otherwise, `dim` is squeezed (see `torch.squeeze`), + * resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). + * + * @define reduceops_dtype + * the desired data type of returned tensor. If specified, the input tensor is casted to `dtype` + * before the operation is performed. This is useful for preventing data type overflows. + */ package torch import internal.NativeConverters.* import org.bytedeco.pytorch.global.torch as torchNative +import org.bytedeco.pytorch.LongArrayRef +import org.bytedeco.pytorch.ScalarTypeOptional /** Reduction Ops * * https://pytorch.org/docs/stable/torch.html#reduction-ops */ -// TODO argmax Returns the indices of the maximum value of all elements in the `input` tensor. -// TODO argmin Returns the indices of the minimum value(s) of the flattened tensor or along a dimension -// TODO amax Returns the maximum value of each slice of the `input` tensor in the given dimension(s) dim. -// TODO amin Returns the minimum value of each slice of the `input` tensor in the given dimension(s) dim. -// TODO aminmax Computes the minimum and maximum values of the `input` tensor. -// TODO all Tests if all elements in `input` evaluate to True. -// TODO any Tests if any element in `input` evaluates to True. -// TODO max Returns the maximum value of all elements in the `input` tensor. -// TODO min Returns the minimum value of all elements in the `input` tensor. -// TODO dist Returns the p-norm of (input - other) -// TODO logsumexp Returns the log of summed exponentials of each row of the `input` tensor in the given dimension dim. -// TODO mean Returns the mean value of all elements in the `input` tensor. -// TODO nanmean Computes the mean of all non-NaN elements along the specified dimensions. -// TODO median Returns the median of the values in input. -// TODO nanmedian Returns the median of the values in input, ignoring NaN values. -// TODO mode Returns a namedtuple (values, indices) where values is the mode value of each row of the `input` tensor in the given dimension dim, i.e. a value which appears most often in that row, and indices is the index location of each mode value found. -// TODO norm Returns the matrix norm or vector norm of a given tensor. -// TODO nansum Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. -// TODO prod Returns the product of all elements in the `input` tensor. -// TODO quantile Computes the q-th quantiles of each row of the `input` tensor along the dimension dim. +/** Returns the indices of the maximum value of all elements in the tensor. + * + * This is the second value returned by torch.max(). See its documentation for the exact semantics + * of this method. + * + * Example: + * ```scala sc + * val a = torch.rand(Seq(1, 3)) + * torch.argmax(a) + * // tensor dtype=int64, shape=[], device=CPU + * // 1 + * ``` + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. If [[None]], the argmin of the flattened input is returned. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @return + */ +def argmax[D <: DType]( + input: Tensor[D], + dim: Long | Option[Long] = None, + keepdim: Boolean = false +): Tensor[Int64] = Tensor( + torchNative.argmax(input.native, dim.toOptional, keepdim) +) + +/** Returns the indices of the minimum value of all elements in the tensor. + * + * This is the second value returned by torch.min(). See its documentation for the exact semantics + * of this method. + * + * Example: + * ```scala sc + * val a = torch.rand(Seq(1, 3)) + * argmin(a) + * // tensor dtype=int64, shape=[], device=CPU + * // 1 + * ``` + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. If [[None]], the argmin of the flattened input is returned. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @return + */ +def argmin[D <: DType]( + input: Tensor[D], + dim: Long | Option[Long] = None, + keepdim: Boolean = false +): Tensor[Int64] = Tensor( + torchNative.argmin(input.native, dim.toOptional, keepdim) +) + +/** Returns the maximum value of each slice of the `input` tensor in the given dimension(s) `dim`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @return + */ +def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = + Tensor( + torchNative.amax(input.native, dim.toArray, keepdim) + ) + +/** Returns the minimum value of each slice of the `input` tensor in the given dimension(s) `dim`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @return + */ +def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = + Tensor( + torchNative.amin(input.native, dim.toArray, keepdim) + ) + +/** Computes the minimum and maximum values of the `input` tensor. + * + * @group reduction_ops + * + * @param dim + * The dimension along which to compute the values. If [[None]], computes the values over the + * entire input tensor. Default is [[None]]. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @return + */ +def aminmax[D <: DType]( + input: Tensor[D], + dim: Long | Option[Long] = None, + keepdim: Boolean = false +): (Tensor[D], Tensor[D]) = + val native = torchNative.aminmax(input.native, dim.toOptional, keepdim) + (Tensor(native.get0()), Tensor(native.get1())) + +/** Tests if all elements of this tensor evaluate to `true`. + * + * @group reduction_ops + */ +def all[D <: DType](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.all(input.native)) + +/** For each row of `input` in the given dimension `dim`, returns `true` if all elements in the row + * evaluate to `true` and `false` otherwise. + * + * $single_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def all[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tensor[Bool] = Tensor( + torchNative.all(input.native, dim, keepdim) +) + +/** Tests if any elements of this tensor evaluate to `true`. + * + * @group reduction_ops + */ +def any[D <: DType](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.any(input.native)) + +/** For each row of `input` in the given dimension `dim`, returns `true` if any element in the row + * evaluates to `true` and `false` otherwise. + * + * $single_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def any[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tensor[Bool] = Tensor( + torchNative.any(input.native, dim, keepdim) +) + +/** Returns the maximum value of all elements in the `input` tensor. + * + * @group reduction_ops + */ +def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max()) + +/** Returns a [[TensorTuple]] `(values, indices)` where `values` is the maximum value of each row of + * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each + * maximum value found (argmax). + * + * $single_keepdim_details + * + * @note + * If there are multiple maximal values in a reduced row then the indices of the first maximal + * value are returned. + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = + val nativeTuple = torchNative.max(input.native, dim, keepdim) + TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) + +/** Returns the maximum value of all elements in the `input` tensor. + * + * @group reduction_ops + */ +def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min()) + +/** Returns a [[TensorTuple]] `(values, indices)` where `values` is the minimum value of each row of + * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each + * maximum value found (argmax). + * + * $single_keepdim_details + * + * @note + * If there are multiple minimal values in a reduced row then the indices of the first minimal + * value are returned. + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = + val nativeTuple = torchNative.min(input.native, dim, keepdim) + TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) + +/** Returns the p-norm of (`input` - `other`) + * + * The shapes of `input` and `other` must be broadcastable. + * + * @group reduction_ops + * + * @param p + * the norm to be computed + */ +// TODO dtype promotion floatNN/complexNN => highest floatNN +// TODO (using AtLeastOneFloat[D, D2] +def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float = 2): Tensor[D] = + Tensor(torchNative.dist(input.native, other.native, toScalar(p))) + +/** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension + * `dim`. The computation is numerically stabilized. + * + * For summation index $j$ given by `dim` and other indices $i$, the result is + * + * > $$\text{{logsumexp}}(x)_{{i}} = \log \sum_j \exp(x_{{ij}})$$ + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def logsumexp[D <: DType]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false +): Tensor[D] = Tensor( + torchNative.logsumexp(input.native, dim.toArray, keepdim) +) + +/** Returns the mean value of all elements in the `input` tensor. + * + * @group reduction_ops + */ +def mean[D <: DType]( + input: Tensor[D] +): Tensor[D] = Tensor(torchNative.mean(input.native)) + +/** Returns the mean value of all elements in the `input` tensor. + * + * @group reduction_ops + * + * @param dtype + * $reduceops_dtype + */ +def mean[D <: DType]( + input: Tensor[?], + dtype: D +): Tensor[D] = Tensor(torchNative.mean(input.native, new ScalarTypeOptional(dtype.toScalarType))) + +/** Returns the mean value of each row of the `input` tensor in the given dimension `dim`. If `dim` + * is a list of dimensions, reduce over all of them. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def mean[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + dtype: D2 = derive +): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + // TODO factor out + val derivedDType = dtype match + case _: Derive => input.dtype + case d: DType => d + Tensor( + torchNative.mean( + input.native, + dim.toArray, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + +/** Computes the mean of all [non-NaN] elements along the specified dimensions. + * + * This function is identical to `torch.mean` when there are no [NaN] values in the `input` tensor. + * In the presence of [NaN], `torch.mean` will propagate the [NaN] to the output whereas + * `torch.nanmean` will ignore the [NaN] values ([torch.nanmean(a)] is equivalent to + * [torch.mean(a\[\~a.isnan()\])]). + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def nanmean[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + dtype: D2 = derive +): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + // TODO factor out + val derivedDType = dtype match + case _: Derive => input.dtype + case d: DType => d + Tensor( + torchNative.nanmean( + input.native, + dim.toArray, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + + /** Returns the median of the values in `input`. + * + * @note + * The median is not unique for `input` tensors with an even number of elements. In this case + * the lower of the two medians is returned. To compute the mean of both medians, use + * `torch.quantile` with `q=0.5` instead. + * + * Warning + * + * This function produces deterministic (sub)gradients unlike `median(dim=0)` + * + * @group reduction_ops + */ +def median[D <: DType]( + input: Tensor[D] +): Tensor[D] = Tensor(torchNative.median(input.native)) + +/** Returns a [[TensorTuple]] `(values, indices)` where `values` contains the median of each row of + * `input` in the dimension `dim`, and `indices` contains the index of the median values found in + * the dimension `dim`. + * + * By default, `dim` is the last dimension of the `input` tensor. + * + * $single_keepdim_details + * + * @note + * The median is not unique for `input` tensors with an even number of elements in the dimension + * `dim`. In this case the lower of the two medians is returned. To compute the mean of both + * medians in `input`, use `torch.quantile` with `q=0.5` instead. + * + * Warning + * + * `indices` does not necessarily contain the first occurrence of each median value found, unless + * it is unique. The exact implementation details are device-specific. Do not expect the same + * result when run on CPU and GPU in general. For the same reason do not expect the gradients to be + * deterministic. + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def median[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Long = -1, + keepdim: Boolean = false +): TensorTuple[D] = + val nativeTuple = torchNative.median(input.native, dim, keepdim) + TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) + + /** Returns the median of the values in `input`, ignoring `NaN` values. + * + * This function is identical to `torch.median` when there are no `NaN` values in `input`. When + * `input` has one or more `NaN` values, `torch.median` will always return `NaN`, while this + * function will return the median of the non-`NaN` elements in `input`. If all the elements in + * `input` are `NaN` it will also return `NaN`. + * + * @group reduction_ops + */ +def nanmedian[D <: DType]( + input: Tensor[D] +): Tensor[D] = Tensor(torchNative.nanmedian(input.native)) + +/** Returns a [[TensorTuple]] ``(values, indices)`` where ``values`` contains the median of each row + * of `input` in the dimension `dim`, ignoring ``NaN`` values, and ``indices`` contains the index + * of the median values found in the dimension `dim`. + * + * This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced + * row. When a reduced row has one or more ``NaN`` values, :func:`torch.median` will always reduce + * it to ``NaN``, while this function will reduce it to the median of the non-``NaN`` elements. If + * all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def nanmedian[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Long = -1, + keepdim: Boolean = false +): TensorTuple[D] = + val nativeTuple = torchNative.nanmedian(input.native, dim, keepdim) + TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) + +/** Returns a [[TensorTuple]] `(values, indices)` where `values` is the mode value of each row of + * the `input` tensor in the given dimension `dim`, + * i.e. a value which appears most often in that row, and `indices` is the index location of each + * mode value found. + * + * By default, `dim` is the last dimension of the `input` tensor. + * + * $single_keepdim_details + * + * @note + * This function is not defined for `torch.cuda.Tensor` yet. + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + */ +def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = + val nativeTuple = torchNative.mode(input.native, dim, keepdim) + TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) + +/** Returns the sum of each row of the `input` tensor in the given dimension `dim`, treating Not a + * Numbers (NaNs) as zero. If `dim` is a list of dimensions, reduce over all of them. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def nansum[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + dtype: D2 = derive +): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + // TODO factor out + val derivedDType = dtype match + case _: Derive => input.dtype + case d: DType => d + Tensor( + torchNative.nansum( + input.native, + dim.toArray, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + +/** Returns the product of all elements in the `input` tensor. + * + * @group reduction_ops + */ +def prod[D <: DType, D2 <: DType | Derive]( + input: Tensor[D] +): Tensor[D] = Tensor(torchNative.prod(input.native)) + +/** Returns the product of all elements in the `input` tensor. + * + * @group reduction_ops + * + * @param dtype + * $reduceops_dtype + */ +def prod[D <: DType]( + input: Tensor[?], + dtype: D +): Tensor[D] = Tensor(torchNative.prod(input.native, new ScalarTypeOptional(dtype.toScalarType))) + +/** Returns the product of each row of the `input` tensor in the given dimension `dim`. + * + * $single_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension to reduce. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def prod[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Long, + keepdim: Boolean = false, + dtype: D2 = derive +): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + // TODO factor out + val derivedDType = dtype match + case _: Derive => input.dtype + case d: DType => d + Tensor( + torchNative.prod( + input.native, + dim, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + + /** Computes the q-th quantiles of each row of the `input` tensor along the dimension `dim`. + * + * To compute the quantile, we map q in \[0, 1\] to the range of indices \[0, n\] to find the + * location of the quantile in the sorted input. If the quantile lies between two data points `a + * < b` with indices `i` and `j` in the sorted order, result is computed according to the given + * `interpolation` method as follows: + * + * - `linear`: `a + (b - a) * fraction`, where `fraction` is the fractional part of the + * computed quantile index. + * - `lower`: `a`. + * - `higher`: `b`. + * - `nearest`: `a` or `b`, whichever\'s index is closer to the computed quantile index + * (rounding down for .5 fractions). + * - `midpoint`: `(a + b) / 2`. + * + * If `q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + * equal to the size of `q`, the remaining dimensions are what remains from the reduction. + * + * @note + * By default `dim` is `None` resulting in the `input` tensor being flattened before + * computation. + * + * @group reduction_ops + * + * @param q + * (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param interpolation + * interpolation method to use when the desired quantile lies between two data points. Can be + * ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. Default is ``linear``. + */ +// def quantile[D <: DType, D2 <: DType | Derive]( +// input: Tensor[D], +// q: Double | Tensor[?], // TODO only float tensor? +// dim: Option[Long] = None, +// keepdim: Boolean = false, + +// ): Tensor[DTypeOrDeriveFromTensor[D, D2]] = +// Tensor( +// torchNative.quantile( +// input.native, +// q, +// dim.toOptional, +// keepdim, +// // TODO figure out how to create c10:string_view values +// ) +// ) + // TODO nanquantile This is a variant of torch.quantile() that "ignores" NaN values, computing the quantiles q as if NaN values in `input` did not exist. -// TODO std Calculates the standard deviation over the dimensions specified by dim. -// TODO std_mean Calculates the standard deviation and mean over the dimensions specified by dim. +// (same issue as quantile, need to figure out how to create c10:string_view values) -/* Returns the sum of all elements in the `input` tensor. */ -def sum[D <: DType]( +/** Calculates the standard deviation over the dimensions specified by `dim`. `dim` can be a single + * dimension, list of dimensions, or `None` to reduce over all dimensions. + * + * The standard deviation ($\sigma$) is calculated as + * + * $$\sigma = \sqrt{\frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2}$$ + * + * where $x$ is the sample set of elements, $\bar{x}$ is the sample mean, $N$ is the number of + * samples and $\delta N$ is the `correction`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param correction + * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s + * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. + */ +def std[D <: DType]( input: Tensor[D], - dim: Array[Long] = Array(), + dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, - dtype: Option[DType] = None + correction: Long = 1 ): Tensor[D] = - val lar = new org.bytedeco.pytorch.LongArrayRef(dim, dim.size) - val laro = new org.bytedeco.pytorch.LongArrayRefOptional(lar) - // TODO Add dtype - val sto = new org.bytedeco.pytorch.ScalarTypeOptional() - Tensor(torchNative.sum(input.native, dim, keepdim, sto)) + Tensor( + torchNative.std( + input.native, + dim.toArray, + correction.toOptional, + keepdim + ) + ) + +/** Calculates the standard deviation and mean over the dimensions specified by `dim`. `dim` can be + * a single dimension, list of dimensions, or `None` to reduce over all dimensions. + * + * The standard deviation ($\sigma$) is calculated as + * + * $$\sigma = \sqrt{\frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2}$$ + * + * where $x$ is the sample set of elements, $\bar{x}$ is the sample mean, $N$ is the number of + * samples and $\delta N$ is the `correction`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param correction + * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s + * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. + * @return + * A tuple (std, mean) containing the standard deviation and mean. + */ +def std_mean[D <: DType]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + correction: Long = 1 +): (Tensor[D], Tensor[D]) = + val nativeTuple = + torchNative.std_mean( + input.native, + dim.toArray, + correction.toOptional, + keepdim + ) + (Tensor[D](nativeTuple.get0), Tensor[D](nativeTuple.get1)) + +/** Returns the sum of all elements in the `input` tensor. + * + * @group reduction_ops + */ +def sum[D <: DType, D2 <: DType | Derive]( + input: Tensor[D] +): Tensor[D] = Tensor(torchNative.sum(input.native)) + +/** Returns the sum of all elements in the `input` tensor. + * + * @group reduction_ops + * + * @param dtype + * $reduceops_dtype + */ +def sum[D <: DType]( + input: Tensor[?], + dtype: D +): Tensor[D] = Tensor(torchNative.sum(input.native, new ScalarTypeOptional(dtype.toScalarType))) + +/** Returns the sum of each row of the `input` tensor in the given dimension `dim`. + * + * If dim is a list of dimensions, reduce over all of them. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param dtype + * $reduceops_dtype + */ +def sum[D <: DType, D2 <: DType | Derive]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + dtype: D2 = derive +): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + // TODO factor out + val derivedDType = dtype match + case _: Derive => input.dtype + case d: DType => d + Tensor( + torchNative.sum( + input.native, + dim.toArray, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + + // TODO unique Returns the unique elements of the `input` tensor. + // seems to be implemented in https://github.com/pytorch/pytorch/blob/main/torch/functional.py + // and calls different native functions depending on dim unique_dim or _unique2 + + // TODO unique_consecutive Eliminates all but the first element from every consecutive group of equivalent elements. + // Similar to unique we should look at _unique_consecutive_impl in Python first + // https://github.com/pytorch/pytorch/blob/dbc8eb2a8fd894fbc110bbb9f70037249868afa8/torch/functional.py#L827 + + // TODO var + /* TODO Calculates the variance over the dimensions specified by dim. */ + // def variance[D <: DType](input: Tensor[D], dim: Seq[Int] = Nil, correction: Option[Int] = None, keepdim: Boolean = false) = + // Tensor(torchNative.`var`(input.native, dim.toArray.map(_.toLong), toOptional(correction), keepdim)) -// TODO unique Returns the unique elements of the `input` tensor. -// TODO unique_consecutive Eliminates all but the first element from every consecutive group of equivalent elements. +/** Calculates the variance over the dimensions specified by `dim`. `dim` can be a single dimension, + * list of dimensions, or `None` to reduce over all dimensions. + * + * The variance ($\sigma^2$) is calculated as + * + * $$\sigma^2 = \frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2$$ + * + * where $x$ is the sample set of elements, $\bar{x}$ is the sample mean, $N$ is the number of + * samples and $\delta N$ is the `correction`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param correction + * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s + * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. + */ +def variance[D <: DType]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + correction: Long = 1 +): Tensor[D] = + Tensor( + torchNative.`var`( + input.native, + dim.toArray, + correction.toOptional, + keepdim + ) + ) -/* TODO Calculates the variance over the dimensions specified by dim. */ -//def variance[D <: DType](input: Tensor[D], dim: Seq[Int] = Nil, correction: Option[Int] = None, keepdim: Boolean = false) = -// Tensor(torchNative.`var`(input.native, dim.toArray.map(_.toLong), toOptional(correction), keepdim)) +/** Calculates the variance and mean over the dimensions specified by `dim`. `dim` can be a single + * dimension, list of dimensions, or `None` to reduce over all dimensions. + * + * The variance ($\sigma^2$) is calculated as + * + * $$\sigma^2 = \frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2$$ + * + * where $x$ is the sample set of elements, $\bar{x}$ is the sample mean, $N$ is the number of + * samples and $\delta N$ is the `correction`. + * + * $multi_keepdim_details + * + * @group reduction_ops + * + * @param dim + * the dimension or dimensions to reduce. If empty, all dimensions are reduced. + * @param keepdim + * whether the output tensor has `dim` retained or not. + * @param correction + * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s + * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. + * @return + * A tuple (var, mean) containing the variance and mean. + */ +def var_mean[D <: DType]( + input: Tensor[D], + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + correction: Long = 1 +): (Tensor[D], Tensor[D]) = + val nativeTuple = + torchNative.var_mean( + input.native, + dim.toArray, + correction.toOptional, + keepdim + ) + (Tensor[D](nativeTuple.get0), Tensor[D](nativeTuple.get1)) -// TODO var_mean Calculates the variance and mean over the dimensions specified by dim. -// TODO count_nonzero Counts the number of non-zero values in the tensor `input` along the given dim. +/** Counts the number of non-zero values in the tensor `input` along the given `dim`. If no dim is + * specified then all non-zeros in the tensor are counted. + * + * @group reduction_ops + * + * @param dim + * Dim or seq of dims along which to count non-zeros. + */ +def count_nonzero( + input: Tensor[?], + dim: Int | Seq[Int] = Seq.empty +): Tensor[Int64] = + val nativeDim = dim.toArray + Tensor( + if nativeDim.isEmpty then torchNative.count_nonzero(input.native) + else torchNative.count_nonzero(input.native, nativeDim: _*) + ) diff --git a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala index 50bea4c1..dbda5d48 100644 --- a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala +++ b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala @@ -19,7 +19,7 @@ package torch class ReductionOpsSuite extends TensorCheckSuite { testUnaryOp( - op = sum(_, Array(), false, None), + op = sum(_), opName = "sum", inputTensor = Tensor(Seq(5.0, 5.0)), expectedTensor = Tensor(10.0) From 328d7ceb5dea5e87003fdcaa1bd915643491532d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 18 Jun 2023 23:02:29 +0200 Subject: [PATCH 2/3] Add tests for most reduction ops --- core/src/main/scala/torch/Tensor.scala | 2 + core/src/main/scala/torch/Types.scala | 3 +- .../torch/internal/NativeConverters.scala | 2 +- .../main/scala/torch/ops/ReductionOps.scala | 61 +++-- core/src/test/scala/torch/Generators.scala | 2 +- .../scala/torch/ops/ReductionOpsSuite.scala | 223 +++++++++++++++++- 6 files changed, 264 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 79734c34..bd765580 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -351,6 +351,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto /** Returns the tensor with elements logged. */ def log: Tensor[D] = Tensor(native.log()) + def long: Tensor[Int64] = to(dtype = int64) + def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor[Promoted[D, D2]](native.matmul(u.native)) diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala index 9e6fc279..086da9a0 100644 --- a/core/src/main/scala/torch/Types.scala +++ b/core/src/main/scala/torch/Types.scala @@ -47,4 +47,5 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN /* Evidence used in operations where at least one Float or Complex is required */ -type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN +type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) | + B <:< (FloatNN | ComplexNN) diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index daafd300..7d82a58d 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -85,7 +85,7 @@ private[torch] object NativeConverters: extension (x: ScalaType) def toScalar: pytorch.Scalar = x match - case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte) + case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte) case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item() case x: Byte => pytorch.Scalar(x) case x: Short => pytorch.Scalar(x) diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index 7bb4f97f..fba5af41 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -65,7 +65,7 @@ import org.bytedeco.pytorch.ScalarTypeOptional * whether the output tensor has `dim` retained or not. * @return */ -def argmax[D <: DType]( +def argmax[D <: IntNN | FloatNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -94,7 +94,7 @@ def argmax[D <: DType]( * whether the output tensor has `dim` retained or not. * @return */ -def argmin[D <: DType]( +def argmin[D <: IntNN | FloatNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -114,7 +114,11 @@ def argmin[D <: DType]( * whether the output tensor has `dim` retained or not. * @return */ -def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = +def amax[D <: RealNN]( + input: Tensor[D], + dim: Long | Seq[Long], + keepdim: Boolean = false +): Tensor[D] = Tensor( torchNative.amax(input.native, dim.toArray, keepdim) ) @@ -131,7 +135,11 @@ def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = * whether the output tensor has `dim` retained or not. * @return */ -def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = +def amin[D <: RealNN]( + input: Tensor[D], + dim: Long | Seq[Long], + keepdim: Boolean = false +): Tensor[D] = Tensor( torchNative.amin(input.native, dim.toArray, keepdim) ) @@ -147,7 +155,7 @@ def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = * whether the output tensor has `dim` retained or not. * @return */ -def aminmax[D <: DType]( +def aminmax[D <: RealNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -203,7 +211,7 @@ def any[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * * @group reduction_ops */ -def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max()) +def max[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.max()) /** Returns a [[TensorTuple]] `(values, indices)` where `values` is the maximum value of each row of * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each @@ -222,7 +230,7 @@ def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max()) * @param keepdim * whether the output tensor has `dim` retained or not. */ -def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def max[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.max(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -230,7 +238,7 @@ def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * * @group reduction_ops */ -def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min()) +def min[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.min()) /** Returns a [[TensorTuple]] `(values, indices)` where `values` is the minimum value of each row of * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each @@ -249,7 +257,7 @@ def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min()) * @param keepdim * whether the output tensor has `dim` retained or not. */ -def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.min(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -263,8 +271,11 @@ def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * the norm to be computed */ // TODO dtype promotion floatNN/complexNN => highest floatNN -// TODO (using AtLeastOneFloat[D, D2] -def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float = 2): Tensor[D] = +def dist[D <: NumericNN, D2 <: NumericNN]( + input: Tensor[D], + other: Tensor[D2], + p: Float = 2 +)(using AtLeastOneFloat[D, D2]): Tensor[D] = Tensor(torchNative.dist(input.native, other.native, toScalar(p))) /** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension @@ -283,7 +294,7 @@ def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float * @param keepdim * whether the output tensor has `dim` retained or not. */ -def logsumexp[D <: DType]( +def logsumexp[D <: RealNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false @@ -295,7 +306,7 @@ def logsumexp[D <: DType]( * * @group reduction_ops */ -def mean[D <: DType]( +def mean[D <: FloatNN | ComplexNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.mean(input.native)) @@ -306,7 +317,7 @@ def mean[D <: DType]( * @param dtype * $reduceops_dtype */ -def mean[D <: DType]( +def mean[D <: FloatNN | ComplexNN]( input: Tensor[?], dtype: D ): Tensor[D] = Tensor(torchNative.mean(input.native, new ScalarTypeOptional(dtype.toScalarType))) @@ -362,7 +373,7 @@ def mean[D <: DType, D2 <: DType | Derive]( * @param dtype * $reduceops_dtype */ -def nanmean[D <: DType, D2 <: DType | Derive]( +def nanmean[D <: FloatNN, D2 <: DType | Derive]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -394,7 +405,7 @@ def nanmean[D <: DType, D2 <: DType | Derive]( * * @group reduction_ops */ -def median[D <: DType]( +def median[D <: NumericRealNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.median(input.native)) @@ -427,7 +438,7 @@ def median[D <: DType]( * @param dtype * $reduceops_dtype */ -def median[D <: DType, D2 <: DType | Derive]( +def median[D <: NumericRealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Long = -1, keepdim: Boolean = false @@ -444,7 +455,7 @@ def median[D <: DType, D2 <: DType | Derive]( * * @group reduction_ops */ -def nanmedian[D <: DType]( +def nanmedian[D <: NumericRealNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.nanmedian(input.native)) @@ -466,7 +477,7 @@ def nanmedian[D <: DType]( * @param dtype * $reduceops_dtype */ -def nanmedian[D <: DType, D2 <: DType | Derive]( +def nanmedian[D <: NumericRealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Long = -1, keepdim: Boolean = false @@ -493,7 +504,7 @@ def nanmedian[D <: DType, D2 <: DType | Derive]( * @param keepdim * whether the output tensor has `dim` retained or not. */ -def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def mode[D <: RealNN](input: Tensor[D], dim: Long = -1, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.mode(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -511,7 +522,7 @@ def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten * @param dtype * $reduceops_dtype */ -def nansum[D <: DType, D2 <: DType | Derive]( +def nansum[D <: RealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -658,7 +669,7 @@ def prod[D <: DType, D2 <: DType | Derive]( * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. */ -def std[D <: DType]( +def std[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -697,7 +708,7 @@ def std[D <: DType]( * @return * A tuple (std, mean) containing the standard deviation and mean. */ -def std_mean[D <: DType]( +def std_mean[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -801,7 +812,7 @@ def sum[D <: DType, D2 <: DType | Derive]( * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. */ -def variance[D <: DType]( +def variance[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -840,7 +851,7 @@ def variance[D <: DType]( * @return * A tuple (var, mean) containing the variance and mean. */ -def var_mean[D <: DType]( +def var_mean[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, diff --git a/core/src/test/scala/torch/Generators.scala b/core/src/test/scala/torch/Generators.scala index 170c240b..d4f55638 100644 --- a/core/src/test/scala/torch/Generators.scala +++ b/core/src/test/scala/torch/Generators.scala @@ -56,7 +56,7 @@ object Generators: inline def genTensor[D <: DType]: Gen[Tensor[D]] = Gen.oneOf(allDTypes.filter(_.isInstanceOf[D])).map { dtype => - ones(10, dtype = dtype.asInstanceOf[D]) + ones(Seq(4, 4), dtype = dtype.asInstanceOf[D]) } val genDType = Gen.oneOf(allDTypes) diff --git a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala index dbda5d48..1bc1c720 100644 --- a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala +++ b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala @@ -18,11 +18,232 @@ package torch class ReductionOpsSuite extends TensorCheckSuite { + // TODO test with dim/keepdim variants for corresponding ops + + testUnaryOp( + op = argmax(_), + opName = "argmax", + inputTensor = + Tensor(Seq(1.0, 0.5, 1.2, -2)), // TODO check why we can't call tensor ops on inputTensor here + expectedTensor = Tensor(2L) + ) + + testUnaryOp( + op = argmin(_), + opName = "argmin", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(3L) + ) + + propertyTestUnaryOp(amax(_, 0), "amax") + // TODO unit test amax + + propertyTestUnaryOp(amin(_, 0), "amin") + // TODO unit test amin + + propertyTestUnaryOp(aminmax(_), "aminmax") + // TODO unit test aminmax + + testUnaryOp( + op = all, + opName = "all", + inputTensor = Tensor(Seq(true, true, false, true)), + expectedTensor = Tensor(false) + ) + + test("all") { + assertEquals( + all(Tensor(Seq(true, true, false, true))), + Tensor(false) + ) + assertEquals( + all(Tensor(Seq(true, true, true, true))), + Tensor(true) + ) + } + + testUnaryOp( + op = any, + opName = "any", + inputTensor = Tensor(Seq(true, true, false, true)), + expectedTensor = Tensor(true) + ) + + test("any") { + assertEquals( + any(Tensor(Seq(true, true, false, true))), + Tensor(true) + ) + assertEquals( + any(Tensor(Seq(false, false, false, false))), + Tensor(false) + ) + } + + testUnaryOp( + op = max, + opName = "max", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(1.2) + ) + + testUnaryOp( + op = min, + opName = "min", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(-2.0) + ) + + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloatOrComplex + // propertyTestBinaryOp(dist35, "dist") + + def unitTestDist(p: Float, expected: Float) = unitTestBinaryOp[Float32, Float, Float32, Float]( + dist(_, _, p), + "dist", + inputTensors = ( + Tensor(Seq[Float](-1.5393, -0.8675, 0.5916, 1.6321)), + Tensor(Seq[Float](0.0967, -1.0511, 0.6295, 0.8360)) + ), + expectedTensor = Tensor(expected) + ) + + unitTestDist(3.5, 1.6727) + unitTestDist(3, 1.6973) + unitTestDist(0, 4) + unitTestDist(1, 2.6537) + + propertyTestUnaryOp(logsumexp(_, dim = 0), "logsumexp") + // TODO unit test logsumexp + + testUnaryOp( + op = mean, + opName = "mean", + inputTensor = Tensor(Seq(0.2294, -0.5481, 1.3288)), + expectedTensor = Tensor(0.3367) + ) + + test("mean with nan") { + assert(mean(Tensor(Seq(Float.NaN, 1, 2))).isnan.item) + } + testUnaryOp( - op = sum(_), + op = nanmean(_), + opName = "nanmean", + inputTensor = Tensor(Seq(Float.NaN, 1, 2, 1, 2, 3)), + expectedTensor = Tensor(1.8f) + ) + + test("nanmean with nan") { + val t = Tensor(Seq(Float.NaN, 1, 2)) + assert(!nanmean(t).isnan.item) + } + + testUnaryOp( + op = median, + opName = "median", + inputTensor = Tensor(Seq(1, 5, 2, 3, 4)), + expectedTensor = Tensor(3) + ) + + testUnaryOp( + op = nanmedian, + opName = "nanmedian", + inputTensor = Tensor(Seq(1, 5, Float.NaN, 3, 4)), + expectedTensor = Tensor(3f) + ) + + propertyTestUnaryOp(mode(_), "mode") + + test("mode") { + torch.manualSeed(0) + assertEquals( + torch.mode(Tensor[Int](Seq(6, 5, 1, 0, 2)), 0), + TensorTuple(Tensor(0), Tensor[Long](3L)) + ) + } + + // test("mode") { + // torch.manualSeed(0) + // val a = Tensor(Seq(6, 5, 1, 0, 2)) + // val b = a + Tensor(Seq(-3 , -11, -6, -7, 4)).reshape(5,1) + // val values = Tensor(Seq(-5, -5, -5, -4, -8)) + // val indices = Tensor(Seq(1, 1, 1, 1, 1)).long + // assertEquals( + // torch.mode(b, 0), + // TensorTuple(values, indices) + // ) + // } + + testUnaryOp( + op = nansum(_), + opName = "nansum", + inputTensor = Tensor(Seq(1, 5, Float.NaN, 3, 4)), + expectedTensor = Tensor(13f) + ) + + testUnaryOp( + op = prod, + opName = "prod", + inputTensor = Tensor(Seq(5.0, 5.0)), + expectedTensor = Tensor(25.0) + ) + + // TODO quantile + // TODO nanquantile + + propertyTestUnaryOp(std(_), "std") + + unitTestUnaryOp[Float32, Float]( + op = std(_, dim = 1, keepdim = true), + opName = "std", + inputTensor = Tensor( + Seq( + Seq[Float](0.2035, 1.2959, 1.8101, -0.4644), + Seq[Float](1.5027, -0.3270, 0.5905, 0.6538), + Seq[Float](-1.5745, 1.3330, -0.5596, -0.6548), + Seq[Float](0.1264, -0.5080, 1.6420, 0.1992) + ).flatten + ).reshape(4, 4), + expectedTensor = Tensor(Seq[Float](1.0311, 0.7477, 1.2204, 0.9087)).reshape(4, 1) + ) + + propertyTestUnaryOp(std_mean(_), "std_mean") + // TODO unit test std_mean + + testUnaryOp( + op = sum, opName = "sum", inputTensor = Tensor(Seq(5.0, 5.0)), expectedTensor = Tensor(10.0) ) + // TODO unique + // TODO unique_consecutive + + propertyTestUnaryOp(variance(_), "variance") + + unitTestUnaryOp[Float32, Float]( + op = variance(_, dim = 1, keepdim = true), + opName = "variance", + inputTensor = Tensor( + Seq( + Seq[Float](0.2035, 1.2959, 1.8101, -0.4644), + Seq[Float](1.5027, -0.3270, 0.5905, 0.6538), + Seq[Float](-1.5745, 1.3330, -0.5596, -0.6548), + Seq[Float](0.1264, -0.5080, 1.6420, 0.1992) + ).flatten + ).reshape(4, 4), + expectedTensor = Tensor(Seq[Float](1.0631, 0.5590, 1.4893, 0.8258)).reshape(4, 1) + ) + + propertyTestUnaryOp(var_mean(_), "var_mean") + // TODO unit test var_mean + + testUnaryOp( + op = count_nonzero(_), + opName = "count_nonzero", + inputTensor = Tensor(Seq(1, 0, 0, 1, 0)), + expectedTensor = Tensor(2L) + ) + } From fc94bccd07407cc79d005b39a0cb948af47917f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Tue, 20 Jun 2023 21:48:58 +0200 Subject: [PATCH 3/3] Fix dist promoted type and add missing FloatPromoted cases --- core/src/main/scala/torch/DType.scala | 6 ++++-- core/src/main/scala/torch/ops/ReductionOps.scala | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index f32e3773..18b33c6d 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -390,8 +390,10 @@ type NumericPromoted[D <: DType] <: DType = D match /** Promoted type for tensor operations that always output floats (e.g. `sin`) */ type FloatPromoted[D <: DType] <: FloatNN = D match - case Float64 => Float64 - case _ => Float32 + case Float16 => Float16 + case BFloat16 => BFloat16 + case Float64 => Float64 + case _ => Float32 /** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */ type ComplexToReal[D <: DType] <: DType = D match diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index fba5af41..476ce085 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -270,12 +270,13 @@ def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten * @param p * the norm to be computed */ -// TODO dtype promotion floatNN/complexNN => highest floatNN def dist[D <: NumericNN, D2 <: NumericNN]( input: Tensor[D], other: Tensor[D2], p: Float = 2 -)(using AtLeastOneFloat[D, D2]): Tensor[D] = +)(using + AtLeastOneFloat[D, D2] +): Tensor[Promoted[FloatPromoted[ComplexToReal[D]], FloatPromoted[ComplexToReal[D2]]]] = Tensor(torchNative.dist(input.native, other.native, toScalar(p))) /** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension