Skip to content

Commit

Permalink
Fix convert. (#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Nov 20, 2023
1 parent 094e676 commit 9610b4f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def convert_multi(
def convert_single(
model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
pt_filename = hf_hub_download(
repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
)

sf_name = "model.safetensors"
sf_filename = os.path.join(folder, sf_name)
Expand Down Expand Up @@ -225,15 +227,15 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])

def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
try:
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id, revision=revision)
revision_commit = api.model_info(model_id, revision=revision).sha
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
for discussion in discussions:
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)

if main_commit == commits[1].commit_id:
if revision_commit == commits[1].commit_id:
return discussion
return None

Expand Down Expand Up @@ -274,7 +276,7 @@ def convert(
info = api.model_info(model_id, revision=revision)
filenames = set(s.rfilename for s in info.siblings)

with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d:
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
Expand Down

0 comments on commit 9610b4f

Please sign in to comment.