From c8b32fa45e5e512ea92727c1fe309b80fc54373e Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Sat, 9 Dec 2023 08:08:30 -0800 Subject: [PATCH] solve first set of store/restore issues --- .../coalescent/basta/BastaLikelihood.java | 113 ++++++++---------- .../basta/BastaLikelihoodDelegate.java | 28 +++-- .../basta/GenericBastaLikelihoodDelegate.java | 26 ++++ .../basta/NativeBastaLikelihoodDelegate.java | 5 + 4 files changed, 103 insertions(+), 69 deletions(-) diff --git a/src/dr/evomodel/coalescent/basta/BastaLikelihood.java b/src/dr/evomodel/coalescent/basta/BastaLikelihood.java index 96be4d227b..c7c8351abd 100644 --- a/src/dr/evomodel/coalescent/basta/BastaLikelihood.java +++ b/src/dr/evomodel/coalescent/basta/BastaLikelihood.java @@ -30,6 +30,7 @@ import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; +import dr.evomodel.bigfasttree.BestSignalsFromBigFastTreeIntervals; import dr.evomodel.bigfasttree.BigFastTreeIntervals; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.StrictClockBranchRates; @@ -77,8 +78,6 @@ public class BastaLikelihood extends AbstractModelLikelihood implements private double storedLogLikelihood; protected boolean likelihoodKnown; - private final boolean isTreeRandom; - public BastaLikelihood(String name, Tree treeModel, PatternList patternList, @@ -111,10 +110,6 @@ public BastaLikelihood(String name, addModel(likelihoodDelegate); this.tree = treeModel; - isTreeRandom = (treeModel instanceof AbstractModel) && ((AbstractModel) treeModel).isVariable(); - if (isTreeRandom) { - addModel((AbstractModel)treeModel); // TODO maybe unnecessary as BFTI already signals - } this.branchRateModel = branchRateModel; addModel(branchRateModel); @@ -127,10 +122,14 @@ public BastaLikelihood(String name, this.stateCount = substitutionModel.getDataType().getStateCount(); - treeIntervals = new BigFastTreeIntervals((TreeModel)treeModel); - treeTraversalDelegate = new CoalescentIntervalTraversal(treeModel, treeIntervals, branchRateModel, numberSubIntervals); + if (tree instanceof TreeModel) { + treeIntervals = new BestSignalsFromBigFastTreeIntervals((TreeModel) treeModel); + addModel(treeIntervals); + } else { + throw new RuntimeException("Not yet implemented"); + } - addModel(treeIntervals); + treeTraversalDelegate = new CoalescentIntervalTraversal(treeModel, treeIntervals, branchRateModel, numberSubIntervals); setTipData(); @@ -197,49 +196,53 @@ public final void makeDirty() { } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { - // TODO + if (variable == popSizeParameter) { + likelihoodKnown = false; + } else { + throw new RuntimeException("Not yet implemented"); + } } @Override @SuppressWarnings("Duplicates") protected final void handleModelChangedEvent(Model model, Object object, int index) { - if (model == tree) { - if (object instanceof TreeChangedEvent) { - - final TreeChangedEvent treeChangedEvent = (TreeChangedEvent) object; - - if (!isTreeRandom) throw new IllegalStateException("Attempting to change a fixed tree"); - - if (treeChangedEvent.isNodeChanged()) { - // If a node event occurs the node and its two child nodes - // are flagged for updating this will result in everything - // above being updated as well. Node events occur when a node - // is added to a branch, removed from a branch or its height or - // rate changes. - updateNode(((TreeChangedEvent) object).getNode()); - } else if (treeChangedEvent.isTreeChanged()) { - // Full tree events result in a complete updating of the tree likelihood - // This event type is now used for EmpiricalTreeDistributions. - updateAllNodes(); - } - } - } else if (model == likelihoodDelegate) { - if (index == -1) { - updateAllNodes(); - } else { - updateNode(tree.getNode(index)); - } - - } else if (model == branchRateModel) { - if (index == -1) { - updateAllNodes(); - } else { - updateNode(tree.getNode(index)); - } - } else { - - assert false : "Unknown componentChangedEvent"; - } +// if (model == tree) { +// if (object instanceof TreeChangedEvent) { +// +// final TreeChangedEvent treeChangedEvent = (TreeChangedEvent) object; +// +// if (!isTreeRandom) throw new IllegalStateException("Attempting to change a fixed tree"); +// +// if (treeChangedEvent.isNodeChanged()) { +// // If a node event occurs the node and its two child nodes +// // are flagged for updating this will result in everything +// // above being updated as well. Node events occur when a node +// // is added to a branch, removed from a branch or its height or +// // rate changes. +// updateNode(((TreeChangedEvent) object).getNode()); +// } else if (treeChangedEvent.isTreeChanged()) { +// // Full tree events result in a complete updating of the tree likelihood +// // This event type is now used for EmpiricalTreeDistributions. +// updateAllNodes(); +// } +// } +// } else if (model == likelihoodDelegate) { +// if (index == -1) { +// updateAllNodes(); +// } else { +// updateNode(tree.getNode(index)); +// } +// +// } else if (model == branchRateModel) { +// if (index == -1) { +// updateAllNodes(); +// } else { +// updateNode(tree.getNode(index)); +// } +// } else { +// +// assert false : "Unknown componentChangedEvent"; +// } if (COUNT_TOTAL_OPERATIONS) totalModelChangedCount++; @@ -249,27 +252,18 @@ protected final void handleModelChangedEvent(Model model, Object object, int ind @Override protected final void storeState() { - assert (likelihoodKnown) : "the likelihood should always be known at this point in the cycle"; storedLogLikelihood = logLikelihood; - - if (TEST) treeIntervals.storeModelState(); } @Override protected final void restoreState() { - - // restore the likelihood and flag it as known logLikelihood = storedLogLikelihood; likelihoodKnown = true; - - if (TEST) treeIntervals.restoreModelState(); } @Override - protected void acceptState() { - if (TEST) treeIntervals.acceptModelState(); - } // nothing to do + protected void acceptState() { } // nothing to do private double calculateLogLikelihood() { @@ -287,10 +281,7 @@ private double calculateLogLikelihood() { final List matrixOperations = treeTraversalDelegate.getMatrixOperations(); final List intervalStarts = treeTraversalDelegate.getIntervalStarts(); - - final List nodeOperations = - treeTraversalDelegate.getOtherOperations(); - + if (COUNT_TOTAL_OPERATIONS) { totalPropagationCount += branchOperations.size(); totalMatrixUpdateCount += matrixOperations.size(); diff --git a/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java index 45875c2943..695d5a7f89 100644 --- a/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java +++ b/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java @@ -79,31 +79,37 @@ default void setPopulationSizes(int index, double[] sizes) { abstract class AbstractBastaLikelihoodDelegate extends AbstractModel implements BastaLikelihoodDelegate, Citable { + protected static final boolean PRINT_COMMANDS = true; + protected final int maxNumCoalescentIntervals; protected final ParallelizationScheme parallelizationScheme; protected final int stateCount; + protected final Tree tree; + public AbstractBastaLikelihoodDelegate(String name, Tree tree, int stateCount) { super(name); + this.tree = tree; this.stateCount = stateCount; this.maxNumCoalescentIntervals = getMaxNumberOfCoalescentIntervals(tree); this.parallelizationScheme = ParallelizationScheme.NONE; } private int getMaxNumberOfCoalescentIntervals(Tree tree) { - BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree - int zeroLengthSampling = 0; - for (int i = 0; i < intervals.getIntervalCount(); ++i) { - if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) { - ++zeroLengthSampling; - } - } - return tree.getNodeCount() - zeroLengthSampling; +// BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree +// int zeroLengthSampling = 0; +// for (int i = 0; i < intervals.getIntervalCount(); ++i) { +// if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) { +// ++zeroLengthSampling; +// } +// } +// return tree.getNodeCount() - zeroLengthSampling; + return tree.getNodeCount(); } @Override @@ -141,6 +147,8 @@ enum ParallelizationScheme { FULL } + abstract protected void clearAll(); + abstract protected double computeBranchIntervalOperations(List branchIntervalOperations); abstract protected double computeTransitionProbabilityOperations(List matrixOperations); @@ -154,6 +162,10 @@ public double calculateLikelihood(List branchOperations List intervalStarts, int rootNodeNumber) { + if (PRINT_COMMANDS) { + System.err.println("Tree = " + tree); + } + double logLikelihood = computeTransitionProbabilityOperations(matrixOperation); logLikelihood += computeBranchIntervalOperations(branchOperations); logLikelihood += computeCoalescentIntervalReduction(intervalStarts, branchOperations); diff --git a/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java index 6b2e0a3842..2eae2600c8 100644 --- a/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java +++ b/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java @@ -4,6 +4,7 @@ import dr.evomodel.substmodel.EigenDecomposition; import dr.math.matrixAlgebra.WrappedVector; +import java.util.Arrays; import java.util.List; /** @@ -46,9 +47,24 @@ public GenericBastaLikelihoodDelegate(String name, this.temp = new double[stateCount * stateCount]; } + @Override + protected void clearAll() { +// Arrays.fill(partials, tree.getExternalNodeCount() * stateCount, partials.length, 0.0); +// Arrays.fill(matrices, 0.0); +// Arrays.fill(sizes, 0.0); + +// Arrays.fill(coalescent, 0.0); +// Arrays.fill(e, 0.0); +// Arrays.fill(f, 0.0); +// Arrays.fill(g, 0.0); +// Arrays.fill(h, 0.0); + } + @Override protected double computeBranchIntervalOperations(List branchIntervalOperations) { + Arrays.fill(coalescent, 0.0); + for (BranchIntervalOperation operation : branchIntervalOperations) { // TODO execute parallel by intervalNumber or executionOrder peelPartials( partials, operation.outputBuffer, @@ -93,6 +109,11 @@ protected double computeTransitionProbabilityOperations(List intervalStarts, List branchIntervalOperations) { + Arrays.fill(e, 0.0); + Arrays.fill(f, 0.0); + Arrays.fill(g, 0.0); + Arrays.fill(h, 0.0); + for (int interval = 0; interval < intervalStarts.size() - 1; ++interval) { // TODO execute in parallel (no race conditions) int start = intervalStarts.get(interval); int end = intervalStarts.get(interval + 1); @@ -117,6 +138,11 @@ protected double computeCoalescentIntervalReduction(List intervalStarts sizes, coalescent, stateCount); } + if (PRINT_COMMANDS) { + System.err.println("logL = " + logL + "\n"); + } + +// return 0.0; return logL; } diff --git a/src/dr/evomodel/coalescent/basta/NativeBastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/NativeBastaLikelihoodDelegate.java index 09fee2cdf5..9ce619bd3c 100644 --- a/src/dr/evomodel/coalescent/basta/NativeBastaLikelihoodDelegate.java +++ b/src/dr/evomodel/coalescent/basta/NativeBastaLikelihoodDelegate.java @@ -21,6 +21,11 @@ public NativeBastaLikelihoodDelegate(String name, jni = NativeBastaJniWrapper.getBastaJniWrapper(); } + @Override + protected void clearAll() { + + } + @Override protected double computeBranchIntervalOperations(List branchIntervalOperations) { if (branchIntervalOperations != null) {