Skip to content

Commit

Permalink
Merge pull request #513 from TuringLang/ref/compiler_rebased
Browse files Browse the repository at this point in the history
Refactoring of compiler
  • Loading branch information
willtebbutt authored Sep 14, 2018
2 parents 1baf3a1 + ededd96 commit 10de954
Show file tree
Hide file tree
Showing 29 changed files with 919 additions and 740 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ MCMCChain 0.1.1
Libtask 0.1.1
Flux 0.6.7
Stan
MacroTools

ProgressMeter 0.6.0
BinaryProvider 0.4.0
16 changes: 14 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using LinearAlgebra
using ProgressMeter
using Markdown
using Libtask
using MacroTools

# @init @require Stan="682df890-35be-576f-97d0-3d8c8b33a550" begin
using Stan
Expand All @@ -37,6 +38,7 @@ import MCMCChain: AbstractChains, Chains
const ADBACKEND = Ref(:reverse_diff)
setadbackend(backend_sym) = begin
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = backend_sym
end

Expand Down Expand Up @@ -72,7 +74,7 @@ const CACHERANGES = 0b01

abstract type InferenceAlgorithm end
abstract type Hamiltonian <: InferenceAlgorithm end

abstract type AbstractSampler end
"""
Sampler{T}
Expand All @@ -85,11 +87,21 @@ An implementation of an algorithm should include the following:
Turing translates models to chunks that call the modelling functions at specified points. The dispatch is based on the value of a `sampler` variable. To include a new inference algorithm implements the requirements mentioned above in a separate file,
then include that file at the end of this one.
"""
mutable struct Sampler{T<:InferenceAlgorithm}
mutable struct Sampler{T<:InferenceAlgorithm} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
end

"""
Robust initialization method for model parameters in Hamiltonian samplers.
"""
struct HamiltonianRobustInit <: AbstractSampler end
struct SampleFromPrior <: AbstractSampler end

# This can be removed when all `spl=nothing` is replaced with
# `spl=SampleFromPrior`
const AnySampler = Union{Nothing, AbstractSampler}

include("utilities/helper.jl")
include("utilities/transform.jl")
include("core/varinfo.jl") # core internal variable container
Expand Down
4 changes: 2 additions & 2 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function gradient_forward(
# Define function to compute log joint.
function f(θ)
vi[spl] = θ
return -runmodel(model, vi, spl).logp
return -runmodel!(model, vi, spl).logp
end

# Set chunk size and do ForwardMode.
Expand Down Expand Up @@ -57,7 +57,7 @@ function gradient_reverse(
# Specify objective function.
function f(θ)
vi[spl] = θ
return -runmodel(model, vi, spl).logp
return -runmodel!(model, vi, spl).logp
end

# Compute forward and reverse passes.
Expand Down
Loading

0 comments on commit 10de954

Please sign in to comment.