Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mk/rational arithmetic #141

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/implementations/BigInt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ function operate_to!(output::BigInt, ::typeof(+), a::BigInt, b::BigInt)
return Base.GMP.MPZ.add!(output, a, b)
end

function operate_to!(output::BigInt, ::typeof(copy), a::BigInt)
return Base.GMP.MPZ.set!(output, a)
end

function operate_to!(output::BigInt, ::typeof(copy), a::Int)
return Base.GMP.MPZ.set_si!(output, a)
end

# -

promote_operation(::typeof(-), ::Vararg{Type{BigInt},N}) where {N} = BigInt
Expand All @@ -35,6 +43,10 @@ function operate_to!(output::BigInt, ::typeof(-), a::BigInt, b::BigInt)
return Base.GMP.MPZ.sub!(output, a, b)
end

function operate_to!(output::BigInt, ::typeof(-), a::BigInt)
return Base.GMP.MPZ.neg!(output, a)
end

# *

promote_operation(::typeof(*), ::Vararg{Type{BigInt},N}) where {N} = BigInt
Expand Down
139 changes: 131 additions & 8 deletions src/implementations/Rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,29 @@ end
# +

function promote_operation(
::typeof(+),
::Union{typeof(+),typeof(-)},
::Type{Rational{S}},
::Type{Rational{T}},
) where {S,T}
return Rational{promote_sum_mul(S, T)}
end

function promote_operation(
op::Union{typeof(+),typeof(-)},
::Type{Rational{S}},
::Type{I},
) where {S,I<:Integer}
return promote_operation(op, Rational{S}, Rational{I})
end

function promote_operation(
op::Union{typeof(+),typeof(-)},
::Type{I},
::Type{Rational{S}},
) where {S,I<:Integer}
return promote_operation(op, Rational{S}, Rational{I})
end

function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Rational)
xd, yd = Base.divgcd(promote(x.den, y.den)...)
# TODO: Use `checked_mul` and `checked_add` like in Base
Expand All @@ -46,16 +62,28 @@ function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Rational)
return output
end

# -
function operate_to!(output::Rational, ::typeof(+), x::Rational, y::Integer)
# TODO Use `checked_mul` and `checked_add` like in Base
operate_to!(output.num, *, x.den, y)
operate!(+, output.num, x.num)
operate_to!(output.den, *, x.den, oftype(x.den, 1))
return output
end

function promote_operation(
::typeof(-),
::Type{Rational{S}},
::Type{Rational{T}},
) where {S,T}
return Rational{promote_sum_mul(S, T)}
function operate_to!(output::Rational, ::typeof(+), y::Integer, x::Rational)
return operate_to!(output, +, x, y)
end

# unary -

function operate_to!(output::Rational, ::typeof(-), x::Rational)
operate_to!(output.num, -, x.num)
operate_to!(output.den, copy, x.den)
return output
end

# binary -

function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Rational)
xd, yd = Base.divgcd(promote(x.den, y.den)...)
# TODO: Use `checked_mul` and `checked_sub` like in Base
Expand All @@ -65,6 +93,22 @@ function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Rational)
return output
end

function operate_to!(output::Rational, ::typeof(-), x::Rational, y::Integer)
# TODO Use `checked_mul` and `checked_sub` like in Base
operate_to!(output.num, *, x.den, y)
operate!(-, output.num)
operate!(+, output.num, x.num)
operate_to!(output.den, copy, x.den)
return output
end

function operate_to!(output::Rational, ::typeof(-), y::Integer, x::Rational)
# TODO Use `checked_mul` and `checked_sub` like in Base
operate_to!(output, -, x, y)
operate_to!(output, -, output)
return output
end

# *

function promote_operation(
Expand All @@ -75,6 +119,22 @@ function promote_operation(
return Rational{promote_operation(*, S, T)}
end

function promote_operation(
::typeof(*),
::Type{Rational{S}},
::Type{I},
) where {S,I<:Integer}
return promote_operation(*, Rational{S}, Rational{I})
end

function promote_operation(
::typeof(*),
::Type{I},
::Type{Rational{S}},
) where {S,I<:Integer}
return promote_operation(*, Rational{S}, Rational{I})
end

function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Rational)
xn, yd = Base.divgcd(promote(x.num, y.den)...)
xd, yn = Base.divgcd(promote(x.den, y.num)...)
Expand All @@ -83,6 +143,69 @@ function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Rational)
return output
end

function operate_to!(output::Rational, ::typeof(*), x::Rational, y::Integer)
xn = x.num
xd, yn = Base.divgcd(promote(x.den, y)...)
operate_to!(output.num, *, xn, yn)
operate_to!(output.den, copy, x.den)
return output
end

function operate_to!(output::Rational, ::typeof(*), y::Integer, x::Rational)
return operate_to!(output, *, x, y)
end

# //

function operate_to!(
output::Rational,
op::Union{typeof(/),typeof(//)},
x::Rational,
y::Rational,
)
xn, yn = Base.divgcd(promote(x.num, y.num)...)
xd, yd = Base.divgcd(promote(x.den, y.den)...)
operate_to!(output.num, *, xn, yd)
operate_to!(output.den, *, xd, yn)
return output
end

function operate_to!(
output::Rational,
op::Union{typeof(/),typeof(//)},
x::Rational,
y::Integer,
)
xn, yn = Base.divgcd(promote(x.num, y)...)
operate_to!(output.num, copy, xn)
operate_to!(output.den, *, x.den, yn)
return output
end

function operate_to!(
output::Rational,
op::Union{typeof(/),typeof(//)},
x::Integer,
y::Rational,
)
xn, yd = Base.divgcd(promote(x, y.den)...)
operate_to!(output.num, *, xn, yd)
operate_to!(output.den, copy, y.num)
return output
end

function operate_to!(
output::Rational,
op::Union{typeof(/),typeof(//)},
x::Integer,
y::Integer,
)
n, d = Base.divgcd(promote(x, y)...)
operate_to!(output.num, copy, n)
operate_to!(output.den, copy, d)
return output
end

# gcd

function promote_operation(
Expand Down
20 changes: 6 additions & 14 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,22 @@ function promote_operation_fallback(
end

function promote_operation_fallback(
::typeof(/),
op::Function,
::Type{S},
::Type{T},
) where {S,T}
return typeof(zero(S) / oneunit(T))
U = Base.promote_op(op, S, T)
return return U == Union{} ? typeof(op(oneunit(S), oneunit(T))) : U
end

# Julia v1.0.x has trouble with inference with the `Vararg` method, see
# https://travis-ci.org/jump-dev/JuMP.jl/jobs/617606373
function promote_operation_fallback(
op::F,
::Type{S},
::Type{T},
) where {F<:Function,S,T}
return typeof(op(zero(S), zero(T)))
end

function promote_operation_fallback(
op::F,
args::Vararg{Type,N},
) where {F<:Function,N}
return typeof(op(zero.(args)...))
U = Base.promote_op(op, args...)
return return U == Union{} ? typeof(op(oneunit.(args)...)) : U
end

promote_operation_fallback(::typeof(*), ::Type{T}) where {T} = T
Expand Down Expand Up @@ -172,9 +166,7 @@ function operate(
) where {N}
return op(x, y, args...)
end

operate(op::Union{typeof(-),typeof(/)}, x, y) where {N} = op(x, y)

operate(op::Union{typeof(-),typeof(/),typeof(//)}, x, y) = op(x, y)
operate(::typeof(convert), ::Type{T}, x) where {T} = convert(T, x)

operate(::typeof(convert), ::Type{T}, x::T) where {T} = copy_if_mutable(x)
Expand Down
2 changes: 1 addition & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
if VERSION >= v"1.5"
# FIXME This should not allocate but I couldn't figure out where these
# 240 come from.
alloc_test(() -> MA.broadcast!!(+, a, b), 240)
alloc_test(() -> MA.broadcast!!(+, a, b), 80)
alloc_test(() -> MA.broadcast!!(+, a, c), 0)
end
end
21 changes: 21 additions & 0 deletions test/rational.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
for op in (+, -, *, //)
for (a,b) in (
(2 // 3, 5),
(2, 3 // 5),
(2 // 3, 5 // 7),
(big(2) // 3, 5),
(big(2), 3 // 5),
(big(2) // 3, 5 // 7),
)
@test MA.operate_to!!(MA.copy_if_mutable(op(a, b)), op, a, b) ==
op(a, b)
@test MA.operate_to!!(MA.copy_if_mutable(op(b, a)), op, b, a) ==
op(b, a)
end
end

op = //
for (a, b) in ((2, 3), (big(2), 3), (2, big(3)))
@test MA.operate_to!!(MA.copy_if_mutable(op(a, b)), op, a, b) == op(a, b)
@test MA.operate_to!!(MA.copy_if_mutable(op(b, a)), op, b, a) == op(b, a)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ end
@testset "BigInt" begin
include("big.jl")
end

@testset "Rational" begin
include("rational.jl")
end

@testset "Broadcast" begin
include("broadcast.jl")
end
Expand Down