Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of github.com:beast-dev/beast-mcmc into hmc-…
Browse files Browse the repository at this point in the history
…clock
  • Loading branch information
msuchard committed Dec 13, 2023
2 parents 1bdff5e + 12bbd6c commit 1a1949c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ci/TestXML/testSkyglide.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3033,8 +3033,8 @@
<populationSizes>

<!-- skygrid.logPopSize is in log units unlike other popSize -->
<!-- changing 2.100001 to 2.1 will cause instability for numerical derivative -->
<parameter id="skygrid.logPopSize" dimension="6" value="1.1 4.1 2.100001 2.1 1.2 0.8"/>
<!-- changing 2.1001 to 2.1 will cause instability for numerical derivative -->
<parameter id="skygrid.logPopSize" dimension="6" value="1.1 4.1 2.1001 2.1 1.2 0.8"/>
<!-- <parameter id="skygrid.logPopSize" dimension="6" value="1.3 1.3 1.3 1.3 1.3 1.3"/> -->
<!-- <parameter id="skygrid.logPopSize" dimension="3" value="1.1 1.1 1.1"/> -->
<!-- <parameter id="skygrid.logPopSize" dimension="5" value="1.1 2.1 1.2 3.1 4.1"/> -->
Expand Down
23 changes: 21 additions & 2 deletions src/dr/evomodel/coalescent/smooth/SkyGlideGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
package dr.evomodel.coalescent.smooth;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
Expand All @@ -39,7 +40,7 @@
* @author Xiang Ji
* @author Marc A. Suchard
*/
public class SkyGlideGradient implements GradientWrtParameterProvider, Reportable {
public class SkyGlideGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable {

private final SkyGlideLikelihood likelihood;

Expand Down Expand Up @@ -88,7 +89,19 @@ public double[] getGradientLogDensity() {

@Override
public String getReport() {
return GradientWrtParameterProvider.getReportAndCheckForError(this, wrtParameter.getParameterLowerBound(), wrtParameter.getParameterUpperBound(), tolerance);
String output = GradientWrtParameterProvider.getReportAndCheckForError(this, wrtParameter.getParameterLowerBound(), wrtParameter.getParameterUpperBound(), tolerance)
+ "\n" + HessianWrtParameterProvider.getReportAndCheckForError(this, tolerance);
return output;
}

@Override
public double[] getDiagonalHessianLogDensity() {
return wrtParameter.getDiagonalHessianLogDensity(likelihood);
}

@Override
public double[][] getHessianLogDensity() {
throw new RuntimeException("Not yet implemented.");
}

public enum WrtParameter {
Expand All @@ -98,6 +111,11 @@ public enum WrtParameter {
return likelihood.getGradientWrtLogPopulationSize();
}

@Override
double[] getDiagonalHessianLogDensity(SkyGlideLikelihood likelihood) {
return likelihood.getDiagonalHessianLogDensityWrtLogPopSize();
}

@Override
double getParameterLowerBound() {
return Double.NEGATIVE_INFINITY;
Expand All @@ -109,6 +127,7 @@ public enum WrtParameter {
}
};
abstract double[] getGradientLogDensity(SkyGlideLikelihood likelihood);
abstract double[] getDiagonalHessianLogDensity(SkyGlideLikelihood likelihood);
abstract double getParameterLowerBound();
abstract double getParameterUpperBound();
}
Expand Down
90 changes: 84 additions & 6 deletions src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ public double[] getGradientWrtLogPopulationSize() {


if (firstGridIndex == lastGridIndex) {
updateIntervalGradient(intervalStart, intervalEnd, firstGridIndex, lineageCount, gradient);
updateIntervalGradientWrtLogPopSize(intervalStart, intervalEnd, firstGridIndex, lineageCount, gradient);
} else {
updateIntervalGradient(intervalStart, gridPointParameter.getParameterValue(firstGridIndex), firstGridIndex, lineageCount, gradient);
updateIntervalGradientWrtLogPopSize(intervalStart, gridPointParameter.getParameterValue(firstGridIndex), firstGridIndex, lineageCount, gradient);
currentGridIndex = firstGridIndex;
while(currentGridIndex + 1 < lastGridIndex) {
updateIntervalGradient(gridPointParameter.getParameterValue(currentGridIndex), gridPointParameter.getParameterValue(currentGridIndex + 1), currentGridIndex + 1, lineageCount, gradient);
updateIntervalGradientWrtLogPopSize(gridPointParameter.getParameterValue(currentGridIndex), gridPointParameter.getParameterValue(currentGridIndex + 1), currentGridIndex + 1, lineageCount, gradient);
currentGridIndex++;
}
updateIntervalGradient(gridPointParameter.getParameterValue(currentGridIndex), intervalEnd, currentGridIndex + 1, lineageCount, gradient);
updateIntervalGradientWrtLogPopSize(gridPointParameter.getParameterValue(currentGridIndex), intervalEnd, currentGridIndex + 1, lineageCount, gradient);
}
currentGridIndex = lastGridIndex;
}
Expand All @@ -166,6 +166,84 @@ public double[] getGradientWrtLogPopulationSize() {
return gradient;
}

public double[] getDiagonalHessianLogDensityWrtLogPopSize() {
double[] diagonalHessian = new double[logPopSizeParameter.getDimension()];

for (int index = 0; index < trees.size(); index++) {
BigFastTreeIntervals interval = intervals.get(index);
Tree thisTree = trees.get(index);
int currentGridIndex = 0;
for (int i = 0; i < interval.getIntervalCount(); i++) {
final int lineageCount = interval.getLineageCount(i);
int[] nodeIndices = interval.getNodeNumbersForInterval(i);
final double intervalStart = thisTree.getNodeHeight(thisTree.getNode(nodeIndices[0]));
final double intervalEnd = thisTree.getNodeHeight(thisTree.getNode(nodeIndices[1]));

if (intervalStart != intervalEnd) {
int[] gridIndices = getGridPoints(currentGridIndex, intervalStart, intervalEnd);
final int firstGridIndex = gridIndices[0];
final int lastGridIndex = gridIndices[1];


if (firstGridIndex == lastGridIndex) {
updateIntervalDiagonalHessianWrtLogPopSize(intervalStart, intervalEnd, firstGridIndex, lineageCount, diagonalHessian);
} else {
updateIntervalDiagonalHessianWrtLogPopSize(intervalStart, gridPointParameter.getParameterValue(firstGridIndex), firstGridIndex, lineageCount, diagonalHessian);
currentGridIndex = firstGridIndex;
while(currentGridIndex + 1 < lastGridIndex) {
updateIntervalDiagonalHessianWrtLogPopSize(gridPointParameter.getParameterValue(currentGridIndex), gridPointParameter.getParameterValue(currentGridIndex + 1), currentGridIndex + 1, lineageCount, diagonalHessian);
currentGridIndex++;
}
updateIntervalDiagonalHessianWrtLogPopSize(gridPointParameter.getParameterValue(currentGridIndex), intervalEnd, currentGridIndex + 1, lineageCount, diagonalHessian);
}
currentGridIndex = lastGridIndex;
}
}
}

return diagonalHessian;
}

private void updateIntervalDiagonalHessianWrtLogPopSize(double intervalStart, double intervalEnd, int gridIndex,
int lineageCount, double[] diagonalHessian) {
final double slope = getGridSlope(gridIndex);
final double intercept = getGridIntercept(gridIndex);
final double lineageMultiplier = -0.5 * lineageCount * (lineageCount - 1);
assert(slope != 0 || intercept != 0);
final double realSmall = getMagicUnderFlowBound(slope);
if (intervalStart != intervalEnd) {
final double expIntervalStart = Math.exp(-slope * intervalStart);
final double expIntervalEnd = Math.exp(-slope * intervalEnd);
final double expIntercept = Math.exp(-intercept);

final double thisGridTime = gridPointParameter.getParameterValue(gridIndex);
final double lastGridTime = gridIndex == 0 ? 0 : gridPointParameter.getParameterValue(gridIndex - 1);

final double secondDerivativeWrtIntercept = getLinearInverseIntegral(intervalStart, intervalEnd, gridIndex);
final double secondDerivativeWrtSlope = Math.abs(slope) < realSmall ? expIntercept
* (intervalEnd * intervalEnd * intervalEnd - intervalStart * intervalStart * intervalStart) / 3 :
expIntercept * (-2 / slope / slope * (intervalEnd * expIntervalEnd - intervalStart * expIntervalStart)
+(intervalStart * intervalStart * expIntervalStart - intervalEnd * intervalEnd * expIntervalEnd) / slope
+2 / slope / slope / slope * (expIntervalStart - expIntervalEnd));
final double derivativeWrtSlope = Math.abs(slope) < realSmall ? expIntercept * (intervalStart * intervalStart - intervalEnd * intervalEnd) / 2 :
expIntercept * ((intervalEnd * expIntervalEnd - intervalStart * expIntervalStart) / slope - (expIntervalStart - expIntervalEnd) / slope / slope);
final double secondDerivativeWrtInterceptSlope = -derivativeWrtSlope;

final double partialInterceptPartialFirstLogPopSize = thisGridTime / (thisGridTime - lastGridTime);
final double partialInterceptPartialSecondLogPopSize = - lastGridTime / (thisGridTime - lastGridTime);
final double partialSlopePartialFirstLogPopSize = - 1 / (thisGridTime - lastGridTime);
final double partialSLopePartialSecondLogPopSize = 1 / (thisGridTime - lastGridTime);

diagonalHessian[gridIndex] += lineageMultiplier * (secondDerivativeWrtIntercept * partialInterceptPartialFirstLogPopSize * partialInterceptPartialFirstLogPopSize
+ 2 * secondDerivativeWrtInterceptSlope * partialSlopePartialFirstLogPopSize * partialInterceptPartialFirstLogPopSize
+ secondDerivativeWrtSlope * partialSlopePartialFirstLogPopSize * partialSlopePartialFirstLogPopSize);
diagonalHessian[gridIndex + 1] += lineageMultiplier * (secondDerivativeWrtIntercept * partialInterceptPartialSecondLogPopSize * partialInterceptPartialSecondLogPopSize
+ 2 * secondDerivativeWrtInterceptSlope * partialSLopePartialSecondLogPopSize * partialInterceptPartialSecondLogPopSize
+ secondDerivativeWrtSlope * partialSLopePartialSecondLogPopSize * partialSLopePartialSecondLogPopSize);
}
}



public double getSingleTreeLogLikelihood(int index) {
BigFastTreeIntervals interval = intervals.get(index);
Expand Down Expand Up @@ -272,8 +350,8 @@ private double getMagicUnderFlowBound(double slope) { // TODO: arbitrary magic b
return MachineAccuracy.SQRT_EPSILON*(Math.abs(slope) + 1.0);
}

private void updateIntervalGradient(double intervalStart, double intervalEnd, int gridIndex, int lineageCount,
double[] gradient) {
private void updateIntervalGradientWrtLogPopSize(double intervalStart, double intervalEnd, int gridIndex, int lineageCount,
double[] gradient) {
final double slope = getGridSlope(gridIndex);
final double intercept = getGridIntercept(gridIndex);
final double lineageMultiplier = -0.5 * lineageCount * (lineageCount - 1);
Expand Down

0 comments on commit 1a1949c

Please sign in to comment.