diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..e63ec3eb4 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -86,23 +86,26 @@ def main(args): logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(args.model_dir, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) else: logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: + import torch import onnx import onnxruntime as ort diff --git a/requirements.txt b/requirements.txt index 279de350c..805f0501d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,9 +22,12 @@ huggingface-hub==0.20.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.14.1 -# onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 +# onnx==1.15.0 +# onnxruntime-gpu==1.17.1 +# onnxruntime==1.17.1 +# for cuda 12.1(default 11.8) +# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + # this is for onnx: # protobuf==3.20.3 # open clip for SDXL