Skip to content

Commit

Permalink
remove ugly ChatGPT ESS code and use built-in SSIT MH convergence plo…
Browse files Browse the repository at this point in the history
…t function
  • Loading branch information
alexpopinga committed Jul 11, 2024
1 parent d8193f4 commit 92d2528
Showing 1 changed file with 15 additions and 41 deletions.
56 changes: 15 additions & 41 deletions WorkSpace/AlexP/SSIT_BurstingGene_MCMC.m
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
F1 = F1.loadData(dataFile, dataToFit);
end

function result = performMLEandMCMC(F1,logSpace,stepSize)
function result = performMLEandMCMC(F1,logSpace,stepSize,chainLength)

% MLE Fitting Options
maxFitIter = 1000;
Expand All @@ -56,7 +56,7 @@
F1.fittingOptions.logPrior = @(x)-sum((log10(x)-log10PriorMean(1:4)).^2./(2*log10PriorStd(1:4).^2));

% Metropolis Hastings Properties
nSamplesMH = 1000; % Number of MH Samples to run
nSamplesMH = chainLength; % Number of MH Samples to run
nThinMH = 2; % Thin rate for MH sampling
nBurnMH = 100; % Number for MH burn in

Expand Down Expand Up @@ -136,36 +136,23 @@ function plotMHTraces(samples, paramNames)
sgtitle('Metropolis-Hastings Trace Plots');
end

%% ChatGPT 4 Compute ESS
function acf = autocorr(x, max_lag)
n = length(x);
x_mean = mean(x);
acf = zeros(1, max_lag + 1);

for lag = 0:max_lag
num = sum((x(1:n-lag) - x_mean) .* (x(1+lag:n) - x_mean));
den = sum((x - x_mean).^2);
acf(lag + 1) = num / den;
end
end

function act = integrated_act(x, max_lag)
acf_values = autocorr(x, max_lag);
act = 1 + 2 * sum(acf_values(2:end));
end

function ess = effective_sample_size(x)
n = length(x);
max_lag = min(1000, n - 1); % You can choose an appropriate lag based on your data
act = integrated_act(x, max_lag);
ess = n / act;
function ess = effective_sample_size(mhResults)
ac = xcorr(mhResults.mhValue-mean(mhResults.mhValue),'normalized');
ac = ac(size(mhResults.mhValue,1):end);
% Uncomment following plot line to plot autocorrelation function,
%% BUT this will overwrite one of the trace plots.
% plot(ac,'LineWidth',3); hold on
N = size(mhResults.mhValue,1);
tau = 1+2*sum((ac(2:N/100)));
Neff = N/tau;
ess = Neff;
end

F1 = setupAndSolveModel();

% Get computation time for MCMC
tic
MHResults = performMLEandMCMC(F1,true,0.1);
MHResults = performMLEandMCMC(F1,true,0.1,3000);
toc

% Assume MHResults.mhSamples is a 1000x4 matrix
Expand All @@ -177,21 +164,8 @@ function plotMHTraces(samples, paramNames)
% Call the function to plot the traces
plotMHTraces(samples, paramNames);

% Assume MHResults.mhsamples is your MCMC chain with size 1000x4
mcmc_chain = samples;
[num_iterations, num_parameters] = size(mcmc_chain);

% Initialize the ESS vector
ess_values = zeros(1, num_parameters);

% Calculate ESS for each parameter
for param = 1:num_parameters
ess_values(param) = effective_sample_size(mcmc_chain(:, param));
end

% Display the ESS values
disp('Effective Sample Size (ESS) for each parameter:');
disp(ess_values);
ess = effective_sample_size(MHResults);
ess

% Compute FIM for subsampling of MH results.
%J = floor(linspace(nSamplesMH/2,nSamplesMH,nFIMsamples));
Expand Down

0 comments on commit 92d2528

Please sign in to comment.