From dff148b011df81b37417359c667b470b938e6ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Tue, 27 Jun 2023 09:03:53 +0200 Subject: [PATCH] Fix avg-pooling and add dtype tests for it --- .../scala/torch/nn/functional/Pooling.scala | 12 ++++++------ .../{MaxPoolSuite.scala => PoolingSuite.scala} | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 11 deletions(-) rename core/src/test/scala/torch/nn/{MaxPoolSuite.scala => PoolingSuite.scala} (81%) diff --git a/core/src/main/scala/torch/nn/functional/Pooling.scala b/core/src/main/scala/torch/nn/functional/Pooling.scala index d38441f0..6d362449 100644 --- a/core/src/main/scala/torch/nn/functional/Pooling.scala +++ b/core/src/main/scala/torch/nn/functional/Pooling.scala @@ -96,8 +96,8 @@ private[torch] trait Pooling { ): Tensor[D] = val options = AvgPool2dOptions(toNative(kernelSize)) stride match - case s: Int => options.stride().put(toNative(s)) - case None => + case s: (Int | (Int, Int)) => options.stride().put(toNative(s)) + case None => options.padding().put(toNative(padding)) options.ceil_mode().put(ceilMode) options.count_include_pad().put(countIncludePad) @@ -110,7 +110,7 @@ private[torch] trait Pooling { * * @group nn_pooling */ - def avgPool3d[D <: FloatNN | Complex32]( + def avgPool3d[D <: Float16 | Float32 | Float64 | Complex32]( input: Tensor[D], kernelSize: Int | (Int, Int, Int), stride: Int | (Int, Int, Int) | None.type = None, @@ -121,8 +121,8 @@ private[torch] trait Pooling { ): Tensor[D] = val options = AvgPool3dOptions(toNative(kernelSize)) stride match - case s: Int => options.stride().put(toNative(s)) - case None => + case s: (Int | (Int, Int, Int)) => options.stride().put(toNative(s)) + case None => options.padding().put(toNative(padding)) options.ceil_mode().put(ceilMode) options.count_include_pad().put(countIncludePad) @@ -300,7 +300,7 @@ private[torch] trait Pooling { val options: MaxPool3dOptions = MaxPool3dOptions(toNative(kernelSize)) stride match case s: (Int | (Int, Int, Int)) => options.stride().put(toNative(s)) - case None => // options.stride().put(toNative(kernelSize)) + case None => options.padding().put(toNative(padding)) options.dilation().put(toNative(dilation)) options.ceil_mode().put(ceilMode) diff --git a/core/src/test/scala/torch/nn/MaxPoolSuite.scala b/core/src/test/scala/torch/nn/PoolingSuite.scala similarity index 81% rename from core/src/test/scala/torch/nn/MaxPoolSuite.scala rename to core/src/test/scala/torch/nn/PoolingSuite.scala index 831ee126..a15ddc6c 100644 --- a/core/src/test/scala/torch/nn/MaxPoolSuite.scala +++ b/core/src/test/scala/torch/nn/PoolingSuite.scala @@ -21,7 +21,7 @@ import functional as F import org.scalacheck.Gen import torch.Generators.allDTypes -class MaxPoolSuite extends TensorCheckSuite { +class PoolingSuite extends TensorCheckSuite { test("MaxPool2d output shapes") { val input = torch.randn(Seq(1, 3, 244, 244)) // pool of square window of size=3, stride=2 @@ -35,6 +35,7 @@ class MaxPoolSuite extends TensorCheckSuite { } val shape3d = Seq(16, 50, 32) + propertyTestUnaryOp(F.avgPool1d(_, 3), "avgPool1d", genRandTensor(shape3d)) propertyTestUnaryOp(F.maxPool1d(_, 3), "maxPool1d", genRandTensor(shape3d)) propertyTestUnaryOp(F.maxPool1dWithIndices(_, 3), "maxPool1dWithIndices", genRandTensor(shape3d)) @@ -43,19 +44,25 @@ class MaxPoolSuite extends TensorCheckSuite { torch.rand(shape, dtype = dtype.asInstanceOf[D]) } - val shape4d = Seq(20, 16, 50, 32) + val shape4d = Seq(8, 16, 50, 32) + propertyTestUnaryOp(F.avgPool2d(_, 3), "avgPool2d", genRandTensor(shape4d)) propertyTestUnaryOp(F.maxPool2d(_, 3), "maxPool2d", genRandTensor(shape4d)) propertyTestUnaryOp(F.maxPool2dWithIndices(_, 3), "maxPool2dWithIndices", genRandTensor(shape4d)) - val shape5d = Seq(20, 16, 50, 44, 31) + val shape5d = Seq(2, 16, 50, 44, 31) + propertyTestUnaryOp( + F.avgPool3d(_, (3, 2, 2), stride = (2, 1, 2)), + "avgPool3d", + genRandTensor(shape5d) + ) propertyTestUnaryOp( F.maxPool3d(_, (3, 2, 2), stride = (2, 1, 2)), "maxPool3d", - genRandTensor(shape4d) + genRandTensor(shape5d) ) propertyTestUnaryOp( F.maxPool3dWithIndices(_, (3, 2, 2), stride = (2, 1, 2)), "maxPool3dWithIndices", - genRandTensor(shape4d) + genRandTensor(shape5d) ) }