Skip to content

Commit

Permalink
Merge pull request #1192 from sdbds/main
Browse files Browse the repository at this point in the history
Add WDV3 support
  • Loading branch information
kohya-ss committed Mar 20, 2024
2 parents e281e86 + 6c51c97 commit 7da41be
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
21 changes: 12 additions & 9 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,26 @@ def main(args):
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
else:
logger.info("using existing wd14 tagger model")

# 画像を読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort

Expand Down
9 changes: 6 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ huggingface-hub==0.20.1
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/

# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
Expand Down

0 comments on commit 7da41be

Please sign in to comment.