-
Notifications
You must be signed in to change notification settings - Fork 47
Avoiding recompilation in Stan models
Brian Lau edited this page Jun 3, 2017
·
3 revisions
Compiling models takes time so if the same model is going to be used repeatedly, we would like to compile it just once.
Within sessions you can avoid recompiling a model in two ways. The simplest method is to reuse a fit object in the call to stan
:
model_code = {
'data {'
' int<lower=0> N;'
' int<lower=0,upper=1> y[N];'
'}'
'parameters {'
' real<lower=0,upper=1> theta;'
'}'
'model {'
'for (n in 1:N)'
' y[n] ~ bernoulli(theta);'
'}'
};
data = struct('N',10,'y',[0, 1, 0, 0, 0, 0, 0, 0, 0, 1]);
% This call will compile the model
fit = stan('model_code',model_code,'data',data);
print(fit);
new_data = struct('N',10,'y',[0, 1, 0, 1, 0, 1, 0, 1, 1, 1]);
% Passing in StanFit object skips recompilation
fit2 = stan('fit',fit,'data',new_data);
print(fit2);
Alternatively, we can create an StanModel instance and call the compile
method first:
sm = StanModel('model_code',model_code);
sm.compile();
% subsequent calls will skip recompilation
fit3 = sm.sampling('data',data);
print(fit3);
fit4 = sm.sampling('data',data);
print(fit4);
This could for example be used to fit the same model to many data sets:
% True values for theta
theta = 0.1:.2:1;
% Set this to the number of cores of your machine to avoid warnings
ncores = 4;
for i = 1:numel(theta)
% Generate some fake data
data = struct('N',10,'y',double(rand(1,10)<theta(i)));
% Sample, using the model compiled above for each data set
fit(i) = stan('fit',sm,'data',data,'chains',min(i,ncores),'iter',150000);
fprintf('Model: %s, id: %s, #chains=%g, seed=%g\n',...
fit(i).model.model_name,fit(i).model.id,...
fit(i).model.chains,fit(i).model.seed);
end
theta
% Run the following command after all fits are completed
arrayfun(@(x) mean(x.extract.theta),fit)