From 8a22bab94117bbd2924935cf68b2d26c1e7fd833 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 10 Jun 2024 23:29:30 +0000 Subject: [PATCH] Fix convert_checkpoint.py for hf and gemma --- convert_checkpoints.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 1b3af726..9ba6836f 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path): def _load_hf_llama_weight(input_ckpt_dir: epath.Path): print(f"Loading checkpoint files from {input_ckpt_dir}.") - safetensors_files = input_ckpt_dir.glob("*.safetensors") + safetensors_files = list(input_ckpt_dir.glob("*.safetensors")) if len(list(safetensors_files)) == 0: raise ValueError( f"No *.safetensors found in the input dir {input_ckpt_dir}" @@ -419,6 +419,13 @@ def _get_llama_state_dict(input_ckpt_dir): return state_dict, params +def fix_json(text): + text = text.replace("'", '"') + lines = text.split("\n") + lines[-3] = lines[-3].replace(",", "") + return "\n".join(lines) + + def _get_gemma_state_dict(input_ckpt_dir): ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." @@ -426,7 +433,8 @@ def _get_gemma_state_dict(input_ckpt_dir): state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[ "model_state_dict" ] - model_config = json.loads((input_ckpt_dir / "config.json").read_text()) + config_text = fix_json((input_ckpt_dir / "config.json").read_text()) + model_config = json.loads(config_text) for key in list(state_dict.keys()): if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: assert (