diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 018688a67..08e4f35d8 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -942,12 +942,14 @@ def __init__( padding_allowed: bool = False, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), ) -> None: super().__init__() self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision self.scales_precision: torch.dtype = scales_precision + self.device: torch.device = device @torch.no_grad() def _create_quantized_state_dict( @@ -988,9 +990,9 @@ def _create_quantized_state_dict( self.groupsize, self.scales_precision, ) - cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") - cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") - cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") + cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device) + cur_state_dict[f"{fqn}.scales"] = scales.to(self.device) + cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device) # TODO: support bias? return cur_state_dict