Skip to content

Commit

Permalink
allow NewAntigenicLikelihood to use FastMatrixParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 14, 2023
1 parent 837ffff commit 913f15b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/dr/evolution/coalescent/CoalescentGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class CoalescentGradient implements GradientWrtParameterProvider, Reporta

public enum Wrt {
NODE_HEIGHTS,
PARAMETER;
PARAMETER
}

public CoalescentGradient(CoalescentLikelihood likelihood,
Expand Down
21 changes: 16 additions & 5 deletions src/dr/evomodel/antigenic/NewAntigenicLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public NewAntigenicLikelihood(
Parameter virusDriftParameter,
Parameter serumDriftParameter,
MatrixParameter virusSamplingParameter, // TODO Remove
MatrixParameter serumLocationsParameter,
MatrixParameterInterface serumLocationsParameter,
CompoundParameter tipTraitsParameter,
Parameter virusOffsetsParameter,
Parameter serumOffsetsParameter,
Expand Down Expand Up @@ -651,9 +651,20 @@ private Parameter setupSerumBreadths(Parameter serumBreadthsParameter) {
return serumBreadthsParameter;
}

protected void setupLocationsParameter(MatrixParameter locationsParameter, List<String> strains) {
locationsParameter.setColumnDimension(mdsDimension);
locationsParameter.setRowDimension(strains.size());
protected void setupLocationsParameter(MatrixParameterInterface locationsParameter, List<String> strains) {
if (locationsParameter instanceof MatrixParameter) {
((MatrixParameter) locationsParameter).setColumnDimension(mdsDimension);
((MatrixParameter) locationsParameter).setRowDimension(strains.size());
} else if (locationsParameter instanceof FastMatrixParameter) {
FastMatrixParameter fmp = (FastMatrixParameter) locationsParameter;
if (fmp.getRowDimension() != mdsDimension) {
throw new IllegalArgumentException("Column dim must be " + mdsDimension);
}
if (fmp.getColumnDimension() != strains.size()) {
throw new IllegalArgumentException("Row dim must be " + strains.size());
}
}

for (int i = 0; i < strains.size(); i++) {
locationsParameter.getParameter(i).setId(strains.get(i));
}
Expand Down Expand Up @@ -1116,7 +1127,7 @@ private Measurement(final int virus, final int serum, final double virusDate, fi
private final Parameter serumDriftParameter;

private final MatrixParameter virusSamplingParameter;
private final MatrixParameter serumLocationsParameter;
private final MatrixParameterInterface serumLocationsParameter;

private final Parameter virusOffsetsParameter;
private final Parameter serumOffsetsParameter;
Expand Down
11 changes: 4 additions & 7 deletions src/dr/evomodelxml/antigenic/AntigenicLikelihoodParser.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package dr.evomodelxml.antigenic;

import dr.evomodel.antigenic.NewAntigenicLikelihood;
import dr.inference.model.AbstractTransformedCompoundMatrix;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.model.*;
import dr.util.Citable;
import dr.util.DataTable;
import dr.xml.*;
Expand Down Expand Up @@ -78,9 +75,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
// TOD Remove
}

MatrixParameter serumLocationsParameter = null;
MatrixParameterInterface serumLocationsParameter = null;
if (xo.hasChildNamed(SERUM_LOCATIONS)) {
serumLocationsParameter = (MatrixParameter) xo.getElementFirstChild(SERUM_LOCATIONS);
serumLocationsParameter = (MatrixParameterInterface) xo.getElementFirstChild(SERUM_LOCATIONS);
}

Parameter mdsPrecision = (Parameter) xo.getElementFirstChild(MDS_PRECISION);
Expand Down Expand Up @@ -173,7 +170,7 @@ public XMLSyntaxRule[] getSyntaxRules() {
AttributeRule.newDoubleRule(DRIFT_INITIAL_LOCATIONS, true, "The degree to drift initial virus and serum locations, defaults to 0.0"),
new ElementRule(TIP_TRAIT, CompoundParameter.class, "Optional parameter of tip locations from the tree", true),
// new ElementRule(VIRUS_LOCATIONS, MatrixParameter.class, "Parameter of locations of all virus"),
new ElementRule(SERUM_LOCATIONS, MatrixParameter.class, "Parameter of locations of all sera"),
new ElementRule(SERUM_LOCATIONS, MatrixParameterInterface.class, "Parameter of locations of all sera"),
new ElementRule(VIRUS_OFFSETS, Parameter.class, "Optional parameter for virus dates to be stored", true),
new ElementRule(SERUM_OFFSETS, Parameter.class, "Optional parameter for serum dates to be stored", true),
new ElementRule(SERUM_POTENCIES, Parameter.class, "Optional parameter for serum potencies", true),
Expand Down
24 changes: 24 additions & 0 deletions src/dr/inference/model/FastMatrixParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ public FastMatrixParameter(String id, List<Parameter> original, boolean signalCo
fireParameterChangedEvent(-1, ChangeType.ALL_VALUES_CHANGED);
}

@Override
public String getDimensionName(int dim) {
int pNum = dim / getRowDimension();
int index = dim % getRowDimension();
String name = getParameter(pNum).getParameterName() + (index + 1);
return name;
}

private void checkParameterLengths(List<Parameter> parameters) {
final int length = parameters.get(0).getDimension();
for (Parameter p : parameters) {
Expand All @@ -93,6 +101,17 @@ private void setProxyParameterNames(List<Parameter> original) {
}
}

private void setProxyParameterName(String name, int column) {
if (proxyParameterNames == null) {
proxyParameterNames = new ArrayList<>();
for (int i = 0; i < getColumnDimension(); ++i) {
proxyParameterNames.add(null);
}
}

proxyParameterNames.set(column, name);
}

private List<String> proxyParameterNames;

private String getProxyParameterName(int column) {
Expand Down Expand Up @@ -169,6 +188,11 @@ public void setParameterValueNotifyChangedAll(int dim, double value) {
throw new RuntimeException("Do not call");
}

@Override
public void setId(String name) {
matrix.setProxyParameterName(name, column);
}

@Override
public String getParameterName() {
String proxyName = matrix.getProxyParameterName(column);
Expand Down
4 changes: 4 additions & 0 deletions src/dr/inference/model/MatrixParameterInterface.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public interface MatrixParameterInterface extends Parameter {

double[] getParameterValues();

default int getParameterCount() {
throw new RuntimeException("Not yet implemented");
}

int getUniqueParameterCount();

Parameter getUniqueParameter(int index);
Expand Down

0 comments on commit 913f15b

Please sign in to comment.