Skip to content

Commit

Permalink
Avoid stack overflows with non-standard float types; closes JuliaMath#76
Browse files Browse the repository at this point in the history
  • Loading branch information
moble committed May 30, 2024
1 parent 68e9db0 commit 00b0efa
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 7 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
julia = "1.6"

[extras]
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["DoubleFloats", "Test"]
46 changes: 43 additions & 3 deletions src/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,67 @@ module NaNMath
using OpenLibm_jll
const libm = OpenLibm_jll.libopenlibm


for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:lgamma, :log1p)
@eval begin
($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
($f)(x::Real) = ($f)(float(x))
($f)(x::Float16) = Float16(($f)(Float32(x)))
function ($f)(x::Real)
xf = float(x)
x === xf && throw(MethodError($f, (x,)))
return ($f)(xf)
end
end
end
sin(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.sin(x) : T(NaN)
cos(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.cos(x) : T(NaN)
tan(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.tan(x) : T(NaN)
asin(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.asin(x)
acos(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.acos(x)
acosh(x::T) where {T<:AbstractFloat} = x < 1 ? T(NaN) : Base.acosh(x)
atanh(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.atanh(x)
log(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log(x)
log2(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log2(x)
log10(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log10(x)
# lgamma does not have a Base version; the MethodError above will suffice
log1p(x::T) where {T<:AbstractFloat} = x < -1 ? T(NaN) : Base.log1p(x)


# Would be more efficient to remove the domain check in Base.sqrt(),
# but this doesn't seem easy to do.
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
sqrt(x::Real) = sqrt(float(x))
function sqrt(x::Real)
xf = float(x)
x === xf && throw(MethodError(sqrt, (x,)))
return sqrt(xf)
end

# Don't override built-in ^ operator
pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x, y)
pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y)
pow(x::Float16, y::Float16) = Float16(pow(Float32(x), Float32(y)))
# We `promote` first before converting to floating pointing numbers to ensure that
# e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)`
pow(x::Number, y::Number) = pow(promote(x, y)...)
pow(x::T, y::T) where {T<:Number} = pow(float(x), float(y))
function pow(x::T, y::T) where {T<:Number}
xf = float(x)
yf = float(y)
x === xf && y === yf && throw(MethodError(pow, (x,y)))
return pow(xf, yf)
end
function pow(x::T, y::T) where {T<:AbstractFloat}
try
return x^y
catch e
if isa(e, DomainError)
return T(NaN)
else
rethrow(e)
end
end
end

"""
NaNMath.sum(A)
Expand Down
67 changes: 64 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,56 @@
using NaNMath
using Test
using DoubleFloats


# https://github.com/JuliaMath/NaNMath.jl/issues/76
@test_throws MethodError NaNMath.pow(1.0, 1.0+im)


for T in (Float64, Float32, Float16, BigFloat)
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:log1p) # Note: do :lgamma separately because it can't handle BigFloat
@eval begin
@test NaNMath.$f($T(2//3)) isa $T
@test NaNMath.$f($T(3//2)) isa $T
@test NaNMath.$f($T(-2//3)) isa $T
@test NaNMath.$f($T(-3//2)) isa $T
@test NaNMath.$f($T(Inf)) isa $T
@test NaNMath.$f($T(-Inf)) isa $T
end
end
end
for T in (Float64, Float32, Float16)
@test NaNMath.lgamma(T(2//3)) isa T
@test NaNMath.lgamma(T(3//2)) isa T
@test NaNMath.lgamma(T(-2//3)) isa T
@test NaNMath.lgamma(T(-3//2)) isa T
@test NaNMath.lgamma(T(Inf)) isa T
@test NaNMath.lgamma(T(-Inf)) isa T
end
@test_throws MethodError NaNMath.lgamma(BigFloat(2//3))

@test isnan(NaNMath.log(-10))
@test isnan(NaNMath.log(-10f0))
@test isnan(NaNMath.log(Float16(-10)))
@test isnan(NaNMath.log1p(-100))
@test isnan(NaNMath.log1p(-100f0))
@test isnan(NaNMath.log1p(Float16(-100)))
@test isnan(NaNMath.pow(-1.5,2.3))
@test isnan(NaNMath.pow(-1.5f0,2.3f0))
@test isnan(NaNMath.pow(-1.5,2.3f0))
@test isnan(NaNMath.pow(-1.5f0,2.3))
@test isnan(NaNMath.pow(Float16(-1.5),Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3))
@test isnan(NaNMath.pow(-1.5,Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Float16(2.3)))
@test isnan(NaNMath.pow(-1.5f0,BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),2.3f0))
@test NaNMath.pow(-1,2) isa Float64
@test NaNMath.pow(-1.5f0,2) isa Float32
@test NaNMath.pow(-1.5f0,2//1) isa Float32
Expand All @@ -16,11 +60,28 @@ using Test
@test NaNMath.pow(-1.5,2//1) isa Float64
@test NaNMath.pow(-1.5,2.3f0) isa Float64
@test NaNMath.pow(-1.5,2.3) isa Float64
@test isnan(NaNMath.sqrt(-5))
@test NaNMath.pow(Float16(-1.5),2.3) isa Float64
@test NaNMath.pow(Float16(-1.5),Float16(2.3)) isa Float16
@test NaNMath.pow(-1.5,Float16(2.3)) isa Float64
@test NaNMath.pow(Float16(-1.5),2.3f0) isa Float32
@test NaNMath.pow(-1.5f0,Float16(2.3)) isa Float32
@test NaNMath.pow(-1.5f0,BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),2.3f0) isa BigFloat
@test NaNMath.pow(-1.5f0,Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),2.3f0) isa Double64
@test NaNMath.sqrt(-5) isa Float64
@test NaNMath.sqrt(5) == Base.sqrt(5)
@test NaNMath.sqrt(-5f0) isa Float32
@test NaNMath.sqrt(5f0) == Base.sqrt(5f0)
@test NaNMath.sqrt(Float16(-5)) isa Float16
@test NaNMath.sqrt(Float16(5)) == Base.sqrt(Float16(5))
@test NaNMath.sqrt(BigFloat(-5)) isa BigFloat
@test NaNMath.sqrt(BigFloat(5)) == Base.sqrt(BigFloat(5))
@test isnan(NaNMath.sqrt(-3.2f0)) && NaNMath.sqrt(-3.2f0) isa Float32
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@inferred NaNMath.sqrt(5)
@inferred NaNMath.sqrt(5.0)
@inferred NaNMath.sqrt(5.0f0)
Expand Down

0 comments on commit 00b0efa

Please sign in to comment.