Skip to content

Commit

Permalink
initial infrastructure for BEAGLE-based BASTA
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 11, 2023
1 parent ae71587 commit 2cb3c1b
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 15 deletions.
Binary file modified lib/beagle.jar
Binary file not shown.
77 changes: 77 additions & 0 deletions src/beagle/basta/BastaFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package beagle.basta;

import beagle.*;

import java.util.logging.Logger;

public class BastaFactory extends BeagleFactory {

public static BeagleBasta loadBastaInstance(
int tipCount,
int partialsBufferCount,
int compactBufferCount,
int stateCount,
int patternCount,
int eigenBufferCount,
int matrixBufferCount,
int categoryCount,
int scaleBufferCount,
int[] resourceList,
long preferenceFlags,
long requirementFlags) {

getBeagleJNIWrapper();
if (BeagleJNIWrapper.INSTANCE != null) {

getBastaJNIWrapper();
if (BastaJNIWrapper.INSTANCE != null) {

try {
BeagleBasta beagle = new BastaJNIImpl(
tipCount,
partialsBufferCount,
compactBufferCount,
stateCount,
patternCount,
eigenBufferCount,
matrixBufferCount,
categoryCount,
scaleBufferCount,
resourceList,
preferenceFlags,
requirementFlags
);

// In order to know that it was a CPU instance created, we have to let BEAGLE
// to make the instance and then override it...

InstanceDetails details = beagle.getDetails();

if (details != null) // If resourceList/requirements not met, details == null here
return beagle;

} catch (BeagleException beagleException) {
Logger.getLogger("beagle").info(" " + beagleException.getMessage());
}
} else {
throw new RuntimeException("No acceptable BEAGLE-BASTA library plugin found. " +
"Make sure that BEAGLE-BASTA is properly installed or try changing resource requirements.");
}
}

throw new RuntimeException("No acceptable BEAGLE library plugins found. " +
"Make sure that BEAGLE is properly installed or try changing resource requirements.");
}

private static BastaJNIWrapper getBastaJNIWrapper() {
if (BastaJNIWrapper.INSTANCE == null) {
try {
BastaJNIWrapper.loadBastaLibrary();
} catch (UnsatisfiedLinkError ule) {
System.err.println("Failed to load BEAGLE-BASTA library: " + ule.getMessage());
}
}

return BastaJNIWrapper.INSTANCE;
}
}
42 changes: 42 additions & 0 deletions src/beagle/basta/BastaJNIImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package beagle.basta;

import beagle.BeagleException;
import beagle.BeagleJNIImpl;

public class BastaJNIImpl extends BeagleJNIImpl implements BeagleBasta {

public BastaJNIImpl(int tipCount,
int partialsBufferCount,
int compactBufferCount,
int stateCount,
int patternCount,
int eigenBufferCount,
int matrixBufferCount,
int categoryCount,
int scaleBufferCount,
final int[] resourceList,
long preferenceFlags,
long requirementFlags) {
super(tipCount, partialsBufferCount, compactBufferCount, stateCount, patternCount, eigenBufferCount,
matrixBufferCount, categoryCount, scaleBufferCount, resourceList, preferenceFlags, requirementFlags);

}

@Override
public void updateBastaPartials(int[] operations, int operationCount, int populationSizeIndex) {
int errCode = BastaJNIWrapper.INSTANCE.updateBastaPartials(instance, operations, operationCount,
populationSizeIndex);
if (errCode != 0) {
throw new BeagleException("updateBastaPartials", errCode);
}
}

@Override
public void accumulateBastaPartials(int[] operations, int operationCount, int[] segments, int segmentCount) {
int errCode = BastaJNIWrapper.INSTANCE.accumulateBastaPartials(instance,operations, operationCount,
segments, segmentCount);
if (errCode != 0) {
throw new BeagleException("accumulateBastaPartials", errCode);
}
}
}
44 changes: 44 additions & 0 deletions src/beagle/basta/BastaJNIWrapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package beagle.basta;

public class BastaJNIWrapper {

private static final String LIBRARY_NAME = getPlatformSpecificLibraryName();

private BastaJNIWrapper() { }

public native int updateBastaPartials(int instance,
final int[] operations,
int operationCount,
int populationSizeIndex);

public native int accumulateBastaPartials(int instance,
final int[] operations,
int operationCount,
final int[] segments,
int segmentCount);

private static String getPlatformSpecificLibraryName() {
String osName = System.getProperty("os.name").toLowerCase();
String osArch = System.getProperty("os.arch").toLowerCase();
if (osName.startsWith("windows")) {
if (osArch.equals("x86") || osArch.equals("i386")) return "hmsbeagle-basta32";
if (osArch.startsWith("amd64") || osArch.startsWith("x86_64")) return "hmsbeagle-basta64";
}
return "hmsbeagle-jni-basta";
}

public static void loadBastaLibrary() throws UnsatisfiedLinkError {
String path = "";
if (System.getProperty("beagle.library.path") != null) {
path = System.getProperty("beagle.library.path");
if (path.length() > 0 && !path.endsWith("/")) {
path += "/";
}
}

System.loadLibrary(path + LIBRARY_NAME);
INSTANCE = new BastaJNIWrapper();
}

public static BastaJNIWrapper INSTANCE;
}
15 changes: 15 additions & 0 deletions src/beagle/basta/BeagleBasta.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package beagle.basta;

import beagle.Beagle;

public interface BeagleBasta extends Beagle {

void updateBastaPartials(final int[] operations,
int operationCount,
int populationSizeIndex);

void accumulateBastaPartials(final int[] operations,
int operationCount,
final int[] segments,
int segmentCount);
}
4 changes: 2 additions & 2 deletions src/dr/evomodel/coalescent/basta/BastaLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ protected void acceptState() { } // nothing to do
private double calculateLogLikelihood() {

// update eigen-decomposition
likelihoodDelegate.setEigenDecomposition(0, substitutionModel.getEigenDecomposition()); // TODO do conditionally and double-buffer
likelihoodDelegate.updateEigenDecomposition(0, substitutionModel.getEigenDecomposition(), false); // TODO do conditionally and double-buffer

// update population sizes
likelihoodDelegate.setPopulationSizes(0, popSizeParameter.getParameterValues()); // TODO do conditionally and double-buffer
likelihoodDelegate.updatePopulationSizes(0, popSizeParameter.getParameterValues(), false); // TODO do conditionally and double-buffer

// update operations on tree
treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations();
Expand Down
4 changes: 2 additions & 2 deletions src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ default void getPartials(int index, double[] partials) {
throw new RuntimeException("Not yet implemented");
}

default void setEigenDecomposition(int index, EigenDecomposition decomposition) {
default void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) {
throw new RuntimeException("Not yet implemented");
}

default void setPopulationSizes(int index, double[] sizes) {
default void updatePopulationSizes(int index, double[] sizes, boolean flip) {
throw new RuntimeException("Not yet implemented");
}

Expand Down
100 changes: 100 additions & 0 deletions src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package dr.evomodel.coalescent.basta;

import beagle.Beagle;
import beagle.basta.BeagleBasta;
import beagle.basta.BastaFactory;
import dr.evolution.tree.Tree;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;

import java.util.List;

/**
* @author Marc A. Suchard
*/
public class BeagleBastaLikelihoodDelegate extends BastaLikelihoodDelegate.AbstractBastaLikelihoodDelegate {

private final BeagleBasta beagle;

private final BufferIndexHelper eigenBufferHelper;
private final OffsetBufferIndexHelper populationSizesBufferHelper;

public BeagleBastaLikelihoodDelegate(String name,
Tree tree,
int stateCount) {
super(name, tree, stateCount);

beagle = BastaFactory.loadBastaInstance(1, 1, 1, 16,
1, 1, 1, 1,
1, null, 0L, 0L);

eigenBufferHelper = new BufferIndexHelper(1, 0);
populationSizesBufferHelper = new OffsetBufferIndexHelper(1, 0, 0);

double[] tmp = new double[16];
beagle.setPartials(0, tmp);
}

@Override
protected void computeBranchIntervalOperations(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations) {

}

@Override
protected void computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations) {

}

@Override
protected double computeCoalescentIntervalReduction(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations) {
return 0;
}

@Override
public void setPartials(int index, double[] partials) {
beagle.setPartials(index, partials);
}

@Override
public void getPartials(int index, double[] partials) {
assert index >= 0;
assert partials != null;

beagle.getPartials(index, Beagle.NONE, partials);
}

@Override
public void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) {
if (flip) {
eigenBufferHelper.flipOffset(0);
}

beagle.setEigenDecomposition(
eigenBufferHelper.getOffsetIndex(0),
decomposition.getEigenVectors(),
decomposition.getInverseEigenVectors(),
decomposition.getEigenValues());
}

@Override
public void updatePopulationSizes(int index, double[] sizes, boolean flip) {
if (flip) {
populationSizesBufferHelper.flipOffset(0);
}

beagle.setPartials(populationSizesBufferHelper.getOffsetIndex(0),
sizes);
}

class OffsetBufferIndexHelper extends BufferIndexHelper {

public OffsetBufferIndexHelper(int maxIndexValue, int minIndexValue, int bufferSetNumber) {
super(maxIndexValue, minIndexValue, bufferSetNumber);
}

@Override
protected int computeOffset(int offset) { return offset; }
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package dr.evomodel.coalescent.basta;

import beagle.Beagle;
import beagle.basta.BeagleBasta;
import beagle.basta.BastaFactory;
import dr.evolution.tree.Tree;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.math.matrixAlgebra.WrappedVector;
Expand Down Expand Up @@ -31,6 +34,34 @@ public GenericBastaLikelihoodDelegate(String name,
int stateCount) {
super(name, tree, stateCount);

BeagleBasta basta = BastaFactory.loadBastaInstance(1, 1, 1, 16, 1, 1, 1, 1,
1, null, 0L, 0L);

// Beagle basta = BeagleFactory.loadBeagleInstance(10, 10, 0, 16, 1, 1, 1, 1,
// 1, null, 0L, 0L);


// beagle = BeagleFactory.loadBeagleInstance(
// tipCount,
// numPartials,
// compactPartialsCount,
// stateCount,
// patternCount,
// evolutionaryProcessDelegate.getEigenBufferCount(),
// numMatrices,
// categoryCount,
// numScaleBuffers, // Always allocate; they may become necessary
// resourceList,
// preferenceFlags,
// requirementFlags
// );

int cumulativeBufferIndex = Beagle.NONE;
/* No need to rescale partials */

double[] tmp = new double[16];
basta.setPartials(0, tmp);

this.partials = new double[maxNumCoalescentIntervals * tree.getNodeCount() * stateCount]; // TODO much too large
this.matrices = new double[maxNumCoalescentIntervals * stateCount * stateCount]; // TODO much too small (except for strict-clock
this.coalescent = new double[maxNumCoalescentIntervals];
Expand Down Expand Up @@ -154,12 +185,12 @@ public void setPartials(int index, double[] partials) {
}

@Override
public void setEigenDecomposition(int index, EigenDecomposition decomposition) {
public void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) {
decompositions[index] = decomposition;
}

@Override
public void setPopulationSizes(int index, double[] sizes) {
public void updatePopulationSizes(int index, double[] sizes, boolean flip) {
assert sizes.length == stateCount;

System.arraycopy(sizes, 0, this.sizes, index * stateCount, stateCount);
Expand Down
Loading

0 comments on commit 2cb3c1b

Please sign in to comment.