Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 9, 2023
1 parent c8b32fa commit f3b481c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/dr/evomodel/coalescent/basta/BastaLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ private double calculateLogLikelihood() {
final List<TransitionMatrixOperation> matrixOperations =
treeTraversalDelegate.getMatrixOperations();
final List<Integer> intervalStarts = treeTraversalDelegate.getIntervalStarts();

if (COUNT_TOTAL_OPERATIONS) {
totalPropagationCount += branchOperations.size();
totalMatrixUpdateCount += matrixOperations.size();
Expand Down
31 changes: 13 additions & 18 deletions src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@ public AbstractBastaLikelihoodDelegate(String name,
}

private int getMaxNumberOfCoalescentIntervals(Tree tree) {
// BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree
// int zeroLengthSampling = 0;
// for (int i = 0; i < intervals.getIntervalCount(); ++i) {
// if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) {
// ++zeroLengthSampling;
// }
// }
// return tree.getNodeCount() - zeroLengthSampling;
return tree.getNodeCount();
BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree
int zeroLengthSampling = 0;
for (int i = 0; i < intervals.getIntervalCount(); ++i) {
if (intervals.getIntervalType(i) == IntervalType.SAMPLE && intervals.getIntervalTime(i) == 0.0) {
++zeroLengthSampling;
}
}
return tree.getNodeCount() - zeroLengthSampling;
}

@Override
Expand Down Expand Up @@ -147,11 +146,9 @@ enum ParallelizationScheme {
FULL
}

abstract protected void clearAll();

abstract protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations);
abstract protected void computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations);

abstract protected double computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations);
abstract protected void computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations);

abstract protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations);
Expand All @@ -166,11 +163,9 @@ public double calculateLikelihood(List<BranchIntervalOperation> branchOperations
System.err.println("Tree = " + tree);
}

double logLikelihood = computeTransitionProbabilityOperations(matrixOperation);
logLikelihood += computeBranchIntervalOperations(branchOperations);
logLikelihood += computeCoalescentIntervalReduction(intervalStarts, branchOperations);

return logLikelihood;
computeTransitionProbabilityOperations(matrixOperation);
computeBranchIntervalOperations(branchOperations);
return computeCoalescentIntervalReduction(intervalStarts, branchOperations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
*/
public class GenericBastaLikelihoodDelegate extends BastaLikelihoodDelegate.AbstractBastaLikelihoodDelegate {

private static final boolean PRINT_COMMANDS = true;

private final double[] partials;
private final double[] matrices;
private final double[] sizes;
Expand Down Expand Up @@ -48,20 +46,7 @@ public GenericBastaLikelihoodDelegate(String name,
}

@Override
protected void clearAll() {
// Arrays.fill(partials, tree.getExternalNodeCount() * stateCount, partials.length, 0.0);
// Arrays.fill(matrices, 0.0);
// Arrays.fill(sizes, 0.0);

// Arrays.fill(coalescent, 0.0);
// Arrays.fill(e, 0.0);
// Arrays.fill(f, 0.0);
// Arrays.fill(g, 0.0);
// Arrays.fill(h, 0.0);
}

@Override
protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {
protected void computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {

Arrays.fill(coalescent, 0.0);

Expand All @@ -81,12 +66,10 @@ protected double computeBranchIntervalOperations(List<BranchIntervalOperation> b
new WrappedVector.Raw(partials, operation.outputBuffer * stateCount, stateCount));
}
}

return 0.0;
}

@Override
protected double computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations) {
protected void computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations) {

for (TransitionMatrixOperation operation : matrixOperations) { // TODO execute in parallel
computeTransitionProbabilities(
Expand All @@ -101,8 +84,6 @@ protected double computeTransitionProbabilityOperations(List<TransitionMatrixOpe
operation.outputBuffer * stateCount * stateCount, stateCount * stateCount));
}
}

return 0.0;
}

@Override
Expand Down Expand Up @@ -142,7 +123,6 @@ protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts
System.err.println("logL = " + logL + "\n");
}

// return 0.0;
return logL;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,21 @@ public NativeBastaLikelihoodDelegate(String name,
}

@Override
protected void clearAll() {

}

@Override
protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {
protected void computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {
if (branchIntervalOperations != null) {
for (BranchIntervalOperation operation : branchIntervalOperations) {
System.err.println(operation.toString());
}
}

return 0.0;
}

@Override
protected double computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations) {
protected void computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations) {
if (matrixOperations != null) {
for (TransitionMatrixOperation operation : matrixOperations) {
System.err.println(operation.toString());
}
}

return 0.0;
}

@Override
Expand Down

0 comments on commit f3b481c

Please sign in to comment.