Skip to content

Commit

Permalink
[BugFix] Model State Reload with Quantized Stubs in SparseAutoModelFo…
Browse files Browse the repository at this point in the history
…rCausalLM (#2226)

* Fix bug for loading models from hf hub

* Update to download only relevant files and not the whole model repo

* Add py files to relevant suffixes
  • Loading branch information
rahul-tuli authored Apr 5, 2024
1 parent fd0a779 commit 88196d5
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 13 deletions.
16 changes: 4 additions & 12 deletions src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
modify_save_pretrained,
)
from sparseml.transformers.sparsification.modification import modify_model
from sparseml.transformers.utils.helpers import resolve_recipe
from sparseml.utils import download_zoo_training_dir
from sparseml.utils.fsdp.context import main_process_first_context
from sparseml.transformers.utils.helpers import download_model_directory, resolve_recipe


__all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"]
Expand Down Expand Up @@ -101,15 +99,9 @@ def skip(*args, **kwargs):
else pretrained_model_name_or_path
)

if pretrained_model_name_or_path.startswith("zoo:"):
_LOGGER.debug(
"Passed zoo stub to SparseAutoModelForCausalLM object. "
"Loading model from SparseZoo training files..."
)
with main_process_first_context():
pretrained_model_name_or_path = download_zoo_training_dir(
zoo_stub=pretrained_model_name_or_path
)
pretrained_model_name_or_path = download_model_directory(
pretrained_model_name_or_path, **kwargs
)

# determine compression format, if any, from the model config
compressor = infer_compressor_from_model_config(pretrained_model_name_or_path)
Expand Down
98 changes: 97 additions & 1 deletion src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import PaddingStrategy

from huggingface_hub import HUGGINGFACE_CO_URL_HOME, hf_hub_download
from huggingface_hub import HUGGINGFACE_CO_URL_HOME, HfFileSystem, hf_hub_download
from sparseml.export.helpers import ONNX_MODEL_NAME
from sparseml.utils import download_zoo_training_dir
from sparseml.utils.fsdp.context import main_process_first_context
from sparsezoo import Model, setup_model


Expand All @@ -52,6 +54,8 @@
"ALL_TASK_NAMES",
"create_fake_dataloader",
"POSSIBLE_TOKENIZER_FILES",
"download_repo_from_huggingface_hub",
"download_model_directory",
]


Expand Down Expand Up @@ -92,6 +96,7 @@ class TaskNames(Enum):
"special_tokens_map.json",
"tokenizer_config.json",
}
RELEVANT_HF_SUFFIXES = ["json", "md", "bin", "safetensors", "yaml", "yml", "py"]


def remove_past_key_value_support_from_config(config: AutoConfig) -> AutoConfig:
Expand Down Expand Up @@ -553,3 +558,94 @@ def fetch_recipe_path(target: str):
recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME)

return recipe_path


def download_repo_from_huggingface_hub(repo_id, **kwargs):
"""
Download relevant model files from the Hugging Face Hub
using the huggingface_hub.hf_hub_download function
Note(s):
- Does not download the entire repo, only the relevant files
for the model, such as the model weights, tokenizer files, etc.
- Does not re-download files that already exist locally, unless
the force_download flag is set to True
:pre-condition: the repo_id must be a valid Hugging Face Hub repo id
:param repo_id: the repo id to download
:param kwargs: additional keyword arguments to pass to hf_hub_download
"""
hf_filesystem = HfFileSystem()
files = hf_filesystem.ls(repo_id)

if not files:
raise ValueError(f"Could not find any files in HF repo {repo_id}")

# All file(s) from hf_filesystem have "name" key
# Extract the file names from the files
relevant_file_names = (
Path(file["name"]).name
for file in files
if any(file["name"].endswith(suffix) for suffix in RELEVANT_HF_SUFFIXES)
)

hub_kwargs_names = (
"subfolder",
"repo_type",
"revision",
"library_name",
"library_version",
"cache_dir",
"local_dir",
"local_dir_use_symlinks",
"user_agent",
"force_download",
"force_filename",
"proxies",
"etag_timeout",
"resume_download",
"token",
"local_files_only",
"headers",
"legacy_cache_layout",
"endpoint",
)
hub_kwargs = {name: kwargs[name] for name in hub_kwargs_names if name in kwargs}

for file_name in relevant_file_names:
last_file = hf_hub_download(repo_id=repo_id, filename=file_name, **hub_kwargs)

# parent directory of the last file is the model directory
return str(Path(last_file).parent.resolve().absolute())


def download_model_directory(pretrained_model_name_or_path: str, **kwargs):
"""
Download the model directory from the HF hub or SparseZoo if the model
is not found locally
:param pretrained_model_name_or_path: the name of or path to the model to load
can be a SparseZoo/HuggingFace model stub
:param kwargs: additional keyword arguments to pass to the download function
:return: the path to the downloaded model directory
"""
pretrained_model_path: Path = Path(pretrained_model_name_or_path)

if pretrained_model_path.exists():
_LOGGER.debug(
"Model directory already exists locally.",
)
return pretrained_model_name_or_path

with main_process_first_context():
if pretrained_model_name_or_path.startswith("zoo:"):
_LOGGER.debug(
"Passed zoo stub to SparseAutoModelForCausalLM object. "
"Loading model from SparseZoo training files..."
)
return download_zoo_training_dir(zoo_stub=pretrained_model_name_or_path)

_LOGGER.debug("Downloading model from HuggingFace Hub.")
return download_repo_from_huggingface_hub(
repo_id=pretrained_model_name_or_path, **kwargs
)

0 comments on commit 88196d5

Please sign in to comment.