Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model is ≈ 2.5 faster if data loop is inside logpdf instead of model definition #670

Closed
itsdfish opened this issue Jan 31, 2019 · 2 comments

Comments

@itsdfish
Copy link
Contributor

Hello-

I have been using Turing for models with custom likelihood functions (Turing has been the only library that has allowed me to do this easily, so thank you!). I usually pass all of the data as an array to the logpdf function so I can re use intermediate calculations, which makes the model run much faster. I noticed that this approach provides a speed up in cases where I did not anticipate one. So I wanted to bring this to your attention in case it is helpful for optimizing Turing. The code below demonstrates a speed up of about 2.5:

using Turing,Parameters

import Distributions: logpdf
mutable struct mydist{T1,T2} <: ContinuousUnivariateDistribution
    μ::T1
    σ::T2
end

function logpdf(dist::mydist,data::Array{Float64,1})
    @unpack μ,σ=dist
    LL = 0.0
    for d in data
        LL += logpdf(Normal(μ,σ),d)
    end
    return LL
end


@model model1(y) = begin
    μ ~ Normal(0,1)
    σ ~ Truncated(Cauchy(0,1),0,Inf)
    N = length(y)
    for n in 1:N
        y[n] ~ Normal(μ,σ)
    end
end

@model model2(y) = begin
    μ ~ Normal(0,1)
    σ ~ Truncated(Cauchy(0,1),0,Inf)
    y ~ mydist(μ,σ)
end

data = rand(Normal(0,1),1000)

Nsamples = 10000
Nadapt = 2000
δ = .85


chain1 = sample(model1(data), NUTS(Nsamples,Nadapt,δ));
describe(chain1)

chain2 = sample(model2(data), NUTS(Nsamples,Nadapt,δ));
describe(chain2)

Here are the timings (estimates and ESS were very similar):

model1

[NUTS] Sampling...100% Time: 0:05:25
[NUTS] Finished with
  Running time        = 325.16674467499894;
  #lf / sample        = 0.0003;
  #evals / sample     = 14.381;
  pre-cond. metric    = [1.0, 1.0].

model2

[NUTS] Sampling...100% Time: 0:02:15
[NUTS] Finished with
  Running time        = 135.0531926290002;
  #lf / sample        = 0.0003;
  #evals / sample     = 15.23;
  pre-cond. metric    = [1.0, 1.0].

mydist is only about 10% faster than Normal, which leads me to believe that the speed difference might be in Turing.

using BenchmarkTools

 @benchmark logpdf(mydist(1.0,1.0),data)
  BenchmarkTools.Trial: 
  memory estimate:  48 bytes
  allocs estimate:  2
  --------------
  minimum time:     7.376 μs (0.00% GC)
  median time:      7.398 μs (0.00% GC)
  mean time:        7.498 μs (0.00% GC)
  maximum time:     16.450 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

 @benchmark logpdf.(Normal(1.0,1.0),data)
  BenchmarkTools.Trial: 
  memory estimate:  8.02 KiB
  allocs estimate:  4
  --------------
  minimum time:     7.826 μs (0.00% GC)
  median time:      7.929 μs (0.00% GC)
  mean time:        8.192 μs (0.00% GC)
  maximum time:     19.067 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4
@mohamed82008
Copy link
Member

Thanks for your comment. This is a really nice hack, but the speedup is not surprising. What you are doing is the function barrier approach. Inside the logpdf function, all variables are type stable, so the code runs fast. However, in the model's body, the code is type unstable, so any loop there will be rather slow. The type instabilities of Turing are being solved right now, so this speedup will be much less (if any) when Turing becomes more type stable. Some Turing performance PRs and issues to look out for are #660 and #665 .

@itsdfish
Copy link
Contributor Author

Thanks for the explanation. Since its related to known type stability problems, I will close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants