From 4a25c8cef2c5c657820f8764afadec5b05cb485d Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Thu, 29 Jun 2023 14:42:58 -0700 Subject: [PATCH 1/6] More old work on BBMRF stuff --- ...sianBridgeMarkovRandomFieldLikelihood.java | 1 + ...seFirstOrderFiniteDifferenceTransform.java | 15 ++++------ src/dr/util/Transform.java | 28 +++++++++++++++++++ 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java b/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java index b6b0c4a1cb..65ef2e729b 100644 --- a/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java +++ b/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java @@ -90,6 +90,7 @@ public double getLogLikelihood() { double logPdf = 0.0; logPdf += firstElementDistribution.logPdf(unconstrained[0]); logPdf += bridge.logPdf(unconstrained[1]); + // Would be +logJacobian of the constrained->unconstrained transform, but this is the unconstrained->constrained transform logPdf -= transform.getLogJacobian(transform.inverse(variables.getParameterValues(),0,dim)); return logPdf; } diff --git a/src/dr/util/InverseFirstOrderFiniteDifferenceTransform.java b/src/dr/util/InverseFirstOrderFiniteDifferenceTransform.java index 493cc2af8c..6da920fd95 100644 --- a/src/dr/util/InverseFirstOrderFiniteDifferenceTransform.java +++ b/src/dr/util/InverseFirstOrderFiniteDifferenceTransform.java @@ -128,21 +128,16 @@ public double getLogJacobian(double[] values) { logJacobian += Math.log(incrementTransform.gradientInverse(s)); } // Why is this inverted? - return -logJacobian; + return logJacobian; } @Override public double[] getGradientLogJacobianInverse(double[] values) { - - double[] gradLogJacobian = firstOrderFiniteDifferenceTransform.getGradientLogJacobianInverse(values); - double[] gradient = new double[dim]; - - for (int i = 0; i < dim - 1; i++) { - gradient[i] = (gradLogJacobian[i] - gradLogJacobian[i+1]) * incrementTransform.derivativeOfTransformWrtValue(values[i]); + double[] grad = new double[dim]; + for (int i = 0; i < dim; i++) { + grad[i] = (1.0 / incrementTransform.derivativeOfTransformWrtValue(values[i])) * incrementTransform.secondDerivativeOfTransformWrtValue(values[i]); } - gradient[dim - 1] = gradLogJacobian[dim - 1] * incrementTransform.derivativeOfTransformWrtValue(values[dim - 1]); - - return gradient; + return grad; } @Override diff --git a/src/dr/util/Transform.java b/src/dr/util/Transform.java index 65ce9a2338..3985c7bce0 100644 --- a/src/dr/util/Transform.java +++ b/src/dr/util/Transform.java @@ -121,6 +121,10 @@ public interface Transform { double[] derivativeOfTransformWrtValue(double[] values, int from, int to); + double secondDerivativeOfTransformWrtValue(double value); + + double[] secondDerivativeOfTransformWrtValue(double[] values, int from, int to); + double secondDerivativeOfInverseTransformWrtValue(double value); double[] secondDerivativeOfInverseTransformWrtValue(double[] values, int from, int to); @@ -315,6 +319,18 @@ public double[] derivativeOfTransformWrtValue(double[] values, int from, int to) return result; } + public double secondDerivativeOfTransformWrtValue(double value) { + throw new RuntimeException("Not yet implemented."); + }; + + public double[] secondDerivativeOfTransformWrtValue(double[] values, int from, int to) { + double[] result = values.clone(); + for (int i = from; i < to; ++i) { + result[i] = secondDerivativeOfTransformWrtValue(values[i]); + } + return result; + } + public double secondDerivativeOfInverseTransformWrtValue(double value) { throw new RuntimeException("Not yet implemented."); } @@ -413,6 +429,14 @@ public double[] derivativeOfTransformWrtValue(double[] values, int from, int to) throw new RuntimeException("Not yet implemented."); } + public double secondDerivativeOfTransformWrtValue(double value) { + throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); + }; + + public double[] secondDerivativeOfTransformWrtValue(double[] values, int from, int to) { + throw new RuntimeException("Not yet implemented."); + } + public double secondDerivativeOfInverseTransformWrtValue(double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } @@ -623,6 +647,8 @@ public double gradient(double value) { public double derivativeOfTransformWrtValue(double value) { return 1.0 / value; } + public double secondDerivativeOfTransformWrtValue(double value) { return -1.0 / (value * value); } + public double secondDerivativeOfInverseTransformWrtValue(double value) { return Math.exp(value); } public double logSecondDerivativeOfInverseTransformWrtValue(double value) { return value; } @@ -1288,6 +1314,8 @@ public double getLogJacobian(double value) { public double derivativeOfTransformWrtValue(double value) { return 1.0; } + public double secondDerivativeOfTransformWrtValue(double value) { return 0.0; } + public double secondDerivativeOfInverseTransformWrtValue(double value) { return 0.0; } public double logSecondDerivativeOfInverseTransformWrtValue(double value) { return Double.NEGATIVE_INFINITY; } From 57e03127059047b029958d5bbaa457b3e6bb2584 Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Fri, 30 Jun 2023 08:40:06 -0700 Subject: [PATCH 2/6] Tests --- ci/TestXML/testBBMRFGradient.xml | 36 ++ .../testBranchSpecificSubstitutionModel.xml | 390 ++++++++++++++++++ 2 files changed, 426 insertions(+) create mode 100644 ci/TestXML/testBranchSpecificSubstitutionModel.xml diff --git a/ci/TestXML/testBBMRFGradient.xml b/ci/TestXML/testBBMRFGradient.xml index cc0ad749bd..7b27e010e3 100644 --- a/ci/TestXML/testBBMRFGradient.xml +++ b/ci/TestXML/testBBMRFGradient.xml @@ -74,6 +74,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ci/TestXML/testBranchSpecificSubstitutionModel.xml b/ci/TestXML/testBranchSpecificSubstitutionModel.xml new file mode 100644 index 0000000000..e8bf572cf7 --- /dev/null +++ b/ci/TestXML/testBranchSpecificSubstitutionModel.xml @@ -0,0 +1,390 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + AAACCCGGTAACAA + + + + AAACCTGGGAATAA + + + + AAACTCGGGAATGA + + + + ATACCCGGTGGTAG + + + + + + + + + + + + + + + + + + + (A:1.0,(B:1.0,(C:1.0,D:1.0):1.0):1.0); + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + check fixed stem weight 0.0 + + + + + + + + -68.07217469138813 + + + + + + check fixed stem weight 1.0 + + + + + + + + -67.91898348796958 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 19f5f8c2297a49bc679d6cbc6cb056d92ae660aa Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Fri, 30 Jun 2023 09:57:20 -0700 Subject: [PATCH 3/6] Bad test --- ci/TestXML/testBBMRFGradient.xml | 66 ++++++++++++++++---------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/ci/TestXML/testBBMRFGradient.xml b/ci/TestXML/testBBMRFGradient.xml index 7b27e010e3..fa015d3ee2 100644 --- a/ci/TestXML/testBBMRFGradient.xml +++ b/ci/TestXML/testBBMRFGradient.xml @@ -74,41 +74,41 @@ - + - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + From 7397d58b17d2ce0a983896589a99810e85692677 Mon Sep 17 00:00:00 2001 From: xji3 Date: Fri, 30 Jun 2023 13:40:54 -0500 Subject: [PATCH 4/6] cache smooth skygrid sufficient statistics --- .../smooth/SmoothSkygridLikelihood.java | 61 ++++++++++++++----- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java index 599da4697e..27df1c602b 100644 --- a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java +++ b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java @@ -31,6 +31,7 @@ import dr.evomodel.bigfasttree.BigFastTreeIntervals; import dr.evomodel.coalescent.AbstractCoalescentLikelihood; import dr.evomodel.tree.TreeModel; +import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.util.Author; import dr.util.Citable; @@ -76,6 +77,18 @@ public SmoothSkygridLikelihood(String name, this.populationSizeInverse = new SmoothSkygridPopulationSizeInverse(logPopSizeParameter, gridPointParameter, smoothFunction, smoothRate); this.lineageCount = new OldSmoothLineageCount(trees.get(0), smoothFunction, smoothRate); intervalsList = new ArrayList<>(); + + this.tmpA = new double[trees.get(0).getNodeCount()]; + this.tmpB = new double[trees.get(0).getNodeCount()]; + this.tmpC = new double[trees.get(0).getNodeCount()]; + this.tmpD = new double[gridPointParameter.getDimension()]; + this.tmpE = new double[gridPointParameter.getDimension()]; + this.tmpF = new double[gridPointParameter.getDimension()]; + this.tmpLineageEffect = new double[trees.get(0).getNodeCount()]; + this.tmpTimes = new double[trees.get(0).getNodeCount()]; + this.tmpCounts = new int[trees.get(0).getNodeCount()]; + this.tmpSumsKnown = false; + for (int i = 0; i < trees.size(); i++) { intervalsList.add(new BigFastTreeIntervals(trees.get(i))); addModel(intervalsList.get(i)); @@ -256,24 +269,25 @@ private double getLineageCountDifference(int intervalIndex, BigFastTreeIntervals } } - protected double calculateLogLikelihood() { - assert(trees.size() == 1); - if (!likelihoodKnown) { + private double[] tmpA; + private double[] tmpB; + private double[] tmpC; + private double[] tmpD; + private double[] tmpE; + private double[] tmpF; + private double[] tmpLineageEffect; + private double[] tmpTimes; + private int[] tmpCounts; + private int uniqueTimes; + private boolean tmpSumsKnown; + + private void calculateTmpSums() { + if (!tmpSumsKnown) { TreeModel tree = trees.get(0); final double startTime = 0; final double endTime = tree.getNodeHeight(tree.getRoot()); final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); - double[] tmpA = new double[tree.getNodeCount()]; - double[] tmpB = new double[tree.getNodeCount()]; - double[] tmpC = new double[tree.getNodeCount()]; - double[] tmpD = new double[maxGridIndex]; - double[] tmpE = new double[maxGridIndex]; - double[] tmpF = new double[maxGridIndex]; - double[] tmpLineageEffect = new double[tree.getNodeCount()]; - double[] tmpTimes = new double[tree.getNodeCount()]; - int[] tmpCounts = new int[tree.getNodeCount()]; - NodeRef[] nodes = new NodeRef[tree.getNodeCount()]; System.arraycopy(tree.getNodes(), 0, nodes, 0, tree.getNodeCount()); Arrays.parallelSort(nodes, (a, b) -> Double.compare(tree.getNodeHeight(a), tree.getNodeHeight(b))); @@ -301,7 +315,7 @@ protected double calculateLogLikelihood() { } tmpLineageEffect[index] = currentLineageEffect; tmpCounts[index] = currentCount; - final int uniqueTimes = index + 1; + uniqueTimes = index + 1; for (int i = 0; i < uniqueTimes; i++) { final double timeI = tmpTimes[i]; @@ -354,6 +368,19 @@ protected double calculateLogLikelihood() { tmpE[k] = sum; tmpF[k] = sum * sum - quadraticSum; } + tmpSumsKnown = true; + } + } + + protected double calculateLogLikelihood() { + assert(trees.size() == 1); + if (!likelihoodKnown) { + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); + + calculateTmpSums(); double tripleIntegrationSum = 0; double lineageEffectSqaredSum = 0; @@ -447,8 +474,12 @@ protected double calculateLogLikelihood() { return logLikelihood; } + protected void handleModelChangedEvent(Model model, Object object, int index) { + super.handleModelChangedEvent(model, object, index); + tmpSumsKnown = false; + } - private double getLineageCountEffect(Tree tree, int node) { + private double getLineageCountEffect(Tree tree, int node) { if (tree.isExternal(tree.getNode(node))) { return 1; } else { From 385ca76ec9ef54fb220f2b52a903b27d48c265b3 Mon Sep 17 00:00:00 2001 From: xji3 Date: Fri, 30 Jun 2023 14:26:56 -0500 Subject: [PATCH 5/6] break large function into separate smaller (but still large) functions --- .../smooth/SmoothSkygridLikelihood.java | 157 ++++++++++-------- 1 file changed, 88 insertions(+), 69 deletions(-) diff --git a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java index 27df1c602b..a5aea7ad22 100644 --- a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java +++ b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java @@ -382,96 +382,115 @@ protected double calculateLogLikelihood() { calculateTmpSums(); - double tripleIntegrationSum = 0; - double lineageEffectSqaredSum = 0; + double lineageEffectSquaredSum = 0; for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i]; - lineageEffectSqaredSum += lineageCountEffect * lineageCountEffect; - tripleIntegrationSum += lineageCountEffect * tmpA[i] * tmpB[i] * tmpC[i]; + lineageEffectSquaredSum += tmpLineageEffect[i] * tmpLineageEffect[i]; } - tripleIntegrationSum *= 2; - for (int k = 0; k < maxGridIndex; k++) { - final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); - final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - tripleIntegrationSum += (nextPopSizeInverse - currentPopSizeInverse) * tmpF[k] * tmpD[k]; - } + double tripleIntegrationSum = getTripleIntegration(startTime, endTime, maxGridIndex, lineageEffectSquaredSum); - tripleIntegrationSum /= -smoothRate.getParameterValue(0) * 2; - tripleIntegrationSum += -0.5 * (1 - lineageEffectSqaredSum) - * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) - * (endTime - startTime); + double doubleIntegrationSum = getDoubleIntegration(startTime, endTime, maxGridIndex, lineageEffectSquaredSum); - double tripleWithQuadraticIntegrationSum = 0; - final double commonFirstTermMultiplier = (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) * (endTime - startTime); - for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i] * tmpLineageEffect[i]; - final double timeI = tmpTimes[i]; - double thisResult = commonFirstTermMultiplier; - final double commonSecondTermMultiplier = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)) - - smoothFunction.getInverseOnePlusExponential(timeI - endTime, smoothRate.getParameterValue(0)); - for (int k = 0; k < maxGridIndex; k++) { - final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); - final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - final double gridTime = gridPointParameter.getParameterValue(k); - final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); - thisResult += (nextPopSizeInverse - currentPopSizeInverse) / smoothRate.getParameterValue(0) - * (inverse * commonSecondTermMultiplier + (2.0 - inverse) * inverse * tmpC[i] + - (1 - inverse) * (1 - inverse) * tmpD[k]); - } - thisResult *= lineageCountEffect; - tripleWithQuadraticIntegrationSum += thisResult; - } - tripleWithQuadraticIntegrationSum *= -0.5; + final double singleIntegration = getSingleIntegration(startTime, endTime); - double firstDoubleIntegrationOffDiagonalSum = 0; - double firstDoubleIntegrationDiagonalSum = 0; - for (int i = 0; i < uniqueTimes; i++) { - final double lineageCountEffect = tmpLineageEffect[i]; - final double timeI = tmpTimes[i]; - firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; - firstDoubleIntegrationDiagonalSum += lineageCountEffect * lineageCountEffect - * smoothFunction.getQuadraticIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + double logPopulationSizeInverse = 0; + for (int i = 0; i < tree.getInternalNodeCount(); i++) { + NodeRef node = tree.getNode(tree.getExternalNodeCount() + i); + logPopulationSizeInverse += Math.log(getSmoothPopulationSizeInverse(tree.getNodeHeight(node), tree.getNodeHeight(tree.getRoot()))); } - firstDoubleIntegrationOffDiagonalSum /= smoothRate.getParameterValue(0); - firstDoubleIntegrationOffDiagonalSum += 0.5 * (1 - lineageEffectSqaredSum) * (endTime - startTime); - final double firstDoubleIntegrationSum = -(firstDoubleIntegrationDiagonalSum * 0.5 + firstDoubleIntegrationOffDiagonalSum) * Math.exp(-logPopSizeParameter.getParameterValue(0)); + logLikelihood = logPopulationSizeInverse + singleIntegration + doubleIntegrationSum + tripleIntegrationSum; - double secondDoubleIntegrationSum = 0; - for (int i = 0; i < uniqueTimes; i++) { - secondDoubleIntegrationSum += 0.5 * tmpB[i] * tmpC[i] * tmpLineageEffect[i]; - } + likelihoodKnown = true; + } + return logLikelihood; + } + double getTripleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { + double tripleIntegrationSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + tripleIntegrationSum += lineageCountEffect * tmpA[i] * tmpB[i] * tmpC[i]; + } + tripleIntegrationSum *= 2; + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + tripleIntegrationSum += (nextPopSizeInverse - currentPopSizeInverse) * tmpF[k] * tmpD[k]; + } + + tripleIntegrationSum /= -smoothRate.getParameterValue(0) * 2; + tripleIntegrationSum += -0.5 * (1 - lineageEffectSquaredSum) + * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) + * (endTime - startTime); + + double tripleWithQuadraticIntegrationSum = 0; + final double commonFirstTermMultiplier = (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))) * (endTime - startTime); + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i] * tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + double thisResult = commonFirstTermMultiplier; + final double commonSecondTermMultiplier = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)) + - smoothFunction.getInverseOnePlusExponential(timeI - endTime, smoothRate.getParameterValue(0)); for (int k = 0; k < maxGridIndex; k++) { final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); - secondDoubleIntegrationSum += 0.5 * tmpE[k] * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse); + final double gridTime = gridPointParameter.getParameterValue(k); + final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + thisResult += (nextPopSizeInverse - currentPopSizeInverse) / smoothRate.getParameterValue(0) + * (inverse * commonSecondTermMultiplier + (2.0 - inverse) * inverse * tmpC[i] + + (1 - inverse) * (1 - inverse) * tmpD[k]); } + thisResult *= lineageCountEffect; + tripleWithQuadraticIntegrationSum += thisResult; + } + tripleWithQuadraticIntegrationSum *= -0.5; + return tripleIntegrationSum + tripleWithQuadraticIntegrationSum; + } - secondDoubleIntegrationSum /= smoothRate.getParameterValue(0); - secondDoubleIntegrationSum += 0.5 * (endTime - startTime) * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))); + double getDoubleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { + double firstDoubleIntegrationOffDiagonalSum = 0; + double firstDoubleIntegrationDiagonalSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; + firstDoubleIntegrationDiagonalSum += lineageCountEffect * lineageCountEffect + * smoothFunction.getQuadraticIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + } + firstDoubleIntegrationOffDiagonalSum /= smoothRate.getParameterValue(0); + firstDoubleIntegrationOffDiagonalSum += 0.5 * (1 - lineageEffectSquaredSum) * (endTime - startTime); + final double firstDoubleIntegrationSum = -(firstDoubleIntegrationDiagonalSum * 0.5 + firstDoubleIntegrationOffDiagonalSum) * Math.exp(-logPopSizeParameter.getParameterValue(0)); - double singleIntegration = 0; - for (int i = 0; i < uniqueTimes; i++) { - final double timeI = tmpTimes[i]; - final double lineageCountEffectI = tmpLineageEffect[i]; - singleIntegration += lineageCountEffectI * smoothFunction.getSingleIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); - } - singleIntegration *= 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); + double secondDoubleIntegrationSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + secondDoubleIntegrationSum += 0.5 * tmpB[i] * tmpC[i] * tmpLineageEffect[i]; + } - double logPopulationSizeInverse = 0; - for (int i = 0; i < tree.getInternalNodeCount(); i++) { - NodeRef node = tree.getNode(tree.getExternalNodeCount() + i); - logPopulationSizeInverse += Math.log(getSmoothPopulationSizeInverse(tree.getNodeHeight(node), tree.getNodeHeight(tree.getRoot()))); - } + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + secondDoubleIntegrationSum += 0.5 * tmpE[k] * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse); + } - logLikelihood = logPopulationSizeInverse + singleIntegration + firstDoubleIntegrationSum + secondDoubleIntegrationSum + tripleIntegrationSum + tripleWithQuadraticIntegrationSum; + secondDoubleIntegrationSum /= smoothRate.getParameterValue(0); + secondDoubleIntegrationSum += 0.5 * (endTime - startTime) * (Math.exp(-logPopSizeParameter.getParameterValue(maxGridIndex)) - Math.exp(-logPopSizeParameter.getParameterValue(0))); - likelihoodKnown = true; + return firstDoubleIntegrationSum + secondDoubleIntegrationSum; + + } + + private double getSingleIntegration(double startTime, double endTime) { + double singleIntegration = 0; + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + final double lineageCountEffectI = tmpLineageEffect[i]; + singleIntegration += lineageCountEffectI * smoothFunction.getSingleIntegration(startTime, endTime, timeI, smoothRate.getParameterValue(0)); } - return logLikelihood; + singleIntegration *= 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); + return singleIntegration; } protected void handleModelChangedEvent(Model model, Object object, int index) { From 8540f368cadda4082b91060d14c060eaeea75019 Mon Sep 17 00:00:00 2001 From: xji3 Date: Mon, 17 Jul 2023 16:00:25 -0500 Subject: [PATCH 6/6] intermediate commit, implementing gradient wrt node height for smooth skygrid --- .../smooth/SmoothSkygridLikelihood.java | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java index a5aea7ad22..0d5080bbd5 100644 --- a/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java +++ b/src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java @@ -81,6 +81,9 @@ public SmoothSkygridLikelihood(String name, this.tmpA = new double[trees.get(0).getNodeCount()]; this.tmpB = new double[trees.get(0).getNodeCount()]; this.tmpC = new double[trees.get(0).getNodeCount()]; + this.tmpADerivOverS = new double[trees.get(0).getNodeCount()]; + this.tmpBDerivOverS = new double[trees.get(0).getNodeCount()]; + this.tmpCDerivOverS = new double[trees.get(0).getNodeCount()]; this.tmpD = new double[gridPointParameter.getDimension()]; this.tmpE = new double[gridPointParameter.getDimension()]; this.tmpF = new double[gridPointParameter.getDimension()]; @@ -270,8 +273,11 @@ private double getLineageCountDifference(int intervalIndex, BigFastTreeIntervals } private double[] tmpA; + private double[] tmpADerivOverS; private double[] tmpB; + private double[] tmpBDerivOverS; private double[] tmpC; + private double[] tmpCDerivOverS; private double[] tmpD; private double[] tmpE; private double[] tmpF; @@ -372,6 +378,49 @@ private void calculateTmpSums() { } } + private void calculateTmpSumDerivatives() { + if (!tmpSumsKnown) { + calculateTmpSums(); + } + + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime); + + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + double sum = 0; + for (int j = 0; j < uniqueTimes; j++) { + if (j != i) { + final double timeJ = tmpTimes[j]; + final double lineageCountEffect = tmpLineageEffect[j]; + final double thisInverse = smoothFunction.getInverseOneMinusExponential(timeJ - timeI, smoothRate.getParameterValue(0)); + sum += lineageCountEffect * thisInverse * (1 - thisInverse); + } + } + tmpADerivOverS[i] = - sum; + } + + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + double sum = 0; + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double thisInverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + sum += (nextPopSizeInverse - currentPopSizeInverse) * thisInverse * (1 - thisInverse); + } + tmpBDerivOverS[i] = -sum; + } + + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + tmpCDerivOverS[i] = smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)); + } + } + protected double calculateLogLikelihood() { assert(trees.size() == 1); if (!likelihoodKnown) { @@ -406,6 +455,26 @@ protected double calculateLogLikelihood() { return logLikelihood; } + private double[] getGradientWrtNodeHeightNew() { + if (!likelihoodKnown) { + calculateLogLikelihood(); + } + TreeModel tree = trees.get(0); + final double startTime = 0; + final double endTime = tree.getNodeHeight(tree.getRoot()); + double[] gradient = new double[tree.getInternalNodeCount()]; + getGradientWrtNodeHeightFromSingleIntegration(startTime, endTime, gradient); + + double lineageEffectSquaredSum = 0; + for (int i = 0; i < uniqueTimes; i++) { + lineageEffectSquaredSum += tmpLineageEffect[i] * tmpLineageEffect[i]; + } + getGradientWrtNodeHeightFromDoubleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient); + + getGradientWrtNodeHeightFromTripleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient); + return gradient; + } + double getTripleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { double tripleIntegrationSum = 0; for (int i = 0; i < uniqueTimes; i++) { @@ -449,6 +518,42 @@ protected double calculateLogLikelihood() { return tripleIntegrationSum + tripleWithQuadraticIntegrationSum; } + private void getGradientWrtNodeHeightFromTripleIntegration(double startTime, double endTime, int maxGridIndex, + double[] gradient) { + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + gradient[i] += lineageCountEffect * (tmpADerivOverS[i] * tmpB[i] * tmpC[i] + tmpA[i] * tmpBDerivOverS[i] * tmpC[i] + tmpA[i] * tmpB[i] * tmpCDerivOverS[i]); + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0)); + + gradient[i] += (nextPopSizeInverse - currentPopSizeInverse) * tmpD[k] * (tmpE[k] - lineageCountEffect * tmpEInverse ) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect; + } + + + final double startTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)); + final double endTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0)); + final double commonSecondTermMultiplier = startTimeInverse - endTimeInverse; + final double commonSecondTermMultiplierDerivativeOverS = - startTimeInverse * (1 - startTimeInverse) + endTimeInverse * (1 - endTimeInverse); + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0)); + final double inverseDerivativeOverS = -inverse * (1 - inverse); + gradient[i] += (nextPopSizeInverse - currentPopSizeInverse) + * (inverseDerivativeOverS * commonSecondTermMultiplier + inverse * commonSecondTermMultiplierDerivativeOverS + + 2 * (1 - inverse) * inverseDerivativeOverS * tmpC[i] + (2.0 - inverse) * inverse * tmpCDerivOverS[i] + + 2 * (1 - inverse) * (-inverseDerivativeOverS) * tmpD[k]); + } + } + + } + double getDoubleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) { double firstDoubleIntegrationOffDiagonalSum = 0; double firstDoubleIntegrationDiagonalSum = 0; @@ -482,6 +587,34 @@ protected double calculateLogLikelihood() { } + private void getGradientWrtNodeHeightFromDoubleIntegration(double startTime, double endTime, int maxGridIndex, + double[] gradient) { + final double firstPopSize = Math.exp(-logPopSizeParameter.getParameterValue(0)); + for (int i = 0; i < uniqueTimes; i++) { + final double lineageCountEffect = tmpLineageEffect[i]; + final double timeI = tmpTimes[i]; + //firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i]; + gradient[i] += -lineageCountEffect * (tmpA[i] * tmpCDerivOverS[i] + tmpADerivOverS[i] * tmpC[i]) * firstPopSize; + + //firstDoubleIntegrationDiagonalSum + gradient[i] += lineageCountEffect * lineageCountEffect + * (smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)) + + (smoothFunction.getDerivative(timeI, endTime, 0, 1, smoothRate.getParameterValue(0)) + - smoothFunction.getDerivative(timeI, startTime, 0, 1, smoothRate.getParameterValue(0)) / smoothRate.getParameterValue(0)) + ) * -0.5 * firstPopSize; + + gradient[i] += 0.5 * tmpLineageEffect[i] * (tmpB[i] * tmpCDerivOverS[i] + tmpBDerivOverS[i] * tmpC[i]); + + for (int k = 0; k < maxGridIndex; k++) { + final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k)); + final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1)); + final double gridTime = gridPointParameter.getParameterValue(k); + final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0)); + gradient[i] += 0.5 * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect; + } + } + } + private double getSingleIntegration(double startTime, double endTime) { double singleIntegration = 0; for (int i = 0; i < uniqueTimes; i++) { @@ -493,6 +626,15 @@ private double getSingleIntegration(double startTime, double endTime) { return singleIntegration; } + private void getGradientWrtNodeHeightFromSingleIntegration(double startTime, double endTime, double[] gradient) { + for (int i = 0; i < uniqueTimes; i++) { + final double timeI = tmpTimes[i]; + final double lineageCountEffectI = tmpLineageEffect[i]; + gradient[i] += lineageCountEffectI * smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0)) + * 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0)); + } + } + protected void handleModelChangedEvent(Model model, Object object, int index) { super.handleModelChangedEvent(model, object, index); tmpSumsKnown = false;