diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index f32e3773..18b33c6d 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -390,8 +390,10 @@ type NumericPromoted[D <: DType] <: DType = D match /** Promoted type for tensor operations that always output floats (e.g. `sin`) */ type FloatPromoted[D <: DType] <: FloatNN = D match - case Float64 => Float64 - case _ => Float32 + case Float16 => Float16 + case BFloat16 => BFloat16 + case Float64 => Float64 + case _ => Float32 /** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */ type ComplexToReal[D <: DType] <: DType = D match diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index fba5af41..476ce085 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -270,12 +270,13 @@ def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten * @param p * the norm to be computed */ -// TODO dtype promotion floatNN/complexNN => highest floatNN def dist[D <: NumericNN, D2 <: NumericNN]( input: Tensor[D], other: Tensor[D2], p: Float = 2 -)(using AtLeastOneFloat[D, D2]): Tensor[D] = +)(using + AtLeastOneFloat[D, D2] +): Tensor[Promoted[FloatPromoted[ComplexToReal[D]], FloatPromoted[ComplexToReal[D2]]]] = Tensor(torchNative.dist(input.native, other.native, toScalar(p))) /** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension