diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 38dde4ec91e267..fec91ae6556ca3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -504,7 +504,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): # Check format of the archive with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() - if metadata.get("format") not in ["pt", "tf", "flax"]: + if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: raise OSError( f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " "you save your model with the `save_pretrained` method."