Skip to content

Commit

Permalink
make action cache less again
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Oct 18, 2024
1 parent 42488c2 commit d9adf39
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 136 deletions.
143 changes: 143 additions & 0 deletions src/dr/evomodel/substmodel/ActionEnabledSubstitution.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@

package dr.evomodel.substmodel;

import dr.evolution.datatype.DataType;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;

import java.util.Arrays;

/**
* @author Marc A. Suchard
* @author Xiang Ji
Expand All @@ -34,4 +42,139 @@ public interface ActionEnabledSubstitution extends SubstitutionModel {

public void getNonZeroEntries(int[] rowIndices, int[] colIndices, double[] values);

public class ActionEnabledSubstitutionWrap extends AbstractModel implements ActionEnabledSubstitution {

private SubstitutionModel substitutionModel;

private boolean substitutionModelKnown;

private int[] rowIndices;

private int[] colIndices;

private double[] values;

private double[] Q;

private int stateCount;


private int numNonZeroEntries;

public ActionEnabledSubstitutionWrap(String name, SubstitutionModel substitutionModel) {
super(name);
this.substitutionModel = substitutionModel;
this.substitutionModelKnown = false;
this.stateCount = substitutionModel.getFrequencyModel().getFrequencyCount();

this.rowIndices = new int[stateCount * stateCount];
this.colIndices = new int[stateCount * stateCount];
this.values = new double[stateCount * stateCount];
this.Q = new double[stateCount * stateCount];
processSubstitutionModel();
addModel(substitutionModel);
}

private void processSubstitutionModel() {
Arrays.fill(rowIndices, 0);
Arrays.fill(colIndices, 0);
Arrays.fill(values, 0);

substitutionModel.getInfinitesimalMatrix(Q);

numNonZeroEntries = 0;
for (int row = 0; row < stateCount; row++) {
for (int col = 0; col < stateCount; col++) {
final double value = Q[row * stateCount + col];
if (value != 0) {
rowIndices[numNonZeroEntries] = row;
colIndices[numNonZeroEntries] = col;
values[numNonZeroEntries] = value;
numNonZeroEntries++;
}
}
}

substitutionModelKnown = true;

}

@Override
public int getNonZeroEntryCount() {
if (!substitutionModelKnown) {
processSubstitutionModel();
}
return numNonZeroEntries;
}

@Override
public void getNonZeroEntries(int[] inRowIndices, int[] inColIndices, double[] inValues) {
if (!substitutionModelKnown) {
processSubstitutionModel();
}
System.arraycopy(rowIndices, 0, inRowIndices, 0, numNonZeroEntries);
System.arraycopy(colIndices, 0, inColIndices, 0, numNonZeroEntries);
System.arraycopy(values, 0, inValues, 0, numNonZeroEntries);
}

@Override
public void getTransitionProbabilities(double distance, double[] matrix) {
throw new RuntimeException("Not yet implemented!");
}

@Override
public EigenDecomposition getEigenDecomposition() {
throw new RuntimeException("Not yet implemented!");
}

@Override
public FrequencyModel getFrequencyModel() {
return substitutionModel.getFrequencyModel();
}

@Override
public void getInfinitesimalMatrix(double[] matrix) {
throw new RuntimeException("Not yet implemented!");
}

@Override
public DataType getDataType() {
throw new RuntimeException("Not yet implemented!");
}

@Override
public boolean canReturnComplexDiagonalization() {
throw new RuntimeException("Not yet implemented!");
}


@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
substitutionModelKnown = false;
fireModelChanged();
}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
substitutionModelKnown = false;
fireModelChanged();
}

@Override
protected void storeState() {

}

@Override
protected void restoreState() {
substitutionModelKnown = false;
}

@Override
protected void acceptState() {

}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -205,139 +205,7 @@ protected void acceptState() {

}

private class ActionEnabledSubstitutionWrap extends AbstractModel implements ActionEnabledSubstitution {

private SubstitutionModel substitutionModel;

private boolean substitutionModelKnown;

private int[] rowIndices;

private int[] colIndices;

private double[] values;

private double[] Q;

private int stateCount;


private int numNonZeroEntries;

ActionEnabledSubstitutionWrap(String name, SubstitutionModel substitutionModel) {
super(name);
this.substitutionModel = substitutionModel;
this.substitutionModelKnown = false;
this.stateCount = substitutionModel.getFrequencyModel().getFrequencyCount();

this.rowIndices = new int[stateCount * stateCount];
this.colIndices = new int[stateCount * stateCount];
this.values = new double[stateCount * stateCount];
this.Q = new double[stateCount * stateCount];
processSubstitutionModel();
addModel(substitutionModel);
}

private void processSubstitutionModel() {
Arrays.fill(rowIndices, 0);
Arrays.fill(colIndices, 0);
Arrays.fill(values, 0);

substitutionModel.getInfinitesimalMatrix(Q);

numNonZeroEntries = 0;
for (int row = 0; row < stateCount; row++) {
for (int col = 0; col < stateCount; col++) {
final double value = Q[row * stateCount + col];
if (value != 0) {
rowIndices[numNonZeroEntries] = row;
colIndices[numNonZeroEntries] = col;
values[numNonZeroEntries] = value;
numNonZeroEntries++;
}
}
}

substitutionModelKnown = true;

}

@Override
public int getNonZeroEntryCount() {
if (!substitutionModelKnown) {
processSubstitutionModel();
}
return numNonZeroEntries;
}

@Override
public void getNonZeroEntries(int[] inRowIndices, int[] inColIndices, double[] inValues) {
if (!substitutionModelKnown) {
processSubstitutionModel();
}
System.arraycopy(rowIndices, 0, inRowIndices, 0, numNonZeroEntries);
System.arraycopy(colIndices, 0, inColIndices, 0, numNonZeroEntries);
System.arraycopy(values, 0, inValues, 0, numNonZeroEntries);
}

@Override
public void getTransitionProbabilities(double distance, double[] matrix) {
throw new RuntimeException("Not yet implemented!");
}

@Override
public EigenDecomposition getEigenDecomposition() {
throw new RuntimeException("Not yet implemented!");
}

@Override
public FrequencyModel getFrequencyModel() {
return substitutionModel.getFrequencyModel();
}

@Override
public void getInfinitesimalMatrix(double[] matrix) {
throw new RuntimeException("Not yet implemented!");
}

@Override
public DataType getDataType() {
throw new RuntimeException("Not yet implemented!");
}

@Override
public boolean canReturnComplexDiagonalization() {
throw new RuntimeException("Not yet implemented!");
}


@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
substitutionModelKnown = false;
fireModelChanged();
}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
substitutionModelKnown = false;
fireModelChanged();
}

@Override
protected void storeState() {

}

@Override
protected void restoreState() {
substitutionModelKnown = false;
}

@Override
protected void acceptState() {

}
}

public enum RateCase {
SINGLE("single") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
import dr.evomodel.treedatalikelihood.EvolutionaryProcessDelegate;
import dr.evomodel.treedatalikelihood.PreOrderSettings;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class ActionSubstitutionModelDelegate implements EvolutionaryProcessDelegate {

private final Tree tree;
private final BranchModel branchModel;
private final int nodeCount;

private final List<ActionEnabledSubstitution> substitutionModels;

private final int stateCount;

private final HashMap<SubstitutionModel, Integer> eigenIndexMap;
Expand All @@ -25,12 +29,25 @@ public ActionSubstitutionModelDelegate (Tree tree,
int nodeCount) {
this.tree = tree;
this.branchModel = branchModel;
this.substitutionModels = getSubstitutionModels(branchModel);
this.nodeCount = nodeCount;
this.stateCount = branchModel.getRootFrequencyModel().getFrequencyCount();
this.eigenIndexMap = new HashMap<>();
for (int i = 0; i < getSubstitutionModelCount(); i++) {
eigenIndexMap.put(branchModel.getSubstitutionModels().get(i), i);
eigenIndexMap.put(substitutionModels.get(i), i);
}
}

private List<ActionEnabledSubstitution> getSubstitutionModels(BranchModel branchModel) {
List<ActionEnabledSubstitution> substitutionModels = new ArrayList<>();
for (SubstitutionModel substitutionModel : branchModel.getSubstitutionModels()) {
if (substitutionModel instanceof ActionEnabledSubstitution) {
substitutionModels.add((ActionEnabledSubstitution) substitutionModel);
} else {
substitutionModels.add(new ActionEnabledSubstitution.ActionEnabledSubstitutionWrap("original.substitution.model", substitutionModel));
}
}
return substitutionModels;
}

@Override
Expand All @@ -40,7 +57,7 @@ public boolean canReturnComplexDiagonalization() {

@Override
public int getEigenBufferCount() {
return branchModel.getSubstitutionModels().size();
return substitutionModels.size();
}

@Override
Expand Down Expand Up @@ -106,12 +123,12 @@ private int getDifferentialMassMatrixBufferCount(PreOrderSettings settings) {
}
@Override
public int getSubstitutionModelCount() {
return branchModel.getSubstitutionModels().size();
return substitutionModels.size();
}

@Override
public SubstitutionModel getSubstitutionModel(int index) {
return branchModel.getSubstitutionModels().get(index);
return substitutionModels.get(index);
}

@Override
Expand Down

0 comments on commit d9adf39

Please sign in to comment.