Skip to content

Commit

Permalink
Merge pull request #454 from facebookresearch/fix_fsdp
Browse files Browse the repository at this point in the history
Fix FSDP support with pytorch 2.1.0
  • Loading branch information
JadeCopet authored May 2, 2024
2 parents 87af0bf + a2bf647 commit 795f8dc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Typo fixes.

Fixing setup.py to install only audiocraft, not the unit tests and scripts.

Fix FSDP support with PyTorch 2.1.0.

## [1.2.0] - 2024-01-11

Adding stereo models.
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.3.0a1'
__version__ = '1.3.0a2'
31 changes: 21 additions & 10 deletions audiocraft/optim/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,27 @@ def purge_fsdp(model: FSDP):
"""
from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
for module in FSDP.fsdp_modules(model):
handles = module._handles
if not handles:
continue
handle = handles[0]
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
true_list = [True for h in handles]
_reshard(module, handles, true_list)
if hasattr(module, "_handles"):
# support for FSDP with torch<2.1.0
handles = module._handles
if not handles:
continue
handle = handles[0]
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
true_list = [True for h in handles]
_reshard(module, handles, true_list)
else:
handle = module._handle
if not handle:
continue
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
_reshard(module, handle, True)


class _FSDPFixStateDict(FSDP):
Expand Down

0 comments on commit 795f8dc

Please sign in to comment.