Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbstractModel implementation of LogDensityProblems interface #110

Merged
merged 6 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.1.3"
version = "4.2"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Expand All @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
LogDensityProblems = "2"
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
LoggingExtras = "0.4, 0.5"
ProgressLogging = "0.1"
StatsBase = "0.32, 0.33"
Expand Down
2 changes: 2 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ProgressLogging: ProgressLogging
using StatsBase: StatsBase
using TerminalLoggers: TerminalLoggers
using Transducers: Transducers
using LogDensityProblems: LogDensityProblems
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

using Distributed: Distributed
using Logging: Logging
Expand Down Expand Up @@ -84,5 +85,6 @@ include("interface.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")
include("logdensityproblems.jl")

end # module AbstractMCMC
23 changes: 23 additions & 0 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
LogDensityModel <: AbstractMCMC.AbstractModel

Wrapper around something that implements the `LogDensityProblem` interface.

This itself then implements the `LogDensityProblem` interface by simply deferring to the wrapped object.
"""
struct LogDensityModel{L} <: AbstractModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make L a subtype of LogDensityProblems? Otherwise, we import but never use LogDensityProblems.jl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no LogDensityProblem abstract type, so I'll just do what @devmotion suggested:)

logdensity::L
end

function LogDensityProblems.dimension(model::LogDensityModel)
return LogDensityProblems.dimension(model.logdensity)
end
function LogDensityProblems.capabilities(model::LogDensityModel)
return LogDensityProblems.capabilities(model.logdensity)
end
function LogDensityProblems.logdensity(model::LogDensityModel, x)
return LogDensityProblems.logdensity(model.logdensity, x)
end
function LogDensityProblems.logdensity_and_gradient(model::LogDensityModel, x)
return LogDensityProblems.logdensity_and_gradient(model.logdensity, x)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern is that wrappers can be very annoying due to method ambiguity issues (if e.g. someone defines methods for specific x and some abstract logdensity type, but is not aware of LogDensityModel) and since you are screwed if you forget to add some definition that should be forwarded (or some log densities define additional functions).

Maybe don't define any LogDensityProblems functions at all? If you know it's a LogDensityModel, you can just call dimension(model.logdensity) etc. in your sampler. And if not, you can't call dimension etc. on your model anyway.

So to summarize: I don't think LogDensityProblems.dimension etc. should be defined for the model.

Copy link
Member Author

@torfjelde torfjelde Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no abstract logdensity type though? So If someone implements something like LogDensityProblems.dimension, it should always be a type they've defined themselves.

Or am I misunderstanding you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is different (I think): I think it is sufficient if LogDensityProblems.dimension etc. are defined for model.logdensity and don't see the need for defining them for model. That would also avoid the issues with the wrapper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, yeah, you definitively can do without it, but IMO it can be annoying.

Most of the time you don't use AbstractMCMC in isolation, and so IMO it's nice to go

model = AbstractMCMC.LogDensityModel(logdensitymodel)

smaples = sample(model, ...)

something_else_which_requires_logdensityproblems_interface(model, ...)

instead of

model = AbstractMCMC.LogDensityModel(logdensitymodel)

smaples = sample(model, ...)

# User has to manually unwrap to call the method.
something_else_which_requires_logdensityproblems_interface(model.logdensity, ...)

It's definitively not a big thing, but I can't see the downside of implementing these methods (since method ambiguities won't be an issue, as described above, unless someone is doing type-piracy).

Copy link
Member

@devmotion devmotion Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be an issue as soon as someone defines e.g. logdensity(model::AbstractModel, x::MyType) I think? And the new LogDensityProblemsAD is another issue: You would have to depend on it here as well if you want to support ADgradient(..., model::LogDensityModel) (which IMO for the same reasons we should not do).

In any case, I don't see a good reason why you should add these methods: If you work with LogDensityModel in a sampler, you already have to restrict your dispatches to this specific type of models because otherwise you can't call logdensity etc. So you know you only accept this type anyway, and then you can just as easily call dimension(model.logdensity) in your sampler.

The example I think is also not really a problem: In that case the user already has the model.logdensity since the model is constructed with it in the first place 🙂 And even if the user does not, I don't see a problem with using model.logdensity instead of model since it allows us to avoid all these wrapper-related issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of being able to tell these users "Hey, if you don't want to use Turing but you want to use the samplers for your own model, just implement LogDensityInterface and wrap your model in LogDensityModel; all your existing code that assumes LogDensityProblems.jl-interface will work nicely with this too, so you lose nothing."

Wouldn't it be even cooler and cleaner if it would just work automatically and you don't have to wrap your LogDensityProblem-compatible model? Maybe we should support arbitrary types of models that implement the LogDensityProblems interface in AbstractMCMC and downstream packages? If that's the general API then one could just sample with anything and always call logdensity etc, regardless of the type (if samplers want to support that - otherwise they could still restrict sample to some specific types of models).

A LogDensityModel would be mainly useful then if multiple sampler packages have the same structure of models. But its purpose would not be to forward logdensity to logdensity of some wrapped function/model - since in that case you could have used the wrapped function/model directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be even cooler and cleaner if it would just work automatically and you don't have to wrap your LogDensityProblem-compatible model?

Aye it would be, but wouldn't sample(model, ...) without restricting to AbstractModel be doomed to lead to method ambiguities?

A LogDensityModel would be mainly useful then if multiple sampler packages have the same structure of models. But its purpose would not be to forward logdensity to logdensity of some wrapped function/model - since in that case you could have used the wrapped function/model directly.

As I said, I'm happy to drop them and just get this LogDensityModel into the package without the forwarded methods, with its purpose just being to indicate that the wrapped object implements the LogDensityInterface. So should I just do this then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye it would be, but wouldn't sample(model, ...) without restricting to AbstractModel be doomed to lead to method ambiguities?

Is there a specific example you have in mind? I think it should be fine eg. if we define no method with type restriction on model in AbstractMCMC (as other arguments should also be more specific in downstream implementations) but maybe there's an example that would be problematic. Generally, I assume due to the AbstractSampler argument and the fact that StatsBase mainly deals with sampling arrays (IIRC) there should not be any problems with other (common) packages - but again, maybe I miss something.

As I said, I'm happy to drop them and just get this LogDensityModel into the package without the forwarded methods, with its purpose just being to indicate that the wrapped object implements the LogDensityInterface. So should I just do this then?

I think this would be the best way forward - not only because I still have doubts regarding the interface for LogDensityProblems but also because IMO it's easier to add stuff incrementally than to remove and break functionality and interfaces again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific example you have in mind?

No, just slightly worried it might happen, though I agree I'd expect it to generally be okay.

I think this would be the best way forward - not only because I still have doubts regarding the interface for LogDensityProblems but also because IMO it's easier to add stuff incrementally than to remove and break functionality and interfaces again.

I agree with this 👍 I'll do that then:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change; could you have a look now?:)