Skip to content

Commit

Permalink
Support musicgen conversion. (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jul 11, 2023
1 parent 0c354d9 commit 6d93a71
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ def check_final_model(model_id: str, folder: str):
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

_, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
_, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
import transformers

class_ = getattr(transformers, config.architectures[0])
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)

if pt_infos != sf_infos:
error_string = create_diff(pt_infos, sf_infos)
Expand Down Expand Up @@ -199,7 +202,19 @@ def check_final_model(model_id: str, folder: str):
sf_model = sf_model.cuda()
kwargs = {k: v.cuda() for k, v in kwargs.items()}

pt_logits = pt_model(**kwargs)[0]
try:
pt_logits = pt_model(**kwargs)[0]
except Exception as e:
try:
# Musicgen special exception.
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
if torch.cuda.is_available():
decoder_input_ids = decoder_input_ids.cuda()

kwargs["decoder_input_ids"] = decoder_input_ids
pt_logits = pt_model(**kwargs)[0]
except Exception:
raise e
sf_logits = sf_model(**kwargs)[0]

torch.testing.assert_close(sf_logits, pt_logits)
Expand Down

0 comments on commit 6d93a71

Please sign in to comment.