Skip to content

Commit

Permalink
Passing the token around. (#302)
Browse files Browse the repository at this point in the history
* Passing the token around.

* Update Python version (flax is broken on 3.7)

* Fixing flax ?

* Apply suggestions from code review
  • Loading branch information
Narsil committed Jul 31, 2023
1 parent 6d93a71 commit f1e4d06
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
# Lowest and highest, no version specified so that
# new releases get automatically tested against
version: [{torch: torch==1.10, python: "3.7"}, {torch: torch, python: "3.10"}]
version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.10"}]
defaults:
run:
working-directory: ./bindings/python
Expand Down
27 changes: 13 additions & 14 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
from transformers.pipelines.base import infer_framework_load_model


COMMIT_DESCRIPTION = """
Expand Down Expand Up @@ -71,15 +70,15 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str, folder: str) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
with open(filename, "r") as f:
data = json.load(f)

filenames = set(data["weight_map"].values())
local_filenames = []
for filename in filenames:
pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token)

sf_filename = rename(pt_filename)
sf_filename = os.path.join(folder, sf_filename)
Expand All @@ -102,8 +101,8 @@ def convert_multi(model_id: str, folder: str) -> ConversionResult:
return operations, errors


def convert_single(model_id: str, folder: str) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)

sf_name = "model.safetensors"
sf_filename = os.path.join(folder, sf_name)
Expand Down Expand Up @@ -156,8 +155,8 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
return "\n".join(errors)


def check_final_model(model_id: str, folder: str):
config = hf_hub_download(repo_id=model_id, filename="config.json")
def check_final_model(model_id: str, folder: str, token: Optional[str]):
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

Expand Down Expand Up @@ -236,15 +235,15 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
return None


def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> ConversionResult:
def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
operations = []
errors = []

extensions = set([".bin", ".ckpt"])
for filename in filenames:
prefix, ext = os.path.splitext(filename)
if ext in extensions:
pt_filename = hf_hub_download(model_id, filename=filename)
pt_filename = hf_hub_download(model_id, filename=filename, token=token)
dirname, raw_filename = os.path.split(filename)
if raw_filename == "pytorch_model.bin":
# XXX: This is a special case to handle `transformers` and the
Expand Down Expand Up @@ -283,14 +282,14 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
elif library_name == "transformers":
if "pytorch_model.bin" in filenames:
operations, errors = convert_single(model_id, folder)
operations, errors = convert_single(model_id, folder, token=api.token)
elif "pytorch_model.bin.index.json" in filenames:
operations, errors = convert_multi(model_id, folder)
operations, errors = convert_multi(model_id, folder, token=api.token)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
check_final_model(model_id, folder)
check_final_model(model_id, folder, token=api.token)
else:
operations, errors = convert_generic(model_id, folder, filenames)
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)

if operations:
new_pr = api.create_commit(
Expand Down
21 changes: 11 additions & 10 deletions bindings/python/py_src/safetensors/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import numpy as np

import jax.numpy as jnp
from jax import Array
from safetensors import numpy


def save(tensors: Dict[str, jnp.DeviceArray], metadata: Optional[Dict[str, str]] = None) -> bytes:
def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
"""
Saves a dictionnary of tensors into raw bytes in safetensors format.
Args:
tensors (`Dict[str, jnp.DeviceArray]`):
tensors (`Dict[str, Array]`):
The incoming tensors. Tensors need to be contiguous and dense.
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
Optional text only metadata you might want to save in your header.
Expand All @@ -37,15 +38,15 @@ def save(tensors: Dict[str, jnp.DeviceArray], metadata: Optional[Dict[str, str]]


def save_file(
tensors: Dict[str, jnp.DeviceArray],
tensors: Dict[str, Array],
filename: Union[str, os.PathLike],
metadata: Optional[Dict[str, str]] = None,
) -> None:
"""
Saves a dictionnary of tensors into raw bytes in safetensors format.
Args:
tensors (`Dict[str, jnp.DeviceArray]`):
tensors (`Dict[str, Array]`):
The incoming tensors. Tensors need to be contiguous and dense.
filename (`str`, or `os.PathLike`)):
The filename we're saving into.
Expand All @@ -71,7 +72,7 @@ def save_file(
return numpy.save_file(np_tensors, filename, metadata=metadata)


def load(data: bytes) -> Dict[str, jnp.DeviceArray]:
def load(data: bytes) -> Dict[str, Array]:
"""
Loads a safetensors file into flax format from pure bytes.
Expand All @@ -80,7 +81,7 @@ def load(data: bytes) -> Dict[str, jnp.DeviceArray]:
The content of a safetensors file
Returns:
`Dict[str, jnp.DeviceArray]`: dictionary that contains name as key, value as `jnp.DeviceArray` on cpu
`Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu
Example:
Expand All @@ -98,7 +99,7 @@ def load(data: bytes) -> Dict[str, jnp.DeviceArray]:
return _np2jnp(flat)


def load_file(filename: Union[str, os.PathLike]) -> Dict[str, jnp.DeviceArray]:
def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
"""
Loads a safetensors file into flax format.
Expand All @@ -110,7 +111,7 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, jnp.DeviceArray]:
available options are all regular flax device locations
Returns:
`Dict[str, jnp.DeviceArray]`: dictionary that contains name as key, value as `jnp.DeviceArray`
`Dict[str, Array]`: dictionary that contains name as key, value as `Array`
Example:
Expand All @@ -125,13 +126,13 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, jnp.DeviceArray]:
return _np2jnp(flat)


def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, jnp.DeviceArray]:
def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
for k, v in numpy_dict.items():
numpy_dict[k] = jnp.array(v)
return numpy_dict


def _jnp2np(jnp_dict: Dict[str, jnp.DeviceArray]) -> Dict[str, np.array]:
def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]:
for k, v in jnp_dict.items():
jnp_dict[k] = np.asarray(v)
return jnp_dict

0 comments on commit f1e4d06

Please sign in to comment.