Skip to content

Commit

Permalink
thorney work refactored as dataLikelihoodDelegate
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Aug 17, 2023
1 parent 8709fa5 commit 354afcd
Show file tree
Hide file tree
Showing 47 changed files with 767 additions and 1,981 deletions.
4 changes: 4 additions & 0 deletions src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import java.util.Arrays;
import java.util.List;

/**
* Smart intervals that don't need a full recalculation.
* author: JT
*/
public class BigFastTreeIntervals extends AbstractModel implements Units, TreeIntervalList {
public BigFastTreeIntervals(TreeModel tree) {
this("bigFastIntervals",tree);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dr.evomodel.bigfasttree.thorney;


public interface BranchLengthLikelihoodDelegate {
double getLogLikelihood(double mutations, double branchLength);

double getGradientWrtTime(double mutations, double time);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evomodel.tree.TreeModel;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.datatype.ContinuousDataType;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
Expand All @@ -9,13 +11,16 @@
import java.util.Map;


public class ConstrainedTreeBranchLengthProvider implements BranchLengthProvider {
public static final String CONSTRAINED_TREE_BRANCHLENGTH_PROVIDER = "ConstrainedTreeBranchLengthProvider";
public ConstrainedTreeBranchLengthProvider(ConstrainedTreeModel constrainedTreeModel,Tree dataTree,Double scale,double minBranchlength,boolean discrete){
this.scale = scale;
this.discrete = discrete;
this.minBranchlength = minBranchlength;
public class ConstrainedTreeBranchLengthProvider extends MutationBranchMap.AbstractMutationBranchMap {
public static final String CONSTRAINED_TREE_BRANCHLENGTH_PROVIDER = "ConstrainedTreeBranchMutationProvider";


public ConstrainedTreeBranchLengthProvider(ConstrainedTreeModel constrainedTreeModel,Tree dataTree,Double scale,double minBranchLength,boolean discrete){
super(ContinuousDataType.INSTANCE);

this.minBranchLength = minBranchLength;

this.constrainedTreeModel = constrainedTreeModel;

externalBranchLengths = new double[dataTree.getExternalNodeCount()];
cladeBranchLengths = new double[dataTree.getInternalNodeCount()];
Expand Down Expand Up @@ -52,22 +57,26 @@ public ConstrainedTreeBranchLengthProvider(ConstrainedTreeModel constrainedTreeM
this(constrainedTreeModel,dataTree,1.0,0.0,true);
}

@Override
public double getBranchLength(Tree tree, NodeRef node) {
if (tree.isExternal(node)) {
public double getBranchLength( NodeRef node) {
if (constrainedTreeModel.isExternal(node)) {
return externalBranchLengths[node.getNumber()];
}
TreeModel subtree = ((ConstrainedTreeModel) tree).getSubtree(node);
NodeRef nodeInSubtree = ((ConstrainedTreeModel) tree).getNodeInSubtree(subtree,node);
TreeModel subtree = constrainedTreeModel.getSubtree(node);
NodeRef nodeInSubtree = constrainedTreeModel.getNodeInSubtree(subtree,node);

if (subtree.isRoot(nodeInSubtree)) {
int subtreeIndex = ((ConstrainedTreeModel) tree).getSubtreeIndex(node);
int subtreeIndex = constrainedTreeModel.getSubtreeIndex(node);
return cladeBranchLengths[subtreeIndex];
}else{
return minBranchlength;
return minBranchLength;
}
}

public double getMutations(NodeRef node){
return getBranchLength(node);
}



/**
* Gets a HashMap of clade bitsets to nodes in tree. This is useful for comparing the topology of trees
Expand Down Expand Up @@ -108,7 +117,10 @@ private BitSet addBits(Tree referenceTree, Tree tree, NodeRef node, HashMap map)

private final double[] cladeBranchLengths;
private final double[] externalBranchLengths;
private final double scale;
private final boolean discrete;
private final double minBranchlength;

private final double minBranchLength;
private ConstrainedTreeModel constrainedTreeModel;



}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.tree.*;
import dr.evolution.util.Taxon;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evomodel.operators.AbstractAdaptableTreeOperator;
import dr.evomodel.tree.TreeModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* Boston, MA 02110-1301 USA
*/

package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@
// (powered by FernFlower decompiler)
//

package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.datatype.ContinuousDataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;

import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;

public class FixedTreeBranchLengthProvider implements BranchLengthProvider {
public class FixedTreeBranchLengthProvider extends MutationBranchMap.AbstractMutationBranchMap {
public static final String FIXED_TREE_BRANCHLENGTH_PROVIDER = "FixedTreeBranchLengthProvider";
private final double[] branchLengths;
private final double scale;
private final boolean discrete;
private final double minBranchlength;
private final Tree tree;

public FixedTreeBranchLengthProvider(Tree fixedTree, Tree dataTree, Double scale, double minBranchlength, boolean discrete) {
this.scale = scale;
this.discrete = discrete;
this.minBranchlength = minBranchlength;

public FixedTreeBranchLengthProvider(Tree fixedTree, Tree dataTree, Double scale, Double minBranchLength, boolean discrete) {
super(ContinuousDataType.INSTANCE);

this.tree = fixedTree;
this.branchLengths = new double[dataTree.getNodeCount()];
if (this.tree.getNodeCount() != dataTree.getNodeCount()) {
Expand All @@ -42,7 +40,8 @@ public FixedTreeBranchLengthProvider(Tree fixedTree, Tree dataTree, Double scale
NodeRef node = fixedTree.getExternalNode(i);
String taxonId = fixedTree.getNodeTaxon(node).getId();
NodeRef dataNode = taxonIdNodeMap.get(taxonId);
this.branchLengths[node.getNumber()] = discrete ? (double)Math.round(dataTree.getBranchLength(dataNode) * scale) : dataTree.getBranchLength(dataNode) * scale;
double observedLength = discrete ? (double) Math.round(dataTree.getBranchLength(dataNode) * scale) : dataTree.getBranchLength(dataNode) * scale;
this.branchLengths[node.getNumber()] = Math.max(minBranchLength, observedLength);
}

Map<BitSet, NodeRef> dataTreeMap = this.getBitSetNodeMap(dataTree, dataTree);
Expand All @@ -58,14 +57,15 @@ public FixedTreeBranchLengthProvider(Tree fixedTree, Tree dataTree, Double scale
for(int i = 0; i < dataTree.getInternalNodeCount(); ++i) {
NodeRef dataNode = dataTree.getInternalNode(i);
NodeRef node = dataTreeNodeMap.get(dataNode);
this.branchLengths[node.getNumber()] = discrete ? (double)Math.round(dataTree.getBranchLength(dataNode) * scale) : dataTree.getBranchLength(dataNode) * scale;
double observedLength = discrete ? (double)Math.round(dataTree.getBranchLength(dataNode) * scale) : dataTree.getBranchLength(dataNode) * scale;
this.branchLengths[node.getNumber()] = Math.max(minBranchLength, observedLength);
}

}
}

public FixedTreeBranchLengthProvider(Tree tree, Tree dataTree) {
this(tree, dataTree, 1.0D, 0.0D, true);
this(tree, dataTree, 1.0D,0.0D, true);
}

public double getBranchLength(Tree tree, NodeRef node) {
Expand All @@ -75,6 +75,10 @@ public double getBranchLength(Tree tree, NodeRef node) {
throw new RuntimeException("Unrecognized Tree");
}
}
public double getMutations(NodeRef node){
return getBranchLength(tree, node);

}

private HashMap<BitSet, NodeRef> getBitSetNodeMap(Tree referenceTree, Tree tree) {
HashMap<BitSet, NodeRef> map = new HashMap();
Expand All @@ -97,4 +101,6 @@ private BitSet addBits(Tree referenceTree, Tree tree, NodeRef node, HashMap map)
map.put(bits, node);
return bits;
}


}
23 changes: 23 additions & 0 deletions src/dr/evomodel/bigfasttree/thorney/MutationBranchMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dr.evomodel.bigfasttree.thorney;


import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;

public interface MutationBranchMap{
public DataType getDataType();
public double getMutations(final NodeRef node);


public abstract class AbstractMutationBranchMap implements MutationBranchMap{
private DataType dataType;

public AbstractMutationBranchMap(DataType dataType){
this.dataType = dataType;
}

public DataType getDataType() {
return dataType;
}
}
}
71 changes: 71 additions & 0 deletions src/dr/evomodel/bigfasttree/thorney/MutationList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dr.evomodel.bigfasttree.thorney;

import java.util.ArrayList;


public interface MutationList{
public double getMutationCount();



public class DetailedMutationList implements MutationList{

private ArrayList<Mutation> mutations;

public DetailedMutationList(){
this.mutations = new ArrayList<>();
}

public double getMutationCount(){
return (double) mutations.size();
};

public int getSite(int mutationIndex){
return mutations.get(mutationIndex).site;
}

public int getRef(int mutationIndex){
return mutations.get(mutationIndex).ref;
}

public int getAlt(int mutationIndex){
return mutations.get(mutationIndex).alt;
}

public void addMutation(Mutation mut){
this.mutations.add(mut);
}

public Mutation removeMutation(int mutationIndex){
return this.mutations.remove(mutationIndex);
}


protected class Mutation{

final private int alt;
final private int ref;
final private int site;

protected Mutation(int site, int ref, int alt){
this.alt= alt;
this.ref = ref;
this.site = site;
}
}
}

public class SimpleMutationList implements MutationList{
private double mutations;
public SimpleMutationList(double muts){
this.mutations = muts;
}

public double getMutationCount(){
return this.mutations;
};
public void setMutationCount(double muts){
this.mutations=muts;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,45 +1,36 @@
package dr.evomodel.treelikelihood.thorneytreelikelihood;
package dr.evomodel.bigfasttree.thorney;


import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.inference.model.*;
import org.apache.commons.math.special.Gamma;
import org.apache.commons.math.util.FastMath;
import org.apache.commons.math.util.MathUtils;

public class PoissonBranchLengthLikelihoodDelegate extends AbstractModel implements ThorneyBranchLengthLikelihoodDelegate {
private final BranchRateModel branchRateModel;
public class PoissonBranchLengthLikelihoodDelegate extends AbstractModel implements BranchLengthLikelihoodDelegate {
private final double scale;

public PoissonBranchLengthLikelihoodDelegate(String name, BranchRateModel branchRateModel, double scale){
public PoissonBranchLengthLikelihoodDelegate(String name, double scale){
super(name);
this.branchRateModel = branchRateModel;
addModel(branchRateModel);
this.scale = scale;
}

@Override
public double getLogLikelihood(double mutations, Tree tree , NodeRef node) {
double rate = this.branchRateModel.getBranchRate(tree, node);
double time = tree.getBranchLength(node);
return SaddlePointExpansion.logPoissonProbability(time*rate*scale, (int) Math.round(mutations));
public double getLogLikelihood(double mutations, double branchLength) {
return SaddlePointExpansion.logPoissonProbability(branchLength*scale, (int) Math.round(mutations));
}


@Override
public double getGradientWrtTime(double mutations, double time) { // TODO: better chain rule handling
if (!(this.branchRateModel instanceof StrictClockBranchRates)){
throw new RuntimeException("gradients are only implemented for a strict clock model");
}
double rate = (double) branchRateModel.getVariable(0).getValue(0);
return SaddlePointExpansion.logPoissonMeanDerivative(time * rate * scale, (int) Math.round(mutations)) * rate * scale;
public double getGradientWrtTime(double mutations, double branchLength) { // TODO: better chain rule handling
// if (!(this.branchRateModel instanceof StrictClockBranchRates)){
// throw new RuntimeException("gradients are only implemented for a strict clock model");
// }
// double rate = (double) branchRateModel.getVariable(0).getValue(0);
// return SaddlePointExpansion.logPoissonMeanDerivative(time * rate * scale, (int) Math.round(mutations)) * rate * scale;
throw new RuntimeException("gradients are not implemented for this model");
}

public BranchRateModel getBranchRateModel(){
return this.branchRateModel;
}


public double getScale() {
return scale;
Expand Down
Loading

0 comments on commit 354afcd

Please sign in to comment.