Skip to content

Commit

Permalink
[FIX] callback on unfreezing layers after milestones
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Jan 31, 2024
1 parent 52d4eb9 commit b57ba81
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import pytorch_lightning as pl

from torch.optim.optimizer import Optimizer
from pytorch_lightning.callbacks.finetuning import BaseFinetuning


# See https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/computer_vision_fine_tuning.py
class MilestonesFinetuning(BaseFinetuning):
def __init__(self, milestones: int = 100):
def __init__(self, milestones: int = 10):
super().__init__()
self.milestones = milestones
self.unfreeze_at_epoch = milestones

def freeze_before_training(self, pl_module: pl.LightningModule):
self.freeze(modules=pl_module.model)
# Freeze all parameters initially
for param in pl_module.model.parameters():
param.requires_grad = False

# Unfreeze the last layer's parameters
print("[INFO] Unfreezing the last layer of the model")
last_layer = list(pl_module.model.children())[-1]
# If the last layer is a container, unfreeze its last layer
if hasattr(last_layer, 'children') and list(last_layer.children()):
last_sublayer = list(last_layer.children())[-1]
for param in last_sublayer.parameters():
param.requires_grad = True
else:
for param in last_layer.parameters():
param.requires_grad = True

def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int
):
if epoch == self.milestones:
# unfreeze BEATs
# Unfreeze the entire model at the specified epoch
if epoch == self.unfreeze_at_epoch:
print("[INFO] Unfreezing all the parameters of the model")
for param in pl_module.model.parameters():
param.requires_grad = True
self.unfreeze_and_add_param_group(
modules=pl_module.model, optimizer=optimizer
)

0 comments on commit b57ba81

Please sign in to comment.