From fe8dbde9713d7a71594d68b41154e7260498e8a5 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Jun 2024 10:20:32 -0700 Subject: [PATCH] Remove JSON config mangling for Gemma ckpt (#124) update gemma convert --- README.md | 5 +++-- convert_checkpoints.py | 9 +-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index d2bea897..30da0427 100644 --- a/README.md +++ b/README.md @@ -59,11 +59,12 @@ the tokenizer that we will use. Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. ```bash +# Install huggingface-cli and login if it's not set up. +pip install -U "huggingface_hub[cli]" +huggingface-cli login huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir ``` -Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object) - ## Mixtral ### Get Mixtral Checkpoint from HuggingFace diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 436a42d5..c3f83160 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -428,13 +428,6 @@ def _get_llama_state_dict(input_ckpt_dir): return state_dict, params -def fix_json(text): - text = text.replace("'", '"') - lines = text.split("\n") - lines[-3] = lines[-3].replace(",", "") - return "\n".join(lines) - - def _get_gemma_state_dict(input_ckpt_dir): ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." @@ -442,7 +435,7 @@ def _get_gemma_state_dict(input_ckpt_dir): state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[ "model_state_dict" ] - config_text = fix_json((input_ckpt_dir / "config.json").read_text()) + config_text = (input_ckpt_dir / "config.json").read_text() model_config = json.loads(config_text) for key in list(state_dict.keys()): if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: