Skip to content

Commit

Permalink
add Compat for v0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Jun 14, 2016
1 parent 9c6ff5f commit 5652772
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 32 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
julia 0.4
Compat
Calculus
NaNMath
19 changes: 19 additions & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,28 @@ isdefined(Base, :__precompile__) && __precompile__()

module ForwardDiff

using Compat

import Calculus
import NaNMath

#######################
# compatibility patch #
#######################

if v"0.4" <= VERSION < v"0.5-"
# e.g. @operator Base.:op -> Base.(:op)
macro operator(qualified_name)
func_name = qualified_name.args[2].args[1]
qualified_name.args[2] = func_name
return qualified_name
end
else
macro operator(qualified_name)
return qualified_name
end
end

#############################
# types/functions/constants #
#############################
Expand Down
14 changes: 7 additions & 7 deletions src/Partials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Base.done(partials::Partials, i) = done(partials.values, i)
@inline Base.rand{N,T}(rng::AbstractRNG, ::Type{Partials{N,T}}) = Partials(rand_tuple(rng, NTuple{N,T}))

Base.isequal{N}(a::Partials{N}, b::Partials{N}) = isequal(a.values, b.values)
Base.:(==){N}(a::Partials{N}, b::Partials{N}) = a.values == b.values
@operator(Base.:(==)){N}(a::Partials{N}, b::Partials{N}) = a.values == b.values

const PARTIALS_HASH = hash(Partials)

Expand Down Expand Up @@ -66,12 +66,12 @@ Base.convert{N,T}(::Type{Partials{N,T}}, partials::Partials{N,T}) = partials
# Arithmetic Functions #
########################

@inline Base.:+{N}(a::Partials{N}, b::Partials{N}) = Partials(add_tuples(a.values, b.values))
@inline Base.:-{N}(a::Partials{N}, b::Partials{N}) = Partials(sub_tuples(a.values, b.values))
@inline Base.:-(partials::Partials) = Partials(minus_tuple(partials.values))
@inline Base.:*(partials::Partials, x::Real) = Partials(scale_tuple(partials.values, x))
@inline Base.:*(x::Real, partials::Partials) = partials*x
@inline Base.:/(partials::Partials, x::Real) = Partials(div_tuple_by_scalar(partials.values, x))
@inline @operator(Base.:+){N}(a::Partials{N}, b::Partials{N}) = Partials(add_tuples(a.values, b.values))
@inline @operator(Base.:-){N}(a::Partials{N}, b::Partials{N}) = Partials(sub_tuples(a.values, b.values))
@inline @operator(Base.:-)(partials::Partials) = Partials(minus_tuple(partials.values))
@inline @operator(Base.:*)(partials::Partials, x::Real) = Partials(scale_tuple(partials.values, x))
@inline @operator(Base.:*)(x::Real, partials::Partials) = partials*x
@inline @operator(Base.:/)(partials::Partials, x::Real) = Partials(div_tuple_by_scalar(partials.values, x))

@inline function _mul_partials{N}(a::Partials{N}, b::Partials{N}, afactor, bfactor)
return Partials(mul_tuples(a.values, b.values, afactor, bfactor))
Expand Down
50 changes: 25 additions & 25 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,26 @@ end
isconstant(n::Dual) = iszero(partials(n))

@ambiguous Base.isequal{N}(a::Dual{N}, b::Dual{N}) = isequal(value(a), value(b))
@ambiguous Base.:(==){N}(a::Dual{N}, b::Dual{N}) = value(a) == value(b)
@ambiguous @operator(Base.:(==)){N}(a::Dual{N}, b::Dual{N}) = value(a) == value(b)
@ambiguous Base.isless{N}(a::Dual{N}, b::Dual{N}) = value(a) < value(b)
@ambiguous Base.:<{N}(a::Dual{N}, b::Dual{N}) = isless(a, b)
@ambiguous Base.:(<=){N}(a::Dual{N}, b::Dual{N}) = <=(value(a), value(b))
@ambiguous @operator(Base.:<){N}(a::Dual{N}, b::Dual{N}) = isless(a, b)
@ambiguous @operator(Base.:(<=)){N}(a::Dual{N}, b::Dual{N}) = <=(value(a), value(b))

for T in (AbstractFloat, Irrational, Real)
Base.isequal(n::Dual, x::T) = isequal(value(n), x)
Base.isequal(x::T, n::Dual) = isequal(n, x)

Base.:(==)(n::Dual, x::T) = (value(n) == x)
Base.:(==)(x::T, n::Dual) = ==(n, x)
@operator(Base.:(==))(n::Dual, x::T) = (value(n) == x)
@operator(Base.:(==))(x::T, n::Dual) = ==(n, x)

Base.isless(n::Dual, x::T) = value(n) < x
Base.isless(x::T, n::Dual) = x < value(n)

Base.:<(n::Dual, x::T) = isless(n, x)
Base.:<(x::T, n::Dual) = isless(x, n)
@operator(Base.:<)(n::Dual, x::T) = isless(n, x)
@operator(Base.:<)(x::T, n::Dual) = isless(x, n)

Base.:(<=)(n::Dual, x::T) = <=(value(n), x)
Base.:(<=)(x::T, n::Dual) = <=(x, value(n))
@operator(Base.:(<=))(n::Dual, x::T) = <=(value(n), x)
@operator(Base.:(<=))(x::T, n::Dual) = <=(x, value(n))
end

Base.isnan(n::Dual) = isnan(value(n))
Expand Down Expand Up @@ -167,49 +167,49 @@ Base.float{N,T}(n::Dual{N,T}) = Dual{N,promote_type(T, Float16)}(n)
# Addition/Subtraction #
#----------------------#

@ambiguous @inline Base.:+{N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) + value(n2), partials(n1) + partials(n2))
@inline Base.:+(n::Dual, x::Real) = Dual(value(n) + x, partials(n))
@inline Base.:+(x::Real, n::Dual) = n + x
@ambiguous @inline @operator(Base.:+){N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) + value(n2), partials(n1) + partials(n2))
@inline @operator(Base.:+)(n::Dual, x::Real) = Dual(value(n) + x, partials(n))
@inline @operator(Base.:+)(x::Real, n::Dual) = n + x

@ambiguous @inline Base.:-{N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) - value(n2), partials(n1) - partials(n2))
@inline Base.:-(n::Dual, x::Real) = Dual(value(n) - x, partials(n))
@inline Base.:-(x::Real, n::Dual) = Dual(x - value(n), -(partials(n)))
@inline Base.:-(n::Dual) = Dual(-(value(n)), -(partials(n)))
@ambiguous @inline @operator(Base.:-){N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) - value(n2), partials(n1) - partials(n2))
@inline @operator(Base.:-)(n::Dual, x::Real) = Dual(value(n) - x, partials(n))
@inline @operator(Base.:-)(x::Real, n::Dual) = Dual(x - value(n), -(partials(n)))
@inline @operator(Base.:-)(n::Dual) = Dual(-(value(n)), -(partials(n)))

# Multiplication #
#----------------#

@inline Base.:*(n::Dual, x::Bool) = x ? n : (signbit(value(n))==0 ? zero(n) : -zero(n))
@inline Base.:*(x::Bool, n::Dual) = n * x
@inline @operator(Base.:*)(n::Dual, x::Bool) = x ? n : (signbit(value(n))==0 ? zero(n) : -zero(n))
@inline @operator(Base.:*)(x::Bool, n::Dual) = n * x

@ambiguous @inline function Base.:*{N}(n1::Dual{N}, n2::Dual{N})
@ambiguous @inline function @operator(Base.:*){N}(n1::Dual{N}, n2::Dual{N})
v1, v2 = value(n1), value(n2)
return Dual(v1 * v2, _mul_partials(partials(n1), partials(n2), v2, v1))
end

@inline Base.:*(n::Dual, x::Real) = Dual(value(n) * x, partials(n) * x)
@inline Base.:*(x::Real, n::Dual) = n * x
@inline @operator(Base.:*)(n::Dual, x::Real) = Dual(value(n) * x, partials(n) * x)
@inline @operator(Base.:*)(x::Real, n::Dual) = n * x

# Division #
#----------#

@ambiguous @inline function Base.:/{N}(n1::Dual{N}, n2::Dual{N})
@ambiguous @inline function @operator(Base.:/){N}(n1::Dual{N}, n2::Dual{N})
v1, v2 = value(n1), value(n2)
return Dual(v1 / v2, _div_partials(partials(n1), partials(n2), v1, v2))
end

@inline function Base.:/(x::Real, n::Dual)
@inline function @operator(Base.:/)(x::Real, n::Dual)
v = value(n)
divv = x / v
return Dual(divv, -(divv / v) * partials(n))
end

@inline Base.:/(n::Dual, x::Real) = Dual(value(n) / x, partials(n) / x)
@inline @operator(Base.:/)(n::Dual, x::Real) = Dual(value(n) / x, partials(n) / x)

# Exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
for f in (macroexpand(:(@operator(Base.:^))), :(NaNMath.pow))
@eval begin
@ambiguous @inline function ($f){N}(n1::Dual{N}, n2::Dual{N})
if iszero(partials(n2))
Expand Down

0 comments on commit 5652772

Please sign in to comment.