Skip to content

Commit

Permalink
Working FreeRate prototype missing commits
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhall272 committed Jul 19, 2023
1 parent 7900a87 commit abd52ff
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 77 deletions.
3 changes: 3 additions & 0 deletions src/dr/app/beast/development_parsers.properties
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ dr.inference.regression.SelfControlledCaseSeries
# SITE PATTERNS
dr.evomodelxml.operators.PatternWeightIncrementOperatorParser


# BRANCH SPECIFIC STUFF
dr.evomodel.branchmodel.lineagespecific.CountableRealizationsParameterParser
dr.evomodel.branchmodel.lineagespecific.DirichletProcessPriorParser
Expand Down Expand Up @@ -315,6 +316,8 @@ dr.inferencexml.distribution.shrinkage.JointBayesianBridgeStatisticsParser
dr.inferencexml.hmc.CompoundPriorPreconditionerParser
dr.inferencexml.hmc.NumericalGradientParser

# SIMPLEX TRANSFORM
dr.util.RealDifferencesToSimplexTransform

# SMOOTH SKYGRID
dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser
Expand Down
10 changes: 4 additions & 6 deletions src/dr/evomodel/siteratemodel/DiscretizedSiteRateModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
*/

package dr.evomodel.siteratemodel;

import dr.inference.model.*;
import dr.evomodel.substmodel.SubstitutionModel;

import java.util.Arrays;
import java.util.Comparator;

Expand Down Expand Up @@ -171,11 +169,11 @@ protected void handleModelChangedEvent(Model model, Object object, int index) {
}

protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == nuParameter) {
// if (variable == nuParameter) {
ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel
} else {
throw new RuntimeException("Unknown variable in DiscretizedSiteRateModel.handleVariableChangedEvent");
}
// } else {
// throw new RuntimeException("Unknown variable in DiscretizedSiteRateModel.handleVariableChangedEvent");
// }
listenerHelper.fireModelChanged(this, variable, index);
}

Expand Down
118 changes: 63 additions & 55 deletions src/dr/evomodel/siteratemodel/FreeRateDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@

public class FreeRateDelegate extends AbstractModel implements SiteRateDelegate, Citable {

public static final Parameterization DEFAULT_PARAMETERIZATION = Parameterization.ABSOLUTE;
/* public static final Parameterization DEFAULT_PARAMETERIZATION = Parameterization.ABSOLUTE;
public enum Parameterization {
ABSOLUTE,
RATIOS,
DIFFERENCES
};
};*/



Expand All @@ -66,31 +66,31 @@ public enum Parameterization {
public FreeRateDelegate(
String name,
int categoryCount,
Parameterization parameterization,
/* Parameterization parameterization,*/
Parameter rateParameter,
Parameter weightParameter) {

super(name);

this.categoryCount = categoryCount;
this.parameterization = parameterization;
// this.parameterization = parameterization;

this.rateParameter = rateParameter;
if (parameterization == Parameterization.ABSOLUTE) {
if (this.rateParameter.getDimension() == 1) {
this.rateParameter.setDimension(categoryCount);
} else if (this.rateParameter.getDimension() != categoryCount) {
throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count");
}
this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount));
} else {
if (this.rateParameter.getDimension() == 1) {
this.rateParameter.setDimension(categoryCount - 1);
} else if (this.rateParameter.getDimension() != categoryCount - 1) {
throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1");
}
this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1));
}
// if (parameterization == Parameterization.ABSOLUTE) {
// if (this.rateParameter.getDimension() == 1) {
// this.rateParameter.setDimension(categoryCount);
// } else if (this.rateParameter.getDimension() != categoryCount) {
// throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count");
// }
// this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount));
// } else {
// if (this.rateParameter.getDimension() == 1) {
// this.rateParameter.setDimension(categoryCount - 1);
// } else if (this.rateParameter.getDimension() != categoryCount - 1) {
// throw new IllegalArgumentException("Rate parameter should have have an initial dimension of one or category count - 1");
// }
// this.rateParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, categoryCount - 1));
// }
addVariable(this.rateParameter);

this.weightParameter = weightParameter;
Expand All @@ -101,6 +101,7 @@ public FreeRateDelegate(
}

addVariable(this.weightParameter);

this.weightParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, categoryCount));
}

Expand All @@ -116,43 +117,44 @@ public void getCategories(double[] categoryRates, double[] categoryProportions)
assert categoryRates != null && categoryRates.length == categoryCount;
assert categoryProportions != null && categoryProportions.length == categoryCount;

if (parameterization == Parameterization.ABSOLUTE) {
double sumRates = 0.0;
double sumWeights = 0.0;
for (int i = 0; i < categoryCount; i++) {
categoryRates[i] = rateParameter.getParameterValue(i);
sumRates += categoryRates[i];
categoryProportions[i] = weightParameter.getParameterValue(i);
sumWeights += categoryProportions[i];
}
assert Math.abs(sumRates - categoryCount) < 1E-10;
assert Math.abs(sumWeights - 1.0) < 1E-10;
} else {
categoryRates[0] = 1.0;
double sumRates = 0.0;
double sumWeights = 0.0;
for (int i = 0; i < categoryCount; i++) {
if (parameterization == Parameterization.RATIOS) {
if (i > 0) {
categoryRates[i] = categoryRates[i - 1] * rateParameter.getParameterValue(i);
}
} else { // Parameterization.DIFFERENCES
categoryRates[i] = categoryRates[i - 1] + rateParameter.getParameterValue(i);
}
sumRates += categoryRates[i + 1];

categoryProportions[i] = weightParameter.getParameterValue(i);
sumWeights += categoryProportions[i];
}
assert Math.abs(sumWeights - 1.0) < 1E-10;

// scale so their mean is 1
for (int i = 0; i < categoryCount; i++) {
categoryRates[i] = categoryCount * categoryRates[i] / sumRates;
}

// if (parameterization == Parameterization.ABSOLUTE) {
double meanRate = 0.0;
double sumWeights = 0.0;
for (int i = 0; i < categoryCount; i++) {
categoryRates[i] = rateParameter.getParameterValue(i);
categoryProportions[i] = weightParameter.getParameterValue(i);
meanRate += categoryRates[i]*categoryProportions[i];
sumWeights += categoryProportions[i];
}
assert Math.abs(meanRate - 1.0) < 1E-10;
assert Math.abs(sumWeights - 1.0) < 1E-10;
// } else {
// categoryRates[0] = 1.0;
// double sumRates = 0.0;
// double sumWeights = 0.0;
// for (int i = 0; i < categoryCount; i++) {
// if (parameterization == Parameterization.RATIOS) {
// if (i > 0) {
// categoryRates[i] = categoryRates[i - 1] * rateParameter.getParameterValue(i);
// }
// } else { // Parameterization.DIFFERENCES
// categoryRates[i] = categoryRates[i - 1] + rateParameter.getParameterValue(i);
// }
// sumRates += categoryRates[i + 1];
//
// categoryProportions[i] = weightParameter.getParameterValue(i);
// sumWeights += categoryProportions[i];
// }
// assert Math.abs(sumWeights - 1.0) < 1E-10;
//
// // scale so their mean is 1
// for (int i = 0; i < categoryCount; i++) {
// categoryRates[i] = categoryCount * categoryRates[i] / sumRates;
// }
// }
}

// *****************************************************************
// Interface ModelComponent
// *****************************************************************
Expand All @@ -162,6 +164,12 @@ protected void handleModelChangedEvent(Model model, Object object, int index) {
}

protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {


if(variable==weightParameter){
rateParameter.fireParameterChangedEvent();
}

listenerHelper.fireModelChanged(this, variable, index);
}

Expand All @@ -188,7 +196,7 @@ protected void acceptState() {

private final int categoryCount;

private final Parameterization parameterization;
// private final Parameterization parameterization;

@Override
public Citation.Category getCategory() {
Expand Down
33 changes: 23 additions & 10 deletions src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import dr.evomodel.siteratemodel.FreeRateDelegate;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.model.VariableListener;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.xml.*;

Expand All @@ -43,13 +45,14 @@
public class FreeRateSiteRateModelParser extends AbstractXMLObjectParser {

public static final String FREE_RATE_SITE_RATE_MODEL = "freeRateSiteRateModel";
public static final String SUBSTITUTION_MODEL = "substitutionModel";
public static final String MUTATION_RATE = "mutationRate";
public static final String SUBSTITUTION_RATE = "substitutionRate";
public static final String RELATIVE_RATE = "relativeRate";
public static final String WEIGHT = "weight";
public static final String RATES = "rates";
public static final String CATEGORIES = "categories";
public static final String PARAMETERIZATION = "parameterization";
// public static final String PARAMETERIZATION = "parameterization";
public static final String WEIGHTS = "weights";

public String getParserName() {
Expand All @@ -59,7 +62,7 @@ public String getParserName() {
public Object parseXMLObject(XMLObject xo) throws XMLParseException {

String msg = "";
SubstitutionModel substitutionModel;
SubstitutionModel substitutionModel = null;

double muWeight = 1.0;

Expand All @@ -78,16 +81,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
}
}

if(xo.hasChildNamed(SUBSTITUTION_MODEL)){
substitutionModel = (SubstitutionModel)xo.getElementFirstChild(SUBSTITUTION_MODEL);
}

int catCount = 4;
catCount = xo.getIntegerAttribute(CATEGORIES);

FreeRateDelegate.Parameterization parameterization = FreeRateDelegate.Parameterization.ABSOLUTE;
if (xo.hasAttribute(PARAMETERIZATION)) {
parameterization = FreeRateDelegate.Parameterization.valueOf(xo.getStringAttribute(PARAMETERIZATION));
}
// FreeRateDelegate.Parameterization parameterization = FreeRateDelegate.Parameterization.ABSOLUTE;
// if (xo.hasAttribute(PARAMETERIZATION)) {
// parameterization = FreeRateDelegate.Parameterization.valueOf(xo.getStringAttribute(PARAMETERIZATION));
// }

Parameter ratesParameter = (Parameter)xo.getElementFirstChild(RATES);

Parameter weightsParameter = (Parameter)xo.getElementFirstChild(WEIGHTS);

msg += "\n " + catCount + " category discrete free rate site rate heterogeneity model)";
Expand All @@ -97,9 +103,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
Logger.getLogger("dr.evomodel").info("\nCreating free rate site rate model.");
}

FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount, parameterization, ratesParameter, weightsParameter);
FreeRateDelegate delegate = new FreeRateDelegate("FreeRateDelegate", catCount,
// parameterization,
ratesParameter, weightsParameter);

DiscretizedSiteRateModel siteRateModel = new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate);

siteRateModel.setSubstitutionModel(substitutionModel);
siteRateModel.addModel(substitutionModel);

return new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate);
return siteRateModel;
}

//************************************************************************
Expand All @@ -125,7 +138,7 @@ public XMLSyntaxRule[] getSyntaxRules() {

private final XMLSyntaxRule[] rules = {
AttributeRule.newIntegerRule(CATEGORIES, true),
AttributeRule.newStringRule(PARAMETERIZATION, true),
// AttributeRule.newStringRule(PARAMETERIZATION, true),
new XORRule(
new ElementRule(SUBSTITUTION_RATE, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
Expand Down
22 changes: 16 additions & 6 deletions src/dr/inference/model/TransformedMultivariateParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,24 @@ public double getParameterValue(int dim) {

public void setParameterValue(int dim, double value) {
update();
transformedValues[dim] = value;
unTransformedValues = inverse(transformedValues);
unTransformedValues[dim] = value;

This comment has been minimized.

Copy link
@gabehassler

gabehassler Sep 1, 2023

Contributor

@mdhall272 This seems like a breaking change to me. Shouldn't setParameterValue set the value of the parameter you're calling it on? I feel like this change means that setParameterValue will do something different for TransformedMultivariateParameter objects than it will for other Parameter objects.

This comment has been minimized.

Copy link
@mdhall272

mdhall272 Sep 4, 2023

Author Contributor

To be honest I hadn't realised this had got to the main branch and it should probably go. The code surrounding transformed parameters seems a bit of a mess and needs refactoring - I've found some inconsistencies about whether the "transformed" or "untransformed" parameter is the one whose value you would expect to set, or the one whose value you would expect a model to use. I was going to do that in my freerate branch when I got the time. Anything from before the appearance of that branch should maybe go back to how it was. This version of FreeRate doesn't actually work properly anyway; I was forgetting what had already gone to @rambaut's quadrature branch.

This comment has been minimized.

Copy link
@gabehassler

gabehassler Sep 5, 2023

Contributor

Thanks for all the information. I may go ahead and revert this file if that's ok. I'm unfamiliar with the rest of the free rate and quadrature work, so I don't want to touch anything else in this commit.

/* transformedValues[dim] = value;
unTransformedValues = inverse(transformedValues);*/
// Need to update all values
parameter.setParameterValueNotifyChangedAll(0, unTransformedValues[0]); // Warn everyone is changed
for (int i = 1; i < parameter.getDimension(); i++) {
parameter.setParameterValueQuietly(i, unTransformedValues[i]); // Do the rest quietly
}
transformedValues = transform(unTransformedValues);
}

public void setParameterValueQuietly(int dim, double value) {
update();
transformedValues[dim] = value;
unTransformedValues = inverse(transformedValues);
unTransformedValues[dim] = value;
transformedValues = transform(unTransformedValues);

/* transformedValues[dim] = value;
unTransformedValues = inverse(transformedValues);*/
// Need to update all values
for (int i = 0; i < parameter.getDimension(); i++) {
parameter.setParameterValueQuietly(i, unTransformedValues[i]);
Expand All @@ -91,18 +96,23 @@ public void addBounds(Bounds<Double> bounds) {
// }

private void update() {
if (hasChanged()) {

// if (hasChanged()) {
unTransformedValues = parameter.getParameterValues();
transformedValues = transform(unTransformedValues);
}
// }
}

private boolean hasChanged() {


for (int i = 0; i < unTransformedValues.length; i++) {
if (parameter.getParameterValue(i) != unTransformedValues[i]) {
return true;
}
}


return false;
}
}
1 change: 1 addition & 0 deletions src/dr/inference/model/TransformedParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public TransformedParameter(Parameter parameter, Transform transform) {
}

public TransformedParameter(Parameter parameter, Transform transform, boolean inverse) {

this.parameter = parameter;
this.transform = transform;
this.inverse = inverse;
Expand Down

0 comments on commit abd52ff

Please sign in to comment.