Skip to content

Commit

Permalink
Fix dist promoted type and add missing FloatPromoted cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jun 20, 2023
1 parent 328d7ce commit fc94bcc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ type NumericPromoted[D <: DType] <: DType = D match

/** Promoted type for tensor operations that always output floats (e.g. `sin`) */
type FloatPromoted[D <: DType] <: FloatNN = D match
case Float64 => Float64
case _ => Float32
case Float16 => Float16
case BFloat16 => BFloat16
case Float64 => Float64
case _ => Float32

/** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */
type ComplexToReal[D <: DType] <: DType = D match
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,13 @@ def min[D <: RealNN](input: Tensor[D], dim: Long, keepdim: Boolean = false): Ten
* @param p
* the norm to be computed
*/
// TODO dtype promotion floatNN/complexNN => highest floatNN
def dist[D <: NumericNN, D2 <: NumericNN](
input: Tensor[D],
other: Tensor[D2],
p: Float = 2
)(using AtLeastOneFloat[D, D2]): Tensor[D] =
)(using
AtLeastOneFloat[D, D2]
): Tensor[Promoted[FloatPromoted[ComplexToReal[D]], FloatPromoted[ComplexToReal[D2]]]] =
Tensor(torchNative.dist(input.native, other.native, toScalar(p)))

/** Returns the log of summed exponentials of each row of the `input` tensor in the given dimension
Expand Down

0 comments on commit fc94bcc

Please sign in to comment.