Skip to content

Commit

Permalink
Merge pull request #463 from probcomp/mrb/kernel_dsl_arg_fix
Browse files Browse the repository at this point in the history
(Kernel DSL) Revert esc change toplevel, include GlobalRef to handle evaluation ambiguity in modules/submodules.
  • Loading branch information
ztangent authored May 17, 2022
2 parents 7b6a884 + 16b818b commit f9aef08
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 32 deletions.
76 changes: 44 additions & 32 deletions src/inference/kernel_dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ end
is_custom_primitive_kernel(::Function) = false
check_is_kernel(::Function) = false

const _is_custom_primitive_kernel = GlobalRef(Gen, :is_custom_primitive_kernel)
const _check_is_kernel = GlobalRef(Gen, :check_is_kernel)
const _check_observations = GlobalRef(Gen, :check_observations)
const _reversal = GlobalRef(Gen, :reversal)

"""
@pkern function k(trace, args...;
check=false, observations=EmptyChoiceMap())
Expand All @@ -22,24 +27,27 @@ check_is_kernel(::Function) = false
Declare a Julia function as a primitive stationary kernel.
The first argument of the function should be a trace, and the return value of the function should be a trace.
The first argument of the function should be a trace, and the return value of the function should be a `(trace, metadata)` where `metadata` is user-provided (but could be useful information, like the result of an accept-reject decision.
There should be keyword arguments check and observations.
"""
macro pkern(ex)
@capture(ex, function f_(args__) body_ end) || error("expected a function")
quote
function $(esc(f))($(args...))
$body
end
Gen.is_custom_primitive_kernel(::typeof($(esc(f)))) = true
Gen.check_is_kernel(::typeof($(esc(f)))) = true
expr = quote
function $(f)($(args...))
$body
end
$(_is_custom_primitive_kernel)(::typeof($f)) = true
$(_check_is_kernel)(::typeof($f)) = true
end
expr = postwalk(flatten unblock rmlines, expr)
esc(expr)
end

function expand_kern_ex(ex)
@capture(ex, function f_(T_, toplevel_args__) body_ end) || error("expected kernel syntax: function my_kernel(trace, args...) (...) end")
trace = T

body = postwalk(body) do x
if @capture(x, for idx_ in range_ body_ end)

Expand Down Expand Up @@ -88,10 +96,9 @@ function expand_kern_ex(ex)
end

elseif @capture(x, T_ ~ k_(T_, args__))

# applying a kernel
quote
check && Gen.check_is_kernel($(k))
check && $(_check_is_kernel)($(k))
$(T) = $(k)($(T), $(args...),
check=check, observations=observations)[1]
end
Expand All @@ -104,14 +111,14 @@ function expand_kern_ex(ex)
end

ex = quote
function $(esc(f))($(trace)::Trace, $(toplevel_args...);
function $(f)($(trace)::Trace, $(toplevel_args...);
check=false, observations=EmptyChoiceMap())
$body
check && Gen.check_observations(get_choices($(trace)), observations)
check && $(_check_observations)(get_choices($(trace)), observations)
metadata = nothing
($(trace), metadata)
end
Gen.check_is_kernel(::typeof($(esc(f)))) = true
$(_check_observations)(::typeof($f)) = true
end

ex, f
Expand All @@ -137,10 +144,10 @@ The two kernels must have the same argument type signatures.
macro rkern(ex)
@capture(ex, k_ : l_) || error("expected a pair of functions")
quote
Gen.is_custom_primitive_kernel($(esc(k))) || error("first function is not a custom primitive kernel")
Gen.is_custom_primitive_kernel($(esc(l))) || error("second function is not a custom primitive kernel")
Gen.reversal(::typeof($(esc(k)))) = $(esc(l))
Gen.reversal(::typeof($(esc(l)))) = $(esc(k))
$(_is_custom_primitive_kernel)($k) || error("first function is not a custom primitive kernel")
$(_is_custom_primitive_kernel)($l) || error("second function is not a custom primitive kernel")
$(_reversal)(::typeof($k)) = $(l)
$(_reversal)(::typeof($l)) = $(k)
end
end

Expand All @@ -163,9 +170,9 @@ function reversal_ex(ex)

# applying a kernel - apply the reverse kernel
quote
check && Gen.check_is_kernel(Gen.reversal($k))
$(T) = Gen.reversal($k)($(T), $(args...),
check=check, observations=observations)[1]
check && $(_check_is_kernel)($(_reversal)($k))
$(T) = $(_reversal)($k)($(T), $(args...),
check=check, observations=observations)[1]
end

elseif isa(x, Expr) && x.head == :block
Expand All @@ -191,16 +198,7 @@ function reversal_ex(ex)
ex, rev
end

"""
@kern function k(trace, args...)
...
end
Construct a composite MCMC kernel.
The resulting object is a Julia function that is annotated as a composite MCMC kernel, and can be called as a Julia function or applied within other composite kernels.
"""
macro kern(ex)
function toplevel(ex::Expr)
kern_ex, kern = expand_kern_ex(ex)
rev_kern_ex, rev_kern = reversal_ex(ex)
rev_kern_ex, _ = expand_kern_ex(rev_kern_ex)
Expand All @@ -212,11 +210,25 @@ macro kern(ex)
$rev_kern_ex

# bind the reversals for both
Gen.reversal(::typeof($(esc(kern)))) = $(esc(rev_kern))
Gen.reversal(::typeof($(esc(rev_kern)))) = $(esc(kern))
$(_reversal)(::typeof($kern)) = $(rev_kern)
$(_reversal)(::typeof($rev_kern)) = $(kern)
end
expr = postwalk(flatten unblock rmlines, expr)
expr
end

"""
@kern function k(trace, args...)
...
end
Construct a composite MCMC kernel.
The resulting object is a Julia function that is annotated as a composite MCMC kernel, and can be called as a Julia function or applied within other composite kernels.
"""
macro kern(ex)
expr = toplevel(ex)
esc(expr)
end

export @pkern, @rkern, @kern, reversal
1 change: 1 addition & 0 deletions test/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ include("map_optimize.jl")
include("elliptical_slice.jl")
include("mh.jl")
include("trace_translators.jl")
include("kernel_dsl.jl")
40 changes: 40 additions & 0 deletions test/inference/kernel_dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@gen function kernel_dsl_test_model()
x ~ normal(0.0, 1.0)
y ~ normal(x, 1.0)
return y
end

@pkern function k1(tr, t)
metadata = false
for _ in 1 : t
tr, acc = mh(tr, select(:x))
if acc
metadata = true
end
end
return tr, metadata
end

@gen function k1_proposal(tr)
x ~ normal(0.0, 1.0)
end

@kern function k1_non_primitive(tr, t)
for _ in 1 : t
tr ~ mh(tr, k1_proposal, ())
end
end

@testset "Kernel DSL" begin
tr = simulate(kernel_dsl_test_model, ())
original = get_choices(tr)[:x]
tr, acc = k1(tr, 1)
if acc
@test original != get_choices(tr)[:x]
else
@test original == get_choices(tr)[:x]
end
original = get_choices(tr)[:x]
tr, acc = k1_non_primitive(tr, 1)
tr, acc = Gen.reversal(k1_non_primitive)(tr, 1)
end

0 comments on commit f9aef08

Please sign in to comment.