Skip to content

Commit

Permalink
overwrite intialized g_idx in gptq_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 25, 2024
1 parent 298ff88 commit 94196b5
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sparseml.modifiers.utils import SPARSITY_THRESHOLD
from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper

from torch.nn import Parameter

try:
import transformers
Expand Down Expand Up @@ -119,6 +120,7 @@ def fasterprune(
self.H[dead, dead] = 1
W[:, dead] = 0

# Or read from self.layer.quantization_scheme
if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
Expand All @@ -135,15 +137,6 @@ def fasterprune(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H


g_idx = []
if actorder:
g_idx = [perm[i] // quant_scheme.weights.group_size for i in range(self.columns)]
g_idx = g_idx[invperm]
else:
g_idx = [i // quant_scheme.weights.group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand All @@ -154,15 +147,6 @@ def fasterprune(
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

# """
# if not channel wise

# strategy = quant_scheme.weights.strategy
# if strategy is not QuantizationStrategy.CHANNEL:
# idx = i

# """

if preserve_zeros:
W1_nz_mask = W_nz_mask[:, i1:i2]
Expand All @@ -189,6 +173,18 @@ def fasterprune(
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
g_idx = torch.Tensor([perm[j] // group_size for j in range(self.columns)], dtype=torch.int32, device=invperm.device)
g_idx = g_idx[invperm]
self.layer.weight_g_idx = Parameter(g_idx, requires_grad=False,)
else:
g_idx = torch.Tensor([j // group_size for j in range(self.columns)], dtype=torch.int32, device=W.device)

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
Expand Down Expand Up @@ -255,7 +251,6 @@ def fasterprune(

if actorder:
W = W[:, invperm]
# g_idx = g_idx[invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
Expand All @@ -265,7 +260,6 @@ def fasterprune(
# place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
self.layer.weight += W
self.g_idx = g_idx

def free(self):
"""
Expand Down

0 comments on commit 94196b5

Please sign in to comment.