diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index dc1c753206c4..9c9af9bbc694 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -798,7 +798,7 @@ def _load_state_dict_into_meta_model( dtype = convert_np_dtype_to_dtype_(dtype) error_msgs = [] - + model_state_dict = model.state_dict() for param_name, param in state_dict.items(): # First part of the test is always true as loaded_state_dict_keys always contains state_dict keys. if param_name not in loaded_state_dict_keys or param_name not in expected_keys: @@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model( if old_param is not None: param = param.astype(dtype=old_param.dtype) with paddle.no_grad(): - model.state_dict()[param_name].get_tensor()._share_data_with(param.value().get_tensor()) + model_state_dict[param_name].get_tensor()._share_data_with(param.value().get_tensor()) param.value().get_tensor()._clear() return error_msgs @@ -1890,7 +1890,7 @@ def _find_mismatched_keys( if ( shard_file.endswith(".safetensors") and config.tensor_parallel_degree > 1 - and "tp" not in shard_file + and "tp" not in os.path.spilt(shard_file)[-1] ): pre_tensor_parallel_split = True assert loaded_keys is not None, "loaded_keys is not None."