diff --git a/unimol_tools/unimol_tools/models/nnmodel.py b/unimol_tools/unimol_tools/models/nnmodel.py index b5da71f..a209dc8 100644 --- a/unimol_tools/unimol_tools/models/nnmodel.py +++ b/unimol_tools/unimol_tools/models/nnmodel.py @@ -103,8 +103,16 @@ def _init_model(self, model_name, **params): :return: An instance of the specified neural network model. :raises ValueError: If the model name is not recognized. """ + freeze_layers = params.get('freeze_layers', None) + freeze_layers_reversed = params.get('freeze_layers_reversed', False) if model_name in NNMODEL_REGISTER: model = NNMODEL_REGISTER[model_name](**params) + if isinstance(freeze_layers, str): + freeze_layers = freeze_layers.replace(' ', '').split(',') + if isinstance(freeze_layers, list): + for layer_name, layer_param in model.named_parameters(): + should_freeze = any(layer_name.startswith(freeze_layer) for freeze_layer in freeze_layers) + layer_param.requires_grad = not (freeze_layers_reversed ^ should_freeze) else: raise ValueError('Unknown model: {}'.format(self.model_name)) return model diff --git a/unimol_tools/unimol_tools/train.py b/unimol_tools/unimol_tools/train.py index 0b0eeba..57f3c1f 100644 --- a/unimol_tools/unimol_tools/train.py +++ b/unimol_tools/unimol_tools/train.py @@ -41,6 +41,8 @@ def __init__(self, max_norm=5.0, use_cuda=True, use_amp=True, + freeze_layers=None, + freeze_layers_reversed=False, **params, ): """ @@ -80,6 +82,8 @@ def __init__(self, :param max_norm: float, default=5.0, max norm of gradient clipping. :param use_cuda: bool, default=True, whether to use GPU. :param use_amp: bool, default=True, whether to use automatic mixed precision. + :param freeze_layers: str or list, frozen layers by startwith name list. ['encoder', 'gbf'] will freeze all the layers whose name start with 'encoder' or 'gbf'. + :param freeze_layers_reversed: bool, default=False, inverse selection of frozen layers :param params: dict, default=None, other parameters. """ @@ -105,6 +109,8 @@ def __init__(self, config.max_norm = max_norm config.use_cuda = use_cuda config.use_amp = use_amp + config.freeze_layers = freeze_layers + config.freeze_layers_reversed = freeze_layers_reversed self.save_path = save_path self.config = config