Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of https://github.com/beast-dev/beast-mcmc i…
Browse files Browse the repository at this point in the history
…nto hmc-clock
  • Loading branch information
xji3 committed Aug 4, 2023
2 parents 1afe972 + 503b389 commit 77713e9
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,10 @@
package dr.evomodel.branchmodel;

import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.branchratemodel.LocalClockModelParser;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;

Expand All @@ -49,8 +43,9 @@
* @version $Id$
*/
public class EstimableStemWeightBranchSpecificBranchModel extends BranchSpecificBranchModel {
private List<Parameter> stemWeightParameters = new ArrayList<>();
protected Map<BitSet, Integer> stemWeightMap = new HashMap<BitSet, Integer>();

final private List<Parameter> stemWeightParameters = new ArrayList<>();
protected Map<BitSet, Integer> stemWeightMap = new HashMap<>();
private boolean hasBackbone = false;

public EstimableStemWeightBranchSpecificBranchModel(TreeModel treeModel, SubstitutionModel rootSubstitutionModel) {
Expand Down Expand Up @@ -128,7 +123,7 @@ public void addBackbone(TaxonList taxonList, SubstitutionModel substitutionModel


protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if ( stemWeightParameters.contains(variable) && clades.size() > 0) {
if (variable instanceof Parameter && stemWeightParameters.contains((Parameter)variable) && clades.size() > 0) {
setUpdateNodeMaps(true);
}
fireModelChanged();
Expand Down
16 changes: 8 additions & 8 deletions src/dr/evomodel/speciation/MasBirthDeathSerialSamplingModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ public final double processModelSegmentBreakPoint(int model, double intervalStar
return lnL;
}

final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
final double[] partialQ_all_young, final double Q_young) {
void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
final double[] partialQ_all_young, final double Q_young) {

for (int k = 0; k <= currentModelSegment; k++) {
gradient[k * 5 + 0] += nLineages * (partialQ_all_old[k * 4 + 0] / Q_Old
Expand All @@ -61,7 +61,7 @@ final void accumulateGradientForInterval(final double[] gradient, final int curr
}
}

final void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1,
void accumulateGradientForSerialSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k <= currentModelSegment; k++) {
Expand All @@ -71,7 +71,7 @@ final void accumulateGradientForSerialSampling(double[] gradient, int currentMod
}
}

final void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1,
void accumulateGradientForIntensiveSampling(double[] gradient, int currentModelSegment, double term1,
double[] intermediate) {

for (int k = 0; k < currentModelSegment; k++) {
Expand All @@ -81,7 +81,7 @@ final void accumulateGradientForIntensiveSampling(double[] gradient, int current
}
}

final void dBCompute(int model, double[] dB) {
void dBCompute(int model, double[] dB) {

for (int k = 0; k < model; ++k) {
for (int p = 0; p < 4; p++) {
Expand All @@ -96,7 +96,7 @@ final void dBCompute(int model, double[] dB) {
dB[model * 4 + 2] = (A - dA[2] * (term1 * lambda + mu + psi)) / (A * A);
}

final void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {
void dPCompute(int model, double t, double intervalStart, double eAt, double[] dP, double[] dG2) {

double G1 = g1(eAt);

Expand All @@ -121,7 +121,7 @@ final void dPCompute(int model, double t, double intervalStart, double eAt, doub
}


final void dQCompute(int model, double t, double[] dQ, double eAt) {
void dQCompute(int model, double t, double[] dQ, double eAt) {

double dwell = t - modelStartTimes[model];
double G1 = g1(eAt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,12 @@

import dr.inference.model.Parameter;

public class TwoParamBirthDeathSerialSamplingModel extends NewBirthDeathSerialSamplingModel {
public class TwoParamBirthDeathSerialSamplingModel extends MasBirthDeathSerialSamplingModel {

public TwoParamBirthDeathSerialSamplingModel(Parameter birthRate, Parameter deathRate, Parameter serialSamplingRate, Parameter treatmentProbability, Parameter samplingProbability, Parameter originTime, boolean condition, int numIntervals, double gridEnd, Type units) {
super(birthRate, deathRate, serialSamplingRate, treatmentProbability, samplingProbability, originTime, condition, numIntervals, gridEnd, units);
}

@Override
public final double processModelSegmentBreakPoint(int model, double intervalStart, double intervalEnd, int nLineages) {
// double lnL = nLineages * (logQ(model, intervalEnd) - logQ(model, intervalStart));
double lnL = nLineages * Math.log(Q(model, intervalEnd) / Q(model, intervalStart));
if ( samplingProbability.getValue(model + 1) > 0.0 && samplingProbability.getValue(model + 1) < 1.0) {
// Add in probability of un-sampled lineages
// We don't need this at t=0 because all lineages in the tree are sampled
// TODO: check if we're right about how many lineages are actually alive at this time. Are we inadvertently over-counting or under-counting due to samples added at this _exact_ time?
lnL += nLineages * Math.log(1.0 - samplingProbability.getValue(model + 1));
}
this.savedLogQ = Double.NaN;
return lnL;
}

final void accumulateGradientForInterval(final double[] gradient, final int currentModelSegment, final int nLineages,
final double[] partialQ_all_old, final double Q_Old,
Expand Down Expand Up @@ -140,54 +127,6 @@ final void dQCompute(int model, double t, double[] dQ, double eAt) {
dQ[model * 4 + 2] = term6 * (dA[2] * term7 - dB[model * 4 + 2] * term3);
}


final double Q(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return eAt / (sqrtDenominator * sqrtDenominator);
}

final double logQ(int model, double time) {
double At = A * (time - modelStartTimes[model]);
double eAt = Math.exp(At);
double sqrtDenominator = g1(eAt);
return At - 2 * Math.log(sqrtDenominator); // TODO log4 (additive constant) is not needed since we always see logQ(a) - logQ(b)
}

@Override
public double processInterval(int model, double tYoung, double tOld, int nLineages) {
double logQ_young;
double logQ_old = Q(model, tOld);
if (!Double.isNaN(this.savedLogQ)) {
logQ_young = this.savedLogQ;
} else {
logQ_young = Q(model, tYoung);
}
this.savedLogQ = logQ_old;
return nLineages * Math.log(logQ_old / logQ_young);
}

@Override
public double processSampling(int model, double tOld) {

double logSampProb;

boolean sampleIsAtEventTime = tOld == modelStartTimes[model];
boolean samplesTakenAtEventTime = rho > 0;

if (sampleIsAtEventTime && samplesTakenAtEventTime) {
logSampProb = Math.log(rho);
if (model > 0) {
logSampProb += Math.log(r + ((1.0 - r) * previousP));
}
} else {
double logPsi = Math.log(psi);
logSampProb = logPsi + Math.log(r + (1.0 - r) * p(model,tOld));
}

return logSampProb;
}
}

/*
Expand Down
4 changes: 2 additions & 2 deletions src/dr/evomodel/tree/MixedEffectsRateStatistic.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private void prepareForComputation() {
for (int i = 0; i < offset; i++) {
NodeRef child = tree.getExternalNode(i);
rates[i] = branchRateModel.getBranchRate(tree, child);
locations[i] = branchRateTransform.getLocation(tree, tree.getNode(i));
locations[i] = branchRateTransform.getLocation(tree, child);
}
if (internal) {
final int n = tree.getInternalNodeCount();
Expand All @@ -100,7 +100,7 @@ private void prepareForComputation() {
NodeRef child = tree.getInternalNode(i);
if (!tree.isRoot(child)) {
rates[k] = branchRateModel.getBranchRate(tree, child);
locations[k] = branchRateTransform.getLocation(tree, tree.getNode(i));
locations[k] = branchRateTransform.getLocation(tree, child);
k++;
}
}
Expand Down
21 changes: 20 additions & 1 deletion src/dr/util/TimeProportionToFixedEffectTransform.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,29 @@ protected double[] transform(double[] values) {
double rateAncestral = Math.exp(values[1]);
double rateDescendant = Math.exp(values[2]);
double[] transformed = new double[1];
transformed[0] = Math.log((rateDescendant * propTime + rateAncestral * (1.0 - propTime)) / (rateDescendant * propTime));
transformed[0] = Math.log(propTime * rateDescendant / rateAncestral + (1.0 - propTime));
return transformed;
}

private void verifyOutput(double[] values) {
double propTime = values[0];
double rateAncestral = Math.exp(values[1]);
double rateDescendant = Math.exp(values[2]);
double[] transformed = new double[1];
transformed[0] = Math.log(propTime * rateDescendant / rateAncestral + (1.0 - propTime));

double fe = transformed[0];
double lb = rateAncestral;
double ub = rateDescendant;
if (rateAncestral > rateDescendant) {
lb = rateDescendant;
ub = rateAncestral;
}

double newRate = rateAncestral * Math.exp(fe);
assert(newRate >= lb && newRate <= ub);
}

@Override
protected double[] inverse(double[] values) {
throw new RuntimeException("Not yet implemented");
Expand Down

0 comments on commit 77713e9

Please sign in to comment.