From b6674177f1ecad8fbdec561b3d228dd4973dc0df Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Sat, 22 Jul 2023 12:15:02 -0700 Subject: [PATCH] demonstrate speed-up --- .../EfficientSpeciationLikelihood.java | 8 +- ...EfficientSpeciationLikelihoodGradient.java | 4 +- .../MasBirthDeathSerialSamplingModel.java | 181 ++++++++++++++++++ .../NewBirthDeathSerialSamplingModel.java | 92 ++++++--- ...ewBirthDeathSerialSamplingModelParser.java | 8 +- src/dr/util/Timer.java | 6 +- 6 files changed, 261 insertions(+), 38 deletions(-) create mode 100644 src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java diff --git a/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java b/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java index bdf455d23c..95a6036e0d 100644 --- a/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java +++ b/src/dr/evomodel/speciation/EfficientSpeciationLikelihood.java @@ -237,9 +237,11 @@ public void addTrait(TreeTrait trait) { public String getReport() { String message = super.getReport(); - message += "\n"; - // add likelihood calculation time - message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n"; + if (MEASURE_RUN_TIME) { + message += "\n"; + // add likelihood calculation time + message += "Likelihood calculation time is " + likelihoodTime / likelihoodCounts + " nanoseconds.\n"; + } return message; } } \ No newline at end of file diff --git a/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java b/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java index 512ceaa87e..ed73a82165 100644 --- a/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java +++ b/src/dr/evomodel/speciation/EfficientSpeciationLikelihoodGradient.java @@ -35,6 +35,8 @@ import dr.util.Timer; import dr.xml.Reportable; +import static dr.evomodel.speciation.CachedGradientDelegate.MEASURE_RUN_TIME; + /** * @author Andy Magee * @author Yucai Shao @@ -136,7 +138,7 @@ public LogColumn[] getColumns() { @Override public String getReport() { String message = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 1E-3); - if (gradientProvider instanceof CachedGradientDelegate) { + if (gradientProvider instanceof CachedGradientDelegate && MEASURE_RUN_TIME) { message += "\n"; message += "Gradient calculation time is " + ((CachedGradientDelegate) gradientProvider).getGradientTime() + " nanoseconds.\n"; } diff --git a/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java new file mode 100644 index 0000000000..a8ccbe2090 --- /dev/null +++ b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java @@ -0,0 +1,181 @@ +/* + * NewBirthDeathSerialSamplingModel.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.speciation; + +import dr.inference.model.Parameter; + +public class MasBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel { + + public MasBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) { + super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units); + } + + @Override + public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) { +// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart)); + double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart)); + if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) { + // Add in probability of un-sampled lineages + // We don't need this at t=0 because all lineages in the tree are sampled + // TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time? + lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1)); + } + this.savedLogQ = Double.NaN; + return lnL; + } + + final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages, + final double[] partialQ_all_old, final double Q_Old, + final double[] partialQ_all_young, final double Q_young) { + + for (int k = 0; k <= currentModelSegment; k++) { + for (int p = 0; p < 4; ++p) { + gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old + - partialQ_all_young[k * 4 + p] / Q_young); + } + } + } + + final void accumulateGradientForSampling(double[] gradient, int currentModelSegment, double term1, + double[] intermediate) { + + for (int k = 0; k <= currentModelSegment; k++) { + for (int p = 0; p < 4; ++p) { + gradient[k * 5 + p] += term1 * intermediate[k * 4 + p]; + } + } + + } + + final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { + + double G1 = g1(eAt); + + double term1 = -A / lambda * ((1 - B) * (eAt - 1) + G1) / (G1 * G1); + + for (int k = 0; k < model; k ++) { + for (int p = 0; p < 4; p++) { + dP[k * 4 + p] = term1 * dB[k * 4 + p]; + } + } + + for (int p = 0; p < 3; ++p) { + double term2 = eAt * (1 + B) * dA[p] * (t - intervalStart) + (eAt - 1) * dB[model * 4 + p]; + dG2[p] = dA[p] - 2 * (G1 * (dA[p] * (1 - B) - dB[model * 4 + p] * A) - (1 - B) * term2 * A) / (G1 * G1); + } + + double G2 = g2(G1); + + dP[model * 4 + 0] = (-mu - psi - lambda * dG2[0] + G2) / (2 * lambda * lambda); + dP[model * 4 + 1] = (1 - dG2[1]) / (2 * lambda); + dP[model * 4 + 2] = (1 - dG2[2]) / (2 * lambda); + dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1); + } + + + final void dQCompute(int model, double t, double[] dQ, double eAt) { + + double dwell = t - modelStartTimes[model]; + double G1 = g1(eAt); + + double term1 = 8 * eAt; + double term2 = G1 / 2 - eAt * (1 + B); + double term3 = eAt - 1; + double term4 = G1 * G1 * G1; + double term5 = -term1 * term3 / term4; + + for (int k = 0; k < model; ++k) { + dQ[k * 4 + 0] = term5 * dB[k * 4 + 0]; + dQ[k * 4 + 1] = term5 * dB[k * 4 + 1]; + dQ[k * 4 + 2] = term5 * dB[k * 4 + 2]; + dQ[k * 4 + 3] = term5 * dB[k * 4 + 3]; + } + + double term6 = term1 / term4; + double term7 = dwell * term2; + + dQ[model * 4 + 0] = term6 * (dA[0] * term7 - dB[model * 4 + 0] * term3); + dQ[model * 4 + 1] = term6 * (dA[1] * term7 - dB[model * 4 + 1] * term3); + dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3); + dQ[model * 4 + 3] = term5 * dB[model * 4 + 3]; + } + + + final double Q(int model, double time) { + double At = A * (time - modelStartTimes[model]); + double eAt = Math.exp(At); + double sqrtDenominator = g1(eAt); + return eAt / (sqrtDenominator * sqrtDenominator); + } + + final double logQ(int model, double time) { + double At = A * (time - modelStartTimes[model]); + double eAt = Math.exp(At); + double sqrtDenominator = g1(eAt); + return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b) + } + + @Override + public double processInterval(int model, double tYoung, double tOld, int nLineages) { + double logQ_young; + double logQ_old = Q(model, tOld); + if (!Double.isNaN(this.savedLogQ)) { + logQ_young = this.savedLogQ; + } else { + logQ_young = Q(model, tYoung); + } + this.savedLogQ = logQ_old; + return nLineages * Math.log(logQ_old / logQ_young); + } + + @Override + public double processSampling(int model, double tOld) { + + double logSampProb; + + boolean sampleIsAtEventTime = tOld == modelStartTimes[model]; + boolean samplesTakenAtEventTime = rho > 0; + + if (sampleIsAtEventTime && samplesTakenAtEventTime) { + logSampProb = Math.log(rho); + if (model > 0) { + logSampProb += Math.log(r + ((1.0 - r) * previousP)); + } + } else { + double logPsi = Math.log(psi); + logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld)); + } + + return logSampProb; + } +} + +/* + * Notes on inlining: + * https://www.baeldung.com/jvm-method-inlining#:~:text=Essentially%2C%20the%20JIT%20compiler%20tries,times%20we%20invoke%20the%20method. + * https://miuv.blog/2018/02/25/jit-optimizations-method-inlining/ + * static, private, final + */ diff --git a/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java index 84d0aad51a..e7000c86ff 100644 --- a/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java +++ b/src/dr/evomodel/speciation/NewBirthDeathSerialSamplingModel.java @@ -82,11 +82,12 @@ public class NewBirthDeathSerialSamplingModel extends SpeciationModel implements double psi; double r; double rho; + double logRho; //double rho0; // TODO remove private boolean[] gradientFlags; - private double savedLogQ; + double savedLogQ; private double savedQ; private double[] partialQ; @@ -101,8 +102,8 @@ public class NewBirthDeathSerialSamplingModel extends SpeciationModel implements private double eAt_Old; private double eAt_End; - private final double[] dA; - private final double[] dB; + final double[] dA; + final double[] dB; private final double[] dG2; boolean computedBCurrent; @@ -257,7 +258,7 @@ protected void handleVariableChangedEvent(Variable variable, int index, Paramete // Do nothing } - private double p(int model, double t) { + final double p(int model, double t) { double eAt = Math.exp(A * (t - modelStartTimes[model])); return p(eAt); } @@ -267,7 +268,7 @@ private double p(double eAt) { return (lambda + mu + psi - A * ((eAt1B - (1.0 - B)) / (eAt1B + (1.0 - B)))) / (2.0 * lambda); } - private double logQ(int model, double time) { + double logQ(int model, double time) { double At = A * (time - modelStartTimes[model]); double eAt = Math.exp(At); double sqrtDenominator = g1(eAt); @@ -339,6 +340,7 @@ private void updateParameterValues(int model) { psi = serialSamplingRate.getParameterValue(model); r = treatmentProbability.getParameterValue(model); rho = samplingProbability.getParameterValue(model); +// logRho = Math.log(rho); this.savedLogQ = Double.NaN; } @@ -390,7 +392,7 @@ public void updateGradientModelValues(int model) { double end = modelStartTimes[model + 1]; double start = modelStartTimes[model]; eAt_End = Math.exp(A * (end - start)); - dPCompute(model, end, start, eAt_End, dPModelEnd); + dPCompute(model, end, start, eAt_End, dPModelEnd, dG2); } computedBCurrent = true; @@ -492,20 +494,20 @@ public List getCitations() { )); } - private double g1(double eAt) { + final double g1(double eAt) { return (1 + B) * eAt + (1 - B); } - private double g2(double G1) { + final double g2(double G1) { return A * (1 - 2 * (1 - B) / G1); } - public double q(int model, double t) { + public final double q(int model, double t) { double eAt = Math.exp(A * (t - modelStartTimes[model])); return q(eAt); } - public double q(double eAt) { + public final double q(double eAt) { double sqrtDenominator = g1(eAt); return 4 * eAt / (sqrtDenominator * sqrtDenominator); } @@ -606,7 +608,7 @@ private void dBCompute(int model, double[] dB) { } } - private void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP) { + void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { double G1 = g1(eAt); @@ -689,7 +691,6 @@ private void dPCompute(int model, double t, double intervalStart, double eAt, do } dP[model * 4 + 3] = -A / lambda * ((1 - B) * (eAt - 1) + G1) * dB[model * 4 + 3] / (G1 * G1); } - } private void dQCompute(int model, double t, double[] dQ) { @@ -698,7 +699,7 @@ private void dQCompute(int model, double t, double[] dQ) { dQCompute(model, t, dQ, eAt); } - private void dQCompute(int model, double t, double[] dQ, double eAt) { + void dQCompute(int model, double t, double[] dQ, double eAt) { double dwell = t - modelStartTimes[model]; double G1 = g1(eAt); @@ -848,10 +849,38 @@ public void processGradientInterval(double[] gradient, int currentModelSegment, // gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); // } // } + +// for (int p = 0; p < 4; ++p) { +// if (gradientFlags[p]) { +// for (int k = 0; k <= currentModelSegment; k++) { +// gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); +// } +// } +// } + + accumulateGradientForInterval(gradient, currentModelSegment, nLineages, + partialQ_all_old, Q_Old, partialQ_all_young, Q_young); + } + + void accumulateGradientForInterval(double[] gradient, int currentModelSegment, int nLineages, + double[] partialQ_all_old, double Q_Old, + double[] partialQ_all_young, double Q_young) { for (int p = 0; p < 4; ++p) { if (gradientFlags[p]) { for (int k = 0; k <= currentModelSegment; k++) { - gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old - partialQ_all_young[k * 4 + p] / Q_young); + gradient[k * 5 + p] += nLineages * (partialQ_all_old[k * 4 + p] / Q_Old + - partialQ_all_young[k * 4 + p] / Q_young); + } + } + } + } + + void accumulateGradientForSampling(double[] gradient, int currentModelSegment, double term1, + double[] intermediate) { + for (int p = 0; p < 4; p++) { + if (gradientFlags[p]) { + for (int k = 0; k <= currentModelSegment; k++) { + gradient[k * 5 + p] += term1 * intermediate[k * 4 + p]; } } } @@ -887,7 +916,7 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // double eAt = Math.exp(A * (intervalEnd - modelStartTimes[currentModelSegment])); - dPCompute(currentModelSegment, intervalEnd, modelStartTimes[currentModelSegment], eAt_Old, this.dPIntervalEnd); + dPCompute(currentModelSegment, intervalEnd, modelStartTimes[currentModelSegment], eAt_Old, dPIntervalEnd, dG2); double term1 = (1 - r) / ((1 - r) * p_it + r); @@ -898,13 +927,15 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // gradient[fractionIndex(k, numIntervals)] += term1 * dPIntervalEnd[k * 4 + 3]; // } - for (int p = 0; p < 4; p++) { - if (gradientFlags[p]) { - for (int k = 0; k <= currentModelSegment; k++) { - gradient[genericIndex(k, p, numIntervals)] += term1 * dPIntervalEnd[k * 4 + p]; - } - } - } +// for (int p = 0; p < 4; p++) { +// if (gradientFlags[p]) { +// for (int k = 0; k <= currentModelSegment; k++) { +// gradient[genericIndex(k, p, numIntervals)] += term1 * dPIntervalEnd[k * 4 + p]; +// } +// } +// } + + accumulateGradientForSampling(gradient, currentModelSegment, term1, dPIntervalEnd); } if (sampleIsAtEventTime && currentModelSegment > 0) { @@ -919,13 +950,16 @@ public void processGradientSampling(double[] gradient, int currentModelSegment, // gradient[samplingIndex(k, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + 2]; // gradient[fractionIndex(k, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + 3]; // } - for (int p = 0; p < 4; p++) { - if (gradientFlags[p]) { - for (int k = 0; k < currentModelSegment; k++) { - gradient[genericIndex(k, p, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + p]; - } - } - } + +// for (int p = 0; p < 4; p++) { +// if (gradientFlags[p]) { +// for (int k = 0; k < currentModelSegment; k++) { +// gradient[genericIndex(k, p, numIntervals)] += term1 * dPModelEnd_prev[k * 4 + p]; +// } +// } +// } + + accumulateGradientForSampling(gradient, currentModelSegment, term1, dPModelEnd_prev); } } diff --git a/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java b/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java index 9dde88edb8..e1231a5fb3 100644 --- a/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java +++ b/src/dr/evomodelxml/speciation/NewBirthDeathSerialSamplingModelParser.java @@ -26,6 +26,7 @@ package dr.evomodelxml.speciation; import dr.evolution.util.Units; +import dr.evomodel.speciation.MasBirthDeathSerialSamplingModel; import dr.evomodel.speciation.NewBirthDeathSerialSamplingModel; import dr.evoxml.util.XMLUnits; import dr.inference.model.Parameter; @@ -98,13 +99,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Logger.getLogger("dr.evomodel").info(citeThisModel); - // return new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); - NewBirthDeathSerialSamplingModel model = new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); + NewBirthDeathSerialSamplingModel model = MAS_TEST ? + new MasBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units) : + new NewBirthDeathSerialSamplingModel(lambda, mu, psi, r, rho, origin, condition, (int)(numGridPoints.getParameterValue(0)), cutoff.getParameterValue(0), units); model.setupGradientFlags(gradientFlags); model.setupTimeline(grids != null ? grids.getParameterValues(): null); return model; } + private static final boolean MAS_TEST = false; + //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ diff --git a/src/dr/util/Timer.java b/src/dr/util/Timer.java index fea98c1863..5ba1f409b3 100644 --- a/src/dr/util/Timer.java +++ b/src/dr/util/Timer.java @@ -31,18 +31,18 @@ public class Timer { private long nanoStart = 0, nanoStop = 0; public void start() { + nanoStart = System.nanoTime(); // One wants the hihest precision first. TODO Do we really need this? start = System.currentTimeMillis(); - nanoStart = System.nanoTime(); } public void stop() { - stop = System.currentTimeMillis(); nanoStop = System.nanoTime(); + stop = System.currentTimeMillis(); } public void update() { - stop = System.currentTimeMillis(); nanoStop = System.nanoTime(); + stop = System.currentTimeMillis(); } /**