Skip to content

Commit

Permalink
Copy, don't accumulate
Browse files Browse the repository at this point in the history
  • Loading branch information
afmagee committed Sep 22, 2023
1 parent 6f5fea6 commit d1e8224
Showing 1 changed file with 20 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ public double[] getGradientLogDensity() {
double[] crossProducts = (double[]) treeTraitProvider.getTrait(tree, null);
double[] generator = new double[crossProducts.length];

if (substitutionModelCount > 1) {
// final int length = stateCount * stateCount;
// System.arraycopy(
// crossProducts, whichSubstitutionModel * length,
// crossProducts, 0, length);
crossProducts = accumulateAcrossSubstitutionModelInstances(crossProducts);
if (whichSubstitutionModel > 1 || substitutionModelCount > 1) {
accumulateAcrossSubstitutionModelInstances(crossProducts);
}

substitutionModel.getInfinitesimalMatrix(generator);
Expand Down Expand Up @@ -244,35 +240,41 @@ private int determineSubstitutionModelCount(BranchModel branchModel) {
return substitutionModels.size();
}

// Should maybe be void and just update crossProducts?
private double[] accumulateAcrossSubstitutionModelInstances(double[] crossProducts) {
double[] accumulated = new double[crossProducts.length];
private void accumulateAcrossSubstitutionModelInstances(double[] crossProducts) {
final int length = stateCount * stateCount;

// TODO first set of entries should be a copy (instead of accumulate)
// copy first set of entries instead of accumulating
System.arraycopy(
crossProducts, whichSubstitutionModel * length,
crossProducts, 0, length);

for (int i : crossProductAccumulationMap) {
for (int j = 0; j < length; j++) {
accumulated[j] += crossProducts[i * length + j];
if ( crossProductAccumulationMap.length > 0 ) {
for (int i : crossProductAccumulationMap) {
for (int j = 0; j < length; j++) {
crossProducts[j] += crossProducts[i * length + j];
}
}
}

return accumulated;
}

private void updateCrossProductAccumulationMap() {
System.err.println("Updating crossProductAccumulationMap");
// System.err.println("Updating crossProductAccumulationMap");
List<Integer> matchingModels = new ArrayList<>();
List<SubstitutionModel> substitutionModels = branchModel.getSubstitutionModels();

// We copy whichSubstitutionModel instead of accumulating it
for (int i = 0; i < substitutionModels.size(); ++i) {
if (substitutionModel == substitutionModels.get(i)) {
if (i != whichSubstitutionModel && substitutionModel == substitutionModels.get(i)) {
matchingModels.add(i);
}
}

crossProductAccumulationMap = new int[matchingModels.size()];
for (int i = 0; i < matchingModels.size(); ++i) {
crossProductAccumulationMap[i] = matchingModels.get(i);
if (matchingModels.size() > 0) {
for (int i = 0; i < matchingModels.size(); ++i) {
crossProductAccumulationMap[i] = matchingModels.get(i);
}
}
}

Expand Down

0 comments on commit d1e8224

Please sign in to comment.