Skip to content

Commit

Permalink
Add tests for most reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jun 19, 2023
1 parent 6ae3bb8 commit 9901dd0
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 28 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
/** Returns the tensor with elements logged. */
def log: Tensor[D] = Tensor(native.log())

def long: Tensor[Int64] = to(dtype = int64)

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
Tensor[Promoted[D, D2]](native.matmul(u.native))

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool]
type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN

/* Evidence used in operations where at least one Float or Complex is required */
type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN
type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | ComplexNN | B <:< FloatNN |
ComplexNN
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private[torch] object NativeConverters:

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
Expand Down
61 changes: 36 additions & 25 deletions core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ import org.bytedeco.pytorch.ScalarTypeOptional
* whether the output tensor has `dim` retained or not.
* @return
*/
def argmax[D <: DType](
def argmax[D <: IntNN | FloatNN](
input: Tensor[D],
dim: Long | Option[Long] = None,
keepdim: Boolean = false
Expand Down Expand Up @@ -94,7 +94,7 @@ def argmax[D <: DType](
* whether the output tensor has `dim` retained or not.
* @return
*/
def argmin[D <: DType](
def argmin[D <: IntNN | FloatNN](
input: Tensor[D],
dim: Long | Option[Long] = None,
keepdim: Boolean = false
Expand All @@ -114,7 +114,11 @@ def argmin[D <: DType](
* whether the output tensor has `dim` retained or not.
* @return
*/
def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] =
def amax[D <: RealNN](
input: Tensor[D],
dim: Long | Seq[Long],
keepdim: Boolean = false
): Tensor[D] =
Tensor(
torchNative.amax(input.native, dim.toArray, keepdim)
)
Expand All @@ -131,7 +135,11 @@ def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean =
* whether the output tensor has `dim` retained or not.
* @return
*/
def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] =
def amin[D <: RealNN](
input: Tensor[D],
dim: Long | Seq[Long],
keepdim: Boolean = false
): Tensor[D] =
Tensor(
torchNative.amin(input.native, dim.toArray, keepdim)
)
Expand All @@ -147,7 +155,7 @@ def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean =
* whether the output tensor has `dim` retained or not.
* @return
*/
def aminmax[D <: DType](
def aminmax[D <: RealNN](
input: Tensor[D],
dim: Long | Option[Long] = None,
keepdim: Boolean = false
Expand Down Expand Up @@ -203,7 +211,7 @@ def any[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens
*
* @group reduction_ops
*/
def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max())
def max[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.max())

/** Returns a [[TensorTuple]] `(values, indices)` where `values` is the maximum value of each row of
* the `input` tensor in the given dimension `dim`. And `indices` is the index location of each
Expand All @@ -222,15 +230,15 @@ def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max())
* @param keepdim
* whether the output tensor has `dim` retained or not.
*/
def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] =
def max[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] =
val nativeTuple = torchNative.max(input.native, dim, keepdim)
TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1))

/** Returns the maximum value of all elements in the `input` tensor.
*
* @group reduction_ops
*/
def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min())
def min[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.min())

/** Returns a [[TensorTuple]] `(values, indices)` where `values` is the minimum value of each row of
* the `input` tensor in the given dimension `dim`. And `indices` is the index location of each
Expand All @@ -249,7 +257,7 @@ def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min())
* @param keepdim
* whether the output tensor has `dim` retained or not.
*/
def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] =
def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] =
val nativeTuple = torchNative.min(input.native, dim, keepdim)
TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1))

Expand All @@ -263,8 +271,11 @@ def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens
* the norm to be computed
*/
// TODO dtype promotion floatNN/complexNN => highest floatNN
// TODO (using AtLeastOneFloat[D, D2]
def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float = 2): Tensor[D] =
def dist[D <: NumericNN, D2 <: NumericNN](
input: Tensor[D],
other: Tensor[D2],
p: Float = 2
)(using AtLeastOneFloat[D, D2]): Tensor[D] =
Tensor(torchNative.dist(input.native, other.native, toScalar(p)))

/** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension
Expand All @@ -283,7 +294,7 @@ def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float
* @param keepdim
* whether the output tensor has `dim` retained or not.
*/
def logsumexp[D <: DType](
def logsumexp[D <: FloatNN | ComplexNN](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false
Expand All @@ -295,7 +306,7 @@ def logsumexp[D <: DType](
*
* @group reduction_ops
*/
def mean[D <: DType](
def mean[D <: FloatNN | ComplexNN](
input: Tensor[D]
): Tensor[D] = Tensor(torchNative.mean(input.native))

Expand All @@ -306,7 +317,7 @@ def mean[D <: DType](
* @param dtype
* $reduceops_dtype
*/
def mean[D <: DType](
def mean[D <: FloatNN | ComplexNN](
input: Tensor[?],
dtype: D
): Tensor[D] = Tensor(torchNative.mean(input.native, new ScalarTypeOptional(dtype.toScalarType)))
Expand Down Expand Up @@ -362,7 +373,7 @@ def mean[D <: DType, D2 <: DType | Derive](
* @param dtype
* $reduceops_dtype
*/
def nanmean[D <: DType, D2 <: DType | Derive](
def nanmean[D <: FloatNN, D2 <: DType | Derive](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down Expand Up @@ -394,7 +405,7 @@ def nanmean[D <: DType, D2 <: DType | Derive](
*
* @group reduction_ops
*/
def median[D <: DType](
def median[D <: NumericRealNN](
input: Tensor[D]
): Tensor[D] = Tensor(torchNative.median(input.native))

Expand Down Expand Up @@ -427,7 +438,7 @@ def median[D <: DType](
* @param dtype
* $reduceops_dtype
*/
def median[D <: DType, D2 <: DType | Derive](
def median[D <: NumericRealNN, D2 <: DType | Derive](
input: Tensor[D],
dim: Long = -1,
keepdim: Boolean = false
Expand All @@ -444,7 +455,7 @@ def median[D <: DType, D2 <: DType | Derive](
*
* @group reduction_ops
*/
def nanmedian[D <: DType](
def nanmedian[D <: NumericRealNN](
input: Tensor[D]
): Tensor[D] = Tensor(torchNative.nanmedian(input.native))

Expand All @@ -466,7 +477,7 @@ def nanmedian[D <: DType](
* @param dtype
* $reduceops_dtype
*/
def nanmedian[D <: DType, D2 <: DType | Derive](
def nanmedian[D <: NumericRealNN, D2 <: DType | Derive](
input: Tensor[D],
dim: Long = -1,
keepdim: Boolean = false
Expand All @@ -493,7 +504,7 @@ def nanmedian[D <: DType, D2 <: DType | Derive](
* @param keepdim
* whether the output tensor has `dim` retained or not.
*/
def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] =
def mode[D <: RealNN](input: Tensor[D], dim: Long = -1, keepdim: Boolean = false): TensorTuple[D] =
val nativeTuple = torchNative.mode(input.native, dim, keepdim)
TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1))

Expand All @@ -511,7 +522,7 @@ def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten
* @param dtype
* $reduceops_dtype
*/
def nansum[D <: DType, D2 <: DType | Derive](
def nansum[D <: RealNN, D2 <: DType | Derive](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down Expand Up @@ -658,7 +669,7 @@ def prod[D <: DType, D2 <: DType | Derive](
* difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s
* correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`.
*/
def std[D <: DType](
def std[D <: FloatNN | ComplexNN](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down Expand Up @@ -697,7 +708,7 @@ def std[D <: DType](
* @return
* A tuple (std, mean) containing the standard deviation and mean.
*/
def std_mean[D <: DType](
def std_mean[D <: FloatNN | ComplexNN](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down Expand Up @@ -801,7 +812,7 @@ def sum[D <: DType, D2 <: DType | Derive](
* difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s
* correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`.
*/
def variance[D <: DType](
def variance[D <: FloatNN | ComplexNN](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down Expand Up @@ -840,7 +851,7 @@ def variance[D <: DType](
* @return
* A tuple (var, mean) containing the variance and mean.
*/
def var_mean[D <: DType](
def var_mean[D <: FloatNN | ComplexNN](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down
Loading

0 comments on commit 9901dd0

Please sign in to comment.