Skip to content

Commit

Permalink
Memory improvements
Browse files Browse the repository at this point in the history
- no_grad
- release models/resources
- only load once, only convert once
- releast instance variables
- use model_management to get cuda device name
  • Loading branch information
christian-byrne committed May 28, 2024
1 parent 8c75b47 commit e9ff300
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 38 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "img2txt-comfyui-nodes"
description = "img2txt auto captioning. Choose from models: BLIP, Llava, MiniCPM, MS-GIT. Use model combos and merge results. Specify questions to ask about images (medium, art style, background). Automatic model download/management. Supports Chinese 🇨🇳 questions/answers via MiniCPM."
version = "1.0.0"
version = "1.1.0"
license = "LICENSE"
dependencies = ["transformers>=4.36.0", "bitsandbytes>=0.43.0", "timm==0.9.10", "sentencepiece==0.1.99", "accelerate>=0.3.0"]

Expand Down
45 changes: 35 additions & 10 deletions src/blip_img2txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
BlipVisionConfig,
)

import torch
import model_management


class BLIPImg2Txt:
def __init__(
Expand All @@ -22,37 +25,59 @@ def __init__(
self.conditional_caption = conditional_caption
self.model_id = model_id

# Determine do_sample and num_beams
if temperature < 1.05 and temperature > 0.95:
do_sample = True
num_beams = 1 # Sampling does not use beam search
else:
do_sample = False
num_beams = (
search_beams if search_beams > 1 else 1
) # Use beam search if num_beams > 1

# Initialize text config kwargs
self.text_config_kwargs = {
"do_sample": do_sample,
# "max_new_tokens": max_words,
"max_length": max_words,
"min_length": min_words,
"num_beams": search_beams,
# "temperature": temperature,
"repetition_penalty": repetition_penalty,
"padding": "max_length",
}
if not do_sample:
self.text_config_kwargs["temperature"] = temperature
self.text_config_kwargs["num_beams"] = num_beams

def generate_caption(self, image: Image):
def generate_caption(self, image: Image.Image) -> str:
if image.mode != "RGB":
image = image.convert("RGB")

processor = BlipProcessor.from_pretrained(self.model_id)

# https://huggingface.co/docs/transformers/model_doc/blip#transformers.BlipTextConfig
# Update and apply configurations
config_text = BlipTextConfig.from_pretrained(self.model_id)
config_text.update(self.text_config_kwargs)
config_vision = BlipVisionConfig.from_pretrained(self.model_id)
config = BlipConfig.from_text_vision_configs(config_text, config_vision)

# Update model configuration
model = BlipForConditionalGeneration.from_pretrained(
self.model_id, config=config
)
model = model.to("cuda")
self.model_id,
config=config,
trust_remote_code=True,
torch_dtype=torch.float16,
).to(model_management.get_torch_device())

inputs = processor(
image,
self.conditional_caption,
return_tensors="pt",
).to("cuda")
).to(model_management.get_torch_device(), torch.float16)

with torch.no_grad():
out = model.generate(**inputs)
ret = processor.decode(out[0], skip_special_tokens=True)

del model
torch.cuda.empty_cache()

return processor.decode(model.generate(**inputs)[0], skip_special_tokens=True)
return ret
34 changes: 20 additions & 14 deletions src/llava_img2txt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from PIL import Image
import torch
import model_management
from transformers import AutoProcessor, LlavaForConditionalGeneration


Expand Down Expand Up @@ -34,8 +36,8 @@ def __init__(

def generate_caption(
self,
raw_image: Image,
):
raw_image: Image.Image,
) -> str:
"""
Generate a caption for an image using the Llava model.
Expand All @@ -56,23 +58,27 @@ def generate_caption(

# model.to() is not supported for 4-bit or 8-bit bitsandbytes models. With 4-bit quantization, use the model as it is, since the model will already be set to the correct devices and casted to the correct `dtype`.
if torch.cuda.is_available() and not self.use_4bit:
model = model.to(0)
model = model.to(model_management.get_torch_device(), torch.float16)

processor = AutoProcessor.from_pretrained(self.model_id)
prompt_chunks = self.__get_prompt_chunks(chunk_size=4)

caption = ""
for prompt_list in prompt_chunks:
prompt = self.__get_single_answer_prompt(prompt_list)
inputs = processor(prompt, raw_image, return_tensors="pt").to(
0, torch.float16
)
output = model.generate(
**inputs, max_new_tokens=self.max_tokens_per_chunk, do_sample=False
)
decoded = processor.decode(output[0][2:], skip_special_tokens=True)
cleaned = self.clean_output(decoded)
caption += cleaned
with torch.no_grad():
for prompt_list in prompt_chunks:
prompt = self.__get_single_answer_prompt(prompt_list)
inputs = processor(prompt, raw_image, return_tensors="pt").to(
model_management.get_torch_device(), torch.float16
)
output = model.generate(
**inputs, max_new_tokens=self.max_tokens_per_chunk, do_sample=False
)
decoded = processor.decode(output[0][2:], skip_special_tokens=True)
cleaned = self.clean_output(decoded)
caption += cleaned

del model
torch.cuda.empty_cache()

return caption

Expand Down
29 changes: 16 additions & 13 deletions src/mini_cpm_img2txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@


class MiniPCMImg2Txt:
MODEL_ID = "openbmb/MiniCPM-V-2"

def __init__(self, question_list: list[str], temperature: float = 0.7):
self.model_id = "openbmb/MiniCPM-V-2"
self.question_list = question_list
self.question_list = self.__create_question_list()
self.temperature = temperature
Expand All @@ -17,32 +16,36 @@ def __create_question_list(self) -> list:
ret.append({"role": "user", "content": q})
return ret

def generate_captions(self, raw_image: Image):
def generate_captions(self, raw_image: Image.Image) -> str:
model = AutoModel.from_pretrained(
"openbmb/MiniCPM-V-2", trust_remote_code=True, torch_dtype=torch.bfloat16
)
try:
# For Nvidia GPUs support BF16 (like A100, H100, RTX3090
model = model.to(device="cuda", dtype=torch.bfloat16)
except Exception as e:
except Exception:
# For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080)
model = model.to(device="cuda", dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained(
MiniPCMImg2Txt.MODEL_ID, trust_remote_code=True
self.model_id, trust_remote_code=True
)
model.eval()

if raw_image.mode != "RGB":
raw_image = raw_image.convert("RGB")

res, context, _ = model.chat(
image=raw_image,
msgs=self.question_list,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=self.temperature,
)
with torch.no_grad():
res, _, _ = model.chat(
image=raw_image,
msgs=self.question_list,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=self.temperature,
)

del model
torch.cuda.empty_cache()

return res

0 comments on commit e9ff300

Please sign in to comment.