From 113b645dc7d0b7710803f23ffbf937ce6461ed1e Mon Sep 17 00:00:00 2001 From: GoldenTeethCN Date: Thu, 6 Jun 2024 16:42:52 +0800 Subject: [PATCH] fix: device consistence (#1891) * fix: device consistence * style: make style on ./optimum/gptq/quantizer.py --- optimum/gptq/quantizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 2c2c9d7e71..902af87bbb 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: