Skip to content

Commit

Permalink
Fix avg-pooling and add dtype tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jun 27, 2023
1 parent c36f0a2 commit aa6d1f5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
12 changes: 6 additions & 6 deletions core/src/main/scala/torch/nn/functional/Pooling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ private[torch] trait Pooling {
): Tensor[D] =
val options = AvgPool2dOptions(toNative(kernelSize))
stride match
case s: Int => options.stride().put(toNative(s))
case None =>
case s: (Int | (Int, Int)) => options.stride().put(toNative(s))
case None =>
options.padding().put(toNative(padding))
options.ceil_mode().put(ceilMode)
options.count_include_pad().put(countIncludePad)
Expand All @@ -110,7 +110,7 @@ private[torch] trait Pooling {
*
* @group nn_pooling
*/
def avgPool3d[D <: FloatNN | Complex32](
def avgPool3d[D <: Float16 | Float32 | Float64 | Complex32](
input: Tensor[D],
kernelSize: Int | (Int, Int, Int),
stride: Int | (Int, Int, Int) | None.type = None,
Expand All @@ -121,8 +121,8 @@ private[torch] trait Pooling {
): Tensor[D] =
val options = AvgPool3dOptions(toNative(kernelSize))
stride match
case s: Int => options.stride().put(toNative(s))
case None =>
case s: (Int | (Int, Int, Int)) => options.stride().put(toNative(s))
case None =>
options.padding().put(toNative(padding))
options.ceil_mode().put(ceilMode)
options.count_include_pad().put(countIncludePad)
Expand Down Expand Up @@ -300,7 +300,7 @@ private[torch] trait Pooling {
val options: MaxPool3dOptions = MaxPool3dOptions(toNative(kernelSize))
stride match
case s: (Int | (Int, Int, Int)) => options.stride().put(toNative(s))
case None => // options.stride().put(toNative(kernelSize))
case None =>
options.padding().put(toNative(padding))
options.dilation().put(toNative(dilation))
options.ceil_mode().put(ceilMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import functional as F
import org.scalacheck.Gen
import torch.Generators.allDTypes

class MaxPoolSuite extends TensorCheckSuite {
class PoolingSuite extends TensorCheckSuite {
test("MaxPool2d output shapes") {
val input = torch.randn(Seq(1, 3, 244, 244))
// pool of square window of size=3, stride=2
Expand All @@ -35,6 +35,7 @@ class MaxPoolSuite extends TensorCheckSuite {
}

val shape3d = Seq(16, 50, 32)
propertyTestUnaryOp(F.avgPool1d(_, 3), "avgPool1d", genRandTensor(shape3d))
propertyTestUnaryOp(F.maxPool1d(_, 3), "maxPool1d", genRandTensor(shape3d))
propertyTestUnaryOp(F.maxPool1dWithIndices(_, 3), "maxPool1dWithIndices", genRandTensor(shape3d))

Expand All @@ -43,19 +44,25 @@ class MaxPoolSuite extends TensorCheckSuite {
torch.rand(shape, dtype = dtype.asInstanceOf[D])
}

val shape4d = Seq(20, 16, 50, 32)
val shape4d = Seq(8, 16, 50, 32)
propertyTestUnaryOp(F.avgPool2d(_, 3), "avgPool2d", genRandTensor(shape4d))
propertyTestUnaryOp(F.maxPool2d(_, 3), "maxPool2d", genRandTensor(shape4d))
propertyTestUnaryOp(F.maxPool2dWithIndices(_, 3), "maxPool2dWithIndices", genRandTensor(shape4d))

val shape5d = Seq(20, 16, 50, 44, 31)
val shape5d = Seq(2, 16, 50, 44, 31)
propertyTestUnaryOp(
F.avgPool3d(_, (3, 2, 2), stride = (2, 1, 2)),
"avgPool3d",
genRandTensor(shape5d)
)
propertyTestUnaryOp(
F.maxPool3d(_, (3, 2, 2), stride = (2, 1, 2)),
"maxPool3d",
genRandTensor(shape4d)
genRandTensor(shape5d)
)
propertyTestUnaryOp(
F.maxPool3dWithIndices(_, (3, 2, 2), stride = (2, 1, 2)),
"maxPool3dWithIndices",
genRandTensor(shape4d)
genRandTensor(shape5d)
)
}

0 comments on commit aa6d1f5

Please sign in to comment.