diff --git a/README.md b/README.md index cc17013..1be8507 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ images = model.generate_text2img( [![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/) [![Huggingface space](https://img.shields.io/badge/🤗-Huggingface-yello.svg)](https://huggingface.co/sberbank-ai/Kandinsky_2.1) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1xSbu-b-EwYd6GdaFPRVgvXBX_mciZ41e?usp=sharing) +[![Replicate](https://replicate.com/cjwbw/kandinsky-2/badge)](https://replicate.com/cjwbw/kandinsky-2) [Habr post](https://habr.com/ru/company/sberbank/blog/725282/) diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..d9ca936 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,28 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + gpu: true + cuda: "11.6" + python_version: "3.10" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "torch==1.13.1" + - "sentencepiece==0.1.97" + - "accelerate==0.16.0" + - "Pillow==9.5.0" + - "attrs==22.2.0" + - "opencv-python==4.7.0.72" + - git+https://github.com/openai/CLIP.git + - "tqdm==4.65.0" + - "ftfy==6.1.1" + - "blobfile==2.0.1" + - "transformers==4.23.1" + - "torchvision==0.14.1" + - "omegaconf==2.3.0" + - "pytorch_lightning==2.0.1" + - "einops==0.6.0" + +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..770508b --- /dev/null +++ b/predict.py @@ -0,0 +1,65 @@ +from typing import List +from cog import BasePredictor, Input, Path +from kandinsky2 import get_kandinsky2 + + +class Predictor(BasePredictor): + def setup(self): + self.model = get_kandinsky2( + "cuda", + task_type="text2img", + cache_dir="./kandinsky2-weights", + model_version="2.1", + use_flash_attention=False, + ) + + def predict( + self, + prompt: str = Input(description="Input Prompt", default="red cat, 4k photo"), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=20, default=4 + ), + scheduler: str = Input( + description="Choose a scheduler", + default="p_sampler", + choices=["ddim_sampler", "p_sampler", "plms_sampler"], + ), + prior_cf_scale: int = Input(default=4), + prior_steps: str = Input(default="5"), + width: int = Input( + description="Choose width. Lower the setting if out of memory.", + default=512, + choices=[256, 288, 432, 512, 576, 768, 1024], + ), + height: int = Input( + description="Choose height. Lower the setting if out of memory.", + default=512, + choices=[256, 288, 432, 512, 576, 768, 1024], + ), + batch_size: int = Input( + description="Choose batch size. Lower the setting if out of memory.", + default=1, + choices=[1, 2, 3, 4], + ), + ) -> List[Path]: + images = self.model.generate_text2img( + prompt, + num_steps=num_inference_steps, + batch_size=batch_size, + guidance_scale=guidance_scale, + h=height, + w=width, + sampler=scheduler, + prior_cf_scale=prior_cf_scale, + prior_steps=prior_steps, + ) + output = [] + for i, im in enumerate(images): + out = f"/tmp/out_{i}.png" + im.save(out) + im.save(f"out_{i}.png") + output.append(Path(out)) + return output