Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbstractModel implementation of LogDensityProblems interface #110

Merged
merged 6 commits into from
Dec 19, 2022

Conversation

torfjelde
Copy link
Member

This PR adds an implementation of AbstractModel which supports the LogDensityProblems.jl interface.

Benefits:

  1. We get a good "default" implementation for a AbstractModel, allowing sharing of model definitions between different sampler packages, e.g. with this AdvancedHMC.jl (see Use LogDensityProblems.jl AdvancedHMC.jl#301 for related discussion) and AdvancedMH.jl can both use LogDensityModel instead of defining their own.
  2. We get AD "for free" through LogDensityProblemsAD.jl.

Project.toml Outdated Show resolved Hide resolved
Comment on lines 1 to 23
"""
LogDensityModel <: AbstractMCMC.AbstractModel

Wrapper around something that implements the `LogDensityProblem` interface.

This itself then implements the `LogDensityProblem` interface by simply deferring to the wrapped object.
"""
struct LogDensityModel{L} <: AbstractModel
logdensity::L
end

function LogDensityProblems.dimension(model::LogDensityModel)
return LogDensityProblems.dimension(model.logdensity)
end
function LogDensityProblems.capabilities(model::LogDensityModel)
return LogDensityProblems.capabilities(model.logdensity)
end
function LogDensityProblems.logdensity(model::LogDensityModel, x)
return LogDensityProblems.logdensity(model.logdensity, x)
end
function LogDensityProblems.logdensity_and_gradient(model::LogDensityModel, x)
return LogDensityProblems.logdensity_and_gradient(model.logdensity, x)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern is that wrappers can be very annoying due to method ambiguity issues (if e.g. someone defines methods for specific x and some abstract logdensity type, but is not aware of LogDensityModel) and since you are screwed if you forget to add some definition that should be forwarded (or some log densities define additional functions).

Maybe don't define any LogDensityProblems functions at all? If you know it's a LogDensityModel, you can just call dimension(model.logdensity) etc. in your sampler. And if not, you can't call dimension etc. on your model anyway.

So to summarize: I don't think LogDensityProblems.dimension etc. should be defined for the model.

Copy link
Member Author

@torfjelde torfjelde Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no abstract logdensity type though? So If someone implements something like LogDensityProblems.dimension, it should always be a type they've defined themselves.

Or am I misunderstanding you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is different (I think): I think it is sufficient if LogDensityProblems.dimension etc. are defined for model.logdensity and don't see the need for defining them for model. That would also avoid the issues with the wrapper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, yeah, you definitively can do without it, but IMO it can be annoying.

Most of the time you don't use AbstractMCMC in isolation, and so IMO it's nice to go

model = AbstractMCMC.LogDensityModel(logdensitymodel)

smaples = sample(model, ...)

something_else_which_requires_logdensityproblems_interface(model, ...)

instead of

model = AbstractMCMC.LogDensityModel(logdensitymodel)

smaples = sample(model, ...)

# User has to manually unwrap to call the method.
something_else_which_requires_logdensityproblems_interface(model.logdensity, ...)

It's definitively not a big thing, but I can't see the downside of implementing these methods (since method ambiguities won't be an issue, as described above, unless someone is doing type-piracy).

Copy link
Member

@devmotion devmotion Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be an issue as soon as someone defines e.g. logdensity(model::AbstractModel, x::MyType) I think? And the new LogDensityProblemsAD is another issue: You would have to depend on it here as well if you want to support ADgradient(..., model::LogDensityModel) (which IMO for the same reasons we should not do).

In any case, I don't see a good reason why you should add these methods: If you work with LogDensityModel in a sampler, you already have to restrict your dispatches to this specific type of models because otherwise you can't call logdensity etc. So you know you only accept this type anyway, and then you can just as easily call dimension(model.logdensity) in your sampler.

The example I think is also not really a problem: In that case the user already has the model.logdensity since the model is constructed with it in the first place 🙂 And even if the user does not, I don't see a problem with using model.logdensity instead of model since it allows us to avoid all these wrapper-related issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of being able to tell these users "Hey, if you don't want to use Turing but you want to use the samplers for your own model, just implement LogDensityInterface and wrap your model in LogDensityModel; all your existing code that assumes LogDensityProblems.jl-interface will work nicely with this too, so you lose nothing."

Wouldn't it be even cooler and cleaner if it would just work automatically and you don't have to wrap your LogDensityProblem-compatible model? Maybe we should support arbitrary types of models that implement the LogDensityProblems interface in AbstractMCMC and downstream packages? If that's the general API then one could just sample with anything and always call logdensity etc, regardless of the type (if samplers want to support that - otherwise they could still restrict sample to some specific types of models).

A LogDensityModel would be mainly useful then if multiple sampler packages have the same structure of models. But its purpose would not be to forward logdensity to logdensity of some wrapped function/model - since in that case you could have used the wrapped function/model directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be even cooler and cleaner if it would just work automatically and you don't have to wrap your LogDensityProblem-compatible model?

Aye it would be, but wouldn't sample(model, ...) without restricting to AbstractModel be doomed to lead to method ambiguities?

A LogDensityModel would be mainly useful then if multiple sampler packages have the same structure of models. But its purpose would not be to forward logdensity to logdensity of some wrapped function/model - since in that case you could have used the wrapped function/model directly.

As I said, I'm happy to drop them and just get this LogDensityModel into the package without the forwarded methods, with its purpose just being to indicate that the wrapped object implements the LogDensityInterface. So should I just do this then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye it would be, but wouldn't sample(model, ...) without restricting to AbstractModel be doomed to lead to method ambiguities?

Is there a specific example you have in mind? I think it should be fine eg. if we define no method with type restriction on model in AbstractMCMC (as other arguments should also be more specific in downstream implementations) but maybe there's an example that would be problematic. Generally, I assume due to the AbstractSampler argument and the fact that StatsBase mainly deals with sampling arrays (IIRC) there should not be any problems with other (common) packages - but again, maybe I miss something.

As I said, I'm happy to drop them and just get this LogDensityModel into the package without the forwarded methods, with its purpose just being to indicate that the wrapped object implements the LogDensityInterface. So should I just do this then?

I think this would be the best way forward - not only because I still have doubts regarding the interface for LogDensityProblems but also because IMO it's easier to add stuff incrementally than to remove and break functionality and interfaces again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific example you have in mind?

No, just slightly worried it might happen, though I agree I'd expect it to generally be okay.

I think this would be the best way forward - not only because I still have doubts regarding the interface for LogDensityProblems but also because IMO it's easier to add stuff incrementally than to remove and break functionality and interfaces again.

I agree with this 👍 I'll do that then:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change; could you have a look now?:)

Co-authored-by: David Widmann <[email protected]>
# Fields
- `logdensity`: The object that implements the LogDensityProblems.jl interface.
"""
struct LogDensityModel{L} <: AbstractModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make L a subtype of LogDensityProblems? Otherwise, we import but never use LogDensityProblems.jl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no LogDensityProblem abstract type, so I'll just do what @devmotion suggested:)

Project.toml Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
src/AbstractMCMC.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

Do we still support Julia 1.3?

@torfjelde
Copy link
Member Author

And we should at least be testing on Julia 1.6, which we currently are not doing.

@codecov
Copy link

codecov bot commented Dec 19, 2022

Codecov Report

Base: 97.19% // Head: 97.19% // No change to project coverage 👍

Coverage data is based on head (314cb57) compared to base (6ef1dcb).
Patch has no changes to coverable lines.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #110   +/-   ##
=======================================
  Coverage   97.19%   97.19%           
=======================================
  Files           7        7           
  Lines         285      285           
=======================================
  Hits          277      277           
  Misses          8        8           
Impacted Files Coverage Δ
src/AbstractMCMC.jl 100.00% <ø> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@torfjelde
Copy link
Member Author

As long as you're happy with the bump to 1.6, this should now be ready to go:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants