Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug printing scientific numbers in MIME"text/latex" #3838

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions src/print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,37 @@ function _is_im_for_printing(coef::Complex)
return _is_zero_for_printing(r) && _is_one_for_printing(i)
end

_escape_if_scientific(::MIME, x::String) = x

function _escape_if_scientific(::MIME"text/latex", x::String)
m = match(r"([0-9]+.[0-9]+)e(-?[0-9]+)", x)
if m === nothing
return x
end
return "$(m[1]) \\times 10^{$(m[2])}"
end

# Helper function that rounds carefully for the purposes of printing Reals
# for example, 5.3 => 5.3, and 1.0 => 1
function _string_round(x::Union{Float32,Float64})
function _string_round(mode, x::Union{Float32,Float64})
if isinteger(x) && typemin(Int64) <= x <= typemax(Int64)
return string(round(Int64, x))
end
return string(x)
return _escape_if_scientific(mode, string(x))
end

_string_round(::typeof(abs), x::Real) = _string_round(abs(x))
_string_round(mode, ::typeof(abs), x::Real) = _string_round(mode, abs(x))

_sign_string(x::Real) = x < zero(x) ? " - " : " + "

function _string_round(::typeof(abs), x::Complex)
function _string_round(mode, ::typeof(abs), x::Complex)
r, i = reim(x)
if _is_zero_for_printing(r)
return _string_round(Complex(r, abs(i)))
return _string_round(mode, Complex(r, abs(i)))
elseif _is_zero_for_printing(i)
return _string_round(Complex(abs(r), i))
return _string_round(mode, Complex(abs(r), i))
else
return _string_round(x)
return _string_round(mode, x)
end
end

Expand All @@ -105,15 +115,15 @@ end

# Fallbacks for other number types

_string_round(x::Any) = string(x)
_string_round(mode, x::Any) = string(x)

_string_round(::typeof(abs), x::Any) = _string_round(x)
_string_round(mode, ::typeof(abs), x::Any) = _string_round(mode, x)

_sign_string(::Any) = " + "

function _string_round(x::Complex)
function _string_round(mode, x::Complex)
r, i = reim(x)
r_str = _string_round(r)
r_str = _string_round(mode, r)
if _is_zero_for_printing(i)
return r_str
elseif _is_zero_for_printing(r)
Expand All @@ -124,12 +134,13 @@ function _string_round(x::Complex)
return "im"
end
else
return string(_string_round(i), "im")
return string(_string_round(mode, i), "im")
end
elseif _is_one_for_printing(i)
return string("(", r_str, _sign_string(i), "im)")
else
return string("(", r_str, _sign_string(i), _string_round(abs, i), "im)")
abs_i = _string_round(mode, abs, i)
return string("(", r_str, _sign_string(i), abs_i, "im)")
end
end

Expand Down Expand Up @@ -587,11 +598,11 @@ function nonlinear_constraint_string(
body = nonlinear_expr_string(model, mode, constraint.expression)
lhs = _set_lhs(constraint.set)
rhs = _set_rhs(constraint.set)
output = "$body $(_math_symbol(mode, rhs[1])) $(_string_round(rhs[2]))"
output = "$body $(_math_symbol(mode, rhs[1])) $(_string_round(mode, rhs[2]))"
if lhs === nothing
return output
end
return "$(_string_round(lhs[2])) $(_math_symbol(mode, lhs[1])) $output"
return "$(_string_round(mode, lhs[2])) $(_math_symbol(mode, lhs[1])) $output"
end

"""
Expand Down Expand Up @@ -825,13 +836,13 @@ function function_string(mode::MIME"text/latex", v::AbstractVariableRef)
return var_name
end

function _term_string(coef, factor)
function _term_string(mode, coef, factor)
if _is_one_for_printing(coef)
return factor
elseif _is_im_for_printing(coef)
return string(factor, " ", _string_round(abs, coef))
return string(factor, " ", _string_round(mode, abs, coef))
else
return string(_string_round(abs, coef), " ", factor)
return string(_string_round(mode, abs, coef), " ", factor)
end
end

Expand Down Expand Up @@ -873,20 +884,20 @@ end
# TODO(odow): remove show_constant in JuMP 1.0
function function_string(mode, a::GenericAffExpr, show_constant = true)
if length(linear_terms(a)) == 0
return show_constant ? _string_round(a.constant) : "0"
return show_constant ? _string_round(mode, a.constant) : "0"
end
terms = fill("", 2 * length(linear_terms(a)))
for (elm, (coef, var)) in enumerate(linear_terms(a))
terms[2*elm-1] = _sign_string(coef)
terms[2*elm] = _term_string(coef, function_string(mode, var))
terms[2*elm] = _term_string(mode, coef, function_string(mode, var))
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = _terms_to_truncated_string(mode, terms)
if show_constant && !_is_zero_for_printing(a.constant)
ret = string(
ret,
_sign_string(a.constant),
_string_round(abs, a.constant),
_string_round(mode, abs, a.constant),
)
end
return ret
Expand All @@ -907,7 +918,7 @@ function function_string(mode, q::GenericQuadExpr)
times = mode == MIME("text/latex") ? "\\times " : "*"
factor = string(x, times, y)
end
terms[2*elm] = _term_string(coef, factor)
terms[2*elm] = _term_string(mode, coef, factor)
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = _terms_to_truncated_string(mode, terms)
Expand Down Expand Up @@ -1032,25 +1043,27 @@ julia> in_set_string(MIME("text/plain"), MOI.Interval(1.0, 2.0))
function in_set_string end

function in_set_string(mode::MIME, set::MOI.LessThan)
return string(_math_symbol(mode, :leq), " ", _string_round(set.upper))
return string(_math_symbol(mode, :leq), " ", _string_round(mode, set.upper))
end

function in_set_string(mode::MIME, set::MOI.GreaterThan)
return string(_math_symbol(mode, :geq), " ", _string_round(set.lower))
return string(_math_symbol(mode, :geq), " ", _string_round(mode, set.lower))
end

function in_set_string(mode::MIME, set::MOI.EqualTo)
return string(_math_symbol(mode, :eq), " ", _string_round(set.value))
return string(_math_symbol(mode, :eq), " ", _string_round(mode, set.value))
end

function in_set_string(::MIME"text/latex", set::MOI.Interval)
lower, upper = _string_round(set.lower), _string_round(set.upper)
lower = _string_round(mode, set.lower)
upper = _string_round(mode, set.upper)
return string("\\in [", lower, ", ", upper, "]")
end

function in_set_string(mode::MIME"text/plain", set::MOI.Interval)
in = _math_symbol(mode, :in)
lower, upper = _string_round(set.lower), _string_round(set.upper)
lower = _string_round(mode, set.lower)
upper = _string_round(mode, set.upper)
return string("$in [", lower, ", ", upper, "]")
end

Expand Down
136 changes: 72 additions & 64 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,73 +887,66 @@ function test_print_hermitian_psd_cone()
end

function test_print_complex_string_round()
@test JuMP._string_round(1.0 + 0.0im) == "1"
@test JuMP._string_round(-1.0 + 0.0im) == "-1"
@test JuMP._string_round(1.0 - 0.0im) == "1"
@test JuMP._string_round(-1.0 - 0.0im) == "-1"
@test JuMP._string_round(0.0 + 1.0im) == "im"
@test JuMP._string_round(-0.0 + 1.0im) == "im"
@test JuMP._string_round(0.0 - 1.0im) == "-im"
@test JuMP._string_round(-0.0 - 1.0im) == "-im"
@test JuMP._string_round(1.0 + 2.0im) == "(1 + 2im)"
@test JuMP._string_round(1.0 - 2.0im) == "(1 - 2im)"
@test JuMP._string_round(-1.0 + 2.0im) == "(-1 + 2im)"
@test JuMP._string_round(-1.0 - 2.0im) == "(-1 - 2im)"
@test JuMP._string_round(1.0 + 1.0im) == "(1 + im)"
@test JuMP._string_round(1.0 - 1.0im) == "(1 - im)"
@test JuMP._string_round(-1.0 + 1.0im) == "(-1 + im)"
@test JuMP._string_round(-1.0 - 1.0im) == "(-1 - im)"
for (test, result) in Any[
1.0+0.0im=>"1",
-1.0+0.0im=>"-1",
1.0-0.0im=>"1",
-1.0-0.0im=>"-1",
0.0+1.0im=>"im",
-0.0+1.0im=>"im",
0.0-1.0im=>"-im",
-0.0-1.0im=>"-im",
1.0+2.0im=>"(1 + 2im)",
1.0-2.0im=>"(1 - 2im)",
-1.0+2.0im=>"(-1 + 2im)",
-1.0-2.0im=>"(-1 - 2im)",
1.0+1.0im=>"(1 + im)",
1.0-1.0im=>"(1 - im)",
-1.0+1.0im=>"(-1 + im)",
-1.0-1.0im=>"(-1 - im)",
]
@test JuMP._string_round(MIME("text/plain"), test) == result
end
return
end

function test_print_huge_integer_string_round()
@test JuMP._string_round(-1 + Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(-1 + Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(-1 + Float64(typemax(Int32))) == "2147483646"
@test JuMP._string_round(-1 + Float64(typemin(Int32))) == "-2147483649"

@test JuMP._string_round(Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(Float64(typemax(Int32))) == "2147483647"
@test JuMP._string_round(Float64(typemin(Int32))) == "-2147483648"

@test JuMP._string_round(1 + Float32(typemax(Int32))) == "2147483648"
@test JuMP._string_round(1 + Float32(typemin(Int32))) == "-2147483648"
@test JuMP._string_round(1 + Float64(typemax(Int32))) == "2147483648"
@test JuMP._string_round(1 + Float64(typemin(Int32))) == "-2147483647"

@test JuMP._string_round(2 * Float32(typemax(Int32))) == "4294967296"
@test JuMP._string_round(2 * Float32(typemin(Int32))) == "-4294967296"
@test JuMP._string_round(2 * Float64(typemax(Int32))) == "4294967294"
@test JuMP._string_round(2 * Float64(typemin(Int32))) == "-4294967296"

@test JuMP._string_round(-1 + Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(-1 + Float32(typemin(Int64))) ==
"-9223372036854775808"
@test JuMP._string_round(-1 + Float64(typemax(Int64))) ==
"9.223372036854776e18"
@test JuMP._string_round(-1 + Float64(typemin(Int64))) ==
"-9223372036854775808"

@test JuMP._string_round(Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(Float32(typemin(Int64))) == "-9223372036854775808"
@test JuMP._string_round(Float64(typemax(Int64))) == "9.223372036854776e18"
@test JuMP._string_round(Float64(typemin(Int64))) == "-9223372036854775808"

@test JuMP._string_round(1 + Float32(typemax(Int64))) == "9.223372e18"
@test JuMP._string_round(1 + Float32(typemin(Int64))) ==
"-9223372036854775808"
@test JuMP._string_round(1 + Float64(typemax(Int64))) ==
"9.223372036854776e18"
@test JuMP._string_round(1 + Float64(typemin(Int64))) ==
"-9223372036854775808"

@test JuMP._string_round(2 * Float32(typemax(Int64))) == "1.8446744e19"
@test JuMP._string_round(2 * Float32(typemin(Int64))) == "-1.8446744e19"
@test JuMP._string_round(2 * Float64(typemax(Int64))) ==
"1.8446744073709552e19"
@test JuMP._string_round(2 * Float64(typemin(Int64))) ==
"-1.8446744073709552e19"
for (test, result) in Any[
-1+Float32(typemax(Int32))=>"2147483648",
-1+Float32(typemin(Int32))=>"-2147483648",
-1+Float64(typemax(Int32))=>"2147483646",
-1+Float64(typemin(Int32))=>"-2147483649",
Float32(typemax(Int32))=>"2147483648",
Float32(typemin(Int32))=>"-2147483648",
Float64(typemax(Int32))=>"2147483647",
Float64(typemin(Int32))=>"-2147483648",
1+Float32(typemax(Int32))=>"2147483648",
1+Float32(typemin(Int32))=>"-2147483648",
1+Float64(typemax(Int32))=>"2147483648",
1+Float64(typemin(Int32))=>"-2147483647",
2*Float32(typemax(Int32))=>"4294967296",
2*Float32(typemin(Int32))=>"-4294967296",
2*Float64(typemax(Int32))=>"4294967294",
2*Float64(typemin(Int32))=>"-4294967296",
-1+Float32(typemax(Int64))=>"9.223372e18",
-1+Float32(typemin(Int64))=>"-9223372036854775808",
-1+Float64(typemax(Int64))=>"9.223372036854776e18",
-1+Float64(typemin(Int64))=>"-9223372036854775808",
Float32(typemax(Int64))=>"9.223372e18",
Float32(typemin(Int64))=>"-9223372036854775808",
Float64(typemax(Int64))=>"9.223372036854776e18",
Float64(typemin(Int64))=>"-9223372036854775808",
1+Float32(typemax(Int64))=>"9.223372e18",
1+Float32(typemin(Int64))=>"-9223372036854775808",
1+Float64(typemax(Int64))=>"9.223372036854776e18",
1+Float64(typemin(Int64))=>"-9223372036854775808",
2*Float32(typemax(Int64))=>"1.8446744e19",
2*Float32(typemin(Int64))=>"-1.8446744e19",
2*Float64(typemax(Int64))=>"1.8446744073709552e19",
2*Float64(typemin(Int64))=>"-1.8446744073709552e19",
]
@test JuMP._string_round(MIME("text/plain"), test) == result
end
return
end

Expand All @@ -965,7 +958,7 @@ function test_print_model_with_huge_integers()
@test sprint(io -> show(io, MIME("text/plain"), c)) == "1.0e20 x $eq 42"
eq = JuMP._math_symbol(MIME("text/latex"), :eq)
@test sprint(io -> show(io, MIME("text/latex"), c)) ==
"\$\$ 1.0e20 x $eq 42 \$\$"
"\$\$ 1.0 \\times 10^{20} x $eq 42 \$\$"
return
end

Expand Down Expand Up @@ -1124,4 +1117,19 @@ function test_show_generic_model_bigfloat()
return
end

function test_small_number_latex()
model = Model()
@variable(model, x)
y = 1e-8 * x
@test function_string(MIME("text/latex"), y) == "1.0 \\times 10^{-8} x"
@test function_string(MIME("text/plain"), y) == "1.0e-8 x"
y = 0.23e-8 * x
@test function_string(MIME("text/latex"), y) == "2.3 \\times 10^{-9} x"
@test function_string(MIME("text/plain"), y) == "2.3e-9 x"
y = 1.23 * x
@test function_string(MIME("text/latex"), y) == "1.23 x"
@test function_string(MIME("text/plain"), y) == "1.23 x"
return
end

end # TestPrint
Loading