Skip to content

Commit

Permalink
Merge pull request #79 from breezedeus/pytorch
Browse files Browse the repository at this point in the history
download model files from 'HF' instead of 'CN'
  • Loading branch information
breezedeus authored Apr 10, 2024
2 parents 078f09d + 184bceb commit a7b1056
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package:
rm -rf build
python setup.py sdist bdist_wheel

VERSION = 1.2.3.5
VERSION := $(shell sed -n "s/^__version__ = '\(.*\)'/\1/p" cnstd/__version__.py)
upload:
python -m twine upload dist/cnstd-$(VERSION)* --verbose

Expand Down
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release Notes

# Update 2024.04.10:发布 V1.2.3.6

主要变更:

* CN OSS 不可用了,默认下载模型地址由 `CN` 改为 `HF`

# Update 2023.10.09:发布 V1.2.3.5

主要变更:
Expand Down
2 changes: 1 addition & 1 deletion cnstd/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.

__version__ = '1.2.3.5'
__version__ = '1.2.3.6'
3 changes: 2 additions & 1 deletion cnstd/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2])
VOCAB_FP = Path(__file__).parent.parent / 'label_cn.txt'
# Which OSS source will be used for downloading model files, 'CN' or 'HF'
DOWNLOAD_SOURCE = os.environ.get('CNSTD_DOWNLOAD_SOURCE', 'CN')
DOWNLOAD_SOURCE = os.environ.get('CNSTD_DOWNLOAD_SOURCE', 'HF')
HF_ENDPOINT_LIST = ['https://huggingface.co', 'https://hf-mirror.com']

MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
'db_resnet50': {
Expand Down
58 changes: 37 additions & 21 deletions cnstd/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch
from huggingface_hub import hf_hub_download

from ..consts import MODEL_VERSION, MODEL_CONFIGS, AVAILABLE_MODELS
from ..consts import MODEL_VERSION, MODEL_CONFIGS, HF_ENDPOINT_LIST


fmt = '[%(levelname)s %(asctime)s %(funcName)s:%(lineno)d] %(' 'message)s '
Expand Down Expand Up @@ -185,26 +185,38 @@ def download(url, path=None, download_source='CN', overwrite=False, sha1_hash=No
else:
total_length = int(total_length)
for chunk in tqdm(
r.iter_content(chunk_size=1024),
total=int(total_length / 1024.0 + 0.5),
unit='KB',
unit_scale=False,
dynamic_ncols=True,
r.iter_content(chunk_size=1024),
total=int(total_length / 1024.0 + 0.5),
unit='KB',
unit_scale=False,
dynamic_ncols=True,
):
f.write(chunk)
else:
HF_TOKEN = os.environ.get('HF_TOKEN')
logger.info('Downloading %s from HF Repo %s...' % (fname, url["repo_id"]))
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = hf_hub_download(
repo_id=url["repo_id"],
subfolder=url["subfolder"],
filename=url["filename"],
repo_type="model",
cache_dir=tmp_dir,
token=HF_TOKEN,
)
shutil.copy2(local_path, fname)
for hf_endpoint in HF_ENDPOINT_LIST:
try:
logger.info(
'Downloading %s from HF Repo %s/%s...'
% (fname, hf_endpoint, url["repo_id"])
)
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = hf_hub_download(
repo_id=url["repo_id"],
subfolder=url["subfolder"],
filename=url["filename"],
repo_type="model",
cache_dir=tmp_dir,
token=HF_TOKEN,
endpoint=hf_endpoint,
)
shutil.copy2(local_path, fname)
break
except:
logger.warning(
'Failed to download %s from HF Repo %s/%s.'
% (fname, hf_endpoint, url["repo_id"])
)

if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning(
Expand Down Expand Up @@ -246,13 +258,17 @@ def get_model_file(url, model_dir, download_source='CN'):
zip_file_path = os.path.join(par_dir, url['filename'])
if not os.path.exists(zip_file_path):
try:
download(url, path=zip_file_path, download_source=download_source, overwrite=True)
download(
url, path=zip_file_path, download_source=download_source, overwrite=True
)
except Exception as e:
logger.error(e)
message = f'Failed to download model: {url["filename"]}.'
message += '\n\tPlease open your VPN and try again. \n\t' \
'If this error persists, please follow the instruction at ' \
'[CnSTD/CnOCR Doc](https://www.breezedeus.com/cnocr) to manually download the model files.'
message += (
'\n\tPlease open your VPN and try again. \n\t'
'If this error persists, please follow the instruction at '
'[CnSTD/CnOCR Doc](https://www.breezedeus.com/cnocr) to manually download the model files.'
)
raise ModelDownloadingError(message)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(par_dir)
Expand Down
18 changes: 9 additions & 9 deletions cnstd/yolov7/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def check_requirements(requirements='requirements.txt', exclude=()):
if isinstance(requirements, (str, Path)): # requirements.txt file
file = Path(requirements)
if not file.exists():
print(f"{prefix} {file.resolve()} not found, check failed.")
logger.warning(f"{prefix} {file.resolve()} not found, check failed.")
return
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
else: # list or tuple of packages
Expand All @@ -130,7 +130,7 @@ def check_requirements(requirements='requirements.txt', exclude=()):
pkg.require(r)
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
n += 1
print(f"{prefix} {e.req} not found and is required by YOLOR, attempting auto-update...")
logger.warning(f"{prefix} {e.req} not found and is required by YOLOR, attempting auto-update...")
print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())

if n: # if packages updated
Expand All @@ -144,7 +144,7 @@ def check_img_size(img_size, s=32):
# Verify img_size is a multiple of stride s
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
if new_size != img_size:
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
logger.warning('--img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
return new_size


Expand All @@ -158,7 +158,7 @@ def check_imshow():
cv2.waitKey(1)
return True
except Exception as e:
print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
logger.warning(f'Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
return False


Expand All @@ -179,16 +179,16 @@ def check_dataset(dict):
if val and len(val):
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
logger.warning('\nDataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
if s and len(s): # download script
print('Downloading %s ...' % s)
logger.info('Downloading %s ...' % s)
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
torch.hub.download_url_to_file(s, f)
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
else: # bash script
r = os.system(s)
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
logger.info('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
else:
raise Exception('Dataset not found.')

Expand Down Expand Up @@ -834,7 +834,7 @@ def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes

output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
logger.warning(f'NMS time limit {time_limit}s exceeded')
break # time limit exceeded

return output
Expand All @@ -853,7 +853,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
logger.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")


# def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
Expand Down

0 comments on commit a7b1056

Please sign in to comment.