From 49fdf2b30866ad24f737c8e8975ea2a951d5420e Mon Sep 17 00:00:00 2001 From: McCoy Becker Date: Thu, 14 Apr 2022 10:29:30 -0400 Subject: [PATCH 1/6] Revert esc change toplevel, include GlobalRef to handle evaluation ambiguity in modules/submodules. --- src/inference/kernel_dsl.jl | 72 +++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index cf028273..fbb209ac 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -27,12 +27,14 @@ There should be keyword arguments check and observations. """ macro pkern(ex) @capture(ex, function f_(args__) body_ end) || error("expected a function") + gr1 = GlobalRef(Gen, :is_custom_primitive_kernel) + gr2 = GlobalRef(Gen, :check_is_kernel) quote - function $(esc(f))($(args...)) + function $(f)($(args...)) $body end - Gen.is_custom_primitive_kernel(::typeof($(esc(f)))) = true - Gen.check_is_kernel(::typeof($(esc(f)))) = true + $(gr1)(::typeof($f)) = true + $(gr2)(::typeof($f)) = true end end @@ -88,10 +90,10 @@ function expand_kern_ex(ex) end elseif @capture(x, T_ ~ k_(T_, args__)) - # applying a kernel + gr = GlobalRef(Gen, :check_is_kernel) quote - check && Gen.check_is_kernel($(k)) + check && $(gr)($(k)) $(T) = $(k)($(T), $(args...), check=check, observations=observations)[1] end @@ -103,15 +105,17 @@ function expand_kern_ex(ex) end end + gr1 = GlobalRef(Gen, :check_is_kernel) + gr2 = GlobalRef(Gen, :check_observations) ex = quote - function $(esc(f))($(trace)::Trace, $(toplevel_args...); + function $(f)($(trace)::Trace, $(toplevel_args...); check=false, observations=EmptyChoiceMap()) $body - check && Gen.check_observations(get_choices($(trace)), observations) + check && $(gr2)(get_choices($(trace)), observations) metadata = nothing ($(trace), metadata) end - Gen.check_is_kernel(::typeof($(esc(f)))) = true + $(gr1)(::typeof($f)) = true end ex, f @@ -136,11 +140,13 @@ The two kernels must have the same argument type signatures. """ macro rkern(ex) @capture(ex, k_ : l_) || error("expected a pair of functions") + gr1 = GlobalRef(Gen, :is_custom_primitive_kernel) + gr2 = GlobalRef(Gen, :reversal) quote - Gen.is_custom_primitive_kernel($(esc(k))) || error("first function is not a custom primitive kernel") - Gen.is_custom_primitive_kernel($(esc(l))) || error("second function is not a custom primitive kernel") - Gen.reversal(::typeof($(esc(k)))) = $(esc(l)) - Gen.reversal(::typeof($(esc(l)))) = $(esc(k)) + $(gr1)($k) || error("first function is not a custom primitive kernel") + $(gr1)($l) || error("second function is not a custom primitive kernel") + $(gr2)(::typeof($k)) = $(l) + $(gr2)(::typeof($l)) = $(k) end end @@ -162,9 +168,11 @@ function reversal_ex(ex) elseif @capture(x, T_ ~ k_(T_, args__)) # applying a kernel - apply the reverse kernel + gr1 = GlobalRef(Gen, :check_is_kernel) + gr2 = GlobalRef(Gen, :reversal) quote - check && Gen.check_is_kernel(Gen.reversal($k)) - $(T) = Gen.reversal($k)($(T), $(args...), + check && $(gr1)($(gr2)($k)) + $(T) = $(gr2)($k)($(T), $(args...), check=check, observations=observations)[1] end @@ -191,19 +199,11 @@ function reversal_ex(ex) ex, rev end -""" - @kern function k(trace, args...) - ... - end - -Construct a composite MCMC kernel. - -The resulting object is a Julia function that is annotated as a composite MCMC kernel, and can be called as a Julia function or applied within other composite kernels. -""" -macro kern(ex) +function toplevel(ex::Expr) kern_ex, kern = expand_kern_ex(ex) rev_kern_ex, rev_kern = reversal_ex(ex) rev_kern_ex, _ = expand_kern_ex(rev_kern_ex) + gr = GlobalRef(Gen, :reversal) expr = quote # define forward kerel $kern_ex @@ -212,11 +212,31 @@ macro kern(ex) $rev_kern_ex # bind the reversals for both - Gen.reversal(::typeof($(esc(kern)))) = $(esc(rev_kern)) - Gen.reversal(::typeof($(esc(rev_kern)))) = $(esc(kern)) + $(gr)(::typeof($kern)) = $(rev_kern) + $(gr)(::typeof($rev_kern)) = $(kern) end expr = postwalk(flatten ∘ unblock ∘ rmlines, expr) expr end +""" + @kern function k(trace, args...) + ... + end + +Construct a composite MCMC kernel. + +The resulting object is a Julia function that is annotated as a composite MCMC kernel, and can be called as a Julia function or applied within other composite kernels. +""" +macro kern(ex) + expr = toplevel(ex) + esc(expr) +end + +macro kern(debug_flag, ex) + expr = toplevel(ex) + debug_flag == :debug && display(expr) + esc(expr) +end + export @pkern, @rkern, @kern, reversal From 596651208e3827dedc8860029134a20ae26cb0ba Mon Sep 17 00:00:00 2001 From: McCoy Becker Date: Thu, 14 Apr 2022 11:22:38 -0400 Subject: [PATCH 2/6] Revert esc change toplevel, include GlobalRef to handle evaluation ambiguity in modules/submodules. --- src/inference/kernel_dsl.jl | 49 +++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index fbb209ac..d1191fc1 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -13,6 +13,11 @@ end is_custom_primitive_kernel(::Function) = false check_is_kernel(::Function) = false +_is_custom_primitive_kernel = GlobalRef(Gen, :is_custom_primitive_kernel) +_check_is_kernel = GlobalRef(Gen, :check_is_kernel) +_check_observations = GlobalRef(Gen, :check_observations) +_reversal = GlobalRef(Gen, :reversal) + """ @pkern function k(trace, args...; check=false, observations=EmptyChoiceMap()) @@ -27,21 +32,19 @@ There should be keyword arguments check and observations. """ macro pkern(ex) @capture(ex, function f_(args__) body_ end) || error("expected a function") - gr1 = GlobalRef(Gen, :is_custom_primitive_kernel) - gr2 = GlobalRef(Gen, :check_is_kernel) quote function $(f)($(args...)) - $body - end - $(gr1)(::typeof($f)) = true - $(gr2)(::typeof($f)) = true + $body + end + $(is_custom_primitive_kernel)(::typeof($f)) = true + $(_check_is_kernel)(::typeof($f)) = true end end function expand_kern_ex(ex) @capture(ex, function f_(T_, toplevel_args__) body_ end) || error("expected kernel syntax: function my_kernel(trace, args...) (...) end") trace = T - + body = postwalk(body) do x if @capture(x, for idx_ in range_ body_ end) @@ -91,9 +94,8 @@ function expand_kern_ex(ex) elseif @capture(x, T_ ~ k_(T_, args__)) # applying a kernel - gr = GlobalRef(Gen, :check_is_kernel) quote - check && $(gr)($(k)) + check && $(_check_is_kernel)($(k)) $(T) = $(k)($(T), $(args...), check=check, observations=observations)[1] end @@ -105,17 +107,15 @@ function expand_kern_ex(ex) end end - gr1 = GlobalRef(Gen, :check_is_kernel) - gr2 = GlobalRef(Gen, :check_observations) ex = quote function $(f)($(trace)::Trace, $(toplevel_args...); check=false, observations=EmptyChoiceMap()) $body - check && $(gr2)(get_choices($(trace)), observations) + check && $(_check_observations)(get_choices($(trace)), observations) metadata = nothing ($(trace), metadata) end - $(gr1)(::typeof($f)) = true + $(_check_observations)(::typeof($f)) = true end ex, f @@ -140,13 +140,11 @@ The two kernels must have the same argument type signatures. """ macro rkern(ex) @capture(ex, k_ : l_) || error("expected a pair of functions") - gr1 = GlobalRef(Gen, :is_custom_primitive_kernel) - gr2 = GlobalRef(Gen, :reversal) quote - $(gr1)($k) || error("first function is not a custom primitive kernel") - $(gr1)($l) || error("second function is not a custom primitive kernel") - $(gr2)(::typeof($k)) = $(l) - $(gr2)(::typeof($l)) = $(k) + $(_is_custom_primitive_kernel)($k) || error("first function is not a custom primitive kernel") + $(_is_custom_primitive_kernel)($l) || error("second function is not a custom primitive kernel") + $(_reversal)(::typeof($k)) = $(l) + $(_reversal)(::typeof($l)) = $(k) end end @@ -168,12 +166,10 @@ function reversal_ex(ex) elseif @capture(x, T_ ~ k_(T_, args__)) # applying a kernel - apply the reverse kernel - gr1 = GlobalRef(Gen, :check_is_kernel) - gr2 = GlobalRef(Gen, :reversal) quote - check && $(gr1)($(gr2)($k)) - $(T) = $(gr2)($k)($(T), $(args...), - check=check, observations=observations)[1] + check && $(_check_is_kernel)($(_reversal)($k)) + $(T) = $(_reversal)($k)($(T), $(args...), + check=check, observations=observations)[1] end elseif isa(x, Expr) && x.head == :block @@ -203,7 +199,6 @@ function toplevel(ex::Expr) kern_ex, kern = expand_kern_ex(ex) rev_kern_ex, rev_kern = reversal_ex(ex) rev_kern_ex, _ = expand_kern_ex(rev_kern_ex) - gr = GlobalRef(Gen, :reversal) expr = quote # define forward kerel $kern_ex @@ -212,8 +207,8 @@ function toplevel(ex::Expr) $rev_kern_ex # bind the reversals for both - $(gr)(::typeof($kern)) = $(rev_kern) - $(gr)(::typeof($rev_kern)) = $(kern) + $(_reversal)(::typeof($kern)) = $(rev_kern) + $(_reversal)(::typeof($rev_kern)) = $(kern) end expr = postwalk(flatten ∘ unblock ∘ rmlines, expr) expr From 51f2692dd1acc15c273864612208cfb1f4ac8f38 Mon Sep 17 00:00:00 2001 From: McCoy Becker Date: Thu, 14 Apr 2022 11:23:42 -0400 Subject: [PATCH 3/6] Revert esc change toplevel, include GlobalRef to handle evaluation ambiguity in modules/submodules. --- src/inference/kernel_dsl.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index d1191fc1..9627df74 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -228,10 +228,4 @@ macro kern(ex) esc(expr) end -macro kern(debug_flag, ex) - expr = toplevel(ex) - debug_flag == :debug && display(expr) - esc(expr) -end - export @pkern, @rkern, @kern, reversal From 09c461e236fbb4b5e12e358c3b3adb5031beb8a4 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Mon, 16 May 2022 09:04:14 -0400 Subject: [PATCH 4/6] Added const. --- src/inference/kernel_dsl.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index 9627df74..a0a340a2 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -13,10 +13,10 @@ end is_custom_primitive_kernel(::Function) = false check_is_kernel(::Function) = false -_is_custom_primitive_kernel = GlobalRef(Gen, :is_custom_primitive_kernel) -_check_is_kernel = GlobalRef(Gen, :check_is_kernel) -_check_observations = GlobalRef(Gen, :check_observations) -_reversal = GlobalRef(Gen, :reversal) +const _is_custom_primitive_kernel = GlobalRef(Gen, :is_custom_primitive_kernel) +const _check_is_kernel = GlobalRef(Gen, :check_is_kernel) +const _check_observations = GlobalRef(Gen, :check_observations) +const _reversal = GlobalRef(Gen, :reversal) """ @pkern function k(trace, args...; From cd6161aa51d479022b0df16713b0261d9d887033 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Mon, 16 May 2022 09:25:31 -0400 Subject: [PATCH 5/6] Added const. --- src/inference/kernel_dsl.jl | 9 ++++++--- test/inference/inference.jl | 1 + test/inference/kernel_dsl.jl | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 test/inference/kernel_dsl.jl diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a0a340a2..4cab0025 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -27,18 +27,21 @@ const _reversal = GlobalRef(Gen, :reversal) Declare a Julia function as a primitive stationary kernel. -The first argument of the function should be a trace, and the return value of the function should be a trace. +The first argument of the function should be a trace, and the return value of the function should be a `(trace, metadata)` where `metadata` is user-provided (but could be useful information, like the result of an accept-reject decision. + There should be keyword arguments check and observations. """ macro pkern(ex) @capture(ex, function f_(args__) body_ end) || error("expected a function") - quote + expr = quote function $(f)($(args...)) $body end - $(is_custom_primitive_kernel)(::typeof($f)) = true + $(_is_custom_primitive_kernel)(::typeof($f)) = true $(_check_is_kernel)(::typeof($f)) = true end + expr = postwalk(flatten ∘ unblock ∘ rmlines, expr) + esc(expr) end function expand_kern_ex(ex) diff --git a/test/inference/inference.jl b/test/inference/inference.jl index c86bff2d..f3a5adc9 100644 --- a/test/inference/inference.jl +++ b/test/inference/inference.jl @@ -8,3 +8,4 @@ include("map_optimize.jl") include("elliptical_slice.jl") include("mh.jl") include("trace_translators.jl") +include("kernel_dsl.jl") diff --git a/test/inference/kernel_dsl.jl b/test/inference/kernel_dsl.jl new file mode 100644 index 00000000..87c22cdc --- /dev/null +++ b/test/inference/kernel_dsl.jl @@ -0,0 +1,36 @@ +@gen function kernel_dsl_test_model() + x ~ normal(0.0, 1.0) + y ~ normal(x, 1.0) + return y +end + +@pkern function k1(tr, t) + metadata = false + for _ in 1 : t + tr, acc = mh(tr, select(:x)) + if acc + metadata = true + end + end + return tr, metadata +end + +@kern function k1_non_primitive(tr, t) + for _ in 1 : t + tr ~ mh(tr, select(:x)) + end +end + +@testset "Kernel DSL" begin + tr = simulate(kernel_dsl_test_model, ()) + original = get_choices(tr)[:x] + tr, acc = k1(tr, 1) + if acc + @test original != get_choices(tr)[:x] + else + @test original == get_choices(tr)[:x] + end + original = get_choices(tr)[:x] + tr, acc = k1_non_primitive(tr, 1) + tr, acc = Gen.reversal(k1_non_primitive)(tr, 1) +end From 16b818b2106ae3a048872cf8e3b6d8ccdab0b1de Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Tue, 17 May 2022 12:15:22 -0400 Subject: [PATCH 6/6] Update -- added test with captured proposal. --- test/inference/kernel_dsl.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/inference/kernel_dsl.jl b/test/inference/kernel_dsl.jl index 87c22cdc..7f6f1e83 100644 --- a/test/inference/kernel_dsl.jl +++ b/test/inference/kernel_dsl.jl @@ -15,9 +15,13 @@ end return tr, metadata end +@gen function k1_proposal(tr) + x ~ normal(0.0, 1.0) +end + @kern function k1_non_primitive(tr, t) for _ in 1 : t - tr ~ mh(tr, select(:x)) + tr ~ mh(tr, k1_proposal, ()) end end