Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add erfinv and erfcinv for Float16 and generalize logerfc and logerfcx #372

Merged
merged 14 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SpecialFunctions"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.3.1"
version = "2.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
90 changes: 66 additions & 24 deletions src/erf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,10 @@

function _erfinv(x::Float64)
a = abs(x)
if a >= 1.0
if x == 1.0
return Inf
elseif x == -1.0
return -Inf
end
if a > 1.0
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
elseif a == 1.0
return copysign(Inf, x)
elseif a <= 0.75 # Table 17 in Blair et al.
t = x*x - 0.5625
return x * @horner(t, 0.16030_49558_44066_229311e2,
Expand Down Expand Up @@ -321,13 +318,10 @@

function _erfinv(x::Float32)
a = abs(x)
if a >= 1.0f0
if x == 1.0f0
return Inf32
elseif x == -1.0f0
return -Inf32
end
if a > 1f0
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))
elseif a == 1f0
return copysign(Inf32, x)
elseif a <= 0.75f0 # Table 10 in Blair et al.
t = x*x - 0.5625f0
return x * @horner(t, -0.13095_99674_22f2,
Expand Down Expand Up @@ -362,6 +356,43 @@
end
end

function _erfinv(x::Float16)
a = abs(x)
if a > Float16(1)
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))

Check warning on line 362 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L362

Added line #L362 was not covered by tests
elseif a == Float16(1)
return copysign(Inf16, x)
else
# Perform calculations with `Float32`
x32 = Float32(x)
a32 = Float32(a)
if a32 <= 0.75f0 # Table 7 in Blair et al.
t = x32^2 - 0.5625f0
y = x32 * @horner(t, -0.10976_672f1,
0.53062_1f0) /
@horner(t, -0.10123_953f1,
0.1f1)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
elseif a32 <= 0.9375f0 # Table 26 in Blair et al.
t = x32^2 - 0.87890625f0
y = x32 * @horner(t, 0.10178_950f1,
-0.32827_601f1) /
@horner(t, 0.72455_99f0,
-0.33871_553f1,
0.1f1)
else
# Simpler alternative to Table 47 in Blair et al.
# because of the reduced accuracy requirement
# (it turns out that this branch only covers 128 values).
# Note that the use of log(1-x) rather than log1p is intentional since it will be
# slightly faster and 1-x is exact.
# Ref: https://github.com/JuliaMath/SpecialFunctions.jl/pull/372#discussion_r1592710586
t = sqrt(-log(1-a32))
y = copysign(@horner(t, -0.429159f0, 1.04868f0), x32)

Check warning on line 390 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L389-L390

Added lines #L389 - L390 were not covered by tests
end
return Float16(y)
end
end

function _erfinv(y::BigFloat)
xfloat = erfinv(Float64(y))
if isfinite(xfloat)
Expand Down Expand Up @@ -482,6 +513,25 @@
end
end

function _erfcinv(y::Float16)
if y > Float16(0.0625)
return erfinv(Float16(1) - y)
elseif y <= Float16(0)
if y == Float16(0)
return Inf16

Check warning on line 521 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L519-L521

Added lines #L519 - L521 were not covered by tests
end
throw(DomainError(y, "`y` must be nonnegative."))

Check warning on line 523 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L523

Added line #L523 was not covered by tests
else # Table 47 in Blair et al.
t = 1.0f0 / sqrt(-log(Float32(y)))
x = @horner(t, 0.98650_088f0,

Check warning on line 526 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L525-L526

Added lines #L525 - L526 were not covered by tests
0.92601_777f0) /
(t * @horner(t, 0.98424_719f0,
0.10074_7432f0,
0.1f0))
return Float16(x)

Check warning on line 531 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L531

Added line #L531 was not covered by tests
end
end

function _erfcinv(y::BigFloat)
yfloat = Float64(y)
xfloat = erfcinv(yfloat)
Expand Down Expand Up @@ -526,13 +576,9 @@

# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfc(x::Real) = _logerfc(float(x))

function _logerfc(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x > 0.0
function logerfc(x::Real)
if x > zero(x)
return log(erfcx(x)) - x^2
else
return log(erfc(x))
Expand All @@ -557,13 +603,9 @@

# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfcx(x::Real) = _logerfcx(float(x))

function _logerfcx(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x < 0.0
function logerfcx(x::Real)
if x < zero(x)
return log(erfc(x)) + x^2
else
return log(erfcx(x))
Expand Down
93 changes: 38 additions & 55 deletions test/erf.jl
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
@testset "error functions" begin
@testset "real argument" begin
@test erf(Float16(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float16)
@test erf(Float32(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float32)
@test erf(Float64(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float64)

@test erfc(Float16(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float16)
@test erfc(Float32(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float32)
@test erfc(Float64(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float64)

@test erfcx(Float16(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float16)
@test erfcx(Float32(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float32)
@test erfcx(Float64(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float64)

@test_throws MethodError logerfc(Float16(1))
@test_throws MethodError logerfc(Float16(-1))
@test logerfc(Float32(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float32)
@test logerfc(Float64(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float64)
@test logerfc(Float32(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float64(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float64)
@test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float32(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float32)
@test logerfc(Float64(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float64)

@test_throws MethodError logerfcx(Float16(1))
@test_throws MethodError logerfcx(Float16(-1))
@test iszero(logerfcx(0))
@test logerfcx(Float32(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float64(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float32(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float64(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float32(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float32)
@test logerfcx(Float64(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float64)
@test logerfcx(Float32(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float32)
@test logerfcx(Float64(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float64)
@test logerfcx(Float32(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float32)
@test logerfcx(Float64(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float64)
@test logerfcx(Float32(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float32)
@test logerfcx(Float64(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float64)

@test erfi(Float16(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float16)
@test erfi(Float32(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float32)
@test erfi(Float64(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float64)
for T in (Float16, Float32, Float64)
@test @inferred(erf(T(1))) isa T
@test erf(T(1)) ≈ T(0.84270079294971486934) rtol=2*eps(T)

@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test_throws MethodError erfinv(Float16(1))
@test erfinv(Float32(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float32)
@test erfinv(Float64(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float64)
@test @inferred(erfc(T(1))) isa T
@test erfc(T(1)) ≈ T(0.15729920705028513066) rtol=2*eps(T)

@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test_throws MethodError erfcinv(Float16(1))
@test erfcinv(Float32(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float32)
@test erfcinv(Float64(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float64)
@test @inferred(erfcx(T(1))) isa T
@test erfcx(T(1)) ≈ T(0.42758357615580700442) rtol=2*eps(T)

@test @inferred(logerfc(T(1))) isa T
@test logerfc(T(-100)) ≈ T(0.6931471805599453) rtol=2*eps(T)
@test logerfc(T(1000)) ≈ T(-1.0000074801207219e6) rtol=2*eps(T)
@test logerfc(T(10000)) ≈ T(log(erfc(BigFloat(10000, precision=100)))) rtol=2*eps(T)

@test @inferred(logerfcx(T(1))) isa T
@test logerfcx(T(1)) ≈ T(-0.849605509933248248576017509499) rtol=2eps(T)
@test logerfcx(T(-1)) ≈ T(1.61123231767807049464268192445) rtol=2eps(T)
@test logerfcx(T(-100)) ≈ T(10000.6931471805599453094172321) rtol=2eps(T)
@test logerfcx(T(100)) ≈ T(-5.17758512266433257046678208395) rtol=2eps(T)
@test logerfcx(T(-1000)) ≈ T(1.00000069314718055994530941723e6) rtol=2eps(T)
@test logerfcx(T(1000)) ≈ T(-7.48012072190621214066734919080) rtol=2eps(T)

@test @inferred(erfi(T(1))) isa T
@test erfi(T(1)) ≈ T(1.6504257587975428760) rtol=2*eps(T)

@test @inferred(erfinv(T(1))) isa T
@test erfinv(T(0.84270079294971486934)) ≈ 1 rtol=2*eps(T)

@test dawson(Float16(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float16)
@test dawson(Float32(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float32)
@test dawson(Float64(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float64)
@test @inferred(erfcinv(T(1))) isa T
@test erfcinv(T(0.15729920705028513066)) ≈ 1 rtol=2*eps(T)

@test @inferred(dawson(T(1))) isa T
@test dawson(T(1)) ≈ T(0.53807950691276841914) rtol=2*eps(T)

@test @inferred(faddeeva(T(1))) isa Complex{T}
@test faddeeva(T(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(T)
end

@test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test faddeeva(0) == faddeeva(0//1) == 1
@test faddeeva(Float16(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float16)
@test faddeeva(Float32(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float32)
@test faddeeva(Float64(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float64)
end

@testset "complex arguments" begin
Expand Down
Loading