Skip to content

Commit

Permalink
Merge pull request #12 from sicara/fix-validation
Browse files Browse the repository at this point in the history
Fix call to validation during training (0.2.0 => 0.2.1)
  • Loading branch information
ebennequin authored Jun 22, 2021
2 parents 8387628 + 966bf34 commit afb3155
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.0
current_version = 0.2.1
commit = True
tag = False

Expand Down
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ matplotlib>=3.3.4
pandas>=1.1.0
pylint>=2.7.0
pytest>=6.2.2
pytest-mock>=3.6.1
torch>=1.7.1
torchvision>=0.8.2
tqdm>=4.56.0
2 changes: 1 addition & 1 deletion easyfsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
for few-shot learning experiences.
"""

__version__ = "0.2.0"
__version__ = "0.2.1"
2 changes: 1 addition & 1 deletion easyfsl/methods/abstract_meta_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def fit(

# Validation
if val_loader:
if episode_index + 1 % validation_frequency == 0:
if (episode_index + 1) % validation_frequency == 0:
self.validate(val_loader)

def validate(self, val_loader: DataLoader) -> float:
Expand Down
13 changes: 7 additions & 6 deletions easyfsl/tests/data_tools/easy_set_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ class TestEasySetListDataInstances:
)
],
)
def test_list_data_instances_returns_expected_values(class_roots, images, labels):
with patch("pathlib.Path.glob") as mock_glob:
mock_glob.return_value = [Path("a.png"), Path("b.png")]
with patch("pathlib.Path.is_file") as mock_is_file:
mock_is_file.return_value = True
assert (images, labels) == EasySet.list_data_instances(class_roots)
def test_list_data_instances_returns_expected_values(
class_roots, images, labels, mocker
):
mocker.patch("pathlib.Path.glob", return_value=[Path("a.png"), Path("b.png")])
mocker.patch("pathlib.Path.is_file", return_value=True)

assert (images, labels) == EasySet.list_data_instances(class_roots)
103 changes: 69 additions & 34 deletions easyfsl/tests/methods/abstract_meta_learner_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import patch

import pytest
import torch
from torchvision.models import resnet18
Expand Down Expand Up @@ -30,21 +28,24 @@ def test_evaluate_on_one_task_gives_correct_output(
query_labels,
expected_correct,
expected_total,
mocker,
):
with patch("torch.Tensor.cuda", new=torch.Tensor.cpu):
with patch("easyfsl.methods.AbstractMetaLearner.forward") as mock_forward:
with patch("easyfsl.methods.AbstractMetaLearner.process_support_set"):
mock_forward.return_value = torch.tensor(5 * [[0.25, 0.75]]).cuda()
model = AbstractMetaLearner(resnet18())
assert (
model.evaluate_on_one_task(
support_images,
support_labels,
query_images,
query_labels,
)
== (expected_correct, expected_total)
)
mocker.patch("torch.Tensor.cuda", new=torch.Tensor.cpu)
mocker.patch(
"easyfsl.methods.AbstractMetaLearner.forward",
return_value=torch.tensor(5 * [[0.25, 0.75]]).cuda(),
)
mocker.patch("easyfsl.methods.AbstractMetaLearner.process_support_set")
model = AbstractMetaLearner(resnet18())
assert (
model.evaluate_on_one_task(
support_images,
support_labels,
query_images,
query_labels,
)
== (expected_correct, expected_total)
)


# pylint: enable=not-callable
Expand All @@ -66,20 +67,20 @@ def test_process_support_set_raises_error_when_not_implemented():

class TestAMLValidate:
@staticmethod
def test_validate_returns_accuracy():
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
mock_evaluate.return_value = 0.0
meta_learner = AbstractMetaLearner(resnet18())
assert meta_learner.validate(None) == 0.0
def test_validate_returns_accuracy(mocker):
mocker.patch("easyfsl.methods.AbstractMetaLearner.evaluate", return_value=0.0)
meta_learner = AbstractMetaLearner(resnet18())
assert meta_learner.validate(None) == 0.0

@staticmethod
def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy():
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
mock_evaluate.return_value = 0.5
meta_learner = AbstractMetaLearner(resnet18())
meta_learner.best_validation_accuracy = 0.1
meta_learner.validate(None)
assert meta_learner.best_model_state is not None
def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy(
mocker,
):
mocker.patch("easyfsl.methods.AbstractMetaLearner.evaluate", return_value=0.5)
meta_learner = AbstractMetaLearner(resnet18())
meta_learner.best_validation_accuracy = 0.1
meta_learner.validate(None)
assert meta_learner.best_model_state is not None

@staticmethod
@pytest.mark.parametrize(
Expand All @@ -91,10 +92,44 @@ def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy():
)
def test_validate_leaves_best_model_state_if_it_has_worse_validation_accuracy(
accuracy,
mocker,
):
mocker.patch(
"easyfsl.methods.AbstractMetaLearner.evaluate", return_value=accuracy
)
meta_learner = AbstractMetaLearner(resnet18())
meta_learner.best_validation_accuracy = 0.1
meta_learner.validate(None)
assert meta_learner.best_model_state is None

@staticmethod
@pytest.mark.parametrize(
"n_train_episodes,validation_frequency,expected_number_of_validations",
[
(5, 1, 5),
(5, 5, 1),
(5, 6, 0),
(5, 3, 1),
(6, 3, 2),
],
)
def test_validation_occurs_when_expected(
n_train_episodes, validation_frequency, expected_number_of_validations, mocker
):
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
mock_evaluate.return_value = accuracy
meta_learner = AbstractMetaLearner(resnet18())
meta_learner.best_validation_accuracy = 0.1
meta_learner.validate(None)
assert meta_learner.best_model_state is None
mocker.patch(
"easyfsl.methods.AbstractMetaLearner.fit_on_task", return_value=0.0
)
mocker.patch("easyfsl.methods.AbstractMetaLearner.validate")
spy_validate = mocker.spy(AbstractMetaLearner, "validate")

meta_learner = AbstractMetaLearner(resnet18())
train_loader = n_train_episodes * [(None, None, None, None, None)]

meta_learner.fit(
train_loader=train_loader,
optimizer=None,
val_loader=True,
validation_frequency=validation_frequency,
)

assert spy_validate.call_count == expected_number_of_validations
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="easyfsl",
version="0.2.0",
version="0.2.1",
description="Ready-to-use PyTorch code to boost your way into few-shot image classification",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit afb3155

Please sign in to comment.