Skip to content

Commit

Permalink
update g_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 25, 2024
1 parent 94196b5 commit 2525f69
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import time

from torch.nn import Parameter

from sparseml.modifiers.utils import SPARSITY_THRESHOLD
from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper

from torch.nn import Parameter

try:
import transformers
Expand Down Expand Up @@ -177,34 +178,48 @@ def fasterprune(
group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
g_idx = torch.Tensor([perm[j] // group_size for j in range(self.columns)], dtype=torch.int32, device=invperm.device)
g_idx = torch.tensor(
[perm[j] // group_size for j in range(self.columns)],
dtype=torch.int32,
device=invperm.device
)

g_idx = g_idx[invperm]
self.layer.weight_g_idx = Parameter(g_idx, requires_grad=False,)
self.layer.weight_g_idx = Parameter(
g_idx,
requires_grad=False,
)
else:
g_idx = torch.Tensor([j // group_size for j in range(self.columns)], dtype=torch.int32, device=W.device)
g_idx = torch.Tensor(
[j // group_size for j in range(self.columns)],

device=W.device,
)

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

strategy = quant_scheme.weights.strategy

breakpoint()
if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
g_idx,
)
elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
q = fake_quantize(
q,
scale[:, 0],
zero_point[:, 0],
# g_idx,
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.GROUP
Expand All @@ -222,6 +237,7 @@ def fasterprune(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
# g_idx,
altered_qargs,
)

Expand All @@ -247,8 +263,7 @@ def fasterprune(

_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())



if actorder:
W = W[:, invperm]

Expand Down

0 comments on commit 2525f69

Please sign in to comment.