Skip to content

Commit

Permalink
Remove redundant overloads of torch.mean to avoid compiler issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Aug 19, 2023
1 parent 4a1a138 commit 2e770cd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 23 deletions.
22 changes: 1 addition & 21 deletions core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,26 +302,6 @@ private[torch] trait ReductionOps {
torchNative.logsumexp(input.native, dim.toArray, keepdim)
)

/** Returns the mean value of all elements in the `input` tensor.
*
* @group reduction_ops
*/
def mean[D <: FloatNN | ComplexNN](
input: Tensor[D]
): Tensor[D] = Tensor(torchNative.mean(input.native))

/** Returns the mean value of all elements in the `input` tensor.
*
* @group reduction_ops
*
* @param dtype
* ${reduceops_dtype}
*/
def mean[D <: FloatNN | ComplexNN](
input: Tensor[?],
dtype: D
): Tensor[D] = Tensor(torchNative.mean(input.native, new ScalarTypeOptional(dtype.toScalarType)))

/** Returns the mean value of each row of the `input` tensor in the given dimension `dim`. If
* `dim` is a list of dimensions, reduce over all of them.
*
Expand All @@ -336,7 +316,7 @@ private[torch] trait ReductionOps {
* @param dtype
* ${reduceops_dtype}
*/
def mean[D <: DType, D2 <: DType | Derive](
def mean[D <: FloatNN | ComplexNN, D2 <: FloatNN | ComplexNN | Derive](
input: Tensor[D],
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/torch/ops/ReductionOpsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ class ReductionOpsSuite extends TensorCheckSuite {
// TODO unit test logsumexp

testUnaryOp(
op = mean,
op = mean(_),
opName = "mean",
inputTensor = Tensor(Seq(0.2294, -0.5481, 1.3288)),
expectedTensor = Tensor(0.3367)
)

test("mean with nan") {
assert(mean(Tensor(Seq(Float.NaN, 1, 2))).isnan.item)
assert(mean(Tensor[Float](Seq(Float.NaN, 1, 2))).isnan.item)
}

testUnaryOp(
Expand Down

0 comments on commit 2e770cd

Please sign in to comment.