Skip to content

Commit

Permalink
Merge branch 'hmc-clock' of https://github.com/beast-dev/beast-mcmc i…
Browse files Browse the repository at this point in the history
…nto hmc-clock
  • Loading branch information
xji3 committed Dec 18, 2023
2 parents dcce80a + fd2a747 commit 990985d
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 199 deletions.
59 changes: 54 additions & 5 deletions ci/TestXML/testGaussianMarkovRandomField.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
<precision>
<parameter id="precision" value="1.5"/>
</precision>
<start>
<mean>
<parameter id="mean" value="0.0"/>
</start>
<lambda>
<parameter id="lambda" value="0.9"/>
</lambda>
</mean>
</GaussianMarkovRandomField>
</distribution>
<data>
Expand Down Expand Up @@ -54,4 +51,56 @@
<randomFieldGradient idref="gradientMean"/>
</report>

<randomField id="gmrfProper">
<distribution>
<GaussianMarkovRandomField dim="4" matchPseudoDeterminant="true">
<precision>
<parameter id="precisionProper" value="1.5"/>
</precision>
<mean>
<parameter id="meanProper" value="0.0"/>
</mean>
<lambda>
<parameter id="lambda" value="0.9"/>
</lambda>
</GaussianMarkovRandomField>
</distribution>
<data>
<parameter idref="data"/>
</data>
</randomField>


<report>
<randomField idref="gmrfProper"/>
</report>

<randomFieldGradient id="gradientProper">
<randomField idref="gmrfProper"/>
<parameter idref="data"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientProper"/>
</report>

<randomFieldGradient id="gradientPrecisionProper">
<randomField idref="gmrfProper"/>
<parameter idref="precisionProper"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientPrecisionProper"/>
</report>

<randomFieldGradient id="gradientMeanProper">
<randomField idref="gmrfProper"/>
<parameter idref="meanProper"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientMeanProper"/>
</report>


</beast>
101 changes: 101 additions & 0 deletions ci/TestXML/testSkyglide.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,33 @@
<report>
<gmrfSkyGridLikelihood idref="skygrid"/>
</report>

<smoothSkygridLikelihood id="smoothSkygrid" old="false">
<populationSizes>

<!-- skygrid.logPopSize is in log units unlike other popSize -->
<parameter idref="skygrid.logPopSize"/>
</populationSizes>
<precisionParameter>
<parameter idref="skygrid.precision"/>
</precisionParameter>
<numGridPoints>
<parameter idref="skygrid.numGridPoints"/>
</numGridPoints>
<cutOff>
<parameter idref="skygrid.cutOff"/>
</cutOff>
<populationTree>
<treeModel idref="treeModel"/>
</populationTree>
<smoothRate>
<parameter id="smooth.Rate" value="100.0" lower="0.0"/>
</smoothRate>
</smoothSkygridLikelihood>

<report>
<smoothSkygridLikelihood idref="smoothSkygrid"/>
</report>

<skyGlideLikelihood id="skyGlideLikelihood">
<populationSizes>
Expand Down Expand Up @@ -3100,6 +3126,81 @@
</report>


<gmrfDistributionLikelihood id="gmrfDistribution">
<data>
<!-- skygrid.logPopSize is in log units unlike other popSize -->
<parameter idref="skygrid.logPopSize"/>
</data>
<precisionParameter>
<parameter idref="skygrid.precision"/>
</precisionParameter>
<!--
<gridPoints>
<parameter idref="gridPoints"/>
</gridPoints>
-->

<numGridPoints>
<parameter idref="skygrid.numGridPoints"/>
</numGridPoints>
<cutOff>
<parameter idref="skygrid.cutOff"/>
</cutOff>

</gmrfDistributionLikelihood>

<report>
<gmrfDistributionLikelihood idref="gmrfDistribution"/>
</report>

<randomField id="gmrf">
<distribution>
<gaussianMarkovRandomField dim="6" matchPseudoDeterminant="false"> <!-- false returns Mandev's calculation -->
<precision>
<parameter idref="skygrid.precision"/>
</precision>
<mean>
<parameter id="mean" value="0.0"/>
</mean>
</gaussianMarkovRandomField>
</distribution>
<data>
<parameter idref="skygrid.logPopSize"/>
</data>
</randomField>


<report>
<randomField idref="gmrf"/>
</report>

<randomFieldGradient id="gradient">
<randomField idref="gmrf"/>
<parameter idref="skygrid.logPopSize"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradient"/>
</report>

<randomFieldGradient id="gradientPrecision">
<randomField idref="gmrf"/>
<parameter idref="skygrid.precision"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientPrecision"/>
</report>

<randomFieldGradient id="gradientMean">
<randomField idref="gmrf"/>
<parameter idref="mean"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientMean"/>
</report>

<!--
<smoothSkygridGradient id="smoothSkygridGradient" wrtParameter="nodeHeight" gradientCheckTolerance="0.005">
<smoothSkygridLikelihood idref="smoothSkygrid"/>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* CtmcFrequencyDistributionGradient.java
*
* Copyright (c) 2002-2024 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.evomodel.treedatalikelihood.discrete;

import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.util.Citation;

import java.util.List;

/**
* @author Marc A. Suchard
*/

public class CtmcFrequencyModelGradient extends AbstractLogAdditiveSubstitutionModelGradient {

private final FrequencyModel frequencyModel;

public CtmcFrequencyModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel,
ApproximationMode.FIRST_ORDER);

List<SubstitutionModel> substitutionModels = likelihoodDelegate.getBranchModel().getSubstitutionModels();
this.frequencyModel = likelihoodDelegate.getBranchModel().getRootFrequencyModel();

for (SubstitutionModel model : substitutionModels) {
if (frequencyModel != model.getFrequencyModel()) {
throw new RuntimeException("Not yet implemented");
}
}
}

@Override
protected double preProcessNormalization(double[] differentials, double[] generator,
boolean normalize) {
double total = 0.0;
if (normalize) {
for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
final int ij = i * stateCount + j;
total += differentials[ij] * generator[ij];
}
}
}
return total;
}

@Override
double processSingleGradientDimension(int j, double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
// derivative wrt pi[j]
double total = 0.0;

for (int i = 0; i < stateCount; ++i) {
final int ii = i * stateCount + i;
final int ij = i * stateCount + j;
total += (differentials[ij] - differentials[ii]) * generator[ij];
}

if (normalize) {
for (int i = 0; i < stateCount; ++i) {
final int ij = i * stateCount + j;
total -= generator[ij] * pi[j] * normalizationConstant;
}
}

return total;
}

@Override
public Parameter getParameter() {
return frequencyModel.getFrequencyParameter();
}

@Override
public LogColumn[] getColumns() {
throw new RuntimeException("Not yet implemented");
}

@Override
public Citation.Category getCategory() {
return Citation.Category.SUBSTITUTION_MODELS;
}

@Override
public String getDescription() {
return null; // TODO
}

@Override
public List<Citation> getCitations() {
// TODO Update
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ public RandomEffectsSubstitutionModelGradient(String traitName,
}
}

// Parameter makeCompoundParameter(GeneralizedLinearModel glm) {
// CompoundParameter parameter = new CompoundParameter("random.effects");
// for (int i = 0; i < glm.getNumberOfRandomEffects(); ++i) {
// parameter.addParameter(glm.getRandomEffect(i));
// }
// return parameter;
// }

ParameterMap makeParameterMap(GeneralizedLinearModel glm) {

return new ParameterMap() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class GaussianMarkovRandomFieldParser extends AbstractXMLObjectParser {
private static final String PARSER_NAME = "gaussianMarkovRandomField";
private static final String DIMENSION = "dim";
private static final String PRECISION = "precision";
private static final String START = "start";
private static final String MEAN = "mean";
private static final String LAMBDA = "lambda";
private static final String MATCH_PSEUDO_DETERMINANT = "matchPseudoDeterminant";

Expand All @@ -54,14 +54,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
throw new XMLParseException("Scale must be > 0.0");
}

Parameter start = xo.hasChildNamed(START) ?
(Parameter) xo.getElementFirstChild(START) : null;
Parameter start = xo.hasChildNamed(MEAN) ?
(Parameter) xo.getElementFirstChild(MEAN) : null;

Parameter lambda = (Parameter) xo.getElementFirstChild(LAMBDA);

if (Math.abs(lambda.getParameterValue(0)) > 1.0) {
throw new XMLParseException("Lambda must be between -1.0 and 1.0");
}
Parameter lambda = xo.hasChildNamed(LAMBDA) ?
(Parameter) xo.getElementFirstChild(LAMBDA) : null;

RandomField.WeightProvider weights = parseWeightProvider(xo, dim);

Expand All @@ -78,7 +75,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
AttributeRule.newIntegerRule(DIMENSION),
new ElementRule(PRECISION,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(START,
new ElementRule(MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
new ElementRule(LAMBDA,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
Expand Down
Loading

0 comments on commit 990985d

Please sign in to comment.