Skip to content

Commit

Permalink
Accessor for partials precision and variance
Browse files Browse the repository at this point in the history
  • Loading branch information
pbastide committed Dec 15, 2023
1 parent 5266403 commit cf0a9ac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ public void updatePreOrderPartial(

// A. Get current precision of k and j
final DenseMatrix64F Pk = wrap(preOrderPartials, kbo + dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Pj = wrapPartialPrecision(jbo);

// final DenseMatrix64F Vk = wrap(prePartials, kbo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Vj = wrap(partials, jbo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Vj = wrapPartialVariance(jbo);

if (allZeroDiagonals(Vj)) {

final DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pj = wrapPartialPrecision(jbo);

assert (!allZeroDiagonals(Pj));

Expand Down Expand Up @@ -329,11 +329,11 @@ protected void updatePartial(
final double lpi = partials[ibo + dimTrait + 2 * dimTrait * dimTrait];
final double lpj = partials[jbo + dimTrait + 2 * dimTrait * dimTrait];

final DenseMatrix64F Pi = wrap(partials, ibo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pi = wrapPartialPrecision(ibo);
final DenseMatrix64F Pj = wrapPartialPrecision(jbo);

final DenseMatrix64F Vi = wrap(partials, ibo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Vj = wrap(partials, jbo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Vi = wrapPartialVariance(ibo);
final DenseMatrix64F Vj = wrapPartialVariance(jbo);

if (TIMING) {
endTime("peel1");
Expand Down Expand Up @@ -728,11 +728,11 @@ public void calculateRootLogLikelihood(int rootBufferIndex, int priorBufferIndex
// TODO For each trait in parallel
for (int trait = 0; trait < numTraits; ++trait) {

final DenseMatrix64F PRoot = wrap(partials, rootOffset + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F PPrior = wrap(partials, priorOffset + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F PRoot = wrapPartialPrecision(rootOffset);
final DenseMatrix64F PPrior = wrapPartialPrecision(priorOffset);

final DenseMatrix64F VRoot = wrap(partials, rootOffset + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F VPrior = wrap(partials, priorOffset + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F VRoot = wrapPartialVariance(rootOffset);
final DenseMatrix64F VPrior = wrapPartialVariance(priorOffset);

// TODO Block below is for the conjugate prior ONLY
{
Expand Down Expand Up @@ -862,4 +862,11 @@ public void getRootPriorPrecision(DenseMatrix64F Pd, DenseMatrix64F PPrior, Dens
// }

double[] inverseDiffusions;

protected DenseMatrix64F wrapPartialPrecision(int ibo) {
return wrap(partials, ibo + dimTrait, dimTrait, dimTrait);
}
protected DenseMatrix64F wrapPartialVariance(int ibo) {
return wrap(partials, ibo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ public void updatePreOrderPartial(

// A. Get current precision of k and j
final DenseMatrix64F Pk = wrap(preOrderPartials, kbo + dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Pj = wrapPartialPrecision(jbo);

// final DenseMatrix64F Vk = wrap(preOrderPartials, kbo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Vj = wrap(partials, jbo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Vj = wrapPartialVariance(jbo);

// B. Inflate variance along sibling branch using matrix inversion
// final DenseMatrix64F Vjp = matrix0;
Expand Down Expand Up @@ -264,7 +264,7 @@ public void updatePreOrderPartial(
System.err.println("pM: " + new WrappedVector.Raw(preOrderPartials, kbo, dimTrait));
System.err.println("pP: " + Pk);
System.err.println("sM: " + new WrappedVector.Raw(partials, jbo, dimTrait));
DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
DenseMatrix64F Pj = wrapPartialPrecision(jbo);
DenseMatrix64F Vj = new DenseMatrix64F(dimTrait, dimTrait);
CommonOps.invert(Pj, Vj);
System.err.println("sP: " + Vj);
Expand Down Expand Up @@ -400,8 +400,8 @@ protected void updatePartial(
}

if (DEBUG) {
final DenseMatrix64F Pi = wrap(partials, ibo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pj = wrap(partials, jbo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pi = wrapPartialPrecision(ibo);
final DenseMatrix64F Pj = wrapPartialPrecision(jbo);
reportMeansAndPrecisions(trait, ibo, jbo, kbo, Pi, Pj, Pk);
}

Expand Down Expand Up @@ -490,7 +490,7 @@ private InversionResult increaseVariances(int ibo,
}

// A. Get current precision of i and j
final DenseMatrix64F Pi = wrap(partials, ibo + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Pi = wrapPartialPrecision(ibo);

if (TIMING) {
endTime("peel1");
Expand All @@ -505,7 +505,7 @@ private InversionResult increaseVariances(int ibo,
if (useVariancei) {

final DenseMatrix64F Vip = matrix0;
final DenseMatrix64F Vi = wrap(partials, ibo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F Vi = wrapPartialVariance(ibo);
CommonOps.add(Vi, Vdi, Vip);
if (allZeroOrInfinite(Vip)) {
throw new RuntimeException("Zero-length branch on data is not allowed.");
Expand Down Expand Up @@ -610,8 +610,8 @@ public void calculateRootLogLikelihood(int rootBufferIndex, int priorBufferIndex
// TODO For each trait in parallel
for (int trait = 0; trait < numTraits; ++trait) {

final DenseMatrix64F PPrior = wrap(partials, priorOffset + dimTrait, dimTrait, dimTrait);
final DenseMatrix64F VPrior = wrap(partials, priorOffset + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
final DenseMatrix64F PPrior = wrapPartialPrecision(priorOffset);
final DenseMatrix64F VPrior = wrapPartialVariance(priorOffset);


// TODO Block below is for the conjugate prior ONLY
Expand Down Expand Up @@ -646,7 +646,7 @@ public void calculateRootLogLikelihood(int rootBufferIndex, int priorBufferIndex
System.err.print(" " + partials[rootOffset + g]);
}
System.err.println("");
System.err.println("PRoot: " + wrap(partials, rootOffset + dimTrait, dimTrait, dimTrait));
System.err.println("PRoot: " + wrapPartialPrecision(rootOffset));
System.err.println("PPrior: " + PPrior);
System.err.println("PTotal: " + PTotal);
System.err.println("\n SS:" + SS);
Expand Down Expand Up @@ -675,7 +675,7 @@ public void calculateRootLogLikelihood(int rootBufferIndex, int priorBufferIndex
// if (anyDiagonalInfinities(P)) {
// // Inflate variance
// final DenseMatrix64F Vp = matrix0;
// final DenseMatrix64F Vi = wrap(partials, bo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
// final DenseMatrix64F Vi = wrapPartialVariance(bo);
//
// CommonOps.add(Vi, v, Vd, Vp);
// c = safeInvert(Vp, Pp, true);
Expand Down

0 comments on commit cf0a9ac

Please sign in to comment.