Skip to content

Commit

Permalink
a diagonal hessian that works for 3 tip tree
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Dec 18, 2023
1 parent 990985d commit a67b22d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
5 changes: 2 additions & 3 deletions src/dr/evomodel/coalescent/smooth/SkyGlideGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ public double[] getGradientLogDensity() {
@Override
public String getReport() {
String output = GradientWrtParameterProvider.getReportAndCheckForError(this, wrtParameter.getParameterLowerBound(), wrtParameter.getParameterUpperBound(), tolerance)
;
// + "\n" + HessianWrtParameterProvider.getReportAndCheckForError(this, tolerance);
+ "\n" + HessianWrtParameterProvider.getReportAndCheckForError(this, tolerance);
return output;
}

Expand Down Expand Up @@ -151,7 +150,7 @@ public enum WrtParameter {

@Override
double[] getDiagonalHessianLogDensity(SkyGlideLikelihood likelihood, int treeIndex) {
throw new RuntimeException("Not yet implemented.");
return likelihood.getDiagonalHessianWrtNodeHeight(treeIndex);
}

@Override
Expand Down
79 changes: 59 additions & 20 deletions src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ public List<TreeModel> getTrees() {
return trees;
}

public BigFastTreeIntervals getIntervals(int treeIndex) {
return intervals.get(treeIndex);
}

public TreeModel getTree(int treeIndex) {
return trees.get(treeIndex);
}

@Override
public String getReport() {
return "skyGlideLikelihood(" + getLogLikelihood() + ")";
Expand Down Expand Up @@ -247,7 +255,55 @@ private void updateIntervalDiagonalHessianWrtLogPopSize(double intervalStart, do
}
}

public enum NodeHeightDerivativeType {
GRADIENT {
@Override
double getNodeHeightDerivative(double intercept, double slope, double time, double lineageMultiplier) {
return lineageMultiplier * Math.exp(-intercept - slope * time);
}

@Override
void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood likelihood, int treeIndex, double[] derivatives) {

int currentGridIndex = 0;
BigFastTreeIntervals interval = likelihood.getIntervals(treeIndex);
TreeModel tree = likelihood.getTree(treeIndex);

for (int i = 0; i < interval.getIntervalCount(); i++) {
if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = interval.getIntervalTime(i + 1);
final int nodeIndex = interval.getNodeNumbersForInterval(i)[1];
currentGridIndex = likelihood.getGridIndex(time, currentGridIndex);
final double slope = likelihood.getGridSlope(currentGridIndex);
derivatives[nodeIndex - tree.getExternalNodeCount()] -= slope;
}
}
}
},
DIAGONAL_HESSIAN {
@Override
double getNodeHeightDerivative(double intercept, double slope, double time, double lineageMultiplier) {
return - lineageMultiplier * Math.exp(-intercept - slope * time) * slope;
}

@Override
void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood likelihood, int treeIndex, double[] derivatives) {

}
};
abstract double getNodeHeightDerivative(double intercept, double slope, double time, double lineageMultiplier);
abstract void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood likelihood, int treeIndex, double[] derivatives);
}

public double[] getGradientWrtNodeHeight(int treeIndex) {
return getDerivativeWrtNodeHeight(treeIndex, NodeHeightDerivativeType.GRADIENT);
}

public double[] getDiagonalHessianWrtNodeHeight(int treeIndex) {
return getDerivativeWrtNodeHeight(treeIndex, NodeHeightDerivativeType.DIAGONAL_HESSIAN);
}

public double[] getDerivativeWrtNodeHeight(int treeIndex, NodeHeightDerivativeType derivativeType) {

BigFastTreeIntervals interval = intervals.get(treeIndex);
Tree thisTree = trees.get(treeIndex);
Expand Down Expand Up @@ -279,7 +335,7 @@ public double[] getGradientWrtNodeHeight(int treeIndex) {

final double lineageMultiplier = 0.5 * lineageCount * (lineageCount - 1);
if (!thisTree.isExternal(thisTree.getNode(nodeIndices[0]))) {
tmp += lineageMultiplier * Math.exp(-firstGridIntercept - firstGridSlope * intervalStart);
tmp += derivativeType.getNodeHeightDerivative(firstGridIntercept, firstGridSlope, intervalStart, lineageMultiplier);
}

int count = 0;
Expand All @@ -295,7 +351,7 @@ public double[] getGradientWrtNodeHeight(int treeIndex) {
}

if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
tmp = -lineageMultiplier * Math.exp(-lastGridIntercept - lastGridSlope * intervalEnd);
tmp = -derivativeType.getNodeHeightDerivative(lastGridIntercept, lastGridSlope, intervalEnd, lineageMultiplier);
}
}
currentGridIndex = lastGridIndex;
Expand All @@ -312,27 +368,10 @@ public double[] getGradientWrtNodeHeight(int treeIndex) {
j++;
}

updateSingleTreePopulationInverseGradientWrtNodeHeight(treeIndex, gradient);
derivativeType.updateSingleTreePopulationInverseGradientWrtNodeHeight(this, treeIndex, gradient);

return gradient;
}
private void updateSingleTreePopulationInverseGradientWrtNodeHeight(int index, double[] gradient) {

BigFastTreeIntervals interval = intervals.get(index);
TreeModel tree = trees.get(index);
int currentGridIndex = 0;

for (int i = 0; i < interval.getIntervalCount(); i++) {

if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = interval.getIntervalTime(i + 1);
final int nodeIndex = interval.getNodeNumbersForInterval(i)[1];
currentGridIndex = getGridIndex(time, currentGridIndex);
final double slope = getGridSlope(currentGridIndex);
gradient[nodeIndex - tree.getExternalNodeCount()] -= slope;
}
}
}


public double getSingleTreeLogLikelihood(int index) {
Expand Down

0 comments on commit a67b22d

Please sign in to comment.