From d1e82243be78ca79db7d58f525a32530dc91a511 Mon Sep 17 00:00:00 2001 From: Andy Magee Date: Fri, 22 Sep 2023 14:13:59 -0700 Subject: [PATCH] Copy, don't accumulate --- ...tLogAdditiveSubstitutionModelGradient.java | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java b/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java index ab8576849a..15de465532 100644 --- a/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java +++ b/src/dr/evomodel/treedatalikelihood/discrete/AbstractLogAdditiveSubstitutionModelGradient.java @@ -118,12 +118,8 @@ public double[] getGradientLogDensity() { double[] crossProducts = (double[]) treeTraitProvider.getTrait(tree, null); double[] generator = new double[crossProducts.length]; - if (substitutionModelCount > 1) { -// final int length = stateCount * stateCount; -// System.arraycopy( -// crossProducts, whichSubstitutionModel * length, -// crossProducts, 0, length); - crossProducts = accumulateAcrossSubstitutionModelInstances(crossProducts); + if (whichSubstitutionModel > 1 || substitutionModelCount > 1) { + accumulateAcrossSubstitutionModelInstances(crossProducts); } substitutionModel.getInfinitesimalMatrix(generator); @@ -244,35 +240,41 @@ private int determineSubstitutionModelCount(BranchModel branchModel) { return substitutionModels.size(); } - // Should maybe be void and just update crossProducts? - private double[] accumulateAcrossSubstitutionModelInstances(double[] crossProducts) { - double[] accumulated = new double[crossProducts.length]; + private void accumulateAcrossSubstitutionModelInstances(double[] crossProducts) { final int length = stateCount * stateCount; - // TODO first set of entries should be a copy (instead of accumulate) + // copy first set of entries instead of accumulating + System.arraycopy( + crossProducts, whichSubstitutionModel * length, + crossProducts, 0, length); - for (int i : crossProductAccumulationMap) { - for (int j = 0; j < length; j++) { - accumulated[j] += crossProducts[i * length + j]; + if ( crossProductAccumulationMap.length > 0 ) { + for (int i : crossProductAccumulationMap) { + for (int j = 0; j < length; j++) { + crossProducts[j] += crossProducts[i * length + j]; + } } } - return accumulated; } private void updateCrossProductAccumulationMap() { - System.err.println("Updating crossProductAccumulationMap"); +// System.err.println("Updating crossProductAccumulationMap"); List matchingModels = new ArrayList<>(); List substitutionModels = branchModel.getSubstitutionModels(); + + // We copy whichSubstitutionModel instead of accumulating it for (int i = 0; i < substitutionModels.size(); ++i) { - if (substitutionModel == substitutionModels.get(i)) { + if (i != whichSubstitutionModel && substitutionModel == substitutionModels.get(i)) { matchingModels.add(i); } } crossProductAccumulationMap = new int[matchingModels.size()]; - for (int i = 0; i < matchingModels.size(); ++i) { - crossProductAccumulationMap[i] = matchingModels.get(i); + if (matchingModels.size() > 0) { + for (int i = 0; i < matchingModels.size(); ++i) { + crossProductAccumulationMap[i] = matchingModels.get(i); + } } }