Skip to content

Commit

Permalink
Avoid loading model weights before recipe application if any
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Apr 8, 2024
1 parent 5aae81b commit 6323062
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
13 changes: 9 additions & 4 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,18 @@ def log_model_load(
)


def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path: str):
def apply_recipe_structure_to_model(
model: Module, recipe_path: str, model_path: str, reload_weights=True
):
"""
Takes a loaded Pytorch model and applies any structural changes such as quantization
to the model, then reloads the model.
:param model: PyTorch model to apply structure to
:param recipe_path: path to recipe to apply to the model
:param model_path: path to model, used for reloading the state dict
:param reload_weights: flag to reload the weights after applying the recipe.
Dafault is True.
"""
orig_state_dict = model.state_dict()

Expand All @@ -121,22 +125,23 @@ def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path:
_LOGGER.info(f"Applied {msg} to the model at {model_path}")

# reload the state dict for the model now that architecture matches expected
if reload_model_state(model, model_path, orig_state_dict):
if reload_weights and reload_model_state(model, model_path, orig_state_dict):
_LOGGER.info(
"Reloaded model state after SparseML recipe structure modifications "
f"from {model_path}"
)


def reload_model_state(
model: Module, load_path: str, orig_state_dict: Dict[str, Any]
model: Module, load_path: str, orig_state_dict: Dict[str, Any], force_reload=False
) -> bool:
"""
Reload the weights after model architecture changes due to recipe application.
:param model: PyTorch model to reload
:param load_path: path to model
:param orig_state_dict: state dict of model
:param force_reload: flag to force reload the weights. Default is False.
:return: True if weights are successfully reloaded; False otherwise.
"""
invalid_load_path = not load_path or not os.path.isdir(load_path)
Expand All @@ -163,7 +168,7 @@ def reload_model_state(

current_state_dict = model.state_dict()

if set(orig_state_dict.keys()) == set(current_state_dict):
if not force_reload and set(orig_state_dict.keys()) == set(current_state_dict):
# no change in keys, ignore reload
return False

Expand Down
26 changes: 23 additions & 3 deletions src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
from sparseml.pytorch.model_load.helpers import (
apply_recipe_structure_to_model,
log_model_load,
reload_model_state,
)
from sparseml.transformers.compression.utils import (
get_safetensors_folder,
infer_compressor_from_model_config,
modify_save_pretrained,
)
from sparseml.transformers.sparsification.modification import modify_model
from sparseml.transformers.sparsification.sparse_config import SparseAutoConfig
from sparseml.transformers.utils.helpers import download_model_directory, resolve_recipe


Expand Down Expand Up @@ -111,9 +113,12 @@ def skip(*args, **kwargs):
logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = logger.getEffectiveLevel()
logger.setLevel(level=logging.ERROR)
model = super(AutoModelForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)

config = SparseAutoConfig.from_pretrained(pretrained_model_name_or_path)

# instantiate model without loading weights
model = super(AutoModelForCausalLM, cls).from_config(config)

logger.setLevel(level=restore_log_level)
model = modify_model(model)
# override the PreTrainedModel instance with compression save function
Expand All @@ -130,12 +135,27 @@ def skip(*args, **kwargs):
compressor.overwrite_weights(model_path=model_path, model=model)

recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path)

# this must be done before recipe is applied
original_state_dict = model.state_dict()

if recipe:
apply_recipe_structure_to_model(
model=model,
model_path=pretrained_model_name_or_path,
recipe_path=recipe,
reload_weights=False,
)

# Load the model weights
if reload_model_state(
model, pretrained_model_name_or_path, original_state_dict, force_reload=True
):
_LOGGER.info(
"Loaded model state after SparseML recipe structure modifications "
f"from {pretrained_model_name_or_path}"
)

return model


Expand Down

0 comments on commit 6323062

Please sign in to comment.