Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug report: tensor on different gpu when use defulat device_map #194

Closed
yyfcc17 opened this issue Nov 15, 2023 · 7 comments
Closed

Bug report: tensor on different gpu when use defulat device_map #194

yyfcc17 opened this issue Nov 15, 2023 · 7 comments

Comments

@yyfcc17
Copy link

yyfcc17 commented Nov 15, 2023

in quantizer.py, the .cuda() method use gpu 0, when a block is placed on other gpu, like gpu 1, this will lead to the error.

have you tested on multi-gpus?

or when quantization we only use one gpu is enough? move one block from cpu to gpu 0, then move it back to cpu after the quantization of this block is done?

@casper-hansen
Copy link
Owner

Do you have a trace back? I have tested multi-GPU. The .cuda() method is supposed to move the tensor to the current GPU as opposed to doing .to() which specifies which GPU.

@yyfcc17
Copy link
Author

yyfcc17 commented Nov 15, 2023

here is my case:

at the init, i have a block dispatched to gpu 1,

when .cuda() at this line: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L73

the block's parameters are moved to gpu 0, (infered: but the input tensor seems to be still on gpu 1),

even if in _get_input_feat in quantizer.py the inps also moved to the block's parameters device (which is now gpu 0),

after running: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L355

it will run into error, i have checked, it is because the parameter of layernorm (in this block) is on gpu 0, but the input tensor is on gpu 1.

@yyfcc17
Copy link
Author

yyfcc17 commented Nov 15, 2023

Maybe the problem is why .cuda() moves a block from gpu 1 to gpu 0, rather than do nothing.

And what does current device mean in pytorch doc?

@casper-hansen
Copy link
Owner

This may just come down to how the model is initialized. Which version of AutoAWQ are you using? I pushed a big change for multi-GPU yesterday to the main branch

@yyfcc17
Copy link
Author

yyfcc17 commented Nov 15, 2023

git log last commit: Fix multi-GPU loading and inference, which is your yesterday version.

the model i try to support is chatglm2-6b, i don't know if this line in chatglm2-6b modeling will cause the problem:
https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L349

after all, after my workaround of this issue, the model was quantized to int4, and the inference seems good. thanks for your work.

i tried the official llm-awq first, the kernel implementation seems to have some bug. your kernel implementation runs smooth right now 👍

@casper-hansen
Copy link
Owner

I just fixed this! #196

I found that .cuda() was a problem just like you said. I thought it would have been moved correctly, but .cuda() causes everything to be on cuda:0 (I do not understand why though).

@yyfcc17 yyfcc17 closed this as completed Nov 17, 2023
@user-ZJ
Copy link

user-ZJ commented Jan 17, 2024

git log last commit: Fix multi-GPU loading and inference, which is your yesterday version.

the model i try to support is chatglm2-6b, i don't know if this line in chatglm2-6b modeling will cause the problem: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L349

after all, after my workaround of this issue, the model was quantized to int4, and the inference seems good. thanks for your work.

i tried the official llm-awq first, the kernel implementation seems to have some bug. your kernel implementation runs smooth right now 👍

can you shared the code added chatglm2-6b model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants