Skip to content

Commit

Permalink
Fix GMRF to work with lambda null
Browse files Browse the repository at this point in the history
  • Loading branch information
PratyusaDatta committed Dec 15, 2023
1 parent 42e6fc3 commit 8b14d64
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 47 deletions.
55 changes: 52 additions & 3 deletions ci/TestXML/testGaussianMarkovRandomField.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
<start>
<parameter id="mean" value="0.0"/>
</start>
<lambda>
<parameter id="lambda" value="0.9"/>
</lambda>
</GaussianMarkovRandomField>
</distribution>
<data>
Expand Down Expand Up @@ -54,4 +51,56 @@
<randomFieldGradient idref="gradientMean"/>
</report>

<randomField id="gmrfProper">
<distribution>
<GaussianMarkovRandomField dim="4" matchPseudoDeterminant="true">
<precision>
<parameter id="precisionProper" value="1.5"/>
</precision>
<start>
<parameter id="meanProper" value="0.0"/>
</start>
<lambda>
<parameter id="lambda" value="0.9"/>
</lambda>
</GaussianMarkovRandomField>
</distribution>
<data>
<parameter idref="data"/>
</data>
</randomField>


<report>
<randomField idref="gmrfProper"/>
</report>

<randomFieldGradient id="gradientProper">
<randomField idref="gmrfProper"/>
<parameter idref="data"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientProper"/>
</report>

<randomFieldGradient id="gradientPrecisionProper">
<randomField idref="gmrfProper"/>
<parameter idref="precisionProper"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientPrecisionProper"/>
</report>

<randomFieldGradient id="gradientMeanProper">
<randomField idref="gmrfProper"/>
<parameter idref="meanProper"/>
</randomFieldGradient>

<report>
<randomFieldGradient idref="gradientMeanProper"/>
</report>


</beast>
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
Parameter start = xo.hasChildNamed(START) ?
(Parameter) xo.getElementFirstChild(START) : null;

Parameter lambda = (Parameter) xo.getElementFirstChild(LAMBDA);

if (Math.abs(lambda.getParameterValue(0)) > 1.0) {
throw new XMLParseException("Lambda must be between -1.0 and 1.0");
}
Parameter lambda = xo.hasChildNamed(LAMBDA) ?
(Parameter) xo.getElementFirstChild(LAMBDA) : null;

RandomField.WeightProvider weights = parseWeightProvider(xo, dim);

Expand Down
107 changes: 68 additions & 39 deletions src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution {
public GaussianMarkovRandomField(String name,
int dim,
Parameter precision,
Parameter start,
Parameter lambda) {
this(name, dim, precision, start, lambda, null, true);
Parameter start) {
this(name, dim, precision, start, null, null, true);
}

public GaussianMarkovRandomField(String name,
Expand All @@ -85,7 +84,10 @@ public GaussianMarkovRandomField(String name,

addVariable(meanParameter);
addVariable(precisionParameter);
addVariable(lambdaParameter);

if (lambda != null) {
addVariable(lambdaParameter);
}

if (weightProvider != null) {
addModel(weightProvider);
Expand Down Expand Up @@ -124,16 +126,30 @@ public double[] getMean() {
protected SymmetricTriDiagonalMatrix getQ() {
if (!qKnown) {
double precision = precisionParameter.getParameterValue(0);
double lambda = lambdaParameter.getParameterValue(0);
Q.diagonal[0] = precision;
for (int i = 1; i < dim - 1; ++i) {
Q.diagonal[i] = (1 + lambda * lambda) * precision;
if(lambdaParameter == null){
Q.diagonal[0] = precision;
for (int i = 1; i < dim - 1; ++i) {
Q.diagonal[i] = 2 * precision;
}
Q.diagonal[dim - 1] = precision;

for (int i = 0; i < dim - 1; ++i) {
Q.offDiagonal[i] = -precision;
}
}
Q.diagonal[dim - 1] = precision;
else {
double lambda = lambdaParameter.getParameterValue(0);
Q.diagonal[0] = precision;
for (int i = 1; i < dim - 1; ++i) {
Q.diagonal[i] = (1 + lambda * lambda) * precision;
}
Q.diagonal[dim - 1] = precision;

for (int i = 0; i < dim - 1; ++i) {
Q.offDiagonal[i] = -precision * lambda;
for (int i = 0; i < dim - 1; ++i) {
Q.offDiagonal[i] = -precision * lambda;
}
}

// TODO Update for lambda != 1 and for weights

qKnown = true;
Expand All @@ -145,23 +161,39 @@ private double[][] getPrecision() {

if (!precisionKnown) {
final double k = precisionParameter.getParameterValue(0);
final double p = lambdaParameter.getParameterValue(0);

precision[0][0] = k;
precision[0][1] = -1 * k * p;
precision[dim - 1][dim - 1] = k;
precision[dim - 1][dim - 2] = -1 * k * p;
for (int i = 1; i < dim - 1; ++i) {
precision[i][i] = (1 + p * p) * k;
precision[i][i - 1] = -1 * k * p;
precision[i][i + 1] = -1 * k * p;
if(lambdaParameter == null){
precision[0][0] = k;
precision[0][1] = -1 * k;
precision[dim - 1][dim - 1] = k;
precision[dim - 1][dim - 2] = -1 * k;
for (int i = 1; i < dim - 1; ++i) {
precision[i][i] = 2 * k;
precision[i][i - 1] = -1 * k;
precision[i][i + 1] = -1 * k;
}
}
else {
final double p = lambdaParameter.getParameterValue(0);

precision[0][0] = k;
precision[0][1] = -1 * k * p;
precision[dim - 1][dim - 1] = k;
precision[dim - 1][dim - 2] = -1 * k * p;
for (int i = 1; i < dim - 1; ++i) {
precision[i][i] = (1 + p * p) * k;
precision[i][i - 1] = -1 * k * p;
precision[i][i + 1] = -1 * k * p;
}
}

precisionKnown = true;
}
return precision;
}

private boolean checkImpropriety(){
final boolean checkImproper;
checkImproper = lambdaParameter == null;
return checkImproper;
}
@Override
public GradientProvider getGradientWrt(Parameter parameter) {

Expand All @@ -175,7 +207,7 @@ public int getDimension() {
@Override
public double[] getGradientLogDensity(Object x) {
double gradient = gradLogPdfWrtPrecision((double[]) x, getMean(), getQ(),
precisionParameter.getParameterValue(0), lambdaParameter.getParameterValue(0));
precisionParameter.getParameterValue(0), checkImpropriety());
return new double[]{gradient};
}
};
Expand Down Expand Up @@ -218,23 +250,22 @@ public String getType() {

private double matchPseudoDeterminantTerm(int dim) {
double term = 0.0;
double lambda = lambdaParameter.getParameterValue(0);
if(Math.abs(lambda) == 1) {
if(lambdaParameter == null) {
for (int i = 1; i < dim; ++i) {
double x = (2 - 2 * Math.cos(i * Math.PI / dim));
term += Math.log(x);
}
return term;
}
else{
return (dim - 1) * Math.log(1 - lambda * lambda);
double lambda = lambdaParameter.getParameterValue(0);
return (1 - dim) * Math.log(1 - lambda * lambda);
}
}

private double getLogDeterminant() {
double lambda = lambdaParameter.getParameterValue(0);
double logDet;
if(Math.abs(lambda) == 1) {
if(lambdaParameter == null) {
logDet = (dim - 1) * Math.log(precisionParameter.getParameterValue(0)) + logMatchTerm;
}
else{
Expand Down Expand Up @@ -280,8 +311,7 @@ public double logPdf(double[] x) {
//
// System.err.println(x1 + " ?= " + x2);
//
return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0),
1.0, getLogDeterminant());
return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), checkImpropriety(), getLogDeterminant());
}

public static double[] gradLogPdf(double[] x, double[] mean, double[][] precision) {
Expand All @@ -302,17 +332,16 @@ public static double[] gradLogPdf(double[] x, double[] mean, double[][] precisio
}

public static double gradLogPdfWrtPrecision(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
double precision, double lambda) {
double precision, boolean checkImproper) {
final int dim = x.length;

if(Math.abs(lambda) == 1){
if(checkImproper){
return 0.5 * (dim - 1 - getSSE(x, mean, Q)) / precision;
} // TODO Not correct with lambda != 1.0
}

else{
return 0.5 * (dim - getSSE(x, mean, Q)) / precision;
}

}

public static double[] gradLogPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q) {
Expand Down Expand Up @@ -480,8 +509,8 @@ public static double[] diagonalHessianLogPdf(double[] x, SymmetricTriDiagonalMat
// }

private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
double precision, double lambda, double logDeterminant) {
return getLogNormalization(x.length, precision, lambda, logDeterminant) - 0.5 * getSSE(x, mean, Q);
double precision, boolean checkImproper, double logDeterminant) {
return getLogNormalization(x.length, precision, checkImproper, logDeterminant) - 0.5 * getSSE(x, mean, Q);
}

private static double getSSE(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q) {
Expand All @@ -502,11 +531,11 @@ private static double getSSE(double[] x, double[] mean, SymmetricTriDiagonalMatr
return SSE;
}

private static double getLogNormalization(int dim, double precision, double lambda, double logDeterminant) {
private static double getLogNormalization(int dim, double precision, boolean checkImproper, double logDeterminant) {

double logNorm = 0.5 * logDeterminant;

if (Math.abs(lambda) == 1.0) {
if (checkImproper) {
logNorm -= (dim - 1) * HALF_LOG_TWO_PI;
} else {
logNorm -= dim * HALF_LOG_TWO_PI;
Expand Down

0 comments on commit 8b14d64

Please sign in to comment.