-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Update] layer prefix to be set at model level (#1778)
* - Update `src/sparseml/modifiers/obcq/pytorch.py` to use layer prefix for from model - Remove `layer_prefix` from `SparseGPTModifier` base - Update ModelMetaData to include layer_prefix - Added a convenience function to update missing values in RecipeMetaData instance from another RecipeMetaData instance - Update simplify recipe to also include metadata - Update simplify_combine_recipes to include metadata - Add layer_prefix property to `ModifiableModel` - propagate `layer_prefix` to superclass - update session.py to set_layer_prefix on the model before initializing modifiers - Update example recipe to include layer_prefix in metadata * Add missing docstring * - address review comment - update docstring - add test for `update_missing_metadata` * Add test * Style * Fix tests * Style * [modifier refactor] Add constant pruning tests (#1752) * Initial commit * Add end to end tests * Add e2e tests for constant pruning modifier * Move imports inside the test fuctions so that torch isn't imported unless running the tests * Update setup.py to not run modifier tests unless pytorch is specified * [Bugfix] .dict() method on Recipe (#1753) * Bugfix .dict() method on Recipe * Remove extraneous local test, [faulty commit] * [modifier refactor] Add serialization tests (#1755) * Add serialization tests * Clean up * Keep original stage and group names Clean up _get_yaml_dict * fix comment * Typo * [Unit Tests][Modifier Refactor] (#1756) * Move valid recipes to a helper file Add tests for session.py * Increase test coverage of src/sparseml/core/session.py to 100% Run Style Add logs to .gitignore * Increase coverage of tests/sparseml/core/test_state.py to 100% * add tests for lifecycle/event.py * Increase code coverage of lifecycle/event to 100% * increase lifecycle/session.py code coverage to 93% * Address review comments from @Satrat * Address review comments on 1752 (#1772) Update makefile to only ignore *pytorch.py files in modifier dir Fix order in test Add regex to makefile Add helper function to determine if torch tests should be run Check masks Make transformers import optional in sparsegpt.py * Fix merge conflict * Add more tests to check valid modifiers are created (#1774) * [Bug][ConstantPruningModifier] Fix mask de register bug (#1773) * Fix mask de-register logic * forgot to remove commented out line * Move tests inside pytorch directory as requested * Fix session reset (#1790) * fix datasets version to be compatible with fsspec (#1797) * Add kvcache config for Mistral (#1766) * Add kvcache config for Mistral * Update configs.py * Update configs.py * Fix reset logic * Style after resolving merge conflicts --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
- Loading branch information
Showing
10 changed files
with
204 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytest | ||
|
||
from sparseml.core.recipe.metadata import ModelMetaData, RecipeMetaData | ||
|
||
|
||
class TestRecipeMetaData: | ||
@pytest.mark.parametrize( | ||
"self_metadata", | ||
[ | ||
dict(domain="cv", task="classification"), | ||
dict(), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"other_metadata", | ||
[ | ||
dict(domain="domain", task="segmentation", requirements=["torch>=1.6.0"]), | ||
dict( | ||
domain="cv", | ||
task="task", | ||
target_model=ModelMetaData(layer_prefix="something"), | ||
), | ||
], | ||
) | ||
def test_update_missing_metadata(self, self_metadata, other_metadata): | ||
|
||
metadata_a = RecipeMetaData(**self_metadata) | ||
metadata_b = RecipeMetaData(**other_metadata) | ||
|
||
metadata_a.update_missing_metadata(metadata_b) | ||
|
||
all_keys = set(self_metadata.keys()).union(other_metadata.keys()) | ||
|
||
# keys should not be overwritten | ||
# if they already exist | ||
for key in all_keys: | ||
if key in self_metadata: | ||
assert getattr(metadata_a, key) == self_metadata[key] | ||
elif key in other_metadata: | ||
assert getattr(metadata_a, key) == other_metadata[key] |