Skip to content

Commit

Permalink
fix gradient wrt log pop size
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Dec 6, 2023
1 parent e272842 commit 88e85b2
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MachineAccuracy;
import dr.xml.Reportable;

import java.util.ArrayList;
Expand All @@ -54,6 +55,9 @@ public class SkyGlideLikelihood extends AbstractModelLikelihood implements Repor
private final Parameter logPopSizeParameter;
private final Parameter gridPointParameter;

private boolean likelihoodKnown = false;
private double logLikelihood;

public SkyGlideLikelihood(String name,
List<TreeModel> trees,
Parameter logPopSizeParameter,
Expand All @@ -68,6 +72,7 @@ public SkyGlideLikelihood(String name,
this.intervals.add(treeIntervals);
addModel(treeIntervals);
}
addVariable(logPopSizeParameter);
}

@Override
Expand All @@ -77,12 +82,12 @@ public String getReport() {

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {

likelihoodKnown = false;
}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {

likelihoodKnown = false;
}

@Override
Expand All @@ -92,7 +97,7 @@ protected void storeState() {

@Override
protected void restoreState() {

likelihoodKnown = false;
}

@Override
Expand All @@ -107,11 +112,15 @@ public Model getModel() {

@Override
public double getLogLikelihood() {
double lnL = 0;
for (int i = 0; i < trees.size(); i++) {
lnL += getSingleTreeLogLikelihood(i);
if (!likelihoodKnown) {
double lnL = 0;
for (int i = 0; i < trees.size(); i++) {
lnL += getSingleTreeLogLikelihood(i);
}
logLikelihood = lnL;
likelihoodKnown = true;
}
return lnL;
return logLikelihood;
}

public Parameter getLogPopSizeParameter() {
Expand Down Expand Up @@ -261,28 +270,26 @@ private double getLinearInverseIntegral(double start, double end, int gridIndex)
if (slope == 0) {
return Math.exp(-intercept) * (end - start);
} else {
return (Math.exp(-(slope * start + intercept)) - Math.exp(-(slope * end + intercept))) / slope;
return Math.exp(-intercept) * (Math.exp(-slope * start) - Math.exp(-slope * end)) / slope;
}
}


private void updateIntervalGradient(double intervalStart, double intervalEnd, int gridIndex, int lineageCount,
double[] gradient) {
final double slope = getGridSlope(gridIndex);
final double intercept = getGridIntercept(gridIndex);
final double lineageMultiplier = -0.5 * lineageCount * (lineageCount - 1);
assert(slope != 0 || intercept != 0);
final double realSmall = MachineAccuracy.SQRT_EPSILON*(Math.abs(slope) + 1.0); // TODO: arbitrary magic bound
if (intervalStart != intervalEnd) {
if (slope == 0) {
final double multiplier = (intervalEnd - intervalStart) * (-Math.exp(-intercept));
updateGridInterceptDerivativeWrtLogPopSize(gridIndex, gradient, lineageMultiplier * multiplier);
} else {
final double interceptMultiplier = ( - Math.exp(-(slope * intervalStart + intercept)) + Math.exp(-(slope * intervalEnd + intercept))) / slope;
final double slopeMultiplier = (-intervalStart * Math.exp(-(slope * intervalStart + intercept)) + intervalEnd * Math.exp(-(slope * intervalEnd + intercept))) / slope
- (Math.exp(-(slope * intervalStart + intercept)) - Math.exp(-(slope * intervalEnd + intercept))) / slope / slope;
final double slopeMultiplier = slope < realSmall ? Math.exp(-intercept) * (intervalStart * intervalStart - intervalEnd * intervalEnd) / 2
: Math.exp(-intercept) * ( (-intervalStart * Math.exp(-slope * intervalStart) + intervalEnd * Math.exp(-slope * intervalEnd))
- (Math.exp(-slope * intervalStart) - Math.exp(-slope * intervalEnd)) / slope) / slope;
final double interceptMultiplier = slope < realSmall ? (intervalEnd - intervalStart) * (-Math.exp(-intercept))
: Math.exp(-intercept) * (-Math.exp(-slope * intervalStart ) + Math.exp(-slope * intervalEnd )) / slope;

updateGridInterceptDerivativeWrtLogPopSize(gridIndex, gradient, lineageMultiplier * interceptMultiplier);
updateGridSlopeDerivativeWrtLogPopSize(gridIndex, gradient, lineageMultiplier * slopeMultiplier);
}
}
}

Expand Down Expand Up @@ -330,13 +337,6 @@ private void updateGridInterceptDerivativeWrtLogPopSize(int gridIndex, double[]
final double firstDerivative = thisGridTime / (thisGridTime - lastGridTime) * multiplier;
final double secondDerivative = -lastGridTime / (thisGridTime - lastGridTime) * multiplier;

// if (logPopSizeParameter.getParameterValue(gridIndex) == logPopSizeParameter.getParameterValue(gridIndex + 1)) {
// gradient[gridIndex] += (firstDerivative + secondDerivative) / 2;
// gradient[gridIndex + 1] += (firstDerivative + secondDerivative) / 2;
// } else {
// gradient[gridIndex] += firstDerivative;
// gradient[gridIndex + 1] += secondDerivative;
// }
gradient[gridIndex] += firstDerivative;
gradient[gridIndex + 1] += secondDerivative;
}
Expand All @@ -358,6 +358,6 @@ private int[] getGridPoints(int startGridIndex, double startTime, double endTime

@Override
public void makeDirty() {

likelihoodKnown = false;
}
}

0 comments on commit 88e85b2

Please sign in to comment.