Skip to content

Commit

Permalink
add grad wrt mean for GP and clean some warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 30, 2023
1 parent 3e13a4e commit c66bbea
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}

private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
new XORRule(
new StringAttributeRule(DataType.DATA_TYPE, "The type of sequence data",
DataType.getRegisteredDataTypeNames(), false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

import static dr.math.distributions.gp.AdditiveGaussianProcessDistribution.BasisDimension;
import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE;
import static dr.inferencexml.distribution.RandomFieldParser.parseWeightProvider;

public class GaussianProcessFieldParser extends AbstractXMLObjectParser {

Expand Down
3 changes: 1 addition & 2 deletions src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,7 @@ static class SymmetricTriDiagonalMatrix {
double[] offDiagonal;

SymmetricTriDiagonalMatrix(int dim) {
this.diagonal = new double[dim];
this.offDiagonal = new double[dim - 1];
this(new double[dim], new double[dim - 1]);
}

SymmetricTriDiagonalMatrix(double[] diagonal, double[] offDiagonal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,6 @@ public double[] getMean() {
return mean;
}

@Override
public GradientProvider getGradientWrt(Parameter parameter) {
throw new RuntimeException("Unknown parameter");
}

@Override
public String getType() {
return TYPE;
Expand Down Expand Up @@ -313,6 +308,41 @@ protected void restoreState() {
@Override
protected void acceptState() { }

@Override
public GradientProvider getGradientWrt(Parameter parameter) {
if (parameter == meanParameter) {
return new GradientProvider() {
@Override
public int getDimension() {
return meanParameter.getDimension();
}

@Override
public double[] getGradientLogDensity(Object x) {

double[] gradient = gradLogPdf((double[]) x, getMean(), getPrecision());

if (meanParameter.getDimension() == dim) {
for (int i = 0; i < dim; ++i) {
gradient[i] *= -1;
}
return gradient;
} else if (meanParameter.getDimension() == 1) {
double sum = 0.0;
for (int i = 0; i < dim; ++i) {
sum += gradient[i];
}
return new double[]{sum}; // TODO should this be -sum?
}

throw new IllegalArgumentException("Unknown mean parameter structure");
}
};
} else {
throw new RuntimeException("Unknown parameter");
}
}

public static class BasisDimension {

private final GaussianProcessKernel kernel;
Expand Down

0 comments on commit c66bbea

Please sign in to comment.