Skip to content

Commit

Permalink
first working code of affine-correction using cross-products
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 6, 2023
1 parent 821f98b commit 142c699
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/dr/evomodel/branchmodel/BranchModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public interface BranchModel extends Model {
* Gets the substitution model that will be applied at the root.
* @return the substitution model
*/
SubstitutionModel getRootSubstitutionModel();
SubstitutionModel getRootSubstitutionModel(); // TODO should deprecate infavor of getRootFrequenceModel

/**
* Gets the frequency model that will be applied at the root.
Expand Down
2 changes: 1 addition & 1 deletion src/dr/evomodel/substmodel/DifferentialMassProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public String getReport() {
return "Exact";
}
},
APPROXIMATE("approximate") {
FIRST_ORDER("approximate") {
@Override
public double[] dispatch(double time,
DifferentiableSubstitutionModel model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ protected double preProcessNormalization(double[] differentials, double[] genera
return calculateCovariateDifferential(generator, differentials, covariate, pi, normalize);
}

private double calculateCovariateDifferential(double[] generator, double[] differential,
private double calculateCovariateDifferential(double[] generator, double[] crossProduct,
double[] covariate, double[] pi,
boolean doNormalization) {

Expand All @@ -130,10 +130,15 @@ private double calculateCovariateDifferential(double[] generator, double[] diffe
double xij = covariate[k++];
double element = xij * generator[index(i,j)];

total += differential[index(i,j)] * element;
total -= differential[index(i,i)] * element;
if (element != 0.0) {
total += crossProduct[index(i, j)] * element;
total -= crossProduct[index(i, i)] * element;

normalization += element * pi[i];
total += correction(i, j, crossProduct) * element;
total -= correction(i, i, crossProduct) * element;

normalization += element * pi[i];
}
}
}

Expand All @@ -143,17 +148,22 @@ private double calculateCovariateDifferential(double[] generator, double[] diffe
double xij = covariate[k++];
double element = xij * generator[index(i,j)];

total += differential[index(i,j)] * element;
total -= differential[index(i,i)] * element;
if (element != 0.0) {
total += crossProduct[index(i, j)] * element;
total -= crossProduct[index(i, i)] * element;

normalization += element * pi[i];
total += correction(i, j, crossProduct) * element;
total -= correction(i, j, crossProduct) * element;

normalization += element * pi[i];
}
}
}

if (doNormalization) {
for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
total -= differential[index(i,j)] * generator[index(i,j)] * normalization;
total -= crossProduct[index(i,j)] * generator[index(i,j)] * normalization;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public double[] getGradientLogDensity() {
// }

substitutionModel.getInfinitesimalMatrix(generator);
crossProducts = correctDifferentials(crossProducts);
// crossProducts = correctDifferentials(crossProducts);

if (DEBUG_CROSS_PRODUCTS) {
savedDifferentials = crossProducts.clone();
Expand All @@ -186,57 +186,45 @@ public double[] getGradientLogDensity() {
return gradient;
}

double[] correctDifferentials(double[] differentials) {
if (mode == ApproximationMode.AFFINE_CORRECTED) {
double[] correction = new double[differentials.length];
// System.arraycopy(differentials, 0, correction, 0, differentials.length);
double correction(int i, int j, double[] crossProducts) {

if (crossProductAccumulationMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}
if (mode == ApproximationMode.FIRST_ORDER) {
return 0.0;
}

EigenDecomposition ed = substitutionModel.getEigenDecomposition();
int index = findZeroEigenvalueIndex(ed.getEigenValues());

double[] eigenVectors = ed.getEigenVectors();
double[] inverseEigenVectors = ed.getInverseEigenVectors();

double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index);
double[] qPlusQ = getQPlusQ(eigenVectors, inverseEigenVectors, index);

double[] generator = new double[stateCount * stateCount];
substitutionModel.getInfinitesimalMatrix(generator);

for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
double entryMN = 0.0;
for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
if (i == j) {
entryMN += differentials[index12(i,j)] *
(1.0 - qQPlus[index12(i,m)]) * qQPlus[index12(n,j)];
} else {
entryMN += differentials[index12(i,j)] *
- qQPlus[index12(i,m)] * qQPlus[index12(n,j)];
}
// entryMN += differentials[i * stateCount + j] *
// qQPlus[i * stateCount + m] * qQPlus[n * stateCount + j];
}
}
correction[index12(m,n)] = entryMN;
}
}
double[] affineMatrix = new double[stateCount * stateCount];

if (crossProductAccumulationMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}

// TODO Start cache for each (i,j) .. only depends on substitutionModel
EigenDecomposition ed = substitutionModel.getEigenDecomposition();
int index = findZeroEigenvalueIndex(ed.getEigenValues());

double[] eigenVectors = ed.getEigenVectors();
double[] inverseEigenVectors = ed.getInverseEigenVectors();

System.err.println("diff: " + new WrappedVector.Raw(differentials));
System.err.println("corr: " + new WrappedVector.Raw(correction));
double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index);

for (int i = 0; i < differentials.length; ++i) {
differentials[i] -= correction[i];
for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
// TODO there are only stateCount unique values
affineMatrix[index12(m,n)] = (m == i) ?
(qQPlus[index12(m,i)] - 1.0) * qQPlus[index12(j,n)] :
qQPlus[index12(m,i)] * qQPlus[index12(j,n)];
}
}
// TODO End cache

double correction = 0.0;
for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; ++n) {
correction += crossProducts[index12(m,n)] * affineMatrix[index12(m,n)];
}
}

return differentials;
return correction;
}

private int findZeroEigenvalueIndex(double[] eigenvalues) {
Expand Down

0 comments on commit 142c699

Please sign in to comment.