Skip to content

Commit

Permalink
fixed class label parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
fracpete committed Aug 25, 2024
1 parent ab25951 commit 4487403
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/main/java/meka/classifiers/multitarget/NSR.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public void buildClassifier(Instances D) throws Exception {
}
}

protected double[] classLabelToDistribution(String label) {
return A.toDoubleArray(MLUtils.toIntArray("[" + label.replace("+", ",") + "]"));
}

@Override
public double[] distributionForInstance(Instance x) throws Exception {

Expand All @@ -114,15 +118,15 @@ public double[] distributionForInstance(Instance x) throws Exception {
//int max_j = (int)m_Classifier.classifyInstance(x_sl); // where comb_i is selected
String y_max = m_InstancesTemplate.classAttribute().value(max_j); // comb_i e.g. "0+3+0+0+1+2+0+0"

double y[] = Arrays.copyOf(A.toDoubleArray(MLUtils.toIntArray(y_max)),L*2); // "0+3+0+0+1+2+0+0" -> [0.0,3.0,0.0,...,0.0]
double y[] = Arrays.copyOf(classLabelToDistribution(y_max),L*2); // "0+3+0+0+1+2+0+0" -> [0.0,3.0,0.0,...,0.0]

HashMap<Double,Double> votes[] = new HashMap[L];
for(int j = 0; j < L; j++) {
votes[j] = new HashMap<Double,Double>();
}

for(int i = 0; i < w.length; i++) {
double y_i[] = A.toDoubleArray(MLUtils.toIntArray(m_InstancesTemplate.classAttribute().value(i)));
double y_i[] = classLabelToDistribution(m_InstancesTemplate.classAttribute().value(i));
for(int j = 0; j < y_i.length; j++) {
votes[j].put(y_i[j] , votes[j].containsKey(y_i[j]) ? votes[j].get(y_i[j]) + w[i] : w[i]);
}
Expand Down

0 comments on commit 4487403

Please sign in to comment.