Skip to content

Commit

Permalink
gradient wrt drift-parameter in MDS
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Sep 8, 2023
1 parent ffc8ca7 commit df277b5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 29 deletions.
63 changes: 39 additions & 24 deletions src/dr/evomodel/antigenic/AntigenicGradientWrtParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,46 @@ public void getGradient(double[] gradient, int offset,
}
}

class Drift extends Locations {

// OTHER {
// @Override
// public boolean requiresLocationGradient() { return false; }
//
// @Override
// public boolean requiresObservationGradient() { return true; }
//
// @Override
// public int getSize(int viruses, int sera, int dim) {
// return viruses * sera;
// }
//
// @Override
// void getGradient(double[] gradient, int offset,
// double[] locationGradient,
// double[] observationGradient) {
// System.arraycopy(observationGradient, 0, gradient, offset, observationGradient.length);
// }
//
// @Override
// Parameter getParameter(NewAntigenicLikelihood likelihood) {
// return null;
// }
// };
private final Parameter virusTime;
private final Parameter serumTime;

Drift(int viruses, int sera, int mdsDim,
Parameter locationDrift, Parameter virusTime, Parameter serumTime,
NewAntigenicLikelihood.Layout layout) {
super(viruses, sera, mdsDim, locationDrift, layout);
this.virusTime = virusTime;
this.serumTime = serumTime;
}

@Override
int getLocationOffset() {
throw new RuntimeException("Should not be called");
}

@Override
public int getSize() { return 1; }

@Override
public void getGradient(double[] gradient, int offset,
double[] locationGradient, double[] observationGradient) {

double derivative = 0;

int virusOffset = layout.getVirusLocationOffset();
for (int i = 0; i < viruses; ++i) {
derivative += locationGradient[virusOffset + i * mdsDim] * virusTime.getParameterValue(i);
}

int serumOffset = layout.getSerumLocationOffset();
for (int i = 0; i < sera; ++i) {
derivative += locationGradient[serumOffset + i * mdsDim] * serumTime.getParameterValue(i);
}

gradient[offset] = derivative;
}
}

abstract class Base implements AntigenicGradientWrtParameter {

Expand Down
40 changes: 35 additions & 5 deletions src/dr/evomodel/antigenic/NewAntigenicLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public NewAntigenicLikelihood(
this.tipTraitsParameter = tipTraitsParameter;
addVariable(tipTraitsParameter);
this.tipIndices = setupTipIndices(this.tipTraitsParameter, virusNames);
this.virusIndices = setupVirusIndices(tipIndices);

this.mdsDimension = mdsDimension;

Expand Down Expand Up @@ -417,6 +418,12 @@ private void transferLocations() {
for (int j = 0; j < mdsDimension; ++j) {
locations[offset + j] = parameter.getParameterValue(j);
}

if (locationDriftParameter != null) {
locations[offset] += locationDriftParameter.getParameterValue(0) *
virusOffsetsParameter.getParameterValue(i);
}

offset += mdsDimension;
}

Expand All @@ -426,6 +433,11 @@ private void transferLocations() {
for (int j = 0; j < mdsDimension; ++j) {
locations[offset + j] = parameter.getParameterValue(j);
}

if (locationDriftParameter != null) {
locations[offset] += locationDriftParameter.getParameterValue(0) *
serumOffsetsParameter.getParameterValue(i);
}
offset += mdsDimension;
}

Expand Down Expand Up @@ -684,6 +696,14 @@ private int[] setupTipIndices(CompoundParameter tipTraitsParameter,
return tipIndices;
}

private int[] setupVirusIndices(int[] tipIndices) {
int[] virusIndices = new int[tipIndices.length];
for (int i = 0; i < tipIndices.length; ++i) {
virusIndices[tipIndices[i]] = i;
}
return virusIndices;
}

private final int findStrain(String label, List<String> strainNames) {
int index = 0;
for (String strainName : strainNames) {
Expand Down Expand Up @@ -772,24 +792,30 @@ protected void handleVariableChangedEvent(Variable variable, int index, Variable
precisionKnown = false;
} else if (variable == locationDriftParameter) {
setLocationChangedFlags(true);
observationsKnown = false;
locationsKnown = false;
} else if (variable == virusDriftParameter) {
setLocationChangedFlags(true);
observationsKnown = false;
locationsKnown = false;
throw new IllegalArgumentException("Not yet implemented");
} else if (variable == serumDriftParameter) {
setLocationChangedFlags(true);
observationsKnown = false;
locationsKnown = false;
throw new IllegalArgumentException("Not yet implemented");
} else if (variable == serumPotenciesParameter) {
serumEffectChanged[index] = true;
observationsKnown = false;
throw new IllegalArgumentException("Not yet implemented");
} else if (variable == serumBreadthsParameter) {
serumEffectChanged[index] = true;
observationsKnown = false;
throw new IllegalArgumentException("Not yet implemented");
} else if (variable == virusAviditiesParameter) {
virusEffectChanged[index] = true;
observationsKnown = false;
throw new IllegalArgumentException("Not yet implemented");
} else {
throw new IllegalArgumentException("Unknown parameter");


}

likelihoodKnown = false;
Expand Down Expand Up @@ -1060,6 +1086,9 @@ public AntigenicGradientWrtParameter wrtFactory(Parameter parameter) {
} else if (parameter == serumLocationsParameter) {
return new AntigenicGradientWrtParameter.SerumLocations(numViruses, numSera, mdsDimension,
serumLocationsParameter, layout);
} else if (parameter == locationDriftParameter) {
return new AntigenicGradientWrtParameter.Drift(numViruses, numSera, mdsDimension,
locationDriftParameter, virusOffsetsParameter, serumOffsetsParameter, layout);
} else {
throw new IllegalArgumentException("Not yet implemented");
}
Expand Down Expand Up @@ -1108,7 +1137,8 @@ private Measurement(final int virus, final int serum, final double virusDate, fi
private final Parameter serumOffsetsParameter;

private final CompoundParameter tipTraitsParameter;
private int[] tipIndices;
private final int[] tipIndices;
private final int[] virusIndices;

private final Parameter virusAviditiesParameter;
private final Parameter serumPotenciesParameter;
Expand Down

0 comments on commit df277b5

Please sign in to comment.