diff --git a/pipeline/train/instruction_following.py b/pipeline/train/instruction_following.py index 565ff505..eb6f6e06 100644 --- a/pipeline/train/instruction_following.py +++ b/pipeline/train/instruction_following.py @@ -161,7 +161,9 @@ def mask_embedding(m): if args.mask_lm_head: unwrapped_model = accelerator.unwrap_model(model) - if unwrapped_model.lang_encoder.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]: + if isinstance(unwrapped_model, IdeficsForVisionText2Text): + unwrapped_model.lm_head.apply(mask_embedding) + elif unwrapped_model.lang_encoder.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]: unwrapped_model.lang_encoder.transformer.wte.apply(mask_embedding) elif unwrapped_model.lang_encoder.__class__.__name__ == "LlamaForCausalLM": unwrapped_model.lang_encoder.model.embed_tokens.apply(mask_embedding)