Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Specified device cuda:0 does not match device of data cuda:-2 #423

Closed
zzzacwork opened this issue Jun 14, 2022 · 28 comments · Fixed by k2-fsa/k2#1012
Closed

Comments

@zzzacwork
Copy link

I am trying to fine-tune the pre-trained model from gigaspeech recipe, but encountered with the above error, below is the entire traceback log,

2022-06-14 15:53:20,088 INFO [train.py:734] Epoch 0, batch 0, loss[loss=0.3458, simple_loss=0.3931, pruned_loss=0.1492, over 2345.00 frames.], tot_loss[loss=0.3458, simple_loss=0.3931, pruned_loss=0.1492, over 2345.00 frames.], batch size: 13, lr: 1.14e-04
2022-06-14 15:54:07,373 INFO [train.py:754] Computing validation loss
cuda:0
Specified device cuda:0 does not match device of data cuda:-2
Traceback (most recent call last):
  File "./pruned_transducer_stateless2/train.py", line 988, in <module>
    main()
  File "./pruned_transducer_stateless2/train.py", line 981, in main
    run(rank=0, world_size=1, args=args)
  File "./pruned_transducer_stateless2/train.py", line 892, in run
    train_one_epoch(
  File "./pruned_transducer_stateless2/train.py", line 755, in train_one_epoch
    valid_info = compute_validation_loss(
  File "./pruned_transducer_stateless2/train.py", line 609, in compute_validation_loss
    loss, loss_info = compute_loss(
  File "./pruned_transducer_stateless2/train.py", line 552, in compute_loss
    simple_loss, pruned_loss = model(
  File "/home/zacwork/anaconda3/envs/icefall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zacwork/icefall/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py", line 144, in forward
    raise e
  File "/home/zacwork/icefallegs/gigaspeech/ASR/pruned_transducer_stateless2/model.py", line 140, in forward
    y_padded = y.pad(mode="constant", padding_value=0)

I tried to look into the code but could not find any clue about this error. My guess is some of the cuts from my dev set trigger it. Could you please point me to the relevant place to further debug this issue?

Thanks

@csukuangfj
Copy link
Collaborator

File "/home/zacwork/icefallegs/gigaspeech/ASR/pruned_transducer_stateless2/model.py", line 140, in forward
y_padded = y.pad(mode="constant", padding_value=0)

Could you insert

print(y.device)

before y_padded = y.pad(mode="constant", padding_value=0) and show the output?

@zzzacwork
Copy link
Author

I tried that and it shows "cuda:0"

@csukuangfj
Copy link
Collaborator

csukuangfj commented Jun 15, 2022

What is the output of

python3 -m k2.version

[EDITED]: There is a similar error reported at https://discuss.pytorch.org/t/allocation-of-tensor-on-cuda-fails/144204
That post suggests first performing the operations on the CPU and then moving the resulting tensor to GPU.

In your case, it can handle the training successfully but fails during the validation stage, which is very odd.

@zzzacwork
Copy link
Author

Collecting environment information...

k2 version: 1.15.1
Build type: Release
Git SHA1: b173c11ba379e2da0056281fe6b2d56f081419be
Git date: Mon Apr 18 16:10:45 2022
Cuda used to build k2: 11.3
cuDNN used to build k2: 8.2.0
Python version used to build k2: 3.8
OS used to build k2: Ubuntu 18.04.6 LTS
CMake version: 3.18.4
GCC version: 7.5.0
CMAKE_CUDA_FLAGS:  -Wno-deprecated-gpu-targets  --expt-extended-lambda -gencode arch=compute_35,code=sm_35 --expt-extended-lambda -gencode arch=compute_50,code=sm_50 --expt-extended-lambda -gencode arch=compute_60,code=sm_60 --expt-extended-lambda -gencode arch=compute_61,code=sm_61 --expt-extended-lambda -gencode arch=compute_70,code=sm_70 --expt-extended-lambda -gencode arch=compute_75,code=sm_75 --expt-extended-lambda -gencode arch=compute_80,code=sm_80 --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=compute_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall  --compiler-options -Wno-strict-overflow  --compiler-options -Wno-unknown-pragmas 
CMAKE_CXX_FLAGS:  -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable  -Wno-strict-overflow 
PyTorch version used to build k2: 1.11.0
PyTorch is using Cuda: 11.3
NVTX enabled: True
With CUDA: True
Disable debug: True
Sync kernels : False
Disable checks: False
Max cpu memory allocate: 214748364800

k2.version shows the above info.

@zzzacwork
Copy link
Author

What is the output of

python3 -m k2.version

[EDITED]: There is a similar error reported at https://discuss.pytorch.org/t/allocation-of-tensor-on-cuda-fails/144204 That post suggests first performing the operations on the CPU and then moving the resulting tensor to GPU.

In your case, it can handle the training successfully but fails during the validation stage, which is very odd.

I did a y.to(device("cuda:0")), but it still fails, might worth trying to convert to cpu according to that post. I was first wondering it might come from k2 computations, since y is a ragged tensor.

@zzzacwork
Copy link
Author

I tried the following solutions but this issue still persists, I am not sure it is related to the dataset I prepared or the functions that line of code invokes y_padded = y.pad(mode="constant", padding_value=0).

Solution 1: As suggested by @csukuangfj from #247 , I switched to pytorch1.10 and Cuda 10.2, the error changes to (other parts of the stack are the same)

    y_padded = y.pad(mode="constant", padding_value=0)                                                                                                
RuntimeError: CUDA error: invalid argument 

Solution 2: As suggested from https://discuss.pytorch.org/t/allocation-of-tensor-on-cuda-fails/144204, I added lines to convert the device to CPU and convert it back, but the error is still there.

Solution 3: As suggested from #247, I added back the function call to filter our input segments below 1s and above 20s, but it still doesn't solve the issue. (same error message).

Could you please guide me to how to debug it or a workaround ? Any help is highly appreciated in advance.

Thanks,

@danpovey
Copy link
Collaborator

This could be an error from a previous kernel. Perhaps try doing
export K2_SYNC_KERNELS=1
export CUDA_LAUNCH_BLOCKING=1
and hopefully the error might show up a bit earlier.

@csukuangfj
Copy link
Collaborator

Did you install k2 from source? If so, is the machine you used to build k2 the same as the one you are using for training?

@zzzacwork
Copy link
Author

I installed k2 using conda

conda install -c k2-fsa -c pytorch -c conda-forge k2 python=3.8 cudatoolkit=10.2 pytorch=1.10.0

Do you suggest build from the source? The k2.version output is below,

Collecting environment information...
k2 version: 1.16
Build type: Release
Git SHA1: 89300f06d8758a286b809b55532d52c40d88e82a
Git date: Sun Jun 19 11:45:43 2022
Cuda used to build k2: 10.2
cuDNN used to build k2: 8.0.2
Python version used to build k2: 3.8
OS used to build k2: Ubuntu 18.04.6 LTS
CMake version: 3.18.4
GCC version: 7.5.0
CMAKE_CUDA_FLAGS:   -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_35,code=sm_35  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_50,code=sm_50  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_60,code=sm_60  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_61,code=sm_61  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_70,code=sm_70  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall  --compiler-options -Wno-strict-overflow  --compiler-options -Wno-unknown-pragmas 
CMAKE_CXX_FLAGS:  -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable  -Wno-strict-overflow 
PyTorch version used to build k2: 1.10.0
PyTorch is using Cuda: 10.2
NVTX enabled: True
With CUDA: True
Disable debug: True
Sync kernels : True
Disable checks: False
Max cpu memory allocate: 214748364800
k2 abort: False

@zzzacwork
Copy link
Author

This could be an error from a previous kernel. Perhaps try doing export K2_SYNC_KERNELS=1 export CUDA_LAUNCH_BLOCKING=1 and hopefully the error might show up a bit earlier.

I added those 2 options but the stack trace doesn't seem to include more information, I pasted it below,

2022-06-23 07:23:31,341 INFO [train.py:730] (3/4) Epoch 0, batch 17500, loss[loss=0.2263, simple_loss=0.2928, pruned_loss=0.07986, over 2219.00 frames
.], tot_loss[loss=0.2303, simple_loss=0.2954, pruned_loss=0.0826, over 4256698.56 frames.], batch size: 23, lr: 3.78e-05                              
Traceback (most recent call last):                                                                                                                    
  File "./pruned_transducer_stateless2/train.py", line 996, in <module>                                                                               
    main()                                                                                                                                            
  File "./pruned_transducer_stateless2/train.py", line 987, in main                                                                                   
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)   
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')                                                            
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processe
s                                                                                                                                                     
    while not context.join():                                                                                                                         
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join          
    raise ProcessRaisedException(msg, error_index, failed_process.pid)                                                                                
torch.multiprocessing.spawn.ProcessRaisedException:                                                                                                   
                                                                                                                                                      
-- Process 3 terminated with the following error:                                                                                                     
Traceback (most recent call last):                                                                                                                    
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap   
    fn(i, *args)                                                                                                                                      
  File "/home/azureuser/codebase/icefall-birch/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 900, in run                       
    train_one_epoch(                                                                                                                                  
  File "/home/azureuser/codebase/icefall-birch/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 683, in train_one_epoch         
    loss, loss_info = compute_loss(                                                                                                                   
  File "/home/azureuser/codebase/icefall-birch/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py", line 548, in compute_loss              
    simple_loss, pruned_loss = model(                                                                                                                 
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])     
  File "/home/azureuser/anaconda3/envs/icefall-torch1.10cu102/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/azureuser/codebase/icefall-birch/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py", line 138, in forward
    y_padded = y.pad(mode="constant", padding_value=0)
RuntimeError: CUDA error: invalid argument

@danpovey
Copy link
Collaborator

You could run it inside pdb ("python3 -m pdb [program] [args]") and try to print out the dimensions of y.
I assume you
export CUDA_LAUNCH_BLOCKING=1
export K2_SYNC_KERNELS=1
from your shell, and then ran the python code from that same shell.

@zzzacwork
Copy link
Author

You could run it inside pdb ("python3 -m pdb [program] [args]") and try to print out the dimensions of y. I assume you export CUDA_LAUNCH_BLOCKING=1 export K2_SYNC_KERNELS=1 from your shell, and then ran the python code from that same shell.

yes, that was within the same shell. I haven't tried pdb yet.

I built k2 from source with pytorch 1.10.0 , cuda 10.2, cudnn 7.6.5. This exception disappears and the training continues normally.

@zzzacwork
Copy link
Author

updates:

the exception re-appears during the second epoch(epoch 0 finished without error).
I am looking into this , by switching to CPU training, hopefully that will give me more trackable logs.

@zzzacwork
Copy link
Author

zzzacwork commented Jul 8, 2022

updates on this issue,

This seems to be a bug on the pad function with empty tensors(our training data have background noise inputs that have empty transcripts.)
To reproduce this bug,

>>> import k2.ragged as k2r
>>> z = k2r.create_ragged_tensor([[], [], []], device=‘cuda:0’)
>>> z.device
device(type=‘cuda’, index=0)
>>> z.pad(mode=“constant”, padding_value=0)
Traceback (most recent call last):
  File “<stdin>“, line 1, in <module>
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.

if we convert the tensor to CPU then it works. so I guess this is a bug related to pytorch.
My temporary solution to this issue is to pad empty input tensors with one 0 token, inserted to the forward function in model.py.

        if y.shape.max_size(1) == 0:
          y = concat(y, value=0, direction="right")

But then I noticed that the training (pruned-rnn-t-stateless2) gives inf loss on those empty utterance. I checked back the train_bpe code and found there are several special tokens such as <SPOKEN NOISE> , but it was pruned out from the lexicon so I am not sure if that I can insert that token to those empty inputs. Do you have any suggestions on this issue?
Thank you in advance.

@csukuangfj
Copy link
Collaborator

our training data have background noise inputs that have empty transcripts.

Shall we filter out such inputs from the training data?

@zzzacwork
Copy link
Author

our training data have background noise inputs that have empty transcripts.

Shall we filter out such inputs from the training data?

Thanks for the advice, pruning should work I think. But we also found from previous experiments on wav2vec models that those noisy inputs are very helpful. We often have noisy inputs during production as well and by training on those noisy inputs, the model would learn to decode nothing instead of some random words. I am currently trying to retrain the BPE system with an addition special token '' added to those empty samples.

@csukuangfj
Copy link
Collaborator

My temporary solution to this issue is to pad empty input tensors with one 0 token,

In that case, I suggest padding it with at least s_range 0 tokens. Only 1 is not enough for computing the loss.

@zzzacwork
Copy link
Author

Does 0 token also serve as the termination token when computing the loss? I got the impression since normally we also have padding tokens at the end of each y tensor, and those tokens probably don't contribute to the computation of the loss.

@csukuangfj
Copy link
Collaborator

boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

Only things within the given boundary contribute to the loss.

@zzzacwork
Copy link
Author

boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

Only things within the given boundary contribute to the loss.

Thank you for the explanation!

I will try padding with s_range tokens to those empty inputs. By the way, I tried adding <unk> token to the supervisions, but got an error on this line

loss_value = tot_loss["loss"] / tot_loss["frames"]

complaining division by 0, does it mean that the x_lens(frames) are all 0?

@csukuangfj
Copy link
Collaborator

complaining division by 0, does it mean that the x_lens(frames) are all 0?

I suggest either setting a breakpoint and running it step by step or just dumping the data to disk and inspect it.

@danpovey
Copy link
Collaborator

danpovey commented Jul 9, 2022

Is it clear that this is a Torch bug? We are calling pad on a k2 ragged object.
IMO it would be nice not to have to filter out empty inputs, it may be valid to have empty transcriptions under some circumstances.
k2 does support having a device associated with an empty tensor.

@zzzacwork
Copy link
Author

I am not 100% sure that it is a torch bug. I see k2 supports empty tensor very well. but when I call the pad function on that empty tensor, it will give an error if the tensor is associated with cuda device. it works well on cpu though.

>>> import k2.ragged as k2r
>>> z = k2r.create_ragged_tensor([[], [], []], device='cpu')
>>> z.pad(mode="constant", padding_value=0)
tensor([], size=(3, 0), dtype=torch.int32)
>>> z_gpu = k2r.create_ragged_tensor([[], [], []], device=‘cuda:0’)
>>> z_gpu.pad(mode="constant", padding_value=0)
Traceback (most recent call last):
  File “<stdin>“, line 1, in <module>
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.

By the way, I follow @csukuangfj 's suggestion and pad empty inputs with s_range 0 tokens(which is 5 in my current setup) and I find the pruned loss doesn't return inf value anymore.

@danpovey
Copy link
Collaborator

It may be a bug in k2, we should debug that separately. @pkufool ?

@zzzacwork
Copy link
Author

My temporary solution to this issue is to pad empty input tensors with one 0 token,

In that case, I suggest padding it with at least s_range 0 tokens. Only 1 is not enough for computing the loss.

Another issue is that I found the loss function sometimes gave inf if input batch contains samples that only has one symbol. I am curious how the system avoid such exception for inputs that have less than s_range symbols. For instance, the s_range is defaulted to 5 for our setup. Will it automatically limit the ranges parameter so that the loss function will compute the loss correctly?

Thanks in advance.

@csukuangfj
Copy link
Collaborator

From k2-fsa/fast_rnnt#10 (comment)

Can you try the latest master of k2?
(You have to install k2 from source. Please see https://k2-fsa.github.io/k2/installation/from_source.html)

@pkufool
Copy link
Collaborator

pkufool commented Jul 13, 2022

My temporary solution to this issue is to pad empty input tensors with one 0 token,

In that case, I suggest padding it with at least s_range 0 tokens. Only 1 is not enough for computing the loss.

Another issue is that I found the loss function sometimes gave inf if input batch contains samples that only has one symbol. I am curious how the system avoid such exception for inputs that have less than s_range symbols. For instance, the s_range is defaulted to 5 for our setup. Will it automatically limit the ranges parameter so that the loss function will compute the loss correctly?

Thanks in advance.

Try this fix k2-fsa/k2#1009, it has not been merged into master yet. I think this fix can handle the samples that only has one symbol no mater what s_range are you using.

@zzzacwork
Copy link
Author

Thank you for the fix! I downloaded and tested against my bad input batch, no more inf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants