diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 9fa8ea52b655..b3faf2463dff 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -20,17 +20,6 @@ from paddle import _C_ops from paddle.autograd.py_layer import PyLayer -try: - from paddlenlp_ops import flash_attn_bwd -except (ImportError, ModuleNotFoundError): - from paddlenlp.utils.log import logger - - logger.warning( - "if you run ring_flash_attention.py, please ensure you install " - "the paddlenlp_ops by following the instructions " - "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" - ) - class RingCommunicator: def __init__(self, group, local_key, local_value): @@ -233,6 +222,17 @@ def balanced_ring_flash_attention_bwd_func( if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) + try: + from paddlenlp_ops import flash_attn_bwd + except (ImportError, ModuleNotFoundError): + from paddlenlp.utils.log import logger + + logger.warning( + "if you run ring_flash_attention.py, please ensure you install " + "the paddlenlp_ops by following the instructions " + "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) + for step in range(cp_size): block_k, block_v = kv_comm_buffer.get_buffers()