Skip to content

Commit

Permalink
docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Sep 11, 2023
1 parent c19d386 commit 83b23f7
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 129 deletions.
14 changes: 14 additions & 0 deletions src/sparseml/experimental/sparsegpt/opt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
import torch.nn as nn
Expand Down
64 changes: 41 additions & 23 deletions src/sparseml/transformers/sparsification/obcq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,31 @@
# limitations under the License.

import random
from typing import List, Tuple

import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.nn import Module
from transformers import AutoTokenizer, GPT2Tokenizer


__all__ = ["get_wikitext2", "get_ptb", "get_c4"]

# TODO: update these to PyTorch dataloaders

def get_wikitext2(
nsamples: int, seed: int, seqlen: int, model: Module
) -> Tuple[List, GPT2Tokenizer]:
"""
load nsamples of tokenized data from the wikitext2 dataset of length seqlen
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)


def get_wikitext2(nsamples, seed, seqlen, model):
:param nsamples: number of samples to load
:param seed: seed to use for selecting random samples from dataset
:param seqlen: sequence length of each sample
:param model: trained pytorch module to load tokenizer from
:return: list of random samples from wikitext and tokenizer
"""
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(" ".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

import random

random.seed(seed)
trainloader = []
Expand All @@ -49,16 +48,24 @@ def get_wikitext2(nsamples, seed, seqlen, model):
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc, tokenizer
return trainloader, tokenizer


def get_ptb(nsamples, seed, seqlen, model):
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
def get_ptb(
nsamples: int, seed: int, seqlen: int, model: Module
) -> Tuple[List, GPT2Tokenizer]:
"""
load nsamples of tokenized data from the ptb dataset of length seqlen
:param nsamples: number of samples to load
:param seed: seed to use for selecting random samples from dataset
:param seqlen: sequence length of each sample
:param model: trained pytorch module to load tokenizer from
:return: list of random samples from ptb and tokenizer
"""
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")

random.seed(seed)
trainloader = []
Expand All @@ -69,10 +76,21 @@ def get_ptb(nsamples, seed, seqlen, model):
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc, tokenizer
return trainloader, tokenizer


def get_c4(
nsamples: int, seed: int, seqlen: int, model: Module
) -> Tuple[List, GPT2Tokenizer]:
"""
load nsamples of tokenized data from the c4 dataset of length seqlen
def get_c4(nsamples, seed, seqlen, model):
:param nsamples: number of samples to load
:param seed: seed to use for selecting random samples from dataset
:param seqlen: sequence length of each sample
:param model: trained pytorch module to load tokenizer from
:return: list of random samples from c4 and tokenizer
"""
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
Expand Down Expand Up @@ -112,4 +130,4 @@ def __init__(self, input_ids):

valenc = TokenizerWrapper(valenc)

return trainloader, valenc, tokenizer
return trainloader, tokenizer
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import inspect
import logging
from typing import Dict, List, Tuple

import torch
Expand Down Expand Up @@ -157,6 +157,7 @@ def _sequentially_compress(self, **kwargs):
nsamples = self.inputs.shape[0]
for name in order:
gpts = SparseGPT(subset[name])

def add_batch(name):
def tmp(_, inp, out):
gpts.add_batch(inp[0].data, out.data)
Expand Down Expand Up @@ -228,4 +229,4 @@ def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
return res
37 changes: 13 additions & 24 deletions src/sparseml/transformers/sparsification/obcq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from torch.nn import Module
from torch.utils.data import DataLoader

from sparseml.optim import (
BaseManager,
Expand All @@ -37,8 +36,15 @@

class RecipeManagerOneShot(BaseManager):
"""
The base recipe manager, handles managing multiple Modifiers.
Entire current lifecycle is contained within a call to one_shot()
Recipe manager for handling multiple Modifiers called in a one-shot fashion. Call
one_shot() to run initialize() for each modifier in recipe.yaml, followed
by finalize() for each initialized modifier
Life-cycle:
- from_yaml(recipe.yaml)
- one_shot(model, dataloader)
- initialize
- finalize
"""

@staticmethod
Expand All @@ -65,7 +71,7 @@ def from_yaml(
in the recipe with (i.e. num_epochs, init_lr)
:metadata: additional (to the information provided in the recipe) data to be
preserved and utilized in the future - for reproducibility and completeness.
:return: ScheduledModifierManager() created from the recipe file
:return: RecipeManagerOneShot() created from the recipe file
"""
recipe_variables = parse_recipe_variables(recipe_variables)
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
Expand Down Expand Up @@ -106,6 +112,7 @@ def one_shot(
_LOGGER.warning("No GPU available, falling back to CPU")
module.to(device)

# used by SparseGPTModifier for OBCQ algorithm
initialize_kwargs = {"calibration_dataloader": data_loader, "device": device}

self.initialize(module, **initialize_kwargs)
Expand All @@ -117,34 +124,16 @@ def initialize(
**kwargs,
):
"""
Handles any initialization of the manager for the given model.
Initializes all modifiers for the given model.
:param model: the ONNX model to modify
:param model: the model to modify
:param kwargs: Optional kwargs to support specific arguments
for individual modifiers.
"""

for mod in self.iter_modifiers():
mod.initialize(module, **kwargs)

def update(
self,
module: Module,
data_loader: Optional[DataLoader] = None,
):
"""
Handles updating the contained modifiers' states or model
:param model: model to modify
:param data_loader": data loader to be used by modifier
"""

for mod in self.iter_modifiers():
if not mod.enabled:
continue

mod.update(module, data_loader)

def finalize(self, module: Module = None):
"""
Handles any finalization of the modifier for the given model.
Expand Down
25 changes: 24 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


__all__ = ["load_opt_model"]

def load_opt_model(model_path):

def load_opt_model(model_path: str) -> torch.nn.Module:
"""
Load a pretrained OPT model from the specified hugging face path
:param model_path: hugging face path to model
:return: loaded pretrained model
"""

def skip(*args, **kwargs):
pass

Expand Down
33 changes: 14 additions & 19 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from pathlib import Path
from typing import Optional

from transformers import OPTForCausalLM

from sparseml.optim.helpers import load_recipe_yaml_str
from sparseml.transformers.sparsification.obcq.data import (
get_c4,
Expand All @@ -37,16 +35,6 @@
SUPPORTED_MODELS = ["opt"]


def _save(model, tokenizer, save_path, recipe_path):
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))
recipe_output_path = os.path.join(save_path, "recipe.yaml")
with open(recipe_output_path, "w") as fp:
fp.write(load_recipe_yaml_str(recipe_path))


def one_shot(
model_path: str,
dataset_name: str,
Expand All @@ -70,7 +58,6 @@ def one_shot(
if deploy_dir.exists():
raise RuntimeError(f"deploy_dir={deploy_dir} already exists")

# TODO: don't hardcode this for OPT
model_loader_fn = None
if "opt" in model_path:
model_loader_fn = load_opt_model
Expand All @@ -92,7 +79,7 @@ def one_shot(
f"dataset_name={dataset_name} should be one of {SUPPORTED_DATASETS}"
)

calibration_data, test_encoder, tokenizer = data_loader_fn(
calibration_data, tokenizer = data_loader_fn(
num_samples, 0, model.seqlen, model_path
)

Expand All @@ -102,20 +89,28 @@ def one_shot(
_save(model, tokenizer, deploy_dir, recipe_file)


def _save(model, tokenizer, save_path, recipe_path):
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))
recipe_output_path = os.path.join(save_path, "recipe.yaml")
with open(recipe_output_path, "w") as fp:
fp.write(load_recipe_yaml_str(recipe_path))


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"model", type=str, help="OPT model to load; pass `facebook/opt-X`."
)
parser.add_argument("model", type=str, help="Hugging Face stub of model to load")
parser.add_argument(
"dataset",
type=str,
choices=["wikitext2", "ptb", "c4"],
help="Where to extract calibration data from.",
help="Name of dataset to extract calibration data from",
)
parser.add_argument(
"--nsamples", type=int, default=128, help="Number of calibration data samples."
"--nsamples", type=int, default=128, help="Number of calibration data samples"
)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--deploy-dir", type=str, default=".")
Expand Down
Loading

0 comments on commit 83b23f7

Please sign in to comment.