From f843e1065f8a9ea82fd082196b3a5b40a056e388 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 21 Jun 2022 02:55:25 +0000 Subject: [PATCH] feature: add cog and replicate support --- cog.yaml | 21 +++++++++++ predict.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 000000000..c1ed0bf1a --- /dev/null +++ b/cog.yaml @@ -0,0 +1,21 @@ +image: "r8.im/nicholascelestin/latent-diffusion" +build: + gpu: true + python_packages: + - "torch==1.11.0" + - "albumentations==0.4.3" + - "opencv-python==4.1.2.30" + - "pudb==2019.2" + - "imageio==2.9.0" + - "imageio-ffmpeg==0.4.2" + - "pytorch-lightning==1.5.10" + - "omegaconf==2.1.1" + - "test-tube==0.7.5" + - "streamlit==1.10.0" + - "einops==0.3.0" + - "torch-fidelity==0.3.0" + - "transformers==4.3.1" + run: + - "git clone https://github.com/CompVis/taming-transformers.git && cd taming-transformers && pip install -e . && cd .." + - "git clone https://github.com/openai/CLIP.git && cd CLIP && pip install -e . && cd .." +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 000000000..9b110ac16 --- /dev/null +++ b/predict.py @@ -0,0 +1,109 @@ +import time +import typing +import uuid +import os + +import numpy as np +import torch +from PIL import Image +from cog import BasePredictor, Input, Path +from einops import rearrange +from omegaconf import OmegaConf +from tqdm import trange + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + model.cuda() + model.eval() + return model + + +class Predictor(BasePredictor): + def setup(self): + start_time = time.time() + print(f'Performing setup!') + + config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") + model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") + print(f'Model loaded at {time.time() - start_time}') + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + print(f'Model loaded on device at {time.time() - start_time}') + + sampler = PLMSSampler(model) + print(f'Sampler loaded at {time.time() - start_time}') + + self.model = model + self.sampler = sampler + + print(f'Setup complete at {time.time() - start_time}') + + def predict( + self, + prompt: str = Input(description="Image prompt"), + scale: float = Input(description="Unconditional guidance, increase for improved quality and less diversity", + default=5.0), + steps: int = Input(description="Number of diffusion steps", default=50), + eta: float = Input(description="ddim_eta (recommend leaving at default of 0 for faster sampling)", + default=0), + plms: bool = Input(description="Sampling method requiring fewer steps (e.g. 25) to get good quality images", + default=True), + batch_size: int = Input(description="Number of images to generate per batch", default=4), + batches: int = Input(description="Number of batches", default=1), + width: int = Input(description="Width of images (use a multiple of 8 e.g. 256)", default=256), + height: int = Input(description="Height of images (use a multiple of 8 e.g. 256)", default=256) + ) -> typing.List[Path]: + + print(f'Prediction started!') + + if plms: + self.sampler = PLMSSampler(self.model) + else: + self.sampler = DDIMSampler(self.model) + print(f'Sampler loaded ') + + all_samples = list() + with torch.no_grad(): + with self.model.ema_scope(): + uc = None + if scale != 1.0: + uc = self.model.get_learned_conditioning(batch_size * [""]) + for n in trange(batches, desc="Sampling"): + c = self.model.get_learned_conditioning(batch_size * [prompt]) + shape = [4, height // 8, width // 8] + samples_ddim, _ = self.sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=eta) + + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image_path = f'{uuid.uuid4()}.png' + Image.fromarray(x_sample.astype(np.uint8)).save(image_path) + yield Path(image_path) + all_samples.append(Path(image_path)) + return all_samples