Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QKVFlashAttention unexpected parameters error, running in Google Colab #3

Closed
JonathanFly opened this issue Apr 13, 2023 · 21 comments · Fixed by #17 · May be fixed by #38
Closed

QKVFlashAttention unexpected parameters error, running in Google Colab #3

JonathanFly opened this issue Apr 13, 2023 · 21 comments · Fixed by #17 · May be fixed by #38

Comments

@JonathanFly
Copy link

I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.

Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb


class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal
@zzhouj
Copy link

zzhouj commented Apr 13, 2023

i also meet the same error, i guess this code base using a previously version of flash_attn.

Logging to /tmp/openai-2023-04-13-14-33-55-278549
creating model and diffusion...
Traceback (most recent call last):
  File "/content/consistency_models/scripts/image_sample.py", line 143, in <module>
    main()
  File "/content/consistency_models/scripts/image_sample.py", line 37, in main
    model, diffusion = create_model_and_diffusion(
  File "/content/consistency_models/cm/script_util.py", line 76, in create_model_and_diffusion
    model = create_model(
  File "/content/consistency_models/cm/script_util.py", line 140, in create_model
    return UNetModel(
  File "/content/consistency_models/cm/unet.py", line 612, in __init__
    AttentionBlock(
  File "/content/consistency_models/cm/unet.py", line 293, in __init__
    self.attention = QKVFlashAttention(channels, self.num_heads)
  File "/content/consistency_models/cm/unet.py", line 359, in __init__
    self.inner_attn = FlashAttention(
TypeError: __init__() got an unexpected keyword argument 'device'

@JonathanFly
Copy link
Author

JonathanFly commented Apr 13, 2023

Related minor issue, in th_evaluator.py, inception-2015-12-05.pt tries to download automatically but fails, and it doesn't seem like you can pass the path on the command line. Also is it supposed to automatically calculate stats a reference batch? (I'm probably trying to run the sample out of order?)

class FIDAndIS:
def init(
self,
softmax_batch_size=512,
clip_score_batch_size=512,
path="https://openaipublic.blob.core.windows.net/consistency/inception/inception-2015-12-05.pt",
):

@ovencampb214
Copy link

class FIDAndIS:
def init(

@XipengY
Copy link

XipengY commented Apr 17, 2023

The version of v1.0.2 has no device parameter.
https://github.com/HazyResearch/flash-attention/blob/v1.0.2/flash_attn/flash_attention.py#L21

But v0.2.8 has device parameter.
https://github.com/HazyResearch/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py#L21

@XipengY
Copy link

XipengY commented Apr 17, 2023

I use pip install flash-attn==0.2.8 solved it.

This was referenced Apr 18, 2023
@boxwayne
Copy link

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1

@aarontan-git
Copy link

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1

I got the same problem as well

@boxwayne
Copy link

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1

I got the same problem as well

Had you solved this problem? I don't even know what the error message means.

@asanakoy
Copy link

Solution:
Do the following changes in File "/content/consistency_models/cm/unet.py", line 359, in init

-        self.inner_attn = FlashAttention(
-            attention_dropout=attention_dropout, **factory_kwargs
-        )
+        self.inner_attn = FlashAttention(attention_dropout=attention_dropout)

@aarontan-git
Copy link

I tried that, but it didn't solve the problem, were there any other changes you made?

@asanakoy
Copy link

asanakoy commented Apr 21, 2023

Since I'm running on V100, I also had to disable flash-attention (apparently it only works on A100)

index 3fe5184..d9f7c2f 100644
--- a/cm/unet.py
+++ b/cm/unet.py
@@ -270,7 +270,7 @@ class AttentionBlock(nn.Module):
         num_heads=1,
         num_head_channels=-1,
         use_checkpoint=False,
-        attention_type="flash",
+        attention_type="default", #"flash", # disable flash-attention by default in order to run on V100
         encoder_channels=None,
         dims=2,
         channels_last=False,

@asanakoy
Copy link

Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM
image

@treefreq
Copy link

Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM image

I cannot obtain images of similar quality to those in the paper

@ain-soph
Copy link

ain-soph commented May 2, 2023

@boxwayne @aarontan-git @asanakoy For the stride issue, I think it's the rearrange issue because of flashAttn version.

        qkv = self.rearrange(
            qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
        )
        # print(qkv.shape, qkv.stride())
        qkv, _ = self.inner_attn(
            qkv.contiguous(),
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            causal=self.causal,
        )

The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256), which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn:

qkv=qkv.contiguous()

Let me know if that solves the issue. I tested on my side and it works.

@stonecropa
Copy link

The flash-attn I installed is version 1.0.2, no problem.
image

@aarontan-git
Copy link

@boxwayne @aarontan-git @asanakoy For the stride issue, I think it's the rearrange issue because of flashAttn version.

        qkv = self.rearrange(
            qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
        )
        # print(qkv.shape, qkv.stride())
        qkv, _ = self.inner_attn(
            qkv.contiguous(),
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            causal=self.causal,
        )

The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256), which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn:

qkv=qkv.contiguous()

Let me know if that solves the issue. I tested on my side and it works.

I tried your fix, and got the following warning message when trying to run an imagenet consistency training:

Grad strides do not match bucket view strides. This may indicate grad was not created according │·
to the gradient layout contract, or that the param's strides changed since DDP was constructed.  This│·
 is not an error, but may impair performance.                                                        │·
grad.sizes() = [384, 384, 1, 1], strides() = [384, 1, 384, 384]                                      │·
bucket_view.sizes() = [384, 384, 1, 1], strides() = [384, 1, 1, 1] (Triggered internally at ../torch/│·
csrc/distributed/c10d/reducer.cpp:325.)

@ain-soph
Copy link

ain-soph commented May 2, 2023

@aarontan-git You could just leave this warning there if you don't care about it. If you want to fix this, you should check the codes and see which part involves the gradient stride change. And do the tensor storage stride modification to avoid the warning.

@treefreq
Copy link

treefreq commented May 3, 2023

The flash-attn I installed is version 1.0.2, no problem. image

When installing version 1.0.2, the following error will occur; How did you solve it?
image

@nekoshadow1
Copy link

nekoshadow1 commented May 3, 2023

The flash-attn I installed is version 1.0.2, no problem. image

When installing version 1.0.2, the following error will occur; How did you solve it? image

As Jonathan said at the top, change this line of code in /cm/unet.py, clearing out factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

@treefreq
Copy link

treefreq commented May 4, 2023

我安装的闪存是1.0.2版本,没问题。图像

安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?图像

正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

What GPU are you using?
This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message:
image

my GPU is V100........

@nekoshadow1
Copy link

我安装的闪存是1.0.2版本,没问题。图像

安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?图像

正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

What GPU are you using? This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message: image

my GPU is V100........

I used a single A100.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet