From c2ecacb8b204f2d4223df3a5c24c06f007b36675 Mon Sep 17 00:00:00 2001 From: yucais Date: Mon, 31 Jul 2023 21:17:50 -0700 Subject: [PATCH 1/4] remove duplicate code --- .../MasBirthDeathSerialSamplingModel.java | 16 ++--- ...TwoParamBirthDeathSerialSamplingModel.java | 63 +------------------ 2 files changed, 9 insertions(+), 70 deletions(-) diff --git a/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java index 022dedc952..b640885087 100644 --- a/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java +++ b/src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java @@ -47,9 +47,9 @@ public final double processModelSegmentBreakPoint(int model, double intervalStar return lnL; } - final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages, - final double[] partialQ_all_old, final double Q_Old, - final double[] partialQ_all_young, final double Q_young) { + void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages, + final double[] partialQ_all_old, final double Q_Old, + final double[] partialQ_all_young, final double Q_young) { for (int k = 0; k <= currentModelSegment; k++) { gradient[k * 5 + 0] += nLineages * (partialQ_all_old[k * 4 + 0] / Q_Old @@ -61,7 +61,7 @@ final void accumulateGradientForInterval(final double[] gradient, final int curr } } - final void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1, + void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1, double[] intermediate) { for (int k = 0; k <= currentModelSegment; k++) { @@ -71,7 +71,7 @@ final void accumulateGradientForSerialSampling(double[] gradient, int currentMod } } - final void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1, + void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1, double[] intermediate) { for (int k = 0; k < currentModelSegment; k++) { @@ -81,7 +81,7 @@ final void accumulateGradientForIntensiveSampling(double[] gradient, int current } } - final void dBCompute(int model, double[] dB) { + void dBCompute(int model, double[] dB) { for (int k = 0; k < model; ++k) { for (int p = 0; p < 4; p++) { @@ -96,7 +96,7 @@ final void dBCompute(int model, double[] dB) { dB[model * 4 + 2] = (A - dA[2] * (term1 * lambda + mu + psi)) / (A * A); } - final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { + void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) { double G1 = g1(eAt); @@ -121,7 +121,7 @@ final void dPCompute(int model, double t, double intervalStart, double eAt, doub } - final void dQCompute(int model, double t, double[] dQ, double eAt) { + void dQCompute(int model, double t, double[] dQ, double eAt) { double dwell = t - modelStartTimes[model]; double G1 = g1(eAt); diff --git a/src/dr/evomodel/speciation/TwoParamBirthDeathSerialSamplingModel.java b/src/dr/evomodel/speciation/TwoParamBirthDeathSerialSamplingModel.java index 405c425200..3930745171 100644 --- a/src/dr/evomodel/speciation/TwoParamBirthDeathSerialSamplingModel.java +++ b/src/dr/evomodel/speciation/TwoParamBirthDeathSerialSamplingModel.java @@ -27,25 +27,12 @@ import dr.inference.model.Parameter; -public class TwoParamBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel { +public class TwoParamBirthDeathSerialSamplingModel extends MasBirthDeathSerialSamplingModel { public TwoParamBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) { super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units); } - @Override - public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) { -// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart)); - double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart)); - if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) { - // Add in probability of un-sampled lineages - // We don't need this at t=0 because all lineages in the tree are sampled - // TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time? - lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1)); - } - this.savedLogQ = Double.NaN; - return lnL; - } final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages, final double[] partialQ_all_old, final double Q_Old, @@ -140,54 +127,6 @@ final void dQCompute(int model, double t, double[] dQ, double eAt) { dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3); } - - final double Q(int model, double time) { - double At = A * (time - modelStartTimes[model]); - double eAt = Math.exp(At); - double sqrtDenominator = g1(eAt); - return eAt / (sqrtDenominator * sqrtDenominator); - } - - final double logQ(int model, double time) { - double At = A * (time - modelStartTimes[model]); - double eAt = Math.exp(At); - double sqrtDenominator = g1(eAt); - return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b) - } - - @Override - public double processInterval(int model, double tYoung, double tOld, int nLineages) { - double logQ_young; - double logQ_old = Q(model, tOld); - if (!Double.isNaN(this.savedLogQ)) { - logQ_young = this.savedLogQ; - } else { - logQ_young = Q(model, tYoung); - } - this.savedLogQ = logQ_old; - return nLineages * Math.log(logQ_old / logQ_young); - } - - @Override - public double processSampling(int model, double tOld) { - - double logSampProb; - - boolean sampleIsAtEventTime = tOld == modelStartTimes[model]; - boolean samplesTakenAtEventTime = rho > 0; - - if (sampleIsAtEventTime && samplesTakenAtEventTime) { - logSampProb = Math.log(rho); - if (model > 0) { - logSampProb += Math.log(r + ((1.0 - r) * previousP)); - } - } else { - double logPsi = Math.log(psi); - logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld)); - } - - return logSampProb; - } } /* From 30a664493572d47229af0aa754ccc80be6b3f312 Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Tue, 1 Aug 2023 10:11:13 -0700 Subject: [PATCH 2/4] That was dumb --- .../tree/MixedEffectsRateStatistic.java | 4 ++-- .../TimeProportionToFixedEffectTransform.java | 22 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/tree/MixedEffectsRateStatistic.java b/src/dr/evomodel/tree/MixedEffectsRateStatistic.java index 6fbd29924a..3bcb3c64b2 100644 --- a/src/dr/evomodel/tree/MixedEffectsRateStatistic.java +++ b/src/dr/evomodel/tree/MixedEffectsRateStatistic.java @@ -91,7 +91,7 @@ private void prepareForComputation() { for (int i = 0; i < offset; i++) { NodeRef child = tree.getExternalNode(i); rates[i] = branchRateModel.getBranchRate(tree, child); - locations[i] = branchRateTransform.getLocation(tree, tree.getNode(i)); + locations[i] = branchRateTransform.getLocation(tree, child); } if (internal) { final int n = tree.getInternalNodeCount(); @@ -100,7 +100,7 @@ private void prepareForComputation() { NodeRef child = tree.getInternalNode(i); if (!tree.isRoot(child)) { rates[k] = branchRateModel.getBranchRate(tree, child); - locations[k] = branchRateTransform.getLocation(tree, tree.getNode(i)); + locations[k] = branchRateTransform.getLocation(tree, child); k++; } } diff --git a/src/dr/util/TimeProportionToFixedEffectTransform.java b/src/dr/util/TimeProportionToFixedEffectTransform.java index 776d9e0cea..e934d21870 100644 --- a/src/dr/util/TimeProportionToFixedEffectTransform.java +++ b/src/dr/util/TimeProportionToFixedEffectTransform.java @@ -43,10 +43,30 @@ protected double[] transform(double[] values) { double rateAncestral = Math.exp(values[1]); double rateDescendant = Math.exp(values[2]); double[] transformed = new double[1]; - transformed[0] = Math.log((rateDescendant * propTime + rateAncestral * (1.0 - propTime)) / (rateDescendant * propTime)); + transformed[0] = Math.log(propTime + (1.0 - propTime) * rateDescendant/rateAncestral); + verifyOutput(values); return transformed; } + private void verifyOutput(double[] values) { + double propTime = values[0]; + double rateAncestral = Math.exp(values[1]); + double rateDescendant = Math.exp(values[2]); + double[] transformed = new double[1]; + transformed[0] = Math.log(propTime + (1.0 - propTime) * rateDescendant/rateAncestral); + + double fe = transformed[0]; + double lb = rateAncestral; + double ub = rateDescendant; + if (rateAncestral > rateDescendant) { + lb = rateDescendant; + ub = rateAncestral; + } + + double newRate = rateAncestral * Math.exp(fe); + assert(newRate >= lb && newRate <= ub); + } + @Override protected double[] inverse(double[] values) { throw new RuntimeException("Not yet implemented"); From 83178d117d828259954e6edcbd7c1a5e9ecccfaf Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Tue, 1 Aug 2023 16:51:13 -0700 Subject: [PATCH 3/4] That was also stupid --- src/dr/util/TimeProportionToFixedEffectTransform.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dr/util/TimeProportionToFixedEffectTransform.java b/src/dr/util/TimeProportionToFixedEffectTransform.java index e934d21870..ebc3edcaee 100644 --- a/src/dr/util/TimeProportionToFixedEffectTransform.java +++ b/src/dr/util/TimeProportionToFixedEffectTransform.java @@ -43,8 +43,7 @@ protected double[] transform(double[] values) { double rateAncestral = Math.exp(values[1]); double rateDescendant = Math.exp(values[2]); double[] transformed = new double[1]; - transformed[0] = Math.log(propTime + (1.0 - propTime) * rateDescendant/rateAncestral); - verifyOutput(values); + transformed[0] = Math.log(propTime * rateDescendant / rateAncestral + (1.0 - propTime)); return transformed; } @@ -53,7 +52,7 @@ private void verifyOutput(double[] values) { double rateAncestral = Math.exp(values[1]); double rateDescendant = Math.exp(values[2]); double[] transformed = new double[1]; - transformed[0] = Math.log(propTime + (1.0 - propTime) * rateDescendant/rateAncestral); + transformed[0] = Math.log(propTime * rateDescendant / rateAncestral + (1.0 - propTime)); double fe = transformed[0]; double lb = rateAncestral; From 503b3897ae9d1cd840bb09c176b67efb3ff08bda Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Wed, 2 Aug 2023 12:43:43 -0700 Subject: [PATCH 4/4] clean up all warnings --- ...stimableStemWeightBranchSpecificBranchModel.java | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/dr/evomodel/branchmodel/EstimableStemWeightBranchSpecificBranchModel.java b/src/dr/evomodel/branchmodel/EstimableStemWeightBranchSpecificBranchModel.java index 3a2c5a2562..36c8fdbeb9 100644 --- a/src/dr/evomodel/branchmodel/EstimableStemWeightBranchSpecificBranchModel.java +++ b/src/dr/evomodel/branchmodel/EstimableStemWeightBranchSpecificBranchModel.java @@ -26,16 +26,10 @@ package dr.evomodel.branchmodel; import dr.evolution.tree.TreeUtils; -import dr.evolution.util.Taxon; -import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evolution.tree.NodeRef; -import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.evomodel.tree.TreeModel; -import dr.evomodelxml.branchratemodel.LocalClockModelParser; -import dr.inference.model.AbstractModel; -import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; @@ -49,8 +43,9 @@ * @version $Id$ */ public class EstimableStemWeightBranchSpecificBranchModel extends BranchSpecificBranchModel { - private List stemWeightParameters = new ArrayList<>(); - protected Map stemWeightMap = new HashMap(); + + final private List stemWeightParameters = new ArrayList<>(); + protected Map stemWeightMap = new HashMap<>(); private boolean hasBackbone = false; public EstimableStemWeightBranchSpecificBranchModel(TreeModel treeModel, SubstitutionModel rootSubstitutionModel) { @@ -128,7 +123,7 @@ public void addBackbone(TaxonList taxonList, SubstitutionModel substitutionModel protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { - if ( stemWeightParameters.contains(variable) && clades.size() > 0) { + if (variable instanceof Parameter && stemWeightParameters.contains((Parameter)variable) && clades.size() > 0) { setUpdateNodeMaps(true); } fireModelChanged();