Skip to content

Commit

Permalink
Accept device for Int8DynActInt4WeightQuantizer (pytorch#475)
Browse files Browse the repository at this point in the history
Before we migrate away from `Quantizer` APIs, we want to
align the `__init__` arguments between `Int8DynActInt4WeightQuantizer`
and `Int4WeightOnlyQuantizer` so it's easier for users to use.
  • Loading branch information
larryliu0820 authored Jul 3, 2024
1 parent 32a6503 commit a9907f1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,12 +942,14 @@ def __init__(
padding_allowed: bool = False,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> None:
super().__init__()
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision
self.device: torch.device = device

@torch.no_grad()
def _create_quantized_state_dict(
Expand Down Expand Up @@ -988,9 +990,9 @@ def _create_quantized_state_dict(
self.groupsize,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device)
cur_state_dict[f"{fqn}.scales"] = scales.to(self.device)
cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device)
# TODO: support bias?

return cur_state_dict
Expand Down

0 comments on commit a9907f1

Please sign in to comment.