-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sparsification Refactor for LLMs (#1713)
* Initial start implementation * add in further completion state for session and events * add in recipe helper functions for merging, loading, and running callbacks * minor fixes for new framework * add constant pruning modifier * add magntitude pruning modifier * knowledge distillation implementation * fix import errors and multiframework inits * fix import errors and multiframework inits * initialization * RecipeModifiers working * fix import errors * modifiers loading in stages * adding test files * modifier factory implementation * running example, but sparsity not working correctly * fix polynomial scheduler, leave masks enabled on end * remove e2e files * add on_event for modifier lifecycle and add initial integration for torchvision * leave_enabled fixes * fixing evals and finalization * Add test * Add changes to allow accepting strings * fix recipe staging issue * style * style fixes * bug fixes that came up during obcq implementation --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com> Co-authored-by: rahul-tuli <rahul@neuralmagic.com>
- Loading branch information
1 parent
2d0f8a0
commit b9d6b70
Showing
67 changed files
with
5,160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
|
||
from .data import * | ||
from .event import * | ||
from .factory import * | ||
from .framework import * | ||
from .framework_object import * | ||
from .lifecycle import * | ||
from .model import * | ||
from .modifier import * | ||
from .optimizer import * | ||
from .recipe import * | ||
from .state import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
|
||
from .base import ModifiableData |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
from typing import Generic, TypeVar | ||
|
||
from sparseml.core.framework_object import MultiFrameworkObject | ||
|
||
|
||
__all__ = ["ModifiableData"] | ||
|
||
DT = TypeVar("DT") # Dataset Type | ||
|
||
|
||
@dataclass | ||
class ModifiableData(Generic[DT], MultiFrameworkObject): | ||
data: DT = None | ||
num_samples: int = None | ||
|
||
def get_num_batches(self) -> int: | ||
raise NotImplementedError() | ||
|
||
def set_batch_size(self, batch_size: int): | ||
raise NotImplementedError() | ||
|
||
def get_batch_size(self) -> int: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Mapping, Sequence | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from sparseml.core.data.base import ModifiableData | ||
|
||
|
||
__all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"] | ||
|
||
|
||
class DynamicBatchSizeDataLoader: | ||
def __init__(self, data_loader: DataLoader): | ||
self.data_loader = data_loader | ||
self.current_batch_size = data_loader.batch_size | ||
|
||
def __iter__(self): | ||
if self.current_batch_size == self.data_loader.batch_size: | ||
yield from self.data_loader | ||
elif self.current_batch_size < self.data_loader.batch_size: | ||
yield from self._data_split_iter() | ||
else: | ||
yield from self._data_merge_iter() | ||
|
||
def set_batch_size(self, batch_size: int): | ||
self.current_batch_size = batch_size | ||
|
||
def get_batch_size(self) -> int: | ||
return self.current_batch_size | ||
|
||
def _data_split_iter(self): | ||
if self.current_batch_size >= self.data_loader.batch_size: | ||
raise ValueError( | ||
"Current batch size must be less than the original batch size" | ||
) | ||
|
||
for batch in self.data_loader: | ||
num_splits = self.data_loader.batch_size // self.current_batch_size | ||
for i in range(num_splits): | ||
start_idx = i * self.current_batch_size | ||
end_idx = (i + 1) * self.current_batch_size | ||
yield DynamicBatchSizeDataLoader.split_batch(batch, start_idx, end_idx) | ||
|
||
def _data_merge_iter(self): | ||
if self.current_batch_size <= self.data_loader.batch_size: | ||
raise ValueError( | ||
"Current batch size must be greater than the original batch size" | ||
) | ||
|
||
buffer = [] | ||
buffer_size = 0 | ||
for batch in self.data_loader: | ||
buffer.append(batch) | ||
buffer_size += len(batch) | ||
while buffer_size >= self.current_batch_size: | ||
merged = DynamicBatchSizeDataLoader.merge_batches(buffer) | ||
yield DynamicBatchSizeDataLoader.split_batch( | ||
merged, 0, self.current_batch_size | ||
) | ||
buffer = [ | ||
DynamicBatchSizeDataLoader.split_batch( | ||
merged, self.current_batch_size, buffer_size | ||
) | ||
] | ||
buffer_size -= self.current_batch_size | ||
|
||
@staticmethod | ||
def split_batch(batch, start_idx, end_idx): | ||
""" | ||
Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided | ||
indices. | ||
""" | ||
if isinstance(batch, torch.Tensor): | ||
return batch[start_idx:end_idx] | ||
elif isinstance(batch, Mapping): | ||
return { | ||
key: DynamicBatchSizeDataLoader.split_batch(value, start_idx, end_idx) | ||
for key, value in batch.items() | ||
} | ||
elif isinstance(batch, Sequence): | ||
return [ | ||
DynamicBatchSizeDataLoader.split_batch(item, start_idx, end_idx) | ||
for item in batch | ||
] | ||
else: | ||
raise TypeError(f"Unsupported batch type: {type(batch)}") | ||
|
||
@staticmethod | ||
def merge_batches(batches): | ||
""" | ||
Merges a sequence of batches into a single batch. | ||
""" | ||
sample_batch = batches[0] | ||
if isinstance(sample_batch, torch.Tensor): | ||
return torch.cat(batches, dim=0) | ||
elif isinstance(sample_batch, Mapping): | ||
return { | ||
key: DynamicBatchSizeDataLoader.merge_batches( | ||
[batch[key] for batch in batches] | ||
) | ||
for key in sample_batch.keys() | ||
} | ||
elif isinstance(sample_batch, Sequence): | ||
return [ | ||
DynamicBatchSizeDataLoader.merge_batches( | ||
[batch[i] for batch in batches] | ||
) | ||
for i in range(len(sample_batch)) | ||
] | ||
else: | ||
raise TypeError(f"Unsupported batch type: {type(sample_batch)}") | ||
|
||
|
||
class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]): | ||
def __init__(self, data_loader: DataLoader, framework=None): | ||
super().__init__() | ||
self.data = DynamicBatchSizeDataLoader(data_loader) | ||
|
||
def get_num_batches(self) -> int: | ||
return self.num_samples // self.data.get_batch_size() | ||
|
||
def set_batch_size(self, batch_size: int): | ||
self.data.set_batch_size(batch_size) | ||
|
||
def get_batch_size(self) -> int: | ||
return self.data.get_batch_size() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
|
||
__all__ = [ | ||
"EventType", | ||
"Event", | ||
] | ||
|
||
|
||
class EventType(Enum): | ||
# training lifecycle | ||
PRE_INIT = "pre_init" | ||
INITIALIZE = "initialize" | ||
FINALIZE = "finalize" | ||
|
||
# batch lifecycle | ||
BATCH_START = "batch_start" | ||
LOSS_CALCULATED = "loss_calculated" | ||
BATCH_END = "batch_end" | ||
|
||
# step lifecycle | ||
OPTIM_PRE_STEP = "optim_pre_step" | ||
OPTIM_POST_STEP = "optim_post_step" | ||
|
||
def order(self) -> int: | ||
if self == EventType.PRE_INIT: | ||
return 0 | ||
elif self == EventType.INITIALIZE: | ||
return 10 | ||
elif self == EventType.FINALIZE: | ||
return 20 | ||
elif self == EventType.BATCH_START: | ||
return 100 | ||
elif self == EventType.LOSS_CALCULATED: | ||
return 110 | ||
elif self == EventType.OPTIM_PRE_STEP: | ||
return 120 | ||
elif self == EventType.OPTIM_POST_STEP: | ||
return 130 | ||
elif self == EventType.BATCH_END: | ||
return 140 | ||
else: | ||
raise ValueError(f"invalid event type {self}") | ||
|
||
|
||
@dataclass | ||
class Event: | ||
type_: EventType = None | ||
|
||
steps_per_epoch: int = None | ||
batches_per_step: int = None | ||
invocations_per_step: int = None | ||
|
||
global_step: int = 0 | ||
global_batch: int = 0 | ||
|
||
@property | ||
def epoch_based(self) -> bool: | ||
return self.steps_per_epoch is not None | ||
|
||
@property | ||
def epoch(self) -> int: | ||
return self.global_step // self.steps_per_epoch | ||
|
||
@property | ||
def epoch_full(self) -> float: | ||
return self.global_step / float(self.steps_per_epoch) | ||
|
||
@property | ||
def epoch_step(self) -> int: | ||
return self.global_step % self.steps_per_epoch | ||
|
||
@property | ||
def epoch_batch(self) -> int: | ||
batches_per_epoch = ( | ||
self.steps_per_epoch * self.batches_per_step | ||
if self.batches_per_step | ||
else self.steps_per_epoch | ||
) | ||
|
||
return self.global_batch % batches_per_epoch | ||
|
||
@property | ||
def current_index(self) -> float: | ||
if not self.epoch_based: | ||
return self.global_step | ||
|
||
if self.epoch_full - self.epoch > 1.0: | ||
raise ValueError("too many steps per epoch for epoch based event") | ||
|
||
return self.epoch_full | ||
|
||
@current_index.setter | ||
def current_index(self, value: float): | ||
if not self.epoch_based: | ||
self.global_step = int(value) | ||
self.global_batch = ( | ||
self.global_step | ||
if self.batches_per_step is None or self.batches_per_step < 2 | ||
else self.global_step * self.batches_per_step | ||
) | ||
return | ||
|
||
self.global_step = int(value * self.steps_per_epoch) | ||
self.global_batch = ( | ||
self.global_step | ||
if self.batches_per_step is None or self.batches_per_step < 2 | ||
else self.global_step * self.batches_per_step | ||
) | ||
|
||
def should_update( | ||
self, start: Optional[float], end: Optional[float], update: float | ||
): | ||
current = self.current_index | ||
|
||
if start is not None and current < start: | ||
return False | ||
|
||
if end is not None and current > end: | ||
return False | ||
|
||
return update is None or update <= 0.0 or current % update < 1e-10 | ||
|
||
def new_instance(self, **kwargs) -> "Event": | ||
instance = deepcopy(self) | ||
for key, value in kwargs.items(): | ||
setattr(instance, key, value) | ||
|
||
return instance |
Oops, something went wrong.