Skip to content

Commit

Permalink
clean exact gradient code
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Aug 15, 2023
1 parent 04b3519 commit 4b1574c
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 71 deletions.
115 changes: 45 additions & 70 deletions src/dr/evomodel/substmodel/OldGLMSubstitutionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
* @author Marc A. Suchard
*/
@Deprecated
public class OldGLMSubstitutionModel extends ComplexSubstitutionModel implements ParameterReplaceableSubstitutionModel, DifferentiableSubstitutionModel{
public class OldGLMSubstitutionModel extends ComplexSubstitutionModel
implements ParameterReplaceableSubstitutionModel, DifferentiableSubstitutionModel{

public OldGLMSubstitutionModel(String name, DataType dataType, FrequencyModel rootFreqModel,
LogLinearModel glm) {
Expand Down Expand Up @@ -117,108 +118,82 @@ public List<Citation> getCitations() {
public ParameterReplaceableSubstitutionModel factory(List<Parameter> oldParameters, List<Parameter> newParameters) {

LogLinearModel newGLM = glm.factory(oldParameters, newParameters);

OldGLMSubstitutionModel newGLMSubstitutionModel = new OldGLMSubstitutionModel(getModelName(), dataType, freqModel, newGLM);

return newGLMSubstitutionModel;
return new OldGLMSubstitutionModel(getModelName(), dataType, freqModel, newGLM);
}

@Override
public WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt) {
return DifferentiableSubstitutionModelUtil.getInfinitesimalDifferentialMatrix(wrt, this);
}

enum WrtOldGLMSubstitutionModelParameter implements DifferentialMassProvider.DifferentialWrapper.WrtParameter {
INDEPENDENT_PARAMETER {
@Override
void setDim(int dim) {
this.dim = dim;
}

void setEffectIndex(int fixedEffectIndex) {
this.fixedEffectIndex = fixedEffectIndex;
}
static class WrtOldGLMSubstitutionModelParameter implements DifferentialMassProvider.DifferentialWrapper.WrtParameter {

private int dim;
private int fixedEffectIndex;
private int stateCount;
private LogLinearModel glm;
@Override
public double getRate(int switchCase) {
throw new RuntimeException("Should not be called.");
}
final private int dim;
final private int fixedEffectIndex;
final private int stateCount;
final private LogLinearModel glm;

@Override
public double getNormalizationDifferential() {
return 0;
}
public WrtOldGLMSubstitutionModelParameter(LogLinearModel glm, int fixedEffectIndex, int dim, int stateCount) {
this.glm = glm;
this.fixedEffectIndex = fixedEffectIndex;
this.dim = dim;
this.stateCount = stateCount;
}

@Override
public void setupDifferentialFrequencies(double[] differentialFrequencies, double[] frequencies) {
// System.arraycopy(frequencies, 0, differentialFrequencies, 0, frequencies.length);
Arrays.fill(differentialFrequencies, 1);
}
@Override
public double getRate(int switchCase) {
throw new RuntimeException("Should not be called.");
}

public void setStateCount(int stateCount) {
this.stateCount = stateCount;
}
@Override
public double getNormalizationDifferential() {
return 0;
}

public void setGLM(LogLinearModel glm) {
this.glm = glm;
}
@Override
public void setupDifferentialFrequencies(double[] differentialFrequencies, double[] frequencies) {
Arrays.fill(differentialFrequencies, 1);
}

@Override
public void setupDifferentialRates(double[] differentialRates, double[] Q, double normalizingConstant) {
final double[] covariate = glm.getDesignMatrix(fixedEffectIndex).getColumnValues(dim);
@Override
public void setupDifferentialRates(double[] differentialRates, double[] Q, double normalizingConstant) {

// System.arraycopy(covariate, 0, differentialRates, 0, covariate.length);
final double[] covariate = glm.getDesignMatrix(fixedEffectIndex).getColumnValues(dim);

int k = 0;
for (int i = 0; i < stateCount; ++i) {
for (int j = i + 1; j < stateCount; ++j) {
int k = 0;
for (int i = 0; i < stateCount; ++i) {
for (int j = i + 1; j < stateCount; ++j) {

differentialRates[k] = covariate[k] * Q[index(i, j)];
k++;
differentialRates[k] = covariate[k] * Q[index(i, j)];
k++;

}
}
}

for (int j = 0; j < stateCount; ++j) {
for (int i = j + 1; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
for (int i = j + 1; i < stateCount; ++i) {

differentialRates[k] = covariate[k] * Q[index(i, j)];
k++;
differentialRates[k] = covariate[k] * Q[index(i, j)];
k++;

}
}

}
private int index(int i, int j) {
return i * stateCount + j;
}
};
abstract void setDim(int dim);
abstract void setEffectIndex(int effectIndex);
abstract void setStateCount(int stateCount);

abstract void setGLM(LogLinearModel glm);

}

private int index(int i, int j) {
return i * stateCount + j;
}
}

@Override
public DifferentialMassProvider.DifferentialWrapper.WrtParameter factory(Parameter parameter, int dim) {
assert(dim == 0);
WrtOldGLMSubstitutionModelParameter wrtParameter = WrtOldGLMSubstitutionModelParameter.INDEPENDENT_PARAMETER;

final int effectIndex = glm.getEffectNumber(parameter);
if (effectIndex == -1) {
throw new RuntimeException("Only implemented for single dimensions, break up beta to one for each block for now please.");
}
wrtParameter.setDim(dim);
wrtParameter.setEffectIndex(effectIndex);
wrtParameter.setStateCount(stateCount);
wrtParameter.setGLM(glm);
return wrtParameter;
return new WrtOldGLMSubstitutionModelParameter(glm, effectIndex, dim, stateCount);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* DiscreteTraitBranchSubstitutionParameterDelegate.java
*
* Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.AbstractBeagleBranchGradientDelegate;

/**
* @author Andrew Holbrook
* @author Marc A. Suchard
*/
public class AffineCorrectedSubstitutionParameterDelegate extends AbstractBeagleBranchGradientDelegate {

private final BranchRateModel branchRateModel;
private final BranchDifferentialMassProvider branchDifferentialMassProvider;
private final String name;

private static final String GRADIENT_TRAIT_NAME = "affineCorrectedSubstitutionGradient";
private static final String HESSIAN_TRAIT_NAME = "affineCorrectedSubstitutionHessian";

public AffineCorrectedSubstitutionParameterDelegate(String name,
Tree tree,
BeagleDataLikelihoodDelegate likelihoodDelegate,
BranchRateModel branchRateModel,
BranchDifferentialMassProvider branchDifferentialMassProvider) {
super(name, tree, likelihoodDelegate);
this.name = name;
this.branchRateModel = branchRateModel;
this.branchDifferentialMassProvider = branchDifferentialMassProvider;
}

@Override
protected void cacheDifferentialMassMatrix(Tree tree, boolean cacheSquaredMatrix) {
for (int i = 0; i < tree.getNodeCount(); i++) {
NodeRef node = tree.getNode(i);
if (!tree.isRoot(node)) {

final double time = tree.getBranchLength(node) * branchRateModel.getBranchRate(tree, node);
double[] differentialMassMatrix = branchDifferentialMassProvider.getDifferentialMassMatrixForBranch(node, time);
double[] scaledDifferentialMassMatrix = DiscreteTraitBranchRateDelegate.scaleInfinitesimalMatrixByRates(differentialMassMatrix,
DiscreteTraitBranchRateDelegate.DifferentialChoice.GRADIENT, siteRateModel);
evolutionaryProcessDelegate.cacheFirstOrderDifferentialMatrix(beagle, i, scaledDifferentialMassMatrix);
}
}
if (cacheSquaredMatrix) {
throw new RuntimeException("Not yet implemented!");
}
}

@Override
protected int getFirstDerivativeMatrixBufferIndex(int nodeNum) {
return evolutionaryProcessDelegate.getFirstOrderDifferentialMatrixBufferIndex(nodeNum);
}

@Override
protected int getSecondDerivativeMatrixBufferIndex(int nodeNum) {
return evolutionaryProcessDelegate.getSecondOrderDifferentialMatrixBufferIndex(nodeNum);
}

protected String getGradientTraitName() {
return GRADIENT_TRAIT_NAME + ":" + name;
}

protected String getHessianTraitName() {
return HESSIAN_TRAIT_NAME + ":" + name;
}

public static String getName(String name) {
return GRADIENT_TRAIT_NAME + ":" + name;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

package dr.evomodelxml.speciation;

import com.sun.org.apache.bcel.internal.generic.SWITCH;
import dr.evolution.util.Units;
import dr.evomodel.speciation.MasBirthDeathSerialSamplingModel;
import dr.evomodel.speciation.NewBirthDeathSerialSamplingModel;
Expand Down

0 comments on commit 4b1574c

Please sign in to comment.