Skip to content

Commit

Permalink
bf16 by default, gemma2 attns
Browse files Browse the repository at this point in the history
Gemma2 finetuning cannot work until merging huggingface/transformers#31674
  • Loading branch information
hiyouga authored and zhangzh committed Jul 1, 2024
1 parent 8c86417 commit d797ddf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/llamafactory/model/model_utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
logger = get_logger(__name__)


def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"

if model_args.flash_attn == "auto":
return

Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def patch_config(
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)

configure_attn_implementation(config, model_args)
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")

input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(
Expand Down

0 comments on commit d797ddf

Please sign in to comment.