Skip to content

Commit

Permalink
added extension to make it work with Turing.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed May 16, 2024
1 parent 44de469 commit 6d1b768
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AutomaticMALA"
uuid = "fa450ca5-f5a1-411d-b5db-2b752c7e65ee"
authors = ["Tor Erlend Fjelde <tor.erlend95@gmail.com> and contributors"]
authors = ["Tor Erlend Fjelde <tor.github@gmail.com> and contributors"]
version = "0.1.0"

[deps]
Expand All @@ -11,5 +11,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[extensions]
AutomaticMALATuringExt = ["Turing"]

[compat]
AbstractMCMC = "5"
Distributions = "0.25"
DocStringExtensions = "0.9"
LogDensityProblems = "2"
Turing = "0.31"
julia = "1.6"

[extras]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
13 changes: 13 additions & 0 deletions ext/AutomaticMALATuringExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AutomaticMALATuringExt

if isdefined(Base, :get_extension)
using AutomaticMALA: AutomaticMALA
using Turing: Turing
else
using ..AutomaticMALA: AutomaticMALA
using ..Turing: Turing
end

Turing.Inference.getparams(::Turing.DynamicPPL.Model, state::AutomaticMALA.AutoMALAState) = state.x

end
12 changes: 12 additions & 0 deletions src/AutomaticMALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,16 @@ function round_based_adaptation(
return AutoMALA(ϵ_init, 0), initial_params
end

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
function __init__()
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include(
"../ext/AutomaticMALATuringExt.jl"
)
end
end

end

0 comments on commit 6d1b768

Please sign in to comment.