Skip to content

Commit

Permalink
fix: device consistence (#1891)
Browse files Browse the repository at this point in the history
* fix: device consistence

* style: make style on ./optimum/gptq/quantizer.py
  • Loading branch information
Daya-Jin authored Jun 6, 2024
1 parent ac951ca commit 113b645
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
if not has_device_map or device.type == "cpu":
data[k] = v.to(0)
else:
data[k] = v.to(device)
try:
model(**data)
except ValueError:
Expand All @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
if not has_device_map or device.type == "cpu":
data[k] = v.to(0)
else:
data[k] = v.to(device)
try:
model(**data)
except ValueError:
Expand Down

0 comments on commit 113b645

Please sign in to comment.