From 142c6990729074628836e177da43a13cedea635d Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Thu, 5 Oct 2023 17:18:41 -0700 Subject: [PATCH] first working code of affine-correction using cross-products --- src/dr/evomodel/branchmodel/BranchModel.java | 2 +- .../substmodel/DifferentialMassProvider.java | 2 +- .../AbstractGlmSubstitutionModelGradient.java | 26 +++++-- ...tLogAdditiveSubstitutionModelGradient.java | 76 ++++++++----------- 4 files changed, 52 insertions(+), 54 deletions(-) diff --git a/src/dr/evomodel/branchmodel/BranchModel.java b/src/dr/evomodel/branchmodel/BranchModel.java index 93aade89db..9424845321 100644 --- a/src/dr/evomodel/branchmodel/BranchModel.java +++ b/src/dr/evomodel/branchmodel/BranchModel.java @@ -62,7 +62,7 @@ public interface BranchModel extends Model { * Gets the substitution model that will be applied at the root. * @return the substitution model */ - SubstitutionModel getRootSubstitutionModel(); + SubstitutionModel getRootSubstitutionModel(); // TODO should deprecate infavor of getRootFrequenceModel /** * Gets the frequency model that will be applied at the root. diff --git a/src/dr/evomodel/substmodel/DifferentialMassProvider.java b/src/dr/evomodel/substmodel/DifferentialMassProvider.java index a73532ed66..f3c16b3308 100644 --- a/src/dr/evomodel/substmodel/DifferentialMassProvider.java +++ b/src/dr/evomodel/substmodel/DifferentialMassProvider.java @@ -51,7 +51,7 @@ public String getReport() { return "Exact"; } }, - APPROXIMATE("approximate") { + FIRST_ORDER("approximate") { @Override public double[] dispatch(double time, DifferentiableSubstitutionModel model, diff --git a/src/dr/evomodel/treedatalikelihood/discrete/AbstractGlmSubstitutionModelGradient.java b/src/dr/evomodel/treedatalikelihood/discrete/AbstractGlmSubstitutionModelGradient.java index 9988dbfe54..a5ee171ccc 100644 --- a/src/dr/evomodel/treedatalikelihood/discrete/AbstractGlmSubstitutionModelGradient.java +++ b/src/dr/evomodel/treedatalikelihood/discrete/AbstractGlmSubstitutionModelGradient.java @@ -116,7 +116,7 @@ protected double preProcessNormalization(double[] differentials, double[] genera return calculateCovariateDifferential(generator, differentials, covariate, pi, normalize); } - private double calculateCovariateDifferential(double[] generator, double[] differential, + private double calculateCovariateDifferential(double[] generator, double[] crossProduct, double[] covariate, double[] pi, boolean doNormalization) { @@ -130,10 +130,15 @@ private double calculateCovariateDifferential(double[] generator, double[] diffe double xij = covariate[k++]; double element = xij * generator[index(i,j)]; - total += differential[index(i,j)] * element; - total -= differential[index(i,i)] * element; + if (element != 0.0) { + total += crossProduct[index(i, j)] * element; + total -= crossProduct[index(i, i)] * element; - normalization += element * pi[i]; + total += correction(i, j, crossProduct) * element; + total -= correction(i, i, crossProduct) * element; + + normalization += element * pi[i]; + } } } @@ -143,17 +148,22 @@ private double calculateCovariateDifferential(double[] generator, double[] diffe double xij = covariate[k++]; double element = xij * generator[index(i,j)]; - total += differential[index(i,j)] * element; - total -= differential[index(i,i)] * element; + if (element != 0.0) { + total += crossProduct[index(i, j)] * element; + total -= crossProduct[index(i, i)] * element; - normalization += element * pi[i]; + total += correction(i, j, crossProduct) * element; + total -= correction(i, j, crossProduct) * element; + + normalization += element * pi[i]; + } } } if (doNormalization) { for (int i = 0; i < stateCount; ++i) { for (int j = 0; j < stateCount; ++j) { - total -= differential[index(i,j)] * generator[index(i,j)] * normalization; + total -= crossProduct[index(i,j)] * generator[index(i,j)] * normalization; } } } diff --git a/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java b/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java index 7a37548980..ac5c5c21fa 100644 --- a/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java +++ b/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java @@ -159,7 +159,7 @@ public double[] getGradientLogDensity() { // } substitutionModel.getInfinitesimalMatrix(generator); - crossProducts = correctDifferentials(crossProducts); +// crossProducts = correctDifferentials(crossProducts); if (DEBUG_CROSS_PRODUCTS) { savedDifferentials = crossProducts.clone(); @@ -186,57 +186,45 @@ public double[] getGradientLogDensity() { return gradient; } - double[] correctDifferentials(double[] differentials) { - if (mode == ApproximationMode.AFFINE_CORRECTED) { - double[] correction = new double[differentials.length]; -// System.arraycopy(differentials, 0, correction, 0, differentials.length); + double correction(int i, int j, double[] crossProducts) { - if (crossProductAccumulationMap.size() > 1) { - throw new RuntimeException("Not yet implemented"); - } + if (mode == ApproximationMode.FIRST_ORDER) { + return 0.0; + } - EigenDecomposition ed = substitutionModel.getEigenDecomposition(); - int index = findZeroEigenvalueIndex(ed.getEigenValues()); - - double[] eigenVectors = ed.getEigenVectors(); - double[] inverseEigenVectors = ed.getInverseEigenVectors(); - - double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index); - double[] qPlusQ = getQPlusQ(eigenVectors, inverseEigenVectors, index); - - double[] generator = new double[stateCount * stateCount]; - substitutionModel.getInfinitesimalMatrix(generator); - - for (int m = 0; m < stateCount; ++m) { - for (int n = 0; n < stateCount; n++) { - double entryMN = 0.0; - for (int i = 0; i < stateCount; ++i) { - for (int j = 0; j < stateCount; ++j) { - if (i == j) { - entryMN += differentials[index12(i,j)] * - (1.0 - qQPlus[index12(i,m)]) * qQPlus[index12(n,j)]; - } else { - entryMN += differentials[index12(i,j)] * - - qQPlus[index12(i,m)] * qQPlus[index12(n,j)]; - } -// entryMN += differentials[i * stateCount + j] * -// qQPlus[i * stateCount + m] * qQPlus[n * stateCount + j]; - } - } - correction[index12(m,n)] = entryMN; - } - } + double[] affineMatrix = new double[stateCount * stateCount]; + + if (crossProductAccumulationMap.size() > 1) { + throw new RuntimeException("Not yet implemented"); + } + + // TODO Start cache for each (i,j) .. only depends on substitutionModel + EigenDecomposition ed = substitutionModel.getEigenDecomposition(); + int index = findZeroEigenvalueIndex(ed.getEigenValues()); + + double[] eigenVectors = ed.getEigenVectors(); + double[] inverseEigenVectors = ed.getInverseEigenVectors(); - System.err.println("diff: " + new WrappedVector.Raw(differentials)); - System.err.println("corr: " + new WrappedVector.Raw(correction)); + double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index); - for (int i = 0; i < differentials.length; ++i) { - differentials[i] -= correction[i]; + for (int m = 0; m < stateCount; ++m) { + for (int n = 0; n < stateCount; n++) { + // TODO there are only stateCount unique values + affineMatrix[index12(m,n)] = (m == i) ? + (qQPlus[index12(m,i)] - 1.0) * qQPlus[index12(j,n)] : + qQPlus[index12(m,i)] * qQPlus[index12(j,n)]; } + } + // TODO End cache + double correction = 0.0; + for (int m = 0; m < stateCount; ++m) { + for (int n = 0; n < stateCount; ++n) { + correction += crossProducts[index12(m,n)] * affineMatrix[index12(m,n)]; + } } - return differentials; + return correction; } private int findZeroEigenvalueIndex(double[] eigenvalues) {