diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index 3c67b8553..b6fdb8d3c 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -58,7 +58,7 @@ You can then run the following snipet to convert it to a ``.tar`` file: Reward Model ############ -Currently, we only have support for `Pickscore `__ reward model. Since Pickscore is a CLIP-based model, +Currently, we only have support for `Pickscore-style `__ reward models (PickScore/HPSv2). Since Pickscore is a CLIP-based model, you can use the `conversion script `__ from NeMo to convert it from huggingface to NeMo. DRaFT+ Training @@ -81,8 +81,9 @@ To launch reward model training, you must have checkpoints for `UNet `__ and `sd_lora_infer.py `__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have -better prompt alignment and aesthetic quality. \ No newline at end of file +better prompt alignment and aesthetic quality. + +User controllable finetuning with Annealed Importance Guidance (AIG) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model. + +.. tab-set:: + .. tab-item:: AIG on Stable Diffusion XL + :sync: key2 + + Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). + Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. + + To run AIG inference on the terminal directly: + + .. code-block:: bash + + NUMNODES=1 + LR=${LR:=0.00025} + INF_STEPS=${INF_STEPS:=25} + KL_COEF=${KL_COEF:=0.1} + ETA=${ETA:=0.0} + DATASET=${DATASET:="pickapic50k.tar"} + MICRO_BS=${MICRO_BS:=1} + GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} + PEFT=${PEFT:="sdlora"} + NUM_DEVICES=${NUM_DEVICES:=8} + GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) + LOG_WANDB=${LOG_WANDB:="False"} + + echo "additional kwargs: ${ADDITIONAL_KWARGS}" + + WANDB_NAME=SDXL_Draft_annealing + WEBDATASET_PATH=/path/to/${DATASET} + + CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" + CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} + UNET_CKPT="/path/to/unet.ckpt" + VAE_CKPT="/path/to/vae.ckpt" + RM_CKPT="/path/to/reward_model.nemo" + PROMPT=${PROMPT:="Bananas growing on an apple tree"} + DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir + + if [ ! -z "${ACT_CKPT}" ]; then + ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " + echo $ACT_CKPT + fi + + EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} + export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 + set -x + CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.0005 \ + model.optim.sched.warmup_steps=0 \ + model.sampling.base.steps=${INF_STEPS} \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=5 \ + trainer.draftp_sd.max_steps=10000 \ + trainer.draftp_sd.save_interval=200 \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + trainer.devices=${NUM_DEVICES} \ + trainer.num_nodes=${NUMNODES} \ + rm.trainer.devices=${NUM_DEVICES} \ + rm.trainer.num_nodes=${NUMNODES} \ + +prompt="${PROMPT}" \ + exp_manager.create_wandb_logger=${LOG_WANDB} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.first_stage_config.from_NeMo=True \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.unet_config.from_NeMo=True \ + $ACT_CKPT \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' + + + diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py new file mode 100644 index 000000000..ab87315cf --- /dev/null +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -0,0 +1,324 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from copy import deepcopy +from functools import partial + +import numpy as np +import torch +import torch.distributed +import torch.multiprocessing as mp +from megatron.core import parallel_state +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf, open_dict +from packaging.version import Version +from PIL import Image +from torch import nn + +# checkpointing +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import ( + DiffusionEngine, + MegatronDiffusionEngine, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder import ( + AutoencoderKL, + AutoencoderKLInferenceWrapper, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import ( + LatentDiffusion, + MegatronLatentDiffusion, +) +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.model import ( + AttnBlock, + Decoder, + Encoder, + ResnetBlock, +) +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel import ( + ResBlock, + SpatialTransformer, + TimestepEmbedSequential, + UNetModel, +) +from nemo.collections.multimodal.modules.stable_diffusion.encoders.modules import ( + FrozenCLIPEmbedder, + FrozenOpenCLIPEmbedder, + FrozenOpenCLIPEmbedder2, +) +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy + +# from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.supervised import SupervisedTrainer +from nemo_aligner.data.mm import text_webdataset +from nemo_aligner.data.nlp.builders import build_dataloader +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel, get_reward_model +from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + retrieve_custom_trainer_state_dict, + temp_pop_from_config, +) + +mp.set_start_method("spawn", force=True) + + +# TODO: this functionality should go into NeMo +# Specifically, the NeMo MegatronTrainerBuilder must also accept extra FSDP wrap modules so that it doesnt need to be subclassed +class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): + """Builder for SD model Trainer with overrides.""" + + def _training_strategy(self) -> NLPDDPStrategy: + """ + Returns a DDP or a FSDP strategy passed to Trainer.strategy. Copied from `sd_xl_train.py` + """ + if self.cfg.model.get("fsdp", False): + logging.info("FSDP.") + assert ( + not self.cfg.model.optim.get("name") == "distributed_fused_adam" + ), "Distributed optimizer cannot be used with FSDP." + if self.cfg.model.get("megatron_amp_O2", False): + logging.info("Torch FSDP is not compatible with O2 precision recipe. Setting O2 `False`.") + self.cfg.model.megatron_amp_O2 = False + + # Check if its a full-finetuning or PEFT + return NLPFSDPStrategy( + limit_all_gathers=self.cfg.model.get("fsdp_limit_all_gathers", True), + sharding_strategy=self.cfg.model.get("fsdp_sharding_strategy", "full"), + cpu_offload=self.cfg.model.get("fsdp_cpu_offload", False), # offload on is not supported + grad_reduce_dtype=self.cfg.model.get("fsdp_grad_reduce_dtype", 32), + precision=self.cfg.trainer.precision, + ## nn Sequential is supposed to capture the `t_embed`, `label_emb`, `out` layers in the unet + extra_fsdp_wrap_module={ + UNetModel, + TimestepEmbedSequential, + Decoder, + ResnetBlock, + AttnBlock, + nn.Sequential, + MegatronCLIPRewardModel, + FrozenOpenCLIPEmbedder, + FrozenOpenCLIPEmbedder2, + FrozenCLIPEmbedder, + ParallelLinearAdapter, + }, + use_orig_params=False, + set_buffer_dtype=self.cfg.get("fsdp_set_buffer_dtype", None), + ) + + return NLPDDPStrategy( + no_ddp_communication_hook=(not self.cfg.model.get("ddp_overlap")), + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + + +def resolve_and_create_trainer(cfg, pop_trainer_key): + """resolve the cfg, remove the key before constructing the PTL trainer + and then restore it after + """ + OmegaConf.resolve(cfg) + with temp_pop_from_config(cfg.trainer, pop_trainer_key): + return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + + +@hydra_runner(config_path="conf", config_name="draftp_sdxl") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + # set cuda device for each process + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # turn off wandb logging + cfg.exp_manager.create_wandb_logger = False + + if Version(torch.__version__) >= Version("1.12"): + torch.backends.cuda.matmul.allow_tf32 = True + cfg.model.data.train.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices * cfg.trainer.num_nodes) + ] + cfg.model.data.validation.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices * cfg.trainer.num_nodes) + ] + + trainer = resolve_and_create_trainer(cfg, "draftp_sd") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + # Instatiating the model here + ptl_model = MegatronSDXLDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) + init_peft(ptl_model, cfg.model) # init peft + + trainer_restore_path = trainer.ckpt_path + + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the validation ds if needed + train_ds, validation_ds = text_webdataset.build_train_valid_datasets( + cfg.model.data, consumed_samples=consumed_samples + ) + validation_ds = [d["captions"] for d in list(validation_ds)] + + val_dataloader = build_dataloader( + cfg, + dataset=validation_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + ) + init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) + + if cfg.model.get("activation_checkpointing", False): + # call activation checkpointing here + # checkpoint wrapper + logging.info("Applying activation checkpointing on UNet and Decoder.") + non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + + def checkpoint_check_fn(module): + return isinstance(module, (Decoder, UNetModel, MegatronCLIPRewardModel)) + + apply_activation_checkpointing( + ptl_model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=checkpoint_check_fn + ) + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + torch.distributed.barrier() + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + timer = Timer(cfg.exp_manager.get("max_time_per_run", None) if cfg.exp_manager else None) + + draft_p_trainer = SupervisedTrainer( + cfg=cfg.trainer.draftp_sd, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=val_dataloader, + val_dataloader=val_dataloader, + test_dataloader=[], + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + run_init_validation=True, + ) + + if custom_trainer_state_dict is not None: + draft_p_trainer.load_state_dict(custom_trainer_state_dict) + + torch.cuda.empty_cache() + + if cfg.get("prompt") is not None: + logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") + val_dataloader = [[cfg.prompt]] + + wt_types = cfg.get("weight_type", None) + if wt_types is None: + wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] + else: + wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types + logging.info(f"Running on types: {wt_types}") + + # run for all weight types + for wt_type in wt_types: + global_idx = 0 + if wt_type == "base": + # dummy function that assigns a value of 0 all the time + logging.info("using the base model") + wt_draft = lambda sigma, sigma_next, i, total: 0 + elif wt_type == "linear": + wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total + elif wt_type == "draft": + wt_draft = lambda sigma, sigma_next, i, total: 1 + elif wt_type.startswith("power"): # its of the form power_{power} + pow = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow + elif wt_type.startswith("step"): # use a step function (step_{p}) + frac = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) + else: + raise ValueError(f"invalid weighing type: {wt_type}") + logging.info(f"using weighing type for annealed outputs: {wt_type}.") + + # initialize generator + annealed_out_dir = os.path.join(cfg.exp_manager.explicit_log_dir, f"annealed_outputs_sdxl_{wt_type}/") + # generate random seed for reproducibility and make output dir + gen = torch.Generator(device="cpu") + gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) + os.makedirs(annealed_out_dir, exist_ok=True) + + for batch in val_dataloader: + batch_size = len(batch) + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + latents = torch.randn( + [ + batch_size, + ptl_model.in_channels, + ptl_model.height // ptl_model.downsampling_factor, + ptl_model.width // ptl_model.downsampling_factor, + ], + generator=gen, + ).to(torch.cuda.current_device()) + images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) + images = ( + images.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + ) # outputs are already scaled from [0, 255] + # save to pil + for i in range(images.shape[0]): + i = i + global_idx + img_path = os.path.join(annealed_out_dir, f"img_{i:05d}_{local_rank:02d}.png") + prompt_path = os.path.join(annealed_out_dir, f"prompt_{i:05d}_{local_rank:02d}.txt") + Image.fromarray(images[i]).save(img_path) + with open(prompt_path, "w") as fi: + fi.write(batch[i]) + # increment global index + global_idx += batch_size + logging.info("Saved all images.") + + +if __name__ == "__main__": + main() diff --git a/examples/mm/stable_diffusion/train_sdxl_draftp.py b/examples/mm/stable_diffusion/train_sdxl_draftp.py index cc56c2b59..8e7d36a1a 100644 --- a/examples/mm/stable_diffusion/train_sdxl_draftp.py +++ b/examples/mm/stable_diffusion/train_sdxl_draftp.py @@ -91,6 +91,8 @@ mp.set_start_method("spawn", force=True) +# TODO: this functionality should go into NeMo +# Specifically, the NeMo MegatronTrainerBuilder must also accept extra FSDP wrap modules so that it doesnt need to be subclassed class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides.""" diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py index 746b95606..c18aba5fe 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py @@ -205,6 +205,99 @@ def log_visualization(self, prompts): return vae_decoder_output_draft_p, images, captions + @torch.no_grad() + def annealed_guidance(self, batch, x_T, weighing_fn=None): + """ this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model """ + if weighing_fn is None: + weighing_fn = lambda sigma1, sigma2, i, total: i * 1.0 / total + + with torch.cuda.amp.autocast( + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + batch_size = len(batch) + prev_img_draft_p = x_T + + device_draft_p = self.model.betas.device + + # init sampler and make schedule + sampler_draft_p = sampling_utils.initialize_sampler(self.model, self.sampler_type.upper()) + sampler_init = sampling_utils.initialize_sampler(self.init_model, self.sampler_type.upper()) + sampler_draft_p.make_schedule(ddim_num_steps=self.inference_steps, ddim_eta=self.eta, verbose=False) + sampler_init.make_schedule(ddim_num_steps=self.inference_steps, ddim_eta=self.eta, verbose=False) + + cond, u_cond = sampling_utils.encode_prompt( + self.model.cond_stage_model, batch, self.unconditional_guidance_scale + ) + + timesteps = sampler_draft_p.ddim_timesteps + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc=f"{sampler_draft_p.sampler.name} Sampler", total=total_steps) + + list_eps_draft_p = [] + list_eps_init = [] + truncation_steps = self.cfg.truncation_steps + + denoise_step_kwargs = { + "unconditional_guidance_scale": self.unconditional_guidance_scale, + "unconditional_conditioning": u_cond, + } + for i, step in enumerate(iterator): + + denoise_step_args = [total_steps, i, batch_size, device_draft_p, step, cond] + + # run ddim step for FT model + img_draft_p, pred_x0_draft_p, eps_t_draft_p = sampler_draft_p.single_ddim_denoise_step( + prev_img_draft_p.clone(), *denoise_step_args, **denoise_step_kwargs + ) + # run ddim step for base model + img_init, pred_x0_init, eps_t_init = sampler_init.single_ddim_denoise_step( + prev_img_draft_p.clone(), *denoise_step_args, **denoise_step_kwargs + ) + # sigmas_i = sampler_draft_p.ddim_sigmas[i] + # get weighing scheme + w_draft = float(weighing_fn(None, None, i, total_steps)) + w_base = 1 - w_draft + # combine weights + eps = w_base * eps_t_init + w_draft * eps_t_draft_p + # use this to get new image + index = total_steps - i - 1 + ts = torch.full((batch_size,), step, device=device_draft_p, dtype=torch.long) + # get new image + img_new_p, pred_x0_new_p = sampler_draft_p._get_x_prev_and_pred_x0( + False, + batch_size, + index, + device_draft_p, + prev_img_draft_p.clone(), + ts, + None, # model output, we shouldnt need this + eps, + False, + False, + 1.0, + 0.0, + ) + prev_img_draft_p = img_new_p + + last_states = [pred_x0_draft_p] + # stack + trajectories_predx0 = ( + torch.stack(last_states, dim=0).transpose(0, 1).contiguous().view(-1, *last_states[0].shape[1:]) + ) # B1CHW -> BCHW + + vae_decoder_output = [] + for i in range(0, batch_size, self.vae_batch_size): + image = self.model.differentiable_decode_first_stage(trajectories_predx0[i : i + self.vae_batch_size]) + vae_decoder_output.append(image) + + vae_decoder_output = torch.cat(vae_decoder_output, dim=0) + vae_decoder_output = torch.clip((vae_decoder_output + 1) / 2, 0, 1) * 255.0 + + return vae_decoder_output + def generate( self, batch, x_T, ): diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py index 9fc881a67..a7f57e334 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py @@ -42,6 +42,7 @@ get_unique_embedder_keys_from_conditioner, ) from nemo.collections.multimodal.parts.stable_diffusion.sdxl_pipeline import get_sampler_config +from nemo.collections.multimodal.parts.stable_diffusion.utils import append_dims, default, instantiate_from_config from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import logging @@ -264,6 +265,84 @@ def generate_log_images(self, latents, batch, model): ] return log_img, log_reward, vae_decoder_output + @torch.no_grad() + def annealed_guidance(self, batch, x_T, weighing_fn=None): + """ this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model """ + if weighing_fn is None: + weighing_fn = lambda sigma1, sigma2, i, total: i * 1.0 / total + + with torch.cuda.amp.autocast( + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + batch_c = self.append_sdxl_size_keys(batch) + truncation_steps = self.cfg.truncation_steps + force_uc_zero_embeddings = ["txt", "captions"] + sampler = self.sampler + # get conditional guidance keys + cond, uc = self.model.conditioner.get_unconditional_conditioning( + batch_c, batch_uc=None, force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + additional_model_inputs = {} + # get denoisers for base and trained model + denoiser_draft = lambda input, sigma, c: self.model.denoiser( + self.model.model, input, sigma, c, **additional_model_inputs + ) + base_model = self.init_model or self.model + denoiser_base = lambda input, sigma, c: base_model.denoiser( + base_model.model, input, sigma, c, **additional_model_inputs + ) + # prep initial sampler config + x = x_T.clone() + num_steps = sampler.num_steps + x, s_in, sigmas, num_sigmas, cond, uc = sampler.prepare_sampling_loop(x, cond, uc, num_steps) + # last step doesnt count since there is no additional sigma + total_steps = num_sigmas - 1 + iterator = tqdm(range(num_sigmas - 1), desc=f"{sampler.__class__.__name__} Sampler", total=total_steps) + base_model = self.init_model or self.model + for i in iterator: + gamma = sampler.get_gamma(sigmas, num_sigmas, i) + # with context(set_draft_grad_flag): + # just run the sampling without storing any grad + x_next_draft, eps_draft = sampler.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser_draft, + x.clone(), + cond, + uc, + gamma, + return_noise=True, + ) + # get base model + with adapter_control(base_model): + _, eps_init = sampler.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser_base, + x.clone(), + cond, + uc, + gamma, + return_noise=True, + ) + # get weighing scheme + w_draft = float(weighing_fn(sigmas[i], sigmas[i + 1], i, total_steps)) + w_base = 1 - w_draft + # combine weights + eps = w_base * eps_init + w_draft * eps_draft + dt = append_dims(s_in * sigmas[i + 1] - s_in * sigmas[i], x.ndim) + euler_step = sampler.euler_step(x, eps, dt) + # get next x + x = sampler.possible_correction_step( + euler_step, x.clone(), eps, dt, s_in * sigmas[i + 1], denoiser_draft, cond, uc + ) + iterator.set_description(f"iteration: {i}/{total_steps}, w_base={w_base:06f}") + # decode the latent + image = self.model.differentiable_decode_first_stage(x) + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) * 255.0 + return image + @torch.no_grad() def log_visualization(self, prompts): batch_size = len(prompts)