diff --git a/Project.toml b/Project.toml index 7ca7dc1..efc6caf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AutomaticMALA" uuid = "fa450ca5-f5a1-411d-b5db-2b752c7e65ee" -authors = ["Tor Erlend Fjelde and contributors"] +authors = ["Tor Erlend Fjelde and contributors"] version = "0.1.0" [deps] @@ -11,5 +11,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[weakdeps] +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" + +[extensions] +AutomaticMALATuringExt = ["Turing"] + [compat] +AbstractMCMC = "5" +Distributions = "0.25" +DocStringExtensions = "0.9" +LogDensityProblems = "2" +Turing = "0.31" julia = "1.6" + +[extras] +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" \ No newline at end of file diff --git a/ext/AutomaticMALATuringExt.jl b/ext/AutomaticMALATuringExt.jl new file mode 100644 index 0000000..b31a707 --- /dev/null +++ b/ext/AutomaticMALATuringExt.jl @@ -0,0 +1,13 @@ +module AutomaticMALATuringExt + +if isdefined(Base, :get_extension) + using AutomaticMALA: AutomaticMALA + using Turing: Turing +else + using ..AutomaticMALA: AutomaticMALA + using ..Turing: Turing +end + +Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::AutomaticMALA.AutoMALAState) = state.x + +end diff --git a/src/AutomaticMALA.jl b/src/AutomaticMALA.jl index f8f727f..3a821bf 100644 --- a/src/AutomaticMALA.jl +++ b/src/AutomaticMALA.jl @@ -210,4 +210,16 @@ function round_based_adaptation( return AutoMALA(ϵ_init, 0), initial_params end +if !isdefined(Base, :get_extension) + using Requires +end + +@static if !isdefined(Base, :get_extension) + function __init__() + @require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include( + "../ext/AutomaticMALATuringExt.jl" + ) + end +end + end