Skip to content

Commit

Permalink
a delegate to handle node height gradient and hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Oct 9, 2023
1 parent 390bdd8 commit 262d2ce
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ public DiscreteTraitBranchRateGradient(String traitName,
BranchRateModel brm = treeDataLikelihood.getBranchRateModel();
this.branchRateModel = (brm instanceof DifferentiableBranchRates) ? (DifferentiableBranchRates) brm : null;

String name = DiscreteTraitBranchRateDelegate.getName(traitName);
String name = getTraitName(traitName);
TreeTrait test = treeDataLikelihood.getTreeTrait(name);

if (test == null) {
ProcessSimulationDelegate gradientDelegate = new DiscreteTraitBranchRateDelegate(traitName,
treeDataLikelihood.getTree(),
likelihoodDelegate);
ProcessSimulationDelegate gradientDelegate = makeGradientDelegate(traitName, tree, likelihoodDelegate);
TreeTraitProvider traitProvider = new ProcessSimulation(treeDataLikelihood, gradientDelegate);
treeDataLikelihood.addTraits(traitProvider.getTreeTraits());
}
Expand All @@ -106,6 +104,16 @@ public DiscreteTraitBranchRateGradient(String traitName,
// dim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
}

protected String getTraitName(String traitName) {
return DiscreteTraitBranchRateDelegate.getName(null);
}

protected ProcessSimulationDelegate makeGradientDelegate(String traitName, Tree tree, BeagleDataLikelihoodDelegate likelihoodDelegate) {
return new DiscreteTraitBranchRateDelegate(traitName,
tree,
likelihoodDelegate);
}

@Override
public Likelihood getLikelihood() {
return treeDataLikelihood;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* DiscreteTraitNodeHeightDelegate.java
*
* Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/


package dr.evomodel.treedatalikelihood.discrete;

import beagle.Beagle;
import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;

/**
* @author Xiang Ji
* @author Marc A. Suchard
*/
public class DiscreteTraitNodeHeightDelegate extends DiscreteTraitBranchRateDelegate {

static final String GRADIENT_TRAIT_NAME = "NodeHeightGradient";

static final String HESSIAN_TRAIT_NAME = "NodeHeightHessian";

DiscreteTraitNodeHeightDelegate(String name, Tree tree, BeagleDataLikelihoodDelegate likelihoodDelegate) {
super(name, tree, likelihoodDelegate);
}


protected void getNodeDerivatives(Tree tree, double[] first, double[] second) {
super.getNodeDerivatives(tree, first, second);
final int internalNodeCount = tree.getInternalNodeCount();

double[][] prePartials = new double[internalNodeCount][patternCount * stateCount * categoryCount];
double[][] postPartials = new double[internalNodeCount][patternCount * stateCount * categoryCount];
double[][] transitionMatrices = new double[internalNodeCount][stateCount * stateCount * categoryCount];

for (int i = 0; i < internalNodeCount; i++) {
beagle.getPartials(getPostOrderPartialIndex(i + tree.getExternalNodeCount()), Beagle.NONE, postPartials[i]);
beagle.getPartials(getPreOrderPartialIndex(i + tree.getExternalNodeCount()), Beagle.NONE, prePartials[i]);
beagle.getTransitionMatrix(evolutionaryProcessDelegate.getMatrixIndex(i + tree.getExternalNodeCount()), transitionMatrices[i]);
}

}

protected String getGradientTraitName() {
return GRADIENT_TRAIT_NAME;
}

protected String getHessianTraitName() {
return HESSIAN_TRAIT_NAME;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Parameter;
Expand Down Expand Up @@ -67,6 +68,15 @@ public NodeHeightGradientForDiscreteTrait(String traitName,
this.nodeHeightProxyParameter = new NodeHeightProxyParameter("internalNodeHeights", treeModel, true);
}

protected String getTraitName(String traitName) {
return DiscreteTraitNodeHeightDelegate.GRADIENT_TRAIT_NAME;
}

protected ProcessSimulationDelegate makeGradientDelegate(String traitName, Tree tree, BeagleDataLikelihoodDelegate likelihoodDelegate) {
return new DiscreteTraitNodeHeightDelegate(traitName,
tree,
likelihoodDelegate);
}
@Override
public Parameter getParameter() {
return nodeHeightProxyParameter;
Expand Down

0 comments on commit 262d2ce

Please sign in to comment.