Skip to content

Commit

Permalink
Merge pull request #73 from ayrna/development
Browse files Browse the repository at this point in the history
Changes in dependencies, module renaming and files elimination
  • Loading branch information
franberchez authored Jul 8, 2024
2 parents 406f9c1 + f75968f commit 5305630
Show file tree
Hide file tree
Showing 20 changed files with 228 additions and 110 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

## ⚙️ Installation

`dlordinal v2.0.0` is the last version supported by Python 3.8, Python 3.9 and Python 3.10.
`dlordinal v2.1.0` is the last version supported by Python 3.8, Python 3.9 and Python 3.10.

The easiest way to install `dlordinal` is via `pip`:

Expand Down
8 changes: 0 additions & 8 deletions dlordinal/layers/activation_function.py

This file was deleted.

71 changes: 0 additions & 71 deletions dlordinal/layers/tests/test_ordinal_fully_connected.py

This file was deleted.

6 changes: 5 additions & 1 deletion dlordinal/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .metrics import (
accuracy_off1,
amae,
gmsec,
minimum_sensitivity,
accuracy_off1,
mmae,
write_array_to_file,
write_metrics_dict_to_file,
)
Expand All @@ -10,6 +12,8 @@
"gmsec",
"minimum_sensitivity",
"accuracy_off1",
"amae",
"mmae",
"write_array_to_file",
"write_metrics_dict_to_file",
]
81 changes: 71 additions & 10 deletions dlordinal/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
from sklearn.metrics import confusion_matrix
from pathlib import Path
from typing import Callable, Dict, Optional
import json
import os
from sklearn.metrics import recall_score
from pathlib import Path
from typing import Callable, Dict, Optional

import numpy as np
from sklearn.metrics import confusion_matrix, recall_score


def minimum_sensitivity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
Expand Down Expand Up @@ -103,10 +103,74 @@ def gmsec(y_true: np.ndarray, y_pred: np.ndarray) -> float:
return np.sqrt(sensitivities[0] * sensitivities[-1])


def amae(y_true: np.ndarray, y_pred: np.ndarray):
"""Computes the average mean absolute error computed independently for each class.
Parameters
----------
y_true : array-like
Targets labels with one-hot or integer encoding.
y_pred : array-like
Predicted probabilities or labels.
Returns
-------
amae : float
Average mean absolute error.
"""

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
if len(y_pred.shape) > 1:
y_pred = np.argmax(y_pred, axis=1)

cm = confusion_matrix(y_true, y_pred)
n_class = cm.shape[0]
costs = np.reshape(np.tile(range(n_class), n_class), (n_class, n_class))
costs = np.abs(costs - np.transpose(costs))
non_zero_cm_rows = ~np.all(cm == 0, axis=1)
cm_ = cm[non_zero_cm_rows]
errors = costs * cm_
per_class_maes = np.sum(errors, axis=1) / np.sum(cm_, axis=1).astype("double")
return np.mean(per_class_maes)


def mmae(y_true: np.ndarray, y_pred: np.ndarray):
"""Computes the maximum mean absolute error computed independently for each class.
Parameters
----------
y_true : array-like
Target labels with one-hot or integer encoding.
y_pred : array-like
Predicted probabilities or labels.
Returns
-------
mmae : float
Maximum mean absolute error.
"""

if len(y_true.shape) > 1:
y_true = np.argmax(y_true, axis=1)
if len(y_pred.shape) > 1:
y_pred = np.argmax(y_pred, axis=1)

cm = confusion_matrix(y_true, y_pred)
n_class = cm.shape[0]
costs = np.reshape(np.tile(range(n_class), n_class), (n_class, n_class))
costs = np.abs(costs - np.transpose(costs))
non_zero_cm_rows = ~np.all(cm == 0, axis=1)
cm_ = cm[non_zero_cm_rows]
errors = costs * cm_
per_class_maes = np.sum(errors, axis=1) / np.sum(cm_, axis=1).astype("double")
return per_class_maes.max()


def write_metrics_dict_to_file(
metrics: Dict[str, float],
path_str: str,
filter_fn: Optional[Callable[[str, float], bool]] = None,
filter_fn: Optional[Callable[[str, float], bool]] = lambda n, v: True,
) -> None:
"""Writes a dictionary of metrics to a tabular file.
The dictionary is filtered by the filter function.
Expand All @@ -121,7 +185,7 @@ def write_metrics_dict_to_file(
Path to the file that will be saved.
The directory of the file will be created if it does not exist.
If the file exists, the metrics will be appended to the file in a new row.
filter_fn : Optional[Callable[[str, bool], bool]]
filter_fn : Optional[Callable[[str, bool], bool]], default=lambda n, v: True
Function that filters the metrics.
The function takes the name and the value of the metric and returns ``True`` if the metric should be saved.
Expand All @@ -144,9 +208,6 @@ def write_metrics_dict_to_file(
0.5
"""

filter_fn: Callable[[str, bool], bool] = (
filter_fn if filter_fn is not None else lambda n, v: True
)
path = Path(path_str)
directory = path.parents[0]
os.makedirs(directory, exist_ok=True)
Expand Down
68 changes: 68 additions & 0 deletions dlordinal/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from dlordinal.metrics import (
accuracy_off1,
amae,
gmsec,
minimum_sensitivity,
mmae,
write_array_to_file,
write_metrics_dict_to_file,
)
Expand Down Expand Up @@ -50,6 +52,72 @@ def test_gmsec():
assert result == pytest.approx(expected_result, rel=1e-6)


def test_amae():
y_true = np.array([0, 0, 1, 1])
y_pred = np.array([0, 1, 0, 1])
result = amae(y_true, y_pred)
expected_result = 0.5
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 1, 1, 2, 2])
y_pred = np.array([0, 0, 1, 1, 2, 2])
result = amae(y_true, y_pred)
expected_result = 0.0
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 2, 1])
y_pred = np.array([0, 2, 0, 1])
result = amae(y_true, y_pred)
expected_result = 1.0
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 2, 1, 3])
y_pred = np.array([2, 2, 0, 3, 1])
result = amae(y_true, y_pred)
expected_result = 2.0
assert result == pytest.approx(expected_result, rel=1e-6)

# Test using one-hot and probabilities
y_true = np.array([[1, 0], [1, 0], [0, 1], [0, 1]])
y_pred = np.array([[1, 0], [0, 1], [1, 0], [0, 1]])
result = amae(y_true, y_pred)
expected_result = 0.5
assert result == pytest.approx(expected_result, rel=1e-6)


def test_mmae():
y_true = np.array([0, 0, 1, 1])
y_pred = np.array([0, 1, 0, 1])
result = mmae(y_true, y_pred)
expected_result = 0.5
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 1, 1, 2, 2])
y_pred = np.array([0, 0, 1, 1, 2, 2])
result = mmae(y_true, y_pred)
expected_result = 0.0
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 2, 1])
y_pred = np.array([0, 2, 0, 1])
result = mmae(y_true, y_pred)
expected_result = 2.0
assert result == pytest.approx(expected_result, rel=1e-6)

y_true = np.array([0, 0, 2, 1, 3])
y_pred = np.array([2, 2, 0, 3, 1])
result = mmae(y_true, y_pred)
expected_result = 2.0
assert result == pytest.approx(expected_result, rel=1e-6)

# Test using one-hot and probabilities
y_true = np.array([[1, 0], [1, 0], [0, 1], [0, 1]])
y_pred = np.array([[1, 0], [0, 1], [1, 0], [0, 1]])
result = mmae(y_true, y_pred)
expected_result = 0.5
assert result == pytest.approx(expected_result, rel=1e-6)


def test_write_metrics_dict_to_file():
metrics = {"acc": 0.5, "gmsec": 0.25}
path_str = "test_results.txt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
ResNetOrdinalFullyConnected,
VGGOrdinalFullyConnected,
)
from .activation_function import activation_function_by_name
from .stick_breaking_layer import StickBreakingLayer

__all__ = [
"CLM",
"ResNetOrdinalFullyConnected",
"VGGOrdinalFullyConnected",
"activation_function_by_name",
"StickBreakingLayer",
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pytest
import torch
from dlordinal.layers import CLM

from dlordinal.output_layers import CLM


def test_clm_creation():
Expand Down
67 changes: 67 additions & 0 deletions dlordinal/output_layers/tests/test_ordinal_fully_connected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch import nn

from dlordinal.output_layers import (
ResNetOrdinalFullyConnected,
VGGOrdinalFullyConnected,
)


def test_ordinal_resnet_fc_creation():
input_size = 10
num_classes = 5
resnet_fc = ResNetOrdinalFullyConnected(input_size, num_classes)

assert isinstance(resnet_fc, ResNetOrdinalFullyConnected)


def test_ordinal_resnet_fc_output():
input_size = 10
num_classes = 5

resnet_fc = ResNetOrdinalFullyConnected(input_size, num_classes)

input_data = torch.randn(16, input_size)

output = resnet_fc(input_data)

# Check that the output has the correct size
expected_output_size = (16, num_classes - 1)
assert output.size() == expected_output_size

# Check that all values are in the range [0, 1] after applying sigmoid
assert (output >= 0).all()
assert (output <= 1).all()


def test_initialisation_VGG():
input_size = 512
num_classes = 5
activation_function = nn.ReLU

model = VGGOrdinalFullyConnected(input_size, num_classes, activation_function)

assert len(model.classifiers) == num_classes - 1
for classifier in model.classifiers:
assert isinstance(classifier, nn.Sequential)
layers = list(classifier)
assert isinstance(layers[0], nn.Linear)
assert isinstance(layers[1], activation_function)
assert isinstance(layers[2], nn.Dropout)
assert isinstance(layers[3], nn.Linear)
assert isinstance(layers[4], activation_function)
assert isinstance(layers[5], nn.Dropout)
assert isinstance(layers[6], nn.Linear)


def test_forward_VGG():
input_size = 512
num_classes = 5
activation_function = nn.ReLU

model = VGGOrdinalFullyConnected(input_size, num_classes, activation_function)
x = torch.randn(10, input_size) # Batch size of 10
output = model(x)

assert output.shape == (10, num_classes - 1)
assert torch.all(output >= 0) and torch.all(output <= 1)
Loading

0 comments on commit 5305630

Please sign in to comment.