Skip to content

Commit

Permalink
fix: move openai-key to chatgpt ui; minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikubill committed Mar 12, 2023
1 parent 92ad5e7 commit 6eaa008
Showing 1 changed file with 44 additions and 30 deletions.
74 changes: 44 additions & 30 deletions example/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,34 @@ def inference(self, image_path, text):
image_mask = Image.fromarray(visual_mask)
return image_mask.resize(image.size)

class ImageEditing:
def __init__(self, device):
print("Initializing StableDiffusionInpaint to %s" % device)
self.device = device
self.mask_former = MaskFormer(device=self.device)
self.inpainting = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",).to(device)

def remove_part_of_image(self, input):
image_path, to_be_removed_txt = input.split(",")
print(f'remove_part_of_image: to_be_removed {to_be_removed_txt}')
return self.replace_part_of_image(f"{image_path},{to_be_removed_txt},background")

def replace_part_of_image(self, input):
image_path, to_be_replaced_txt, replace_with_txt = input.split(",")
print(f'replace_part_of_image: replace_with_txt {replace_with_txt}')
original_image = Image.open(image_path)
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
updated_image = self.inpainting(prompt=replace_with_txt, image=original_image, mask_image=mask_image).images[0]
updated_image_path = get_new_image_name(image_path, func_name="replace-something")
updated_image.save(updated_image_path)
return updated_image_path
# class ImageEditing:
# def __init__(self, device):
# print("Initializing StableDiffusionInpaint to %s" % device)
# self.device = device
# self.mask_former = MaskFormer(device=self.device)
# # self.inpainting = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",).to(device)

# def remove_part_of_image(self, input):
# image_path, to_be_removed_txt = input.split(",")
# print(f'remove_part_of_image: to_be_removed {to_be_removed_txt}')
# return self.replace_part_of_image(f"{image_path},{to_be_removed_txt},background")

# def replace_part_of_image(self, input):
# image_path, to_be_replaced_txt, replace_with_txt = input.split(",")
# print(f'replace_part_of_image: replace_with_txt {replace_with_txt}')
# mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
# buffered = io.BytesIO()
# mask_image.save(buffered, format="JPEG")
# resp = do_webui_request(
# url=ENDPOINT + "/sdapi/v1/img2img",
# init_images=[readImage(image_path)],
# mask=b64encode(buffered.getvalue()).decode("utf-8"),
# prompt=replace_with_txt,
# )
# image = Image.open(io.BytesIO(base64.b64decode(resp["images"][0])))
# updated_image_path = get_new_image_name(image_path, func_name="replace-something")
# updated_image.save(updated_image_path)
# return updated_image_path

# class Pix2Pix:
# def __init__(self, device):
Expand Down Expand Up @@ -497,10 +504,9 @@ def get_answer_from_question_and_image(self, inputs):


class ConversationBot:
def __init__(self, openai_api_key):
def __init__(self):
print("Initializing VisualChatGPT")
self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
# self.edit = ImageEditing(device="cuda:6")
self.edit = ImageEditing(device=device)
self.i2t = ImageCaptioning(device=device)
self.t2i = T2I(device=device)
self.image2canny = image2canny()
Expand Down Expand Up @@ -590,6 +596,9 @@ def __init__(self, openai_api_key):
Tool(name="Generate Image Condition On Pose Image", func=self.pose2image.inference,
description="useful when you want to generate a new real image from both the user desciption and a human pose image. like: generate a real image of a human from this human pose image, or generate a new real image of a human from this pose. "
"The input to this tool should be a comma seperated string of two, representing the image_path and the user description")]

def init_langchain(self, openai_api_key):
self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
self.agent = initialize_agent(
self.tools,
self.llm,
Expand All @@ -600,7 +609,9 @@ def __init__(self, openai_api_key):
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': VISUAL_CHATGPT_SUFFIX}
)

def run_text(self, text, state):
def run_text(self, openai_api_key, text, state):
if not hasattr(self, "agent"):
self.init_langchain(openai_api_key)
print("===============Running run_text =============")
print("Inputs:", text, state)
print("======>Previous memory:\n %s" % self.agent.memory)
Expand All @@ -612,7 +623,9 @@ def run_text(self, text, state):
print("Outputs:", state)
return state, state

def run_image(self, image, state, txt):
def run_image(self, openai_api_key, image, state, txt):
if not hasattr(self, "agent"):
self.init_langchain(openai_api_key)
print("===============Running run_image =============")
print("Inputs:", image, state)
print("======>Previous memory:\n %s" % self.agent.memory)
Expand All @@ -639,8 +652,9 @@ def run_image(self, image, state, txt):

if __name__ == '__main__':
os.makedirs("image/", exist_ok=True)
bot = ConversationBot(openai_api_key="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")
bot = ConversationBot()
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
openai_api_key = gr.Textbox(type="password", label="Enter your OpenAI API key here")
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
state = gr.State([])
with gr.Row():
Expand All @@ -650,10 +664,10 @@ def run_image(self, image, state, txt):
clear = gr.Button("Clear️")
with gr.Column(scale=0.15, min_width=0):
btn = gr.UploadButton("Upload", file_types=["image"])

txt.submit(bot.run_text, [txt, state], [chatbot, state])
txt.submit(bot.run_text, [openai_api_key, txt, state], [chatbot, state])
txt.submit(lambda: "", None, txt)
btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt])
btn.upload(bot.run_image, [openai_api_key, btn, state, txt], [chatbot, state, txt])
clear.click(bot.memory.clear)
clear.click(lambda: [], None, chatbot)
clear.click(lambda: [], None, state)
Expand Down

0 comments on commit 6eaa008

Please sign in to comment.