Skip to content

Commit

Permalink
Fix Sparsity Logs on FSDP Model Save (#2203)
Browse files Browse the repository at this point in the history
* fix for reporting sparsity of FSDP models

* docstrings

* quality
  • Loading branch information
Sara Adkins authored Apr 10, 2024
1 parent e9a6866 commit 2de5c92
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 50 deletions.
22 changes: 17 additions & 5 deletions src/sparseml/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
Expand Down Expand Up @@ -57,13 +58,22 @@ class ModuleSparsificationInfo:
and quantization
:param module: torch Module to analyze
:param state_dict: optional state_dict to analyze in place of the torch model. This
is used when analyzing an FSDP model, where the full weights may not be accessible
"""

def __init__(self, module: Module):
def __init__(
self, module: Module, state_dict: Optional[Dict[str, torch.Tensor]] = None
):
self.module = module
self.trainable_params = list(
filter(lambda param: param.requires_grad, self.module.parameters())
)
self.state_dict = state_dict

if self.state_dict is not None:
self.trainable_params = [param for _, param in state_dict.items()]
else:
self.trainable_params = list(
filter(lambda param: param.requires_grad, self.module.parameters())
)

def __str__(self):
return json.dumps(
Expand Down Expand Up @@ -124,7 +134,9 @@ def params_prunable_sparse(self) -> int:
"""
return sum(
round(tensor_sparsity(layer.weight).item() * torch.numel(layer.weight))
for (name, layer) in get_prunable_layers(self.module)
for (name, layer) in tqdm(
get_prunable_layers(self.module), desc="Calculating model sparsity"
)
)

@property
Expand Down
32 changes: 25 additions & 7 deletions src/sparseml/transformers/compression/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Dict, Optional

from pydantic import BaseModel
from torch import Tensor
from torch.nn import Module

import sparseml.core.session as session_manager
Expand All @@ -40,14 +41,19 @@ class CompressionConfig(RegistryMixin, BaseModel):
sparsity_structure: Optional[str] = "unstructured"

@staticmethod
def infer_global_sparsity(model: Module) -> float:
def infer_global_sparsity(
model: Module, state_dict: Optional[Dict[str, Tensor]] = None
) -> float:
"""
Calculates the global percentage of sparse zero weights in the model
:param model: pytorch model to infer sparsity of
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:return: global sparsity of model
"""
info = ModuleSparsificationInfo(model)

info = ModuleSparsificationInfo(model, state_dict=state_dict)
global_sparsity = info.params_sparse_percent
return global_sparsity

Expand Down Expand Up @@ -75,17 +81,23 @@ def infer_sparsity_structure() -> str:

@staticmethod
def infer_config_from_model(
model: Module, compress: bool = False
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
) -> Optional["CompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
:param model: pytorch model to calculate sparsity config for
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:return: compression config inferred from the model
"""

global_sparsity = CompressionConfig.infer_global_sparsity(model)
global_sparsity = CompressionConfig.infer_global_sparsity(
model, state_dict=state_dict
)

if global_sparsity < 0.05:
return None
Expand All @@ -102,11 +114,17 @@ def infer_config_from_model(
sparsity_structure=sparsity_structure,
)

def fill_config_details(self, model: Module):
def fill_config_details(
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
):
"""
Fills in informational sparsity parameters from a given model
:param model: pytorch model to infer config parameters from
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
"""
self.global_sparsity = CompressionConfig.infer_global_sparsity(model)
self.global_sparsity = CompressionConfig.infer_global_sparsity(
model, state_dict=state_dict
)
self.sparsity_structure = CompressionConfig.infer_sparsity_structure()
8 changes: 4 additions & 4 deletions src/sparseml/transformers/compression/utils/compress_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def save_pretrained_wrapper(
:param kwargs: additional kwargs to pass on to model.save_pretrained
"""
model = model_ref()
# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.get("state_dict", None)

if qat_active(model):
_LOGGER.info(
Expand All @@ -86,7 +88,7 @@ def save_pretrained_wrapper(
)

if sparsity_config is not None:
sparsity_config.fill_config_details(model)
sparsity_config.fill_config_details(model, state_dict=state_dict)
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
_LOGGER.info(
Expand All @@ -96,7 +98,7 @@ def save_pretrained_wrapper(
"skip_compression_stats=True"
)
sparsity_config = CompressionConfig.infer_config_from_model(
model, compress=save_compressed
model, state_dict=state_dict, compress=save_compressed
)

if sparsity_config is None:
Expand All @@ -111,8 +113,6 @@ def save_pretrained_wrapper(
sparsity_config.format, config=sparsity_config
)

# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.get("state_dict", None)
if state_dict is None:
state_dict = model.state_dict()

Expand Down
61 changes: 35 additions & 26 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,7 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
self.accelerator.wait_for_everyone()

# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.maybe_log_model_sparsification()
self.accelerator.wait_for_everyone()

return output
Expand Down Expand Up @@ -433,11 +429,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
)

# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.maybe_log_model_sparsification()
self.accelerator.wait_for_everyone()

def save_model(
Expand Down Expand Up @@ -479,17 +471,36 @@ def save_model(
if not self.recipe:
return

# save recipe, will contain modifiers from the model's original recipe as well
# as those added from self.recipe
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
session = session_manager.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)
if self.accelerator.is_main_process:
# save recipe, will contain modifiers from the model's original recipe as
# well as those added from self.recipe
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
session = session_manager.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

_LOGGER.info(f"Saved SparseML recipe with model state to {recipe_path}")

_LOGGER.info(f"Saved SparseML recipe with model state to {recipe_path}")
self.accelerator.wait_for_everyone()

def maybe_log_model_sparsification(self):
"""
Log info on model sparsity and quantization if possible. Only print logs on the
main process, and avoid logging for quantized FSDP models
"""
with summon_full_params_context(self.model, offload_to_cpu=True):
# offload to avoid OOM errors
if not self.accelerator.is_main_process:
# only calculate stats rank0 GPU
return
if self.is_fsdp_enabled and qat_active(self.model):
# due to state dict changes we can't log sparsity info with quantized
# models in FSDP
return

self.log_model_sparsification()

def log_model_sparsification(self):
"""
Log the current model sparsification info including pruned and quantized states
Expand All @@ -499,18 +510,16 @@ def log_model_sparsification(self):
_LOGGER.info(
f"Sparsification info for {self.model_state_path}: "
f"{sparsification_info.params_total} total params. "
f"Of those there are {sparsification_info.params_prunable_total} prunable "
)
_LOGGER.info(
f"There are {sparsification_info.params_prunable_total} prunable "
f"params which have {sparsification_info.params_prunable_sparse_percent} "
"avg sparsity."
)
model_type = (
"sparse"
if sparsification_info.params_prunable_sparse_percent > 5
else "dense"
)
_LOGGER.info(
f"{model_type} model detected, "
f"all sparsification info: {sparsification_info}"
f"There are {sparsification_info.params_quantizable} quantizable "
f"params, with a quantization percentage of "
f"{sparsification_info.params_quantized_percent}."
)

def _prepare_model_for_fsdp(self):
Expand Down
19 changes: 11 additions & 8 deletions src/sparseml/utils/fsdp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,17 @@ def save_pretrained_fsdp(
):
state_dict = accelerator.get_state_dict(model, unwrap=False)

accelerator.unwrap_model(model).save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
save_compressed=save_compressed,
safe_serialization=save_safetensors,
)
if accelerator.is_main_process:
accelerator.unwrap_model(model).save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
save_compressed=save_compressed,
safe_serialization=save_safetensors,
)

accelerator.wait_for_everyone()


def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]:
Expand Down

0 comments on commit 2de5c92

Please sign in to comment.