Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Parameter indices which did not receive grad for rank x", Multi-GPU SDXL Training (unet + both text encoders) #997

Closed
fauzanardh opened this issue Dec 11, 2023 · 8 comments

Comments

@fauzanardh
Copy link

I've encountered an error while training the SDXL UNet with both text encoders using the latest development branch.

Here's part of the traceback:

Traceback (most recent call last):
  File "/workspace/sd-scripts/sdxl_train.py", line 823, in <module>
    train(args)
  File "/workspace/sd-scripts/sdxl_train.py", line 521, in train
    encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
  File "/workspace/sd-scripts/library/train_util.py", line 4198, in get_hidden_states_sdxl
    enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1139, in forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by making sure all `forward` function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 1: 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

This error only appears when training the text encoders. The parameter indices that consistently identified as not receiving gradients are 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195.

I also tried adding find_unused_parameters=True to the DistributedDataParallelKwargs but the problem persists.

I'm pretty sure this happens after PR #989 was merged to the dev branch.

@fauzanardh fauzanardh changed the title "Parameter indices which did not receive grad for rank x", when training SDXL unet + both text encoders "Parameter indices which did not receive grad for rank x", Multi-GPU SDXL Training (unet + both text encoders) Dec 11, 2023
@fauzanardh
Copy link
Author

fauzanardh commented Dec 11, 2023

I'm thinking that it has something to do with the text encoders still being wrapped when getting the hidden states

@fauzanardh
Copy link
Author

Unwrapping both text encoders before calling get_hidden_states_sdxl does fix it

te1 = text_encoder1.module if type(text_encoder1) == DDP else text_encoder1
te2 = text_encoder2.module if type(text_encoder2) == DDP else text_encoder2

# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
    args.max_token_length,
    input_ids1,
    input_ids2,
    tokenizer1,
    tokenizer2,
    te1,
    te2,
    None if not args.full_fp16 else weight_dtype,
)

and reverting this change:

unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)

to

pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)

But with this change, I'm afraid it broke the gradient synchronization

@Isotr0py
Copy link
Contributor

I reproduced this with training only text_encoder1, and I don't know why the hidden_layer11's parameters in text_encoder1 not received grad from loss calculation (this appears in both single GPU and DDP training):

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. Since `find_unused_parameters=True` is enabled, this likely  means that not all `forward` outputs participate in computing loss. You can fix this by making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 0: text_model.final_layer_norm.bias, text_model.final_layer_norm.weight, text_model.encoder.layers.11.layer_norm2.bias, text_model.encoder.layers.11.layer_norm2.weight, text_model.encoder.layers.11.mlp.fc2.bias, text_model.encoder.layers.11.mlp.fc2.weight, text_model.encoder.layers.11.mlp.fc1.bias, text_model.encoder.layers.11.mlp.fc1.weight, text_model.encoder.layers.11.layer_norm1.bias, text_model.encoder.layers.11.layer_norm1.weight, text_model.encoder.layers.11.self_attn.out_proj.bias, text_model.encoder.layers.11.self_attn.out_proj.weight, text_model.encoder.layers.11.self_attn.q_proj.bias, text_model.encoder.layers.11.self_attn.q_proj.weight, text_model.encoder.layers.11.self_attn.v_proj.bias, text_model.encoder.layers.11.self_attn.v_proj.weight, text_model.encoder.layers.11.self_attn.k_proj.bias, text_model.encoder.layers.11.self_attn.k_proj.weight

For single GPU training, the grad of hidden_layer11 in text_encoder1 is also 0, but this will not raise an error because all_reduce is not needed:

Keys without grad: ['text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.final_layer_norm.weight', 'text_model.final_layer_norm.bias']

@Isotr0py
Copy link
Contributor

Isotr0py commented Dec 12, 2023

Anyway, if the zero grad in text_encoder1 is expected in training, we can simply freeze the layer11 and final_layer_norm to prevent the error:

sd-scripts/sdxl_train.py

Lines 400 to 401 in 4a2cef8

if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)

->

    if train_text_encoder1:
        # frozen layers11 and final_layer_norm
        text_encoder1.text_model.encoder.layers[11].requires_grad_(False)
        text_encoder1.text_model.final_layer_norm.requires_grad_(False)
        text_encoder1 = accelerator.prepare(text_encoder1)

@fauzanardh
Copy link
Author

fauzanardh commented Dec 12, 2023

I checked open_clip's implementation and this repo, I don't see anything mentioning that we should freeze layer11 and final_layer_norm. So, I don't think that this is the right way to fix this problem either.

@kohya-ss
Copy link
Owner

I don't know the reason for RuntimeError, but SDXL uses the output of the penultimate layer of Text Encoder 1, so I think freezing the last layer and final_layer_norm of CLIP should not change the training result.

@fauzanardh
Copy link
Author

Oops, sorry, I got the text encoders flipped in my mind, I thought that text_encoder1 is the one from open_clip. So, yea, then I think freezing the last layer and final layer norm shouldn't affect the training at all.

@fauzanardh
Copy link
Author

Fixed by PR #1000

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants