diff --git a/example/chatgpt.py b/example/chatgpt.py index 9a1edc897..0814f7b59 100644 --- a/example/chatgpt.py +++ b/example/chatgpt.py @@ -157,27 +157,34 @@ def inference(self, image_path, text): image_mask = Image.fromarray(visual_mask) return image_mask.resize(image.size) -class ImageEditing: - def __init__(self, device): - print("Initializing StableDiffusionInpaint to %s" % device) - self.device = device - self.mask_former = MaskFormer(device=self.device) - self.inpainting = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",).to(device) - - def remove_part_of_image(self, input): - image_path, to_be_removed_txt = input.split(",") - print(f'remove_part_of_image: to_be_removed {to_be_removed_txt}') - return self.replace_part_of_image(f"{image_path},{to_be_removed_txt},background") - - def replace_part_of_image(self, input): - image_path, to_be_replaced_txt, replace_with_txt = input.split(",") - print(f'replace_part_of_image: replace_with_txt {replace_with_txt}') - original_image = Image.open(image_path) - mask_image = self.mask_former.inference(image_path, to_be_replaced_txt) - updated_image = self.inpainting(prompt=replace_with_txt, image=original_image, mask_image=mask_image).images[0] - updated_image_path = get_new_image_name(image_path, func_name="replace-something") - updated_image.save(updated_image_path) - return updated_image_path +# class ImageEditing: +# def __init__(self, device): +# print("Initializing StableDiffusionInpaint to %s" % device) +# self.device = device +# self.mask_former = MaskFormer(device=self.device) +# # self.inpainting = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",).to(device) + +# def remove_part_of_image(self, input): +# image_path, to_be_removed_txt = input.split(",") +# print(f'remove_part_of_image: to_be_removed {to_be_removed_txt}') +# return self.replace_part_of_image(f"{image_path},{to_be_removed_txt},background") + +# def replace_part_of_image(self, input): +# image_path, to_be_replaced_txt, replace_with_txt = input.split(",") +# print(f'replace_part_of_image: replace_with_txt {replace_with_txt}') +# mask_image = self.mask_former.inference(image_path, to_be_replaced_txt) +# buffered = io.BytesIO() +# mask_image.save(buffered, format="JPEG") +# resp = do_webui_request( +# url=ENDPOINT + "/sdapi/v1/img2img", +# init_images=[readImage(image_path)], +# mask=b64encode(buffered.getvalue()).decode("utf-8"), +# prompt=replace_with_txt, +# ) +# image = Image.open(io.BytesIO(base64.b64decode(resp["images"][0]))) +# updated_image_path = get_new_image_name(image_path, func_name="replace-something") +# updated_image.save(updated_image_path) +# return updated_image_path # class Pix2Pix: # def __init__(self, device): @@ -497,10 +504,9 @@ def get_answer_from_question_and_image(self, inputs): class ConversationBot: - def __init__(self, openai_api_key): + def __init__(self): print("Initializing VisualChatGPT") - self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key) - # self.edit = ImageEditing(device="cuda:6") + self.edit = ImageEditing(device=device) self.i2t = ImageCaptioning(device=device) self.t2i = T2I(device=device) self.image2canny = image2canny() @@ -590,6 +596,9 @@ def __init__(self, openai_api_key): Tool(name="Generate Image Condition On Pose Image", func=self.pose2image.inference, description="useful when you want to generate a new real image from both the user desciption and a human pose image. like: generate a real image of a human from this human pose image, or generate a new real image of a human from this pose. " "The input to this tool should be a comma seperated string of two, representing the image_path and the user description")] + + def init_langchain(self, openai_api_key): + self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key) self.agent = initialize_agent( self.tools, self.llm, @@ -600,7 +609,9 @@ def __init__(self, openai_api_key): agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS, 'suffix': VISUAL_CHATGPT_SUFFIX} ) - def run_text(self, text, state): + def run_text(self, openai_api_key, text, state): + if not hasattr(self, "agent"): + self.init_langchain(openai_api_key) print("===============Running run_text =============") print("Inputs:", text, state) print("======>Previous memory:\n %s" % self.agent.memory) @@ -612,7 +623,9 @@ def run_text(self, text, state): print("Outputs:", state) return state, state - def run_image(self, image, state, txt): + def run_image(self, openai_api_key, image, state, txt): + if not hasattr(self, "agent"): + self.init_langchain(openai_api_key) print("===============Running run_image =============") print("Inputs:", image, state) print("======>Previous memory:\n %s" % self.agent.memory) @@ -639,8 +652,9 @@ def run_image(self, image, state, txt): if __name__ == '__main__': os.makedirs("image/", exist_ok=True) - bot = ConversationBot(openai_api_key="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") + bot = ConversationBot() with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo: + openai_api_key = gr.Textbox(type="password", label="Enter your OpenAI API key here") chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT") state = gr.State([]) with gr.Row(): @@ -650,10 +664,10 @@ def run_image(self, image, state, txt): clear = gr.Button("Clear️") with gr.Column(scale=0.15, min_width=0): btn = gr.UploadButton("Upload", file_types=["image"]) - - txt.submit(bot.run_text, [txt, state], [chatbot, state]) + + txt.submit(bot.run_text, [openai_api_key, txt, state], [chatbot, state]) txt.submit(lambda: "", None, txt) - btn.upload(bot.run_image, [btn, state, txt], [chatbot, state, txt]) + btn.upload(bot.run_image, [openai_api_key, btn, state, txt], [chatbot, state, txt]) clear.click(bot.memory.clear) clear.click(lambda: [], None, chatbot) clear.click(lambda: [], None, state)