Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logpdf_grad errors for HomogeneousMixture #445

Closed
fzaiser opened this issue Dec 23, 2021 · 5 comments
Closed

logpdf_grad errors for HomogeneousMixture #445

fzaiser opened this issue Dec 23, 2021 · 5 comments

Comments

@fzaiser
Copy link

fzaiser commented Dec 23, 2021

The following example crashes:

using Gen

@gen function test()
    mix = HomogeneousMixture(broadcasted_normal, [1, 0])
    means = hcat([0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
    @trace(mix([0.25, 0.25, 0.25, 0.25], means, [0.1, 0.1, 0.1, 0.1]), :x)
end
trace = Gen.simulate(test, ())
result = Gen.hmc(trace, selectall())

It throws the following error:

ERROR: LoadError: DimensionMismatch("new dimensions (1, 2) must be consistent with array size 4")
 [1] (::Base.var"#throw_dmrsa#196")(::Tuple{Int64,Int64}, ::Int64) at ./reshapedarray.jl:41
 [2] reshape at ./reshapedarray.jl:45 [inlined]
 [3] reshape(::Array{Float64,1}, ::Int64, ::Int64) at ./reshapedarray.jl:116
 [4] logpdf_grad(::HomogeneousMixture{Array{Float64,N} where N}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,2}, ::Array{Float64,1}) at [...]/packages/Gen/[...]/src/modeling_library/mixture.jl:115
...

I believe the reason is that in the line

(1 for d in 1:dist.dims[i])..., length(dist.dims))
length(dist.dims) should be replaced by K. This removes the exception, but I don't understand the code well enough to be sure that this is the correct fix or whether other parts of the code have to be fixed too.

@bzinberg
Copy link
Contributor

bzinberg commented Dec 23, 2021

Haven't looked in depth, but I suspect this is indeed due to an assumption somewhere that args to distributions will be flat, i.e. cannot be array-valued. The use of length instead of size/axes looks suspect to me.

In general, the args to a distribution could be arrays of different shapes. I'm not aware of us having general machinery for flattening and unflattening arrays in the gradient operations (nor am I sure that flattening and unflattening is the right thing to do, necessarily).

@bzinberg
Copy link
Contributor

(Oops, misread a doc. Deleted comment.)

@fzaiser
Copy link
Author

fzaiser commented Dec 23, 2021

@bzinberg Thanks for the quick reply! In the documentation for HomogeneousMixture, there is an example with a multivariate normal distribution, which takes a mean vector and a covariance matrix (i.e. different shapes for the two arguments). Therefore, I thought it was supported. Do you think this functionality would be difficult to implement?

@alex-lew
Copy link
Contributor

alex-lew commented Mar 3, 2022

Hi @fzaiser! I think a lot of us were on winter break when you posted this and it fell through the cracks -- sorry about that!

I think you're right that the length(dist.dims) on that line should be replaced by K, the number of components. Thanks for tracking this down and finding (then filing) the bug!

(As an aside, HMC will struggle to explore multiple modes in this target — but I think that may be the point of the experiment :).)

As Ben mentioned, there are parts of Gen (including the @dist DSL) that make certain restrictive assumptions about data shapes, but I don't think you should run into that on this example.

@fzaiser
Copy link
Author

fzaiser commented Mar 8, 2022

Hi @alex-lew, no problem and thanks for the fix! I hope to have some time to experiment with it soon. Indeed, I'm aware of HMC struggling with such a multi-modal distribution. :) I was just playing around with gradient-based inference methods when I hit the bug and HMC was the simplest way to reproduce it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants