Skip to content

Commit

Permalink
add function to compute ESS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpopinga committed Jul 11, 2024
1 parent 898b2dd commit d8193f4
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions WorkSpace/AlexP/SSIT_BurstingGene_MCMC.m
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
result = MHResults;
end

%% ChatGPT Plot MH Traces
function plotMHTraces(samples, paramNames)
% Check if the number of parameter names matches the number of columns in samples
if length(paramNames) ~= size(samples, 2)
Expand All @@ -135,9 +136,34 @@ function plotMHTraces(samples, paramNames)
sgtitle('Metropolis-Hastings Trace Plots');
end

%% ChatGPT 4 Compute ESS
function acf = autocorr(x, max_lag)
n = length(x);
x_mean = mean(x);
acf = zeros(1, max_lag + 1);

for lag = 0:max_lag
num = sum((x(1:n-lag) - x_mean) .* (x(1+lag:n) - x_mean));
den = sum((x - x_mean).^2);
acf(lag + 1) = num / den;
end
end

function act = integrated_act(x, max_lag)
acf_values = autocorr(x, max_lag);
act = 1 + 2 * sum(acf_values(2:end));
end

function ess = effective_sample_size(x)
n = length(x);
max_lag = min(1000, n - 1); % You can choose an appropriate lag based on your data
act = integrated_act(x, max_lag);
ess = n / act;
end

F1 = setupAndSolveModel();

% Get computation time for MCMC
tic
MHResults = performMLEandMCMC(F1,true,0.1);
toc
Expand All @@ -151,6 +177,22 @@ function plotMHTraces(samples, paramNames)
% Call the function to plot the traces
plotMHTraces(samples, paramNames);

% Assume MHResults.mhsamples is your MCMC chain with size 1000x4
mcmc_chain = samples;
[num_iterations, num_parameters] = size(mcmc_chain);

% Initialize the ESS vector
ess_values = zeros(1, num_parameters);

% Calculate ESS for each parameter
for param = 1:num_parameters
ess_values(param) = effective_sample_size(mcmc_chain(:, param));
end

% Display the ESS values
disp('Effective Sample Size (ESS) for each parameter:');
disp(ess_values);

% Compute FIM for subsampling of MH results.
%J = floor(linspace(nSamplesMH/2,nSamplesMH,nFIMsamples));
%MHSamplesForFIM = exp(MHResults.mhSamples(J,:));
Expand Down

0 comments on commit d8193f4

Please sign in to comment.