Skip to content

Commit

Permalink
Remove JSON config mangling for Gemma ckpt (#124)
Browse files Browse the repository at this point in the history
update gemma convert
  • Loading branch information
lsy323 authored Jun 13, 2024
1 parent 8a125b6 commit fe8dbde
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ the tokenizer that we will use.
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.

```bash
# Install huggingface-cli and login if it's not set up.
pip install -U "huggingface_hub[cli]"
huggingface-cli login
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
```

Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)

## Mixtral
### Get Mixtral Checkpoint from HuggingFace

Expand Down
9 changes: 1 addition & 8 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,14 @@ def _get_llama_state_dict(input_ckpt_dir):
return state_dict, params


def fix_json(text):
text = text.replace("'", '"')
lines = text.split("\n")
lines[-3] = lines[-3].replace(",", "")
return "\n".join(lines)


def _get_gemma_state_dict(input_ckpt_dir):
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
ckpt_file = ckpt_file[0]
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
"model_state_dict"
]
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
config_text = (input_ckpt_dir / "config.json").read_text()
model_config = json.loads(config_text)
for key in list(state_dict.keys()):
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
Expand Down

0 comments on commit fe8dbde

Please sign in to comment.