Skip to content

Commit

Permalink
more GP kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 25, 2023
1 parent 34c0da7 commit d504bc4
Showing 1 changed file with 82 additions and 7 deletions.
89 changes: 82 additions & 7 deletions src/dr/math/distributions/gp/GaussianProcessKernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

import java.util.List;

import static dr.math.ModifiedBesselFirstKind.bessi;
import static dr.math.functionEval.GammaFunction.factorial;
import static dr.math.functionEval.GammaFunction.gamma;

/**
* @author Marc A. Suchard
* @author Filippo Monti
Expand Down Expand Up @@ -62,18 +66,73 @@ public double getCorrelation(double x, double y) {
}

public double getCorrelation(double[] x, double[] y) {
final int dim = x.length;
return functionalForm(normSquared(x, y));
}

double normSquared = 0.0;
for (int i = 0; i < dim; ++i) {
double diff = x[i] - y[i];
normSquared += diff * diff;
private static final String TYPE = "RadialBasisFunction";
}

class OrnsteinUhlenbeck extends L1Base {

public OrnsteinUhlenbeck(String name, List<Parameter> parameters) { super(name, parameters); }

double functionalForm(double norm) {
double length = parameters.get(0).getParameterValue(0);
return Math.exp(-norm / length);
}

private static final String TYPE = "OrnsteinUhlenbeck";
}

class Matern extends L1Base {

private final Parameter orderParameter;
private final int order;
private final double normalization;
private final double scale;

public Matern(String name, List<Parameter> parameters) {
super(name, parameters);
this.orderParameter = parameters.get(1);

this.order = (int) orderParameter.getParameterValue(0);
this.normalization = Math.pow(2, 1 - order) / factorial(order - 1);
this.scale = Math.sqrt(2 * order);
}

double functionalForm(double norm) {
double length = parameters.get(0).getParameterValue(0);
double argument = scale * norm / length;

return normalization * Math.pow(argument, order) * bessi(argument, order);
}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == orderParameter) {
throw new RuntimeException("Not yet implemented");
}
super.handleVariableChangedEvent(variable, index, type);
}

private static final String TYPE = "Matern";
}

abstract class L1Base extends Base {

public L1Base(String name, List<Parameter> parameters) { super(name, parameters); }

return functionalForm(normSquared);
abstract double functionalForm(double norm);

public double getCorrelation(double x, double y) {
double norm = Math.abs(x - y);
return functionalForm(norm);
}

private static final String TYPE = "RadialBasisFunction";
public double getCorrelation(double[] x, double[] y) {
double norm = Math.sqrt(normSquared(x, y));
return functionalForm(norm);
}
}

static GaussianProcessKernel factory(String type, String name, List<Parameter> parameters)
Expand All @@ -82,6 +141,10 @@ static GaussianProcessKernel factory(String type, String name, List<Parameter> p
return new DotProduct(name, parameters);
} else if (type.equalsIgnoreCase(RadialBasisFunction.TYPE)) {
return new RadialBasisFunction(name, parameters);
} else if (type.equalsIgnoreCase(Matern.TYPE)) {
return new Matern(name, parameters);
} else if (type.equalsIgnoreCase(OrnsteinUhlenbeck.TYPE)) {
return new OrnsteinUhlenbeck(name, parameters);
} else {
throw new IllegalArgumentException("Unknown kernel type");
}
Expand All @@ -102,6 +165,18 @@ public Base(String name,
}
}

double normSquared(double[] x, double[] y) {
final int dim = x.length;

double normSquared = 0.0;
for (int i = 0; i < dim; ++i) {
double diff = x[i] - y[i];
normSquared += diff * diff;
}

return normSquared;
}

@Override
public double getScale() {
return 1.0;
Expand Down

0 comments on commit d504bc4

Please sign in to comment.