Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add erfinv and erfcinv for Float16 and generalize logerfc and logerfcx #372

Merged
merged 14 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SpecialFunctions"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.3.0"
version = "2.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
75 changes: 63 additions & 12 deletions src/erf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,46 @@
end
end

function _erfinv(x::Float16)
a = abs(x)
if a >= Float16(1)
if x == Float16(1)
return Inf16
elseif x == -Float16(1)
return -Inf16

Check warning on line 371 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L370-L371

Added lines #L370 - L371 were not covered by tests
end
devmotion marked this conversation as resolved.
Show resolved Hide resolved
throw(DomainError(a, "`abs(x)` cannot be greater than 1."))

Check warning on line 373 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L373

Added line #L373 was not covered by tests
end

# Perform calculations with `Float32`
x32 = Float32(x)
a32 = Float32(a)
if a32 <= 0.75f0 # Table 7 in Blair et al.
t = x32^2 - 0.5625f0
y = x32 * @horner(t, -0.10976_672f1,
0.53062_1f0) /
@horner(t, -0.10123_953f1,
0.1f1)
elseif a32 <= 0.9375f0 # Table 26 in Blair et al.
t = x32^2 - 0.87890625f0
y = x32 * @horner(t, 0.10178_950f1,
-0.32827_601f1) /
@horner(t, 0.72455_99f0,
-0.33871_553f1,
0.1f1)
else # Table 47 in Blair et al.
t = inv(sqrt(-log1p(-a32)))
y = @horner(t, 0.98650_088f0,

Check warning on line 394 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L393-L394

Added lines #L393 - L394 were not covered by tests
0.92601_777f0) /
(copysign(t, x32) *
@horner(t, 0.98424_719f0,
0.10074_7432f0,
0.1f0))
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

return Float16(y)
end

function _erfinv(y::BigFloat)
xfloat = erfinv(Float64(y))
if isfinite(xfloat)
Expand Down Expand Up @@ -482,6 +522,25 @@
end
end

function _erfcinv(y::Float16)
if y > Float16(0.0625)
return erfinv(Float16(1) - y)
elseif y <= Float16(0)
if y == Float16(0)
return Inf16

Check warning on line 530 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L528-L530

Added lines #L528 - L530 were not covered by tests
end
throw(DomainError(y, "`y` must be nonnegative."))

Check warning on line 532 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L532

Added line #L532 was not covered by tests
else # Table 47 in Blair et al.
t = 1.0f0 / sqrt(-log(Float32(y)))
x = @horner(t, 0.98650_088f0,

Check warning on line 535 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L534-L535

Added lines #L534 - L535 were not covered by tests
0.92601_777f0) /
(t * @horner(t, 0.98424_719f0,
0.10074_7432f0,
0.1f0))
return Float16(x)

Check warning on line 540 in src/erf.jl

View check run for this annotation

Codecov / codecov/patch

src/erf.jl#L540

Added line #L540 was not covered by tests
end
end

function _erfcinv(y::BigFloat)
yfloat = Float64(y)
xfloat = erfcinv(yfloat)
Expand Down Expand Up @@ -526,13 +585,9 @@

# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfc(x::Real) = _logerfc(float(x))

function _logerfc(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x > 0.0
function logerfc(x::Real)
if x > zero(x)
return log(erfcx(x)) - x^2
else
return log(erfc(x))
Expand All @@ -557,13 +612,9 @@

# Implementation
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
logerfcx(x::Real) = _logerfcx(float(x))

function _logerfcx(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x < 0.0
function logerfcx(x::Real)
if x < zero(x)
return log(erfc(x)) + x^2
else
return log(erfcx(x))
Expand Down
93 changes: 38 additions & 55 deletions test/erf.jl
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
@testset "error functions" begin
@testset "real argument" begin
@test erf(Float16(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float16)
@test erf(Float32(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float32)
@test erf(Float64(1)) ≈ 0.84270079294971486934 rtol=2*eps(Float64)

@test erfc(Float16(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float16)
@test erfc(Float32(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float32)
@test erfc(Float64(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float64)

@test erfcx(Float16(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float16)
@test erfcx(Float32(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float32)
@test erfcx(Float64(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float64)

@test_throws MethodError logerfc(Float16(1))
@test_throws MethodError logerfc(Float16(-1))
@test logerfc(Float32(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float32)
@test logerfc(Float64(-100)) ≈ 0.6931471805599453 rtol=2*eps(Float64)
@test logerfc(Float32(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float64(1000)) ≈ -1.0000074801207219e6 rtol=2*eps(Float64)
@test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test logerfc(Float32(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float32)
@test logerfc(Float64(10000)) ≈ log(erfc(BigFloat(10000, precision=100))) rtol=2*eps(Float64)

@test_throws MethodError logerfcx(Float16(1))
@test_throws MethodError logerfcx(Float16(-1))
@test iszero(logerfcx(0))
@test logerfcx(Float32(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float64(1)) ≈ -0.849605509933248248576017509499 rtol=2eps(Float32)
@test logerfcx(Float32(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float64(-1)) ≈ 1.61123231767807049464268192445 rtol=2eps(Float32)
@test logerfcx(Float32(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float32)
@test logerfcx(Float64(-100)) ≈ 10000.6931471805599453094172321 rtol=2eps(Float64)
@test logerfcx(Float32(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float32)
@test logerfcx(Float64(100)) ≈ -5.17758512266433257046678208395 rtol=2eps(Float64)
@test logerfcx(Float32(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float32)
@test logerfcx(Float64(-1000)) ≈ 1.00000069314718055994530941723e6 rtol=2eps(Float64)
@test logerfcx(Float32(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float32)
@test logerfcx(Float64(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float64)

@test erfi(Float16(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float16)
@test erfi(Float32(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float32)
@test erfi(Float64(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float64)
for T in (Float16, Float32, Float64)
@test @inferred(erf(T(1))) isa T
@test erf(T(1)) ≈ T(0.84270079294971486934) rtol=2*eps(T)

@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test_throws MethodError erfinv(Float16(1))
@test erfinv(Float32(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float32)
@test erfinv(Float64(0.84270079294971486934)) ≈ 1 rtol=2*eps(Float64)
@test @inferred(erfc(T(1))) isa T
@test erfc(T(1)) ≈ T(0.15729920705028513066) rtol=2*eps(T)

@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test_throws MethodError erfcinv(Float16(1))
@test erfcinv(Float32(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float32)
@test erfcinv(Float64(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float64)
@test @inferred(erfcx(T(1))) isa T
@test erfcx(T(1)) ≈ T(0.42758357615580700442) rtol=2*eps(T)

@test @inferred(logerfc(T(1))) isa T
@test logerfc(T(-100)) ≈ T(0.6931471805599453) rtol=2*eps(T)
@test logerfc(T(1000)) ≈ T(-1.0000074801207219e6) rtol=2*eps(T)
@test logerfc(T(10000)) ≈ T(log(erfc(BigFloat(10000, precision=100)))) rtol=2*eps(T)

@test @inferred(logerfcx(T(1))) isa T
@test logerfcx(T(1)) ≈ T(-0.849605509933248248576017509499) rtol=2eps(T)
@test logerfcx(T(-1)) ≈ T(1.61123231767807049464268192445) rtol=2eps(T)
@test logerfcx(T(-100)) ≈ T(10000.6931471805599453094172321) rtol=2eps(T)
@test logerfcx(T(100)) ≈ T(-5.17758512266433257046678208395) rtol=2eps(T)
@test logerfcx(T(-1000)) ≈ T(1.00000069314718055994530941723e6) rtol=2eps(T)
@test logerfcx(T(1000)) ≈ T(-7.48012072190621214066734919080) rtol=2eps(T)

@test @inferred(erfi(T(1))) isa T
@test erfi(T(1)) ≈ T(1.6504257587975428760) rtol=2*eps(T)

@test @inferred(erfinv(T(1))) isa T
@test erfinv(T(0.84270079294971486934)) ≈ 1 rtol=2*eps(T)

@test dawson(Float16(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float16)
@test dawson(Float32(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float32)
@test dawson(Float64(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float64)
@test @inferred(erfcinv(T(1))) isa T
@test erfcinv(T(0.15729920705028513066)) ≈ 1 rtol=2*eps(T)

@test @inferred(dawson(T(1))) isa T
@test dawson(T(1)) ≈ T(0.53807950691276841914) rtol=2*eps(T)

@test @inferred(faddeva(T(1))) isa T
devmotion marked this conversation as resolved.
Show resolved Hide resolved
@test faddeeva(T(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(T)
end

@test logerfc(1000) ≈ -1.0000074801207219e6 rtol=2*eps(Float32)
@test erfinv(Integer(0)) == 0 == erfinv(0//1)
@test erfcinv(Integer(1)) == 0 == erfcinv(1//1)
@test faddeeva(0) == faddeeva(0//1) == 1
@test faddeeva(Float16(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float16)
@test faddeeva(Float32(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float32)
@test faddeeva(Float64(1)) ≈ 0.36787944117144233402+0.60715770584139372446im rtol=2*eps(Float64)
end

@testset "complex arguments" begin
Expand Down
Loading