Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of compiler #513

Merged
merged 38 commits into from
Sep 14, 2018
Merged

Refactoring of compiler #513

merged 38 commits into from
Sep 14, 2018

Conversation

trappmartin
Copy link
Member

@trappmartin trappmartin commented Sep 10, 2018

This is a major refactoring of the compiler code. Currently, this PR is work in progress and should not be merged. The PR should be finished soon.

Todo:

@willtebbutt willtebbutt self-requested a review September 10, 2018 14:27
@trappmartin trappmartin changed the title Refactoring of compiler Refactoring of compiler (WIP) Sep 10, 2018
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, much cleaner than before. The thing that I feel most strongly about regards the use of MacroTools to make some code more robust and to avoid reinventing the wheel. Please see my specific comments on that.


Extract function name, arguments and body.
"""
function extractcomponents(fexpr::Expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider using MacroTools.jl here. It's splitdef function is designed precisely to do the job of this function.

src/core/compiler.jl Outdated Show resolved Hide resolved

function insdelim(c, deli=",")
reduce((e, res) -> append!(e, [res, deli]), c; init = [])[1:end-1]
reduce((e, res) -> append!(e, [res, deli]), c; init = [])[1:end-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return statement please. Also, this function appears to only be used in is_inside.jl, so perhaps we could move it to that file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll move it into is_inside.jl

while isa(curr, Expr) && curr.head == :ref
curr = curr.args[1]
end
curr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return

else
map(translate!, ex.args)
end
ex
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return

fname_inner = Symbol("$(fname)_model")
fname_inner_str = string(fname_inner)
fdefn_inner_func = constructfunc(
fname_inner,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation (again, sorry)

src/core/compiler.jl Outdated Show resolved Hide resolved
push!(fargs_outer[1].args, Expr(:kw, :compiler, compiler))

# turn f(x;..) into f(x=nothing;..)
for i = 2:length(fargs_outer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Personal preference) This is maybe a cleaner version:

fargs_outer = map(x -> x isa Symbol ? Expr(:kw, x, :nothing) : x, fargs_outer[2:end])

src/core/compiler.jl Outdated Show resolved Hide resolved
src/core/compiler.jl Outdated Show resolved Hide resolved
@trappmartin
Copy link
Member Author

Thanks for reviewing, I'll consider using MacroTools to improve the code.

@willtebbutt
Copy link
Member

Regarding the @VarName macro, here's a sketch of the function that I would propose to replace it with:

make_varname(expr::Symbol) = (expr, (), gensym())
function make_varname(expr::Expr)
    @assert expr.head ==  :ref "VarName: Malformed variable name $(expr)"
    return _make_varname(expr, Vector{Union{Symbol, Expr}}())..., gensym()
end
function _make_varname(expr::Expr, indices::Vector{Union{Symbol, Expr}})
    if expr.args[1] isa Symbol
        return expr.args[1], vcat(Symbol(expr.args[2]), indices)
    else
        return _make_varname(expr.args[1], vcat(Symbol(expr.args[2]), indices))
    end
end

(named make_varname to distinguish from the other varname function that we have.)

This doesn't quite work correctly, in particular the type of the second thing returned isn't quite the same as what @VarName currently returns, but it shouldn't be hard to refactor to make it return the right thing. eg. run make_varname(:(x[5][4])) and @VarName x[5][4] to see the difference.

@yebai
Copy link
Member

yebai commented Sep 10, 2018

Regarding the @varname macro, here's a sketch of the function that I would propose to replace it with:

There are cases where a function isn't enough, and a macro is required. Consider the following case

x = 1
y = 2
a = Array(undef, 4,4)

a[x,y] ~ Normal(0,1)

Here the indexing only becomes available at runtime. So we have to generate the variable name using a macro.

@trappmartin trappmartin added this to the 0.5 milestone Sep 10, 2018
@willtebbutt
Copy link
Member

@yebai I completely agree that we require a way to construct indices dynamically, but disagree that it necessitates the extra macro. I'll open a separate issue to discuss so that we can resolve this in a separate PR.

@yebai
Copy link
Member

yebai commented Sep 12, 2018

Hi @trappmartin, here is an updated compiler design after removing model_f(;data=Dict()) interface (#475). I found this very helpful when reasoning about closures. Hope this helps for refactoring the code.

    # Compiler design: sample(f_compiletime(x,y), sampler)
    #   f_compiletime(x, y; compiler=compiler) = begin
    #      ex = quote
    #          f_runtime(vi::VarInfo, sampler::Sampler; x = x, y = y) = begin
    #              # pour model definition `fbody`, e.g.
    #              x ~ Normal(0,1)
    #              y ~ Normal(x, 1)
    #          end
    #      end
    #      Main.eval(ex)
    #   end

@willtebbutt willtebbutt mentioned this pull request Sep 13, 2018
@trappmartin
Copy link
Member Author

trappmartin commented Sep 13, 2018

The current code breaks the ad.jl tests. I'm not sure why this is the case and I don't recall observing this before the last rebase.

ad1.jl: Error During Test at /home/travis/build/TuringLang/Turing.jl/test/utility.jl:73
  Got exception outside of a @test
  LoadError: DivideError: integer division error
  Stacktrace:
   [1] rem at ./int.jl:233 [inlined]
   [2] chunk_mode_gradient!(::Array{Float64,1}, ::getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing},Float64},Float64,0,Array{Dual{ForwardDiff.Tag{getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing},Float64},Float64,0},1}}) at /home/travis/.julia/packages/ForwardDiff/hnKaN/src/gradient.jl:127
   [3] gradient!(::Array{Float64,1}, ::getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing}, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing},Float64},Float64,0,Array{Dual{ForwardDiff.Tag{getfield(Turing, Symbol("#f#84")){Turing.VarReplay.VarInfo,typeof(ad_test_model),Nothing},Float64},Float64,0},1}}) at /home/travis/.julia/packages/ForwardDiff/hnKaN/src/gradient.jl:37
   [4] gradient_forward(::Array{Float64,1}, ::Turing.VarReplay.VarInfo, ::Function, ::Nothing, ::Int64) at /home/travis/build/TuringLang/Turing.jl/src/core/ad.jl:32
   [5] gradient_forward(::Array{Float64,1}, ::Turing.VarReplay.VarInfo, ::Function) at /home/travis/build/TuringLang/Turing.jl/src/core/ad.jl:21
  ad.jl                                    |    6      2              8
    AD_compatibility_with_distributions.jl |                      No tests
    ad1.jl                                 |           1              1
    ad2.jl                                 |                      No tests
    ad3.jl                                 |           1              1
    pass_dual_to_dists.jl                  |    6                     6

Does anybody have an idea why the ad tests break?

cc: @yebai @xukai92

@willtebbutt
Copy link
Member

Looks like CHUNK_SIZE isn't getting changed from zero for some reason. Just changing the initialisation for this constant to something other than zero (10?) sorts it out.

const CHUNKSIZE = Ref(10) # default chunksize used by AD

on line 51 of Turing.jl

@trappmartin trappmartin changed the title Refactoring of compiler (WIP) Refactoring of compiler Sep 13, 2018
@trappmartin
Copy link
Member Author

Thanks, @yebai! The assume function calls look much better now.

@willtebbutt willtebbutt merged commit 8b7d753 into master Sep 14, 2018
@yebai yebai deleted the ref/compiler_rebased branch September 14, 2018 19:33
yebai pushed a commit that referenced this pull request Sep 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants