Skip to content

Commit

Permalink
Merge pull request #234 from leondavi/warningFix
Browse files Browse the repository at this point in the history
normalization option fix
  • Loading branch information
leondavi authored Aug 22, 2023
2 parents 83cdb0f + c183c27 commit f66e9c5
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src_py/apiServer/apiServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0


class ApiServer():
def __init__(self):
self.json_dir_parser = JsonDirParser()
Expand Down Expand Up @@ -66,7 +65,7 @@ def help(self):
==========Experiment info===========
-print_saved_experiments() prints saved experiments and their number for statistics later
-plot_loss(ExpNum) saves and shows the loss over time of the chosen experiment
-accuracy_matrix(ExpNum) shows a graphic for the confusion matrix. Also returns all ConfMat[worker][nueron]
-accuracy_matrix(ExpNum, Normalize) Normalize = True | False. shows a graphic for the confusion matrix. Also returns all ConfMat[worker][nueron]
-communication_stats() prints the communication statistics of the current network. integer => message count, float => avg calc time (ms)
_____GLOBAL VARIABLES / CONSTANTS_____
Expand Down Expand Up @@ -313,7 +312,7 @@ def plot_loss(self, expNum):
plt.savefig(f'/usr/local/lib/nerlnet-lib/NErlNet/Results/{expForStats.name}/Training/{fileName}.png')
print(f'\n{fileName}.png was Saved...')

def accuracy_matrix(self, expNum, normalizeEnabled = 'false'):
def accuracy_matrix(self, expNum, normalizeEnabled = False):
expForStats = self.experiments[expNum-1]

# Choose the matching (to the original labeled CSV) CSV from the prediction results list:
Expand Down Expand Up @@ -412,7 +411,10 @@ def accuracy_matrix(self, expNum, normalizeEnabled = 'false'):
for j in range(labelsLen):
# print(f"worker {worker}, has {len(workerNeuronRes[worker][TRUE_LABLE_IND])} labels, with {len(workerNeuronRes[worker][TRUE_LABLE_IND][j])} samples")
# print(f"confusion {worker}:{j}, has is of {workerNeuronRes[worker][TRUE_LABLE_IND][j]}, {workerNeuronRes[worker][PRED_LABLE_IND][j]}")
confMatList[worker][j] = confusion_matrix(workerNeuronRes[worker][globe.TRUE_LABLE_IND][j], workerNeuronRes[worker][globe.PRED_LABLE_IND][j], normalize=normalizeEnabled)
if normalizeEnabled == True :
confMatList[worker][j] = confusion_matrix(workerNeuronRes[worker][globe.TRUE_LABLE_IND][j], workerNeuronRes[worker][globe.PRED_LABLE_IND][j], normalize='all')
else:
confMatList[worker][j] = confusion_matrix(workerNeuronRes[worker][globe.TRUE_LABLE_IND][j], workerNeuronRes[worker][globe.PRED_LABLE_IND][j])
# print(confMatList[worker][j])
disp = ConfusionMatrixDisplay(confMatList[worker][j], display_labels=["X", labelNames[j]])
disp.plot(ax=axes[i, j], colorbar=False)
Expand Down

0 comments on commit f66e9c5

Please sign in to comment.