Skip to content

Commit

Permalink
solve first set of store/restore issues
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 9, 2023
1 parent 7019ee5 commit c8b32fa
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 69 deletions.
113 changes: 52 additions & 61 deletions src/dr/evomodel/coalescent/basta/BastaLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.bigfasttree.BestSignalsFromBigFastTreeIntervals;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
Expand Down Expand Up @@ -77,8 +78,6 @@ public class BastaLikelihood extends AbstractModelLikelihood implements
private double storedLogLikelihood;
protected boolean likelihoodKnown;

private final boolean isTreeRandom;

public BastaLikelihood(String name,
Tree treeModel,
PatternList patternList,
Expand Down Expand Up @@ -111,10 +110,6 @@ public BastaLikelihood(String name,
addModel(likelihoodDelegate);

this.tree = treeModel;
isTreeRandom = (treeModel instanceof AbstractModel) && ((AbstractModel) treeModel).isVariable();
if (isTreeRandom) {
addModel((AbstractModel)treeModel); // TODO maybe unnecessary as BFTI already signals
}

this.branchRateModel = branchRateModel;
addModel(branchRateModel);
Expand All @@ -127,10 +122,14 @@ public BastaLikelihood(String name,

this.stateCount = substitutionModel.getDataType().getStateCount();

treeIntervals = new BigFastTreeIntervals((TreeModel)treeModel);
treeTraversalDelegate = new CoalescentIntervalTraversal(treeModel, treeIntervals, branchRateModel, numberSubIntervals);
if (tree instanceof TreeModel) {
treeIntervals = new BestSignalsFromBigFastTreeIntervals((TreeModel) treeModel);
addModel(treeIntervals);
} else {
throw new RuntimeException("Not yet implemented");
}

addModel(treeIntervals);
treeTraversalDelegate = new CoalescentIntervalTraversal(treeModel, treeIntervals, branchRateModel, numberSubIntervals);

setTipData();

Expand Down Expand Up @@ -197,49 +196,53 @@ public final void makeDirty() {
}

protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
// TODO
if (variable == popSizeParameter) {
likelihoodKnown = false;
} else {
throw new RuntimeException("Not yet implemented");
}
}

@Override @SuppressWarnings("Duplicates")
protected final void handleModelChangedEvent(Model model, Object object, int index) {

if (model == tree) {
if (object instanceof TreeChangedEvent) {

final TreeChangedEvent treeChangedEvent = (TreeChangedEvent) object;

if (!isTreeRandom) throw new IllegalStateException("Attempting to change a fixed tree");

if (treeChangedEvent.isNodeChanged()) {
// If a node event occurs the node and its two child nodes
// are flagged for updating this will result in everything
// above being updated as well. Node events occur when a node
// is added to a branch, removed from a branch or its height or
// rate changes.
updateNode(((TreeChangedEvent) object).getNode());
} else if (treeChangedEvent.isTreeChanged()) {
// Full tree events result in a complete updating of the tree likelihood
// This event type is now used for EmpiricalTreeDistributions.
updateAllNodes();
}
}
} else if (model == likelihoodDelegate) {
if (index == -1) {
updateAllNodes();
} else {
updateNode(tree.getNode(index));
}

} else if (model == branchRateModel) {
if (index == -1) {
updateAllNodes();
} else {
updateNode(tree.getNode(index));
}
} else {

assert false : "Unknown componentChangedEvent";
}
// if (model == tree) {
// if (object instanceof TreeChangedEvent) {
//
// final TreeChangedEvent treeChangedEvent = (TreeChangedEvent) object;
//
// if (!isTreeRandom) throw new IllegalStateException("Attempting to change a fixed tree");
//
// if (treeChangedEvent.isNodeChanged()) {
// // If a node event occurs the node and its two child nodes
// // are flagged for updating this will result in everything
// // above being updated as well. Node events occur when a node
// // is added to a branch, removed from a branch or its height or
// // rate changes.
// updateNode(((TreeChangedEvent) object).getNode());
// } else if (treeChangedEvent.isTreeChanged()) {
// // Full tree events result in a complete updating of the tree likelihood
// // This event type is now used for EmpiricalTreeDistributions.
// updateAllNodes();
// }
// }
// } else if (model == likelihoodDelegate) {
// if (index == -1) {
// updateAllNodes();
// } else {
// updateNode(tree.getNode(index));
// }
//
// } else if (model == branchRateModel) {
// if (index == -1) {
// updateAllNodes();
// } else {
// updateNode(tree.getNode(index));
// }
// } else {
//
// assert false : "Unknown componentChangedEvent";
// }

if (COUNT_TOTAL_OPERATIONS) totalModelChangedCount++;

Expand All @@ -249,27 +252,18 @@ protected final void handleModelChangedEvent(Model model, Object object, int ind

@Override
protected final void storeState() {

assert (likelihoodKnown) : "the likelihood should always be known at this point in the cycle";
storedLogLikelihood = logLikelihood;

if (TEST) treeIntervals.storeModelState();
}

@Override
protected final void restoreState() {

// restore the likelihood and flag it as known
logLikelihood = storedLogLikelihood;
likelihoodKnown = true;

if (TEST) treeIntervals.restoreModelState();
}

@Override
protected void acceptState() {
if (TEST) treeIntervals.acceptModelState();
} // nothing to do
protected void acceptState() { } // nothing to do

private double calculateLogLikelihood() {

Expand All @@ -287,10 +281,7 @@ private double calculateLogLikelihood() {
final List<TransitionMatrixOperation> matrixOperations =
treeTraversalDelegate.getMatrixOperations();
final List<Integer> intervalStarts = treeTraversalDelegate.getIntervalStarts();

final List<OtherOperation> nodeOperations =
treeTraversalDelegate.getOtherOperations();


if (COUNT_TOTAL_OPERATIONS) {
totalPropagationCount += branchOperations.size();
totalMatrixUpdateCount += matrixOperations.size();
Expand Down
28 changes: 20 additions & 8 deletions src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,37 @@ default void setPopulationSizes(int index, double[] sizes) {

abstract class AbstractBastaLikelihoodDelegate extends AbstractModel implements BastaLikelihoodDelegate, Citable {

protected static final boolean PRINT_COMMANDS = true;

protected final int maxNumCoalescentIntervals;

protected final ParallelizationScheme parallelizationScheme;

protected final int stateCount;

protected final Tree tree;

public AbstractBastaLikelihoodDelegate(String name,
Tree tree,
int stateCount) {
super(name);

this.tree = tree;
this.stateCount = stateCount;
this.maxNumCoalescentIntervals = getMaxNumberOfCoalescentIntervals(tree);
this.parallelizationScheme = ParallelizationScheme.NONE;
}

private int getMaxNumberOfCoalescentIntervals(Tree tree) {
BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree
int zeroLengthSampling = 0;
for (int i = 0; i < intervals.getIntervalCount(); ++i) {
if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) {
++zeroLengthSampling;
}
}
return tree.getNodeCount() - zeroLengthSampling;
// BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree
// int zeroLengthSampling = 0;
// for (int i = 0; i < intervals.getIntervalCount(); ++i) {
// if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) {
// ++zeroLengthSampling;
// }
// }
// return tree.getNodeCount() - zeroLengthSampling;
return tree.getNodeCount();
}

@Override
Expand Down Expand Up @@ -141,6 +147,8 @@ enum ParallelizationScheme {
FULL
}

abstract protected void clearAll();

abstract protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations);

abstract protected double computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations);
Expand All @@ -154,6 +162,10 @@ public double calculateLikelihood(List<BranchIntervalOperation> branchOperations
List<Integer> intervalStarts,
int rootNodeNumber) {

if (PRINT_COMMANDS) {
System.err.println("Tree = " + tree);
}

double logLikelihood = computeTransitionProbabilityOperations(matrixOperation);
logLikelihood += computeBranchIntervalOperations(branchOperations);
logLikelihood += computeCoalescentIntervalReduction(intervalStarts, branchOperations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dr.evomodel.substmodel.EigenDecomposition;
import dr.math.matrixAlgebra.WrappedVector;

import java.util.Arrays;
import java.util.List;

/**
Expand Down Expand Up @@ -46,9 +47,24 @@ public GenericBastaLikelihoodDelegate(String name,
this.temp = new double[stateCount * stateCount];
}

@Override
protected void clearAll() {
// Arrays.fill(partials, tree.getExternalNodeCount() * stateCount, partials.length, 0.0);
// Arrays.fill(matrices, 0.0);
// Arrays.fill(sizes, 0.0);

// Arrays.fill(coalescent, 0.0);
// Arrays.fill(e, 0.0);
// Arrays.fill(f, 0.0);
// Arrays.fill(g, 0.0);
// Arrays.fill(h, 0.0);
}

@Override
protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {

Arrays.fill(coalescent, 0.0);

for (BranchIntervalOperation operation : branchIntervalOperations) { // TODO execute parallel by intervalNumber or executionOrder
peelPartials(
partials, operation.outputBuffer,
Expand Down Expand Up @@ -93,6 +109,11 @@ protected double computeTransitionProbabilityOperations(List<TransitionMatrixOpe
protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations) {

Arrays.fill(e, 0.0);
Arrays.fill(f, 0.0);
Arrays.fill(g, 0.0);
Arrays.fill(h, 0.0);

for (int interval = 0; interval < intervalStarts.size() - 1; ++interval) { // TODO execute in parallel (no race conditions)
int start = intervalStarts.get(interval);
int end = intervalStarts.get(interval + 1);
Expand All @@ -117,6 +138,11 @@ protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts
sizes, coalescent, stateCount);
}

if (PRINT_COMMANDS) {
System.err.println("logL = " + logL + "\n");
}

// return 0.0;
return logL;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ public NativeBastaLikelihoodDelegate(String name,
jni = NativeBastaJniWrapper.getBastaJniWrapper();
}

@Override
protected void clearAll() {

}

@Override
protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {
if (branchIntervalOperations != null) {
Expand Down

0 comments on commit c8b32fa

Please sign in to comment.