Skip to content

Commit

Permalink
Log position and velocity separatelly for IBM and IOU
Browse files Browse the repository at this point in the history
  • Loading branch information
pbastide committed Oct 26, 2023
1 parent f07052f commit d5baf08
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 2 deletions.
16 changes: 15 additions & 1 deletion ci/TestXML/testRepeatedMeasuresIOU.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
<parameter id="meanRoot" value="-0.5 1.0 -0.5 1.0 1.0 -2.0 1.0 -2.0"/>
</meanParameter>
<priorSampleSize>
<parameter id="sampleSizeRoot" value="10.0"/>
<parameter id="sampleSizeRoot" value="0.0001"/>
</priorSampleSize>
</conjugateRootPrior>
</traitDataLikelihood>
Expand All @@ -130,6 +130,20 @@
<treeModel idref="treeModel"/>
</traitLogger>
</log>
<logTree id="treeFileLog" logEvery="1000" nexusFormat="true" fileName="testRepeatedMeasuresIOU.trees" sortTranslationTable="true">
<treeModel idref="treeModel"/>
<joint idref="iouPosterior"/>
<trait name="X" tag="all">
<traitDataLikelihood idref="iouLikelihood"/>
</trait>
<trait name="position.X" tag="location">
<traitDataLikelihood idref="iouLikelihood"/>
</trait>
<trait name="velocity.X" tag="velocity">
<traitDataLikelihood idref="iouLikelihood"/>
</trait>

</logTree>
</mcmc>

<assertEqual tolerance="1e-3" verbose="true">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,61 @@ public double[] getTrait(Tree t, NodeRef node) {
treeTraitHelper.addTrait(tipPrecision);
}

protected void addPositionVelocityTrait(final Helper treeTraitHelper) {
// Log separatelly gradient from position
TreeTrait.DA positionTrait = new TreeTrait.DA() {

public String getTraitName() {
return getPositionName(name);
}

public Intent getIntent() {
return Intent.NODE;
}

public double[] getTrait(Tree t, NodeRef node) {

if (t != tree) { // TODO Write a wrapper class around t if TransformableTree
if (t == baseTree) {
node = getBaseNode(t, node);
} else {
throw new RuntimeException("Tree '" + t.getId() + "' and likelihood '" + tree.getId() + "' mismatch");
}
}

return getTraitForNode(node, dimProcessNode, dimNode);
}
};

treeTraitHelper.addTrait(positionTrait);

TreeTrait.DA velocityTrait = new TreeTrait.DA() {

public String getTraitName() {
return getVelocityName(name);
}

public Intent getIntent() {
return Intent.NODE;
}

public double[] getTrait(Tree t, NodeRef node) {

if (t != tree) { // TODO Write a wrapper class around t if TransformableTree
if (t == baseTree) {
node = getBaseNode(t, node);
} else {
throw new RuntimeException("Tree '" + t.getId() + "' and likelihood '" + tree.getId() + "' mismatch");
}
}

return getTraitForNode(node, 0, dimProcessNode);
}
};

treeTraitHelper.addTrait(velocityTrait);
}

public static String getTipTraitName(String name) {
return REALIZED_TIP_TRAIT + "." + name;
}
Expand All @@ -111,6 +166,14 @@ private static String getTipPrecisionName(String name) {
return "precision." + name;
}

private static String getPositionName(String name) {
return "position." + name;
}

private static String getVelocityName(String name) {
return "velocity." + name;
}

private double[] getTraitForAllTips() {

assert simulationProcess != null;
Expand All @@ -124,6 +187,22 @@ private double[] getTraitForAllTips() {
return trait;
}

private double[] getTraitForAllTips(int begin, int end) {

assert simulationProcess != null;

simulationProcess.cacheSimulatedTraits(null);

int dimLogTrait = end - begin;
final int length = dimLogTrait * tree.getExternalNodeCount();
double[] trait = new double[length];
for (int i = 0; i < tree.getExternalNodeCount(); i++) {
System.arraycopy(sample, i * dimLogTrait, trait, 0, dimLogTrait);
}

return trait;
}

private double[] getPrecisionForAllTips() {

assert simulationProcess != null;
Expand Down Expand Up @@ -154,12 +233,30 @@ private double[] getTraitForNode(final NodeRef node) {
}
}

private double[] getTraitForNode(final NodeRef node, int begin, int end) {

assert simulationProcess != null;

simulationProcess.cacheSimulatedTraits(null);

int dimLogTrait = end - begin;

if (node == null) {
return getTraitForAllTips(begin, end);
} else {
double[] trait = new double[dimLogTrait];
System.arraycopy(sample, node.getNumber() * dimNode + begin, trait, 0, dimLogTrait);

return trait;
}
}

public int vectorizeNodeOperations(final List<NodeOperation> nodeOperations, final int[] operations) {

int k = 0;
for (ProcessOnTreeDelegate.NodeOperation op : nodeOperations) {

operations[k ] = op.getNodeNumber(); // Parent sample
operations[k] = op.getNodeNumber(); // Parent sample
operations[k + 1] = op.getLeftChild(); // Node sample
operations[k + 2] = likelihoodDelegate.getActiveNodeIndex(op.getLeftChild()); // Node post-order partial
operations[k + 3] = likelihoodDelegate.getActiveMatrixIndex(op.getLeftChild()); // Node branch info
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public MultivariateConditionalOnTipsRealizedDelegate(String name, Tree tree,
ContinuousDataLikelihoodDelegate likelihoodDelegate) {
super(name, tree, diffusionModel, dataModel, rootPrior, rateTransformation, likelihoodDelegate);
missingInformation = new PartiallyMissingInformation(tree, dataModel);
if (likelihoodDelegate.getDiffusionProcessDelegate().isIntegratedProcess()) {
addPositionVelocityTrait(treeTraitHelper);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ abstract class AbstractContinuousTraitDelegate extends AbstractDelegate {
final int dimProcess;
final int numTraits;
final int dimNode;
final int dimProcessNode;;

final MultivariateDiffusionModel diffusionModel;
final ContinuousTraitPartialsProvider dataModel;
Expand Down Expand Up @@ -186,6 +187,7 @@ abstract class AbstractContinuousTraitDelegate extends AbstractDelegate {
dimProcess = likelihoodDelegate.getDimProcess();
numTraits = likelihoodDelegate.getTraitCount();
dimNode = dimTrait * numTraits;
dimProcessNode = likelihoodDelegate.getDimProcess() * numTraits;
this.diffusionModel = diffusionModel;
this.dataModel = dataModel;
this.rateTransformation = rateTransformation;
Expand Down

0 comments on commit d5baf08

Please sign in to comment.