diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 8b6ae1e5e0..5e8052ffe7 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -17,6 +17,7 @@ from sparseml.modifiers.utils import SPARSITY_THRESHOLD from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper +from torch.nn import Parameter try: import transformers @@ -119,6 +120,7 @@ def fasterprune( self.H[dead, dead] = 1 W[:, dead] = 0 + # Or read from self.layer.quantization_scheme if actorder: perm = torch.argsort(torch.diag(self.H), descending=True) W = W[:, perm] @@ -135,15 +137,6 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H - - g_idx = [] - if actorder: - g_idx = [perm[i] // quant_scheme.weights.group_size for i in range(self.columns)] - g_idx = g_idx[invperm] - else: - g_idx = [i // quant_scheme.weights.group_size for i in range(self.columns)] - g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device) - # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -154,15 +147,6 @@ def fasterprune( Err1 = torch.zeros_like(W1) Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] - - # """ - # if not channel wise - - # strategy = quant_scheme.weights.strategy - # if strategy is not QuantizationStrategy.CHANNEL: - # idx = i - - # """ if preserve_zeros: W1_nz_mask = W_nz_mask[:, i1:i2] @@ -189,6 +173,18 @@ def fasterprune( if quant_scheme.weights is not None: scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point + + group_size = quant_scheme.weights.group_size + if group_size is None or group_size == -1: + group_size = self.layer.weight.shape[1] + + if actorder: + g_idx = torch.Tensor([perm[j] // group_size for j in range(self.columns)], dtype=torch.int32, device=invperm.device) + g_idx = g_idx[invperm] + self.layer.weight_g_idx = Parameter(g_idx, requires_grad=False,) + else: + g_idx = torch.Tensor([j // group_size for j in range(self.columns)], dtype=torch.int32, device=W.device) + from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( fake_quantize, @@ -255,7 +251,6 @@ def fasterprune( if actorder: W = W[:, invperm] - # g_idx = g_idx[invperm] if isinstance(self.layer, transformers.Conv1D): W = W.t() @@ -265,7 +260,6 @@ def fasterprune( # place, clone() or direct assignment won't work self.layer.weight -= self.layer.weight self.layer.weight += W - self.g_idx = g_idx def free(self): """