diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 79734c34..bd765580 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -351,6 +351,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto /** Returns the tensor with elements logged. */ def log: Tensor[D] = Tensor(native.log()) + def long: Tensor[Int64] = to(dtype = int64) + def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor[Promoted[D, D2]](native.matmul(u.native)) diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala index 9e6fc279..086da9a0 100644 --- a/core/src/main/scala/torch/Types.scala +++ b/core/src/main/scala/torch/Types.scala @@ -47,4 +47,5 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN /* Evidence used in operations where at least one Float or Complex is required */ -type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN +type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) | + B <:< (FloatNN | ComplexNN) diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index daafd300..7d82a58d 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -85,7 +85,7 @@ private[torch] object NativeConverters: extension (x: ScalaType) def toScalar: pytorch.Scalar = x match - case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte) + case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte) case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item() case x: Byte => pytorch.Scalar(x) case x: Short => pytorch.Scalar(x) diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index 7bb4f97f..fba5af41 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -65,7 +65,7 @@ import org.bytedeco.pytorch.ScalarTypeOptional * whether the output tensor has `dim` retained or not. * @return */ -def argmax[D <: DType]( +def argmax[D <: IntNN | FloatNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -94,7 +94,7 @@ def argmax[D <: DType]( * whether the output tensor has `dim` retained or not. * @return */ -def argmin[D <: DType]( +def argmin[D <: IntNN | FloatNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -114,7 +114,11 @@ def argmin[D <: DType]( * whether the output tensor has `dim` retained or not. * @return */ -def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = +def amax[D <: RealNN]( + input: Tensor[D], + dim: Long | Seq[Long], + keepdim: Boolean = false +): Tensor[D] = Tensor( torchNative.amax(input.native, dim.toArray, keepdim) ) @@ -131,7 +135,11 @@ def amax[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = * whether the output tensor has `dim` retained or not. * @return */ -def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = false): Tensor[D] = +def amin[D <: RealNN]( + input: Tensor[D], + dim: Long | Seq[Long], + keepdim: Boolean = false +): Tensor[D] = Tensor( torchNative.amin(input.native, dim.toArray, keepdim) ) @@ -147,7 +155,7 @@ def amin[D <: DType](input: Tensor[D], dim: Long | Seq[Long], keepdim: Boolean = * whether the output tensor has `dim` retained or not. * @return */ -def aminmax[D <: DType]( +def aminmax[D <: RealNN]( input: Tensor[D], dim: Long | Option[Long] = None, keepdim: Boolean = false @@ -203,7 +211,7 @@ def any[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * * @group reduction_ops */ -def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max()) +def max[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.max()) /** Returns a [[TensorTuple]] `(values, indices)` where `values` is the maximum value of each row of * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each @@ -222,7 +230,7 @@ def max(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.max()) * @param keepdim * whether the output tensor has `dim` retained or not. */ -def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def max[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.max(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -230,7 +238,7 @@ def max[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * * @group reduction_ops */ -def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min()) +def min[D <: RealNN](input: Tensor[D]): Tensor[Int64] = Tensor(input.native.min()) /** Returns a [[TensorTuple]] `(values, indices)` where `values` is the minimum value of each row of * the `input` tensor in the given dimension `dim`. And `indices` is the index location of each @@ -249,7 +257,7 @@ def min(input: Tensor[?]): Tensor[Int64] = Tensor(input.native.min()) * @param keepdim * whether the output tensor has `dim` retained or not. */ -def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.min(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -263,8 +271,11 @@ def min[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Tens * the norm to be computed */ // TODO dtype promotion floatNN/complexNN => highest floatNN -// TODO (using AtLeastOneFloat[D, D2] -def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float = 2): Tensor[D] = +def dist[D <: NumericNN, D2 <: NumericNN]( + input: Tensor[D], + other: Tensor[D2], + p: Float = 2 +)(using AtLeastOneFloat[D, D2]): Tensor[D] = Tensor(torchNative.dist(input.native, other.native, toScalar(p))) /** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension @@ -283,7 +294,7 @@ def dist[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2], p: Float * @param keepdim * whether the output tensor has `dim` retained or not. */ -def logsumexp[D <: DType]( +def logsumexp[D <: RealNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false @@ -295,7 +306,7 @@ def logsumexp[D <: DType]( * * @group reduction_ops */ -def mean[D <: DType]( +def mean[D <: FloatNN | ComplexNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.mean(input.native)) @@ -306,7 +317,7 @@ def mean[D <: DType]( * @param dtype * $reduceops_dtype */ -def mean[D <: DType]( +def mean[D <: FloatNN | ComplexNN]( input: Tensor[?], dtype: D ): Tensor[D] = Tensor(torchNative.mean(input.native, new ScalarTypeOptional(dtype.toScalarType))) @@ -362,7 +373,7 @@ def mean[D <: DType, D2 <: DType | Derive]( * @param dtype * $reduceops_dtype */ -def nanmean[D <: DType, D2 <: DType | Derive]( +def nanmean[D <: FloatNN, D2 <: DType | Derive]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -394,7 +405,7 @@ def nanmean[D <: DType, D2 <: DType | Derive]( * * @group reduction_ops */ -def median[D <: DType]( +def median[D <: NumericRealNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.median(input.native)) @@ -427,7 +438,7 @@ def median[D <: DType]( * @param dtype * $reduceops_dtype */ -def median[D <: DType, D2 <: DType | Derive]( +def median[D <: NumericRealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Long = -1, keepdim: Boolean = false @@ -444,7 +455,7 @@ def median[D <: DType, D2 <: DType | Derive]( * * @group reduction_ops */ -def nanmedian[D <: DType]( +def nanmedian[D <: NumericRealNN]( input: Tensor[D] ): Tensor[D] = Tensor(torchNative.nanmedian(input.native)) @@ -466,7 +477,7 @@ def nanmedian[D <: DType]( * @param dtype * $reduceops_dtype */ -def nanmedian[D <: DType, D2 <: DType | Derive]( +def nanmedian[D <: NumericRealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Long = -1, keepdim: Boolean = false @@ -493,7 +504,7 @@ def nanmedian[D <: DType, D2 <: DType | Derive]( * @param keepdim * whether the output tensor has `dim` retained or not. */ -def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): TensorTuple[D] = +def mode[D <: RealNN](input: Tensor[D], dim: Long = -1, keepdim: Boolean = false): TensorTuple[D] = val nativeTuple = torchNative.mode(input.native, dim, keepdim) TensorTuple(values = Tensor[D](nativeTuple.get0), indices = new Int64Tensor(nativeTuple.get1)) @@ -511,7 +522,7 @@ def mode[D <: DType](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten * @param dtype * $reduceops_dtype */ -def nansum[D <: DType, D2 <: DType | Derive]( +def nansum[D <: RealNN, D2 <: DType | Derive]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -658,7 +669,7 @@ def prod[D <: DType, D2 <: DType | Derive]( * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. */ -def std[D <: DType]( +def std[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -697,7 +708,7 @@ def std[D <: DType]( * @return * A tuple (std, mean) containing the standard deviation and mean. */ -def std_mean[D <: DType]( +def std_mean[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -801,7 +812,7 @@ def sum[D <: DType, D2 <: DType | Derive]( * difference between the sample size and sample degrees of freedom. Defaults to [Bessel\'s * correction](https://en.wikipedia.org/wiki/Bessel%27s_correction), `correction=1`. */ -def variance[D <: DType]( +def variance[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, @@ -840,7 +851,7 @@ def variance[D <: DType]( * @return * A tuple (var, mean) containing the variance and mean. */ -def var_mean[D <: DType]( +def var_mean[D <: FloatNN | ComplexNN]( input: Tensor[D], dim: Int | Seq[Int] = Seq.empty, keepdim: Boolean = false, diff --git a/core/src/test/scala/torch/Generators.scala b/core/src/test/scala/torch/Generators.scala index 170c240b..d4f55638 100644 --- a/core/src/test/scala/torch/Generators.scala +++ b/core/src/test/scala/torch/Generators.scala @@ -56,7 +56,7 @@ object Generators: inline def genTensor[D <: DType]: Gen[Tensor[D]] = Gen.oneOf(allDTypes.filter(_.isInstanceOf[D])).map { dtype => - ones(10, dtype = dtype.asInstanceOf[D]) + ones(Seq(4, 4), dtype = dtype.asInstanceOf[D]) } val genDType = Gen.oneOf(allDTypes) diff --git a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala index dbda5d48..1bc1c720 100644 --- a/core/src/test/scala/torch/ops/ReductionOpsSuite.scala +++ b/core/src/test/scala/torch/ops/ReductionOpsSuite.scala @@ -18,11 +18,232 @@ package torch class ReductionOpsSuite extends TensorCheckSuite { + // TODO test with dim/keepdim variants for corresponding ops + + testUnaryOp( + op = argmax(_), + opName = "argmax", + inputTensor = + Tensor(Seq(1.0, 0.5, 1.2, -2)), // TODO check why we can't call tensor ops on inputTensor here + expectedTensor = Tensor(2L) + ) + + testUnaryOp( + op = argmin(_), + opName = "argmin", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(3L) + ) + + propertyTestUnaryOp(amax(_, 0), "amax") + // TODO unit test amax + + propertyTestUnaryOp(amin(_, 0), "amin") + // TODO unit test amin + + propertyTestUnaryOp(aminmax(_), "aminmax") + // TODO unit test aminmax + + testUnaryOp( + op = all, + opName = "all", + inputTensor = Tensor(Seq(true, true, false, true)), + expectedTensor = Tensor(false) + ) + + test("all") { + assertEquals( + all(Tensor(Seq(true, true, false, true))), + Tensor(false) + ) + assertEquals( + all(Tensor(Seq(true, true, true, true))), + Tensor(true) + ) + } + + testUnaryOp( + op = any, + opName = "any", + inputTensor = Tensor(Seq(true, true, false, true)), + expectedTensor = Tensor(true) + ) + + test("any") { + assertEquals( + any(Tensor(Seq(true, true, false, true))), + Tensor(true) + ) + assertEquals( + any(Tensor(Seq(false, false, false, false))), + Tensor(false) + ) + } + + testUnaryOp( + op = max, + opName = "max", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(1.2) + ) + + testUnaryOp( + op = min, + opName = "min", + inputTensor = Tensor(Seq(1.0, 0.5, 1.2, -2)), + expectedTensor = Tensor(-2.0) + ) + + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloatOrComplex + // propertyTestBinaryOp(dist35, "dist") + + def unitTestDist(p: Float, expected: Float) = unitTestBinaryOp[Float32, Float, Float32, Float]( + dist(_, _, p), + "dist", + inputTensors = ( + Tensor(Seq[Float](-1.5393, -0.8675, 0.5916, 1.6321)), + Tensor(Seq[Float](0.0967, -1.0511, 0.6295, 0.8360)) + ), + expectedTensor = Tensor(expected) + ) + + unitTestDist(3.5, 1.6727) + unitTestDist(3, 1.6973) + unitTestDist(0, 4) + unitTestDist(1, 2.6537) + + propertyTestUnaryOp(logsumexp(_, dim = 0), "logsumexp") + // TODO unit test logsumexp + + testUnaryOp( + op = mean, + opName = "mean", + inputTensor = Tensor(Seq(0.2294, -0.5481, 1.3288)), + expectedTensor = Tensor(0.3367) + ) + + test("mean with nan") { + assert(mean(Tensor(Seq(Float.NaN, 1, 2))).isnan.item) + } + testUnaryOp( - op = sum(_), + op = nanmean(_), + opName = "nanmean", + inputTensor = Tensor(Seq(Float.NaN, 1, 2, 1, 2, 3)), + expectedTensor = Tensor(1.8f) + ) + + test("nanmean with nan") { + val t = Tensor(Seq(Float.NaN, 1, 2)) + assert(!nanmean(t).isnan.item) + } + + testUnaryOp( + op = median, + opName = "median", + inputTensor = Tensor(Seq(1, 5, 2, 3, 4)), + expectedTensor = Tensor(3) + ) + + testUnaryOp( + op = nanmedian, + opName = "nanmedian", + inputTensor = Tensor(Seq(1, 5, Float.NaN, 3, 4)), + expectedTensor = Tensor(3f) + ) + + propertyTestUnaryOp(mode(_), "mode") + + test("mode") { + torch.manualSeed(0) + assertEquals( + torch.mode(Tensor[Int](Seq(6, 5, 1, 0, 2)), 0), + TensorTuple(Tensor(0), Tensor[Long](3L)) + ) + } + + // test("mode") { + // torch.manualSeed(0) + // val a = Tensor(Seq(6, 5, 1, 0, 2)) + // val b = a + Tensor(Seq(-3 , -11, -6, -7, 4)).reshape(5,1) + // val values = Tensor(Seq(-5, -5, -5, -4, -8)) + // val indices = Tensor(Seq(1, 1, 1, 1, 1)).long + // assertEquals( + // torch.mode(b, 0), + // TensorTuple(values, indices) + // ) + // } + + testUnaryOp( + op = nansum(_), + opName = "nansum", + inputTensor = Tensor(Seq(1, 5, Float.NaN, 3, 4)), + expectedTensor = Tensor(13f) + ) + + testUnaryOp( + op = prod, + opName = "prod", + inputTensor = Tensor(Seq(5.0, 5.0)), + expectedTensor = Tensor(25.0) + ) + + // TODO quantile + // TODO nanquantile + + propertyTestUnaryOp(std(_), "std") + + unitTestUnaryOp[Float32, Float]( + op = std(_, dim = 1, keepdim = true), + opName = "std", + inputTensor = Tensor( + Seq( + Seq[Float](0.2035, 1.2959, 1.8101, -0.4644), + Seq[Float](1.5027, -0.3270, 0.5905, 0.6538), + Seq[Float](-1.5745, 1.3330, -0.5596, -0.6548), + Seq[Float](0.1264, -0.5080, 1.6420, 0.1992) + ).flatten + ).reshape(4, 4), + expectedTensor = Tensor(Seq[Float](1.0311, 0.7477, 1.2204, 0.9087)).reshape(4, 1) + ) + + propertyTestUnaryOp(std_mean(_), "std_mean") + // TODO unit test std_mean + + testUnaryOp( + op = sum, opName = "sum", inputTensor = Tensor(Seq(5.0, 5.0)), expectedTensor = Tensor(10.0) ) + // TODO unique + // TODO unique_consecutive + + propertyTestUnaryOp(variance(_), "variance") + + unitTestUnaryOp[Float32, Float]( + op = variance(_, dim = 1, keepdim = true), + opName = "variance", + inputTensor = Tensor( + Seq( + Seq[Float](0.2035, 1.2959, 1.8101, -0.4644), + Seq[Float](1.5027, -0.3270, 0.5905, 0.6538), + Seq[Float](-1.5745, 1.3330, -0.5596, -0.6548), + Seq[Float](0.1264, -0.5080, 1.6420, 0.1992) + ).flatten + ).reshape(4, 4), + expectedTensor = Tensor(Seq[Float](1.0631, 0.5590, 1.4893, 0.8258)).reshape(4, 1) + ) + + propertyTestUnaryOp(var_mean(_), "var_mean") + // TODO unit test var_mean + + testUnaryOp( + op = count_nonzero(_), + opName = "count_nonzero", + inputTensor = Tensor(Seq(1, 0, 0, 1, 0)), + expectedTensor = Tensor(2L) + ) + }