Skip to content

Commit

Permalink
pointers on how use symmetric diagonal matrix in GMRF
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 2, 2023
1 parent 26125ca commit a2c0c9e
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/dr/inference/distribution/RandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class RandomField extends AbstractModelLikelihood {
public interface WeightProvider {
// TODO returns relative lengths (intercoalescent intervals) between field entries

int getDimension();
}

private final Parameter field;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MultivariateNormalDistributionModelParser.java
* GaussianMarkovRandomFieldParser.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
* Copyright (c) 2002-2023 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
Expand Down Expand Up @@ -36,6 +36,7 @@ public class GaussianMarkovRandomFieldParser extends AbstractXMLObjectParser {
private static final String DIMENSION = "dim";
private static final String PRECISION = "precision";
private static final String START = "start";
private static final String WEIGHTS = "weights";

public String getParserName() { return PARSER_NAME; }

Expand All @@ -49,30 +50,37 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
throw new XMLParseException("Scale must be > 0.0");
}

Parameter start = (Parameter) xo.getElementFirstChild(START);
Parameter start = xo.hasChildNamed(START) ?
(Parameter) xo.getElementFirstChild(START) : null;

return new GaussianMarkovRandomField(dim, incrementPrecision, start);
}
RandomField.WeightProvider weights = xo.hasChildNamed(WEIGHTS) ?
(RandomField.WeightProvider) xo.getElementFirstChild(WEIGHTS) : null;

if (weights != null && weights.getDimension() != dim - 1) {
throw new XMLParseException("Weights dimension (" + weights.getDimension() +
") != distribution dim (" + dim + ") - 1");
}

public XMLSyntaxRule[] getSyntaxRules() {
return rules;
return new GaussianMarkovRandomField(dim, incrementPrecision, start, weights);
}

public XMLSyntaxRule[] getSyntaxRules() { return rules; }

private final XMLSyntaxRule[] rules = {
AttributeRule.newIntegerRule(DIMENSION),
new ElementRule(PRECISION,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(START,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
new ElementRule(WEIGHTS,
new XMLSyntaxRule[]{new ElementRule(RandomField.WeightProvider.class)}, true)

};

public String getParserDescription() {
public String getParserDescription() { // TODO update
return "Describes a normal distribution with a given mean and precision " +
"that can be used in a distributionLikelihood element";
}

public Class getReturnType() {
return RandomField.class;
}

public Class getReturnType() { return RandomField.class; }
}
5 changes: 5 additions & 0 deletions src/dr/inferencexml/distribution/RandomFieldParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
RandomFieldDistribution distribution = (RandomFieldDistribution)
xo.getElementFirstChild(DISTRIBUTION);

if (field.getDimension() != distribution.getDimension()) {
throw new XMLParseException("Field dimension (" + field.getDimension() +
") != distribution dimension (" + distribution.getDimension() + ")");
}

return new RandomField(xo.getId(), field, distribution);
}

Expand Down
201 changes: 161 additions & 40 deletions src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* GaussianMarkovRandomField.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
* Copyright (c) 2002-2023 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
Expand All @@ -28,6 +28,8 @@
import dr.inference.distribution.RandomField;
import dr.inference.model.*;
import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.SymmTridiagMatrix;

import java.util.Arrays;

Expand Down Expand Up @@ -55,32 +57,29 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution {
private boolean determinantKnown;

public GaussianMarkovRandomField(int dim,
Parameter incrementPrecision,
Parameter precision,
Parameter start) {
this(dim, incrementPrecision, start, null);
this(dim, precision, start, null);
}

public GaussianMarkovRandomField(int dim,
Parameter incrementPrecision,
Parameter precision,
Parameter start,
RandomField.WeightProvider weightProvider) {

super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL);

this.dim = dim;
this.meanParameter = start;
this.precisionParameter = incrementPrecision;
this.precisionParameter = precision;
this.weightProvider = weightProvider;

this.mean = new double[dim];
this.precision = new double[dim][dim];

// populateMean(this.mean);
// populatePrecision(this.precision);

meanKnown = false;
precisionKnown = false;
determinantKnown = false;
determinantKnown = false; // TODO No need to be computed separately
}

// private void check() {
Expand All @@ -96,7 +95,9 @@ public GaussianMarkovRandomField(int dim,

public double[] getMean() {
if (!meanKnown) {
if (meanParameter.getDimension() == 1) {
if (meanParameter == null) {
Arrays.fill(mean, 0.0);
} else if (meanParameter.getDimension() == 1) {
Arrays.fill(mean, meanParameter.getParameterValue(0));
} else {
for (int i = 0; i < mean.length; ++i) {
Expand Down Expand Up @@ -177,44 +178,17 @@ public double getLogDet() {
return logDet;
}

// private boolean isDiagonal(double x[][]) {
// for (int i = 0; i < x.length; ++i) {
// for (int j = i + 1; j < x.length; ++j) {
// if (x[i][j] != 0.0) {
// return false;
// }
// }
// }
// return true;
// }
//
// private double logDetForDiagonal(double x[][]) {
// double logDet = 0;
// for (int i = 0; i < x.length; ++i) {
// logDet += Math.log(x[i][i]);
// }
// return logDet;
// }


@Override
public double[][] getScaleMatrix() {
return getPrecision();
}

// public static double calculatePrecisionMatrixDeterminate(double[][] precision) {
// try {
// return new Matrix(precision).determinant();
// } catch (IllegalDimension e) {
// throw new RuntimeException(e.getMessage());
// }
// }


@Override
public Variable<Double> getLocationVariable() {
return null;
}

@Override
public double logPdf(double[] x) {
return logPdf(x, getMean(), getPrecision(), getLogDet());
}
Expand Down Expand Up @@ -270,6 +244,153 @@ public static double[] diagonalHessianLogPdf(double[] x, double[][] precision) {
return hessian;
}

// TODO Below is the relevant code from GMRFMultilocusSkyrideLikelihood for building a `SymmTridiagMatrix`
// TODO `getFieldScalar` rescaling should be handled by `WeightsProvider`

// protected double getFieldScalar() {
// final double rootHeight;
// if (rescaleByRootHeight) {
// rootHeight = tree.getNodeHeight(tree.getRoot());
// } else {
// rootHeight = 1.0;
// }
// return rootHeight;
// }
//
// protected void setupGMRFWeights() {
//
// setupSufficientStatistics();
//
// //Set up the weight Matrix
// double[] offdiag = new double[fieldLength - 1];
// double[] diag = new double[fieldLength];
//
// //First set up the offdiagonal entries;
//
// if (!timeAwareSmoothing) {
// for (int i = 0; i < fieldLength - 1; i++) {
// offdiag[i] = -1.0;
// }
// } else {
// for (int i = 0; i < fieldLength - 1; i++) {
// offdiag[i] = -2.0 / (coalescentIntervals[i] + coalescentIntervals[i + 1]) * getFieldScalar();
// }
// }
//
// //Then set up the diagonal entries;
// for (int i = 1; i < fieldLength - 1; i++)
// diag[i] = -(offdiag[i] + offdiag[i - 1]);
//
// //Take care of the endpoints
// diag[0] = -offdiag[0];
// diag[fieldLength - 1] = -offdiag[fieldLength - 2];
//
// weightMatrix = new SymmTridiagMatrix(diag, offdiag);
// }
//
// public SymmTridiagMatrix getScaledWeightMatrix(double precision) {
// SymmTridiagMatrix a = weightMatrix.copy();
// for (int i = 0; i < a.numRows() - 1; i++) {
// a.set(i, i, a.get(i, i) * precision);
// a.set(i + 1, i, a.get(i + 1, i) * precision);
// }
// a.set(fieldLength - 1, fieldLength - 1, a.get(fieldLength - 1, fieldLength - 1) * precision);
// return a;
// }
//
// public SymmTridiagMatrix getScaledWeightMatrix(double precision, double lambda) {
// if (lambda == 1)
// return getScaledWeightMatrix(precision);
//
// SymmTridiagMatrix a = weightMatrix.copy();
// for (int i = 0; i < a.numRows() - 1; i++) {
// a.set(i, i, precision * (1 - lambda + lambda * a.get(i, i)));
// a.set(i + 1, i, a.get(i + 1, i) * precision * lambda);
// }
//
// a.set(fieldLength - 1, fieldLength - 1, precision * (1 - lambda + lambda * a.get(fieldLength - 1, fieldLength - 1)));
// return a;
// }
//
// private DenseVector getMeanAdjustedGamma() {
// DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues());
// updateGammaWithCovariates(currentGamma);
// return currentGamma;
// }
//
// double getLogFieldLikelihood() {
//
// DenseVector diagonal1 = new DenseVector(fieldLength);
// DenseVector currentGamma = getMeanAdjustedGamma();
//
// double currentLike = handleMissingValues();
//
// SymmTridiagMatrix currentQ = getScaledWeightMatrix(precisionParameter.getParameterValue(0), lambdaParameter.getParameterValue(0));
// currentQ.mult(currentGamma, diagonal1);
//
// currentLike += 0.5 * (fieldLength - 1) * Math.log(precisionParameter.getParameterValue(0)) - 0.5 * currentGamma.dot(diagonal1);
// if (lambdaParameter.getParameterValue(0) == 1) {
// currentLike -= (fieldLength - 1) / 2.0 * LOG_TWO_TIMES_PI;
// } else {
// currentLike -= fieldLength / 2.0 * LOG_TWO_TIMES_PI;
// }
//
// return currentLike;
// }

private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
double precision, double lambda) {

final int dim = x.length;
final double[] delta = new double[dim];

for (int i = 0; i < dim; ++i) {
delta[i] = x[i] - mean[i];
}

double SSE = 0.0;
for (int i = 0; i < dim - 1; i++) {
SSE += Q.diagonal[i] * delta[i] * delta[i] + 2 * Q.offDiagonal[i] * delta[i] * delta[i + 1];
}
SSE += Q.diagonal[dim - 1] * delta[dim - 1] * delta[dim - 1];

double logLikelihood = 0.5 * (dim - 1) * Math.log(precision) - 0.5 * SSE;
if (lambda == 1.0) {
logLikelihood -= (dim - 1) * logNormalize;
} else {
logLikelihood -= dim * logNormalize;
}

return logLikelihood;
}


class SymmetricTriDiagonalMatrix {

double[] diagonal;
double[] offDiagonal;

SymmetricTriDiagonalMatrix(double[] diagonal, double[] offDiagonal) {
this.diagonal = diagonal;
this.offDiagonal = offDiagonal;
}

void copy(SymmetricTriDiagonalMatrix copy) {
System.arraycopy(diagonal, 0, copy.diagonal, 0, diagonal.length);
System.arraycopy(offDiagonal, 0, copy.offDiagonal, 0, offDiagonal.length);
}

void swap(SymmetricTriDiagonalMatrix swap) {
double[] tmp1 = diagonal;
diagonal = swap.diagonal;
swap.diagonal = tmp1;

double[] tmp2 = offDiagonal;
offDiagonal = swap.offDiagonal;
swap.offDiagonal = tmp2;
}
}

// scale only modifies precision
// in one dimension, this is equivalent to:
// PDF[NormalDistribution[mean, Sqrt[scale]*Sqrt[1/precison]], x]
Expand Down Expand Up @@ -318,7 +439,7 @@ public static double logPdf(double[] x, double[] mean, double[][] precision,


@Override
public int getDimension() { return mean.length; }
public int getDimension() { return dim; }

// public Parameter getincrementPrecision() { return precisionParameter; }

Expand Down

0 comments on commit a2c0c9e

Please sign in to comment.