diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 2c2c9d7e71..902af87bbb 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: