Skip to content

Commit

Permalink
first working version of vectorized BASTA
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 8, 2023
1 parent 2fc2f63 commit 7019ee5
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 28 deletions.
22 changes: 17 additions & 5 deletions src/dr/evomodel/coalescent/basta/CoalescentIntervalTraversal.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,15 @@ public int getCurrentOffset(NodeRef node) {

public int getActiveBuffer(NodeRef node) {
if (DEBUG) test(node);
return getCurrentOffset(node) * stride + node.getNumber();
// return 1000 + getCurrentOffset(node) * 1000 + node.getNumber();
int currentOffset = getCurrentOffset(node);
if (currentOffset > 0) {
++currentOffset;
}
return currentOffset * stride + node.getNumber();
}

public int getAccumulationBuffer(NodeRef node) {
return stride + node.getNumber();
}

public int getExecutionOrder(NodeRef node) {
Expand Down Expand Up @@ -284,7 +291,7 @@ private void traverseReverseCoalescentLevelOrder() {
}
}

intervalStarts.remove(intervalStarts.size() - 1);
// intervalStarts.remove(intervalStarts.size() - 1);
}

@SuppressWarnings("unused")
Expand Down Expand Up @@ -326,8 +333,9 @@ private void propagateTransmissionProbabilities(int subInterval, NodeRef node, d
new BranchIntervalOperation(
outputBuffer,
inputBuffer1, -1,
inputMatrix1, -1, length,
executionOrder, subInterval));
inputMatrix1, -1,
outputBuffer, -1,
length, executionOrder, subInterval));

activeNodesForInterval.setExecutionOrder(node, executionOrder);
}
Expand All @@ -339,6 +347,9 @@ private void coalescenceTransmissionProbabilities(int subInterval, NodeRef nodeA
final int inputBuffer1 = activeNodesForInterval.getActiveBuffer(leftChild);
final int inputBuffer2 = activeNodesForInterval.getActiveBuffer(rightChild);

final int extraBuffer1 = activeNodesForInterval.getAccumulationBuffer(leftChild);
final int extraBuffer2 = activeNodesForInterval.getAccumulationBuffer(rightChild);

final int outputBuffer = activeNodesForInterval.getActiveBuffer(nodeAtTopOfInterval);
final int executionOrder = Math.max(
activeNodesForInterval.getExecutionOrder(leftChild),
Expand All @@ -352,6 +363,7 @@ private void coalescenceTransmissionProbabilities(int subInterval, NodeRef nodeA
outputBuffer,
inputBuffer1, inputBuffer2,
inputMatrix1, inputMatrix2,
extraBuffer1, extraBuffer2,
length, executionOrder, subInterval));

activeNodesForInterval.setExecutionOrder(nodeAtTopOfInterval, executionOrder);
Expand Down
141 changes: 127 additions & 14 deletions src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ public class GenericBastaLikelihoodDelegate extends BastaLikelihoodDelegate.Abst
private final double[] sizes;
private final double[] coalescent;

private final double[] e;
private final double[] f;
private final double[] g;
private final double[] h;

private final double[] temp;

private final EigenDecomposition[] decompositions; // TODO flatten?
Expand All @@ -33,19 +38,26 @@ public GenericBastaLikelihoodDelegate(String name,
this.sizes = new double[2 * stateCount];
this.decompositions = new EigenDecomposition[1];

this.e = new double[maxNumCoalescentIntervals * stateCount];
this.f = new double[maxNumCoalescentIntervals * stateCount];
this.g = new double[maxNumCoalescentIntervals * stateCount];
this.h = new double[maxNumCoalescentIntervals * stateCount];

this.temp = new double[stateCount * stateCount];
}

@Override
protected double computeBranchIntervalOperations(List<BranchIntervalOperation> branchIntervalOperations) {

for (BranchIntervalOperation operation : branchIntervalOperations) { // TODO execute parallel by subIntervalNumber or executionOrder
for (BranchIntervalOperation operation : branchIntervalOperations) { // TODO execute parallel by intervalNumber or executionOrder
peelPartials(
partials, operation.outputBuffer,
operation.inputBuffer1, operation.inputBuffer2,
matrices,
operation.inputMatrix1, operation.inputMatrix2,
coalescent, operation.subIntervalNumber,
operation.accBuffer1, operation.accBuffer2,
coalescent, operation.intervalNumber,
sizes, 0,
stateCount);

if (PRINT_COMMANDS) {
Expand Down Expand Up @@ -80,11 +92,32 @@ protected double computeTransitionProbabilityOperations(List<TransitionMatrixOpe
@Override
protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations) {
for (int start : intervalStarts) { // TODO execute in parallel
System.err.println(start);

for (int interval = 0; interval < intervalStarts.size() - 1; ++interval) { // TODO execute in parallel (no race conditions)
int start = intervalStarts.get(interval);
int end = intervalStarts.get(interval + 1);

for (int i = start; i < end; ++i) { // TODO execute in parallel (has race conditions)
BranchIntervalOperation operation = branchIntervalOperations.get(i);
reduceWithinInterval(e, f, g, h, partials,
operation.inputBuffer1, operation.inputBuffer2,
operation.accBuffer1, operation.accBuffer2,
operation.intervalNumber,
stateCount);

}
}

return 0.0;
double logL = 0.0;
for (int i = 0; i < intervalStarts.size() - 1; ++i) { // TODO execute in parallel
BranchIntervalOperation operation = branchIntervalOperations.get(intervalStarts.get(i));

logL += reduceAcrossIntervals(e, f, g, h,
operation.intervalNumber, operation.intervalLength,
sizes, coalescent, stateCount);
}

return logL;
}

@Override
Expand Down Expand Up @@ -112,8 +145,9 @@ private static void peelPartials(double[] partials,
int leftPartialOffset, int rightPartialOffset,
double[] matrices,
int leftMatrixOffset, int rightMatrixOffset,
double[] probability,
int probabilityOffset,
int leftAccOffset, int rightAccOffset,
double[] probability, int probabilityOffset,
double[] sizes, int sizesOffset,
int stateCount) {

resultOffset *= stateCount;
Expand All @@ -125,27 +159,46 @@ private static void peelPartials(double[] partials,
for (int i = 0; i < stateCount; ++i) {
double sum = 0.0;
for (int j = 0; j < stateCount; ++j) {
sum += matrices[leftMatrixOffset + i * stateCount + j] *
partials[leftPartialOffset + j];
sum += matrices[leftMatrixOffset + i * stateCount + j] * partials[leftPartialOffset + j];
}
partials[resultOffset + i] = sum;
}

if (rightPartialOffset >= 0) {
// Handle right
rightPartialOffset *= stateCount;
rightMatrixOffset *= stateCount * stateCount;

leftAccOffset *= stateCount;
rightAccOffset *= stateCount;

sizesOffset *= sizesOffset * stateCount;

double prob = 0.0;
for (int i = 0; i < stateCount; ++i) {
double sum = 0.0;
double right = 0.0;
for (int j = 0; j < stateCount; ++j) {
sum += matrices[rightMatrixOffset + i * stateCount + j] *
partials[rightPartialOffset + j];
right += matrices[rightMatrixOffset + i * stateCount + j] * partials[rightPartialOffset + j];
}
partials[resultOffset + i] *= sum;
// entry = left * right * size
double left = partials[resultOffset + i];
double entry = left * right / sizes[sizesOffset + i];

partials[resultOffset + i] = entry;
partials[leftAccOffset + i] = left;
partials[rightAccOffset + i] = right;

prob += entry;
}

for (int i = 0; i < stateCount; ++i) {
partials[resultOffset + i] /= prob;
}

probability[probabilityOffset] = 1.0; // TODO
probability[probabilityOffset] = prob;
}

// TODO rescale?
}

private static void computeTransitionProbabilities(double distance,
Expand Down Expand Up @@ -204,4 +257,64 @@ private static void computeTransitionProbabilities(double distance,
}
}
}

private static double reduceAcrossIntervals(double[] e, double[] f, double[] g, double[] h,
int interval, double length,
double[] sizes, double[] coalescent,
int stateCount) {

int offset = interval * stateCount;

double sum = 0.0;
for (int k = 0; k < stateCount; ++k) {
sum += (e[offset + k] * e[offset + k] - f[offset + k] +
g[offset + k] * g[offset + k] - h[offset + k]) / sizes[k];
}

double logL = -length * sum / 4;

double prob = coalescent[interval];
if (prob != 0.0) {
logL += Math.log(prob);
}

return logL;
}

private static void reduceWithinInterval(double[] e, double[] f, double[] g, double[] h,
double[] partials,
int startBuffer1, int startBuffer2,
int endBuffer1, int endBuffer2,
int interval, int stateCount) {

interval *= stateCount;

startBuffer1 *= stateCount;
endBuffer1 *= stateCount;

for (int i = 0; i < stateCount; ++i) {
double startP = partials[startBuffer1 + i];
e[interval + i] += startP;
f[interval + i] += startP * startP;

double endP = partials[endBuffer1 + i];
g[interval + i] += endP;
h[interval + i] += endP * endP;
}

if (startBuffer2 >= 0) {
startBuffer2 *= stateCount;
endBuffer2 *= stateCount;

for (int i = 0; i < stateCount; ++i) {
double startP = partials[startBuffer2 + i];
e[interval + i] += startP;
f[interval + i] += startP * startP;

double endP = partials[endBuffer2 + i];
g[interval + i] += endP;
h[interval + i] += endP * endP;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@

package dr.evomodel.coalescent.basta;

import dr.evomodel.treedatalikelihood.TreeTraversal;

import java.util.List;

/**
* ProcessOnCoalescentIntervalDelegate - interface for a plugin delegate for the likelihood based on coalescent intervals
*
Expand All @@ -44,33 +40,40 @@ final class BranchIntervalOperation {
int inputBuffer2,
int inputMatrix1,
int inputMatrix2,
int accBuffer1,
int accBuffer2,
double intervalLength,
int executionOrder,
int subIntervalNumber) {
int intervalNumber) {
this.outputBuffer = outputBuffer;
this.inputBuffer1 = inputBuffer1;
this.inputBuffer2 = inputBuffer2;
this.inputMatrix1 = inputMatrix1;
this.inputMatrix2 = inputMatrix2;
this.accBuffer1 = accBuffer1;
this.accBuffer2 = accBuffer2;
this.intervalLength = intervalLength;
this.executionOrder = executionOrder;
this.subIntervalNumber = subIntervalNumber;
this.intervalNumber = intervalNumber;
}

public String toString() {
return subIntervalNumber + ":" + outputBuffer + " <- " +
return intervalNumber + ":" + outputBuffer + " <- " +
inputBuffer1 + " (" + inputMatrix1 + ") + " +
inputBuffer2 + " (" + inputMatrix2 + ") (" + intervalLength + ") @ " + executionOrder;
inputBuffer2 + " (" + inputMatrix2 + ") (" + intervalLength + ") ["+
accBuffer1 + " + " + accBuffer2 + "] @ " + executionOrder;
}

public final int outputBuffer;
public final int inputBuffer1;
public final int inputBuffer2;
public final int inputMatrix1;
public final int inputMatrix2;
public final int accBuffer1;
public final int accBuffer2;
public final double intervalLength;
public final int executionOrder;
public final int subIntervalNumber;
public final int intervalNumber;
}

final class OtherOperation {
Expand Down

0 comments on commit 7019ee5

Please sign in to comment.