diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 3866342d9be6..936b117323a3 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -148,6 +148,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.Ideogram4LoraLoaderMixin +## Krea2LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.Krea2LoraLoaderMixin + ## LoraBaseMixin [[autodoc]] loaders.lora_base.LoraBaseMixin diff --git a/examples/dreambooth/README_krea2.md b/examples/dreambooth/README_krea2.md new file mode 100644 index 000000000000..4904fe6c3b46 --- /dev/null +++ b/examples/dreambooth/README_krea2.md @@ -0,0 +1,211 @@ +# DreamBooth training example for Krea 2 + +[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +The `train_dreambooth_lora_krea2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Krea 2](https://www.krea.ai/). + +> [!NOTE] +> **About Krea 2: RAW vs Turbo** +> +> Krea 2 ships as two checkpoints that are designed to work together: +> - **Krea 2 RAW** is the base model β€” a pre-trained checkpoint with **no distillation**. It is diverse and highly malleable, and it is the checkpoint you should use for **fine-tuning, post-training, and LoRA training**. It is *not* meant to be used for inference directly (do not expect high-quality outputs from it). +> - **Krea 2 Turbo** is an **8-step distilled** checkpoint built for fast, high-quality text-to-image **inference**. +> +> The recommended workflow is to **train your LoRA on RAW and run inference (and validation) on Turbo** β€” LoRAs trained on RAW express strongly on Turbo, so you get the best of both worlds: a malleable base to fine-tune and a fast, high-quality model to generate with. +> +> Architecturally, Krea 2 uses the Qwen-Image VAE, a 12B DiT (dense), and a Qwen3-VL text encoder with multi-layer feature aggregation. +> +> πŸ“– Read more here: Krea 2 release blog . + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run + +```bash +pip install -r requirements_krea2.txt +``` + +And initialize an [πŸ€—Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Note that we use the PEFT library as backend for LoRA training, so make sure to have `peft>=0.11.1` installed in your environment. + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +## Training + +We train the LoRA on the **RAW** checkpoint. Because RAW is not meant for inference, validation and final inference are run on the **Turbo** checkpoint via `--validation_model_path` (see [Validation on Turbo](#validation-on-turbo)). + +```bash +export MODEL_NAME="krea/Krea-2-Raw" +export TURBO_NAME="krea/Krea-2-Turbo" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-krea2-lora" + +accelerate launch train_dreambooth_lora_krea2.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_checkpointing \ + --cache_latents \ + --rank=32 \ + --lora_alpha=32 \ + --optimizer="adamW" \ + --use_8bit_adam \ + --learning_rate=3e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=1000 \ + --validation_model_path=$TURBO_NAME \ + --validation_prompt="a photo of sks dog" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt`, `validation_epochs` and `validation_model_path` allow the script to run validation inference on Turbo during training (see below). + +> [!NOTE] +> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit (default 512). Note that this uses more resources and may slow down training. + +## Validation on Turbo + +Since RAW is a non-distilled base that is **not meant for inference**, validating on RAW is misleading. Instead, pass `--validation_model_path` pointing at the **Turbo** checkpoint: at every validation step the script transplants the adapter currently being trained on RAW onto the Turbo pipeline and generates with it, so your validation images reflect what the final result will actually look like. + +The Turbo inference recipe is the default for validation: + +* `--validation_num_inference_steps` (default `8`) β€” Turbo is an 8-step distilled model. +* `--validation_guidance_scale` (default `0.0`) β€” Turbo runs without classifier-free guidance. +* `--validation_mu` (default `1.15`) β€” Turbo uses a fixed `mu` for the timestep shift instead of computing it from the resolution. + +If `--validation_model_path` is omitted, validation and final inference fall back to the training checkpoint (using the pipeline defaults). + +## Memory Optimizations + +> [!NOTE] +> Many of these techniques complement each other and can be combined to further reduce memory consumption. Some are mutually exclusive, so check before launching. + +### CPU Offloading +Pass `--offload` to offload the VAE and text encoder to CPU memory and only move them to GPU when needed. + +### Latent Caching +Pre-encode the training images with the VAE and then free it. Enable with `--cache_latents`. + +### Low-precision training with quantization +- **NF4 / 4-bit (QLoRA)** with `bitsandbytes`: pass `--bnb_quantization_config_path` pointing at a JSON of `BitsAndBytesConfig` kwargs (e.g. `{"load_in_4bit": true, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": "bfloat16"}`). This is the biggest single VRAM saver and lets a full Krea 2 LoRA run fit on a single mid-range GPU. +- **FP8 training** with `torchao`: pass `--do_fp8_training`. This uses FP8 scaled-matmul on a bf16-loaded transformer β€” it speeds up compute on supported hardware but, because the weights stay in bf16, it does not by itself reduce memory. Requires a GPU with compute capability β‰₯ 8.9. (`--do_fp8_training` and `--bnb_quantization_config_path` are mutually exclusive.) + +### Gradient Checkpointing and Accumulation +* `--gradient_accumulation_steps` accumulates gradients over several steps before an update, reducing the number of backward/update passes. +* `--gradient_checkpointing` saves memory by recomputing intermediate activations during the backward pass instead of storing them (at the cost of a slower backward pass). + +### 8-bit Adam Optimizer +When training with `AdamW` (not `prodigy`) pass `--use_8bit_adam` to reduce optimizer memory. Make sure `bitsandbytes` is installed. + +### Image Resolution +`--resolution` sets the resolution all train/validation images are resized to (default 1024). Lowering it reduces memory. + +### Precision of saved LoRA layers +By default trained layers are saved in the training precision (e.g. `bf16` under `--mixed_precision="bf16"`). Pass `--upcast_before_saving` to save them in `float32` instead (more memory). + +## LoRA Rank, Alpha and Target Modules + +Two key LoRA hyperparameters are rank and alpha: + +- `--rank`: dimension of the trainable LoRA matrices. Higher rank = more capacity (and more parameters). +- `--lora_alpha`: scaling factor; the LoRA update is scaled by `lora_alpha / rank`. With `lora_alpha == rank` the scale is 1.0. + +`--lora_layers` lets you choose exactly which modules to adapt (comma-separated). By default the script adapts the recommended layer set at rank/alpha 32: + +``` +img_in, final_layer.linear, to_q, to_k, to_v, to_out.0, to_gate, +ff.up, ff.down, text_fusion.projector, txt_in.linear_1, txt_in.linear_2, +time_embed.linear_1, time_embed.linear_2, time_mod_proj +``` + +> [!TIP] +> **Capacity: rank vs. target modules.** The default (rank/alpha **32** on the full layer set above) fits most styles, including ones with heavy high-frequency detail. For **long training runs**, it's recommended to add capacity by **increasing the rank and narrowing the target modules to the attention layers** β€” `--lora_layers="to_q,to_k,to_v,to_out.0,to_gate"` β€” rather than keeping the full layer set, so that prompt adherence doesn't degrade. In general, flat illustrative styles prefer **low-capacity** LoRAs (lower rank, fewer layers) and converge faster, while high-frequency styles (ink-brush paintings, etc.) benefit from more capacity. + +> [!TIP] +> Standard learning rates of `3e-4 ~ 7e-4` with a `constant` schedule work well, and you can go a bit higher with a `cosine` schedule. + +## Captioning for style LoRAs + +For training a style, it's recommended to use captions that **describe the parts of the image you do *not* want baked into the LoRA, while omitting the stylistic parts you *do* want it to learn**, and add a descriptive **trigger phrase** as a style anchor. For example, for a hand-drawn-illustration style: + +> "An astronaut standing beside a space rover on a flat landscape with cacti in the background while a large planet and stars are visible in the background. hand-drawn children's book illustration" + +Here the phrase *"hand-drawn children's book illustration"* anchors the style and is preferred over a random rare token (e.g. `Ill3$tr@te`). For object/character training a trigger word is fine, as long as the captions broadly get the class of the subject right. + +## Inference + +Train on RAW, then load your LoRA into **Turbo** for fast, high-quality generation: + +```python +import torch +from diffusers import Krea2Pipeline + +pipe = Krea2Pipeline.from_pretrained("krea/Krea-2-Turbo", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Load your trained LoRA (trained on Krea 2 RAW) +pipe.load_lora_weights("path/to/your/trained-krea2-lora") + +image = pipe( + prompt="a photo of sks dog", + height=1024, + width=1024, + num_inference_steps=8, + guidance_scale=0.0, + mu=1.15, + generator=torch.Generator("cuda").manual_seed(0), +).images[0] + +image.save("output.png") +``` diff --git a/examples/dreambooth/requirements_krea2.txt b/examples/dreambooth/requirements_krea2.txt new file mode 100644 index 000000000000..85a505f450eb --- /dev/null +++ b/examples/dreambooth/requirements_krea2.txt @@ -0,0 +1,11 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece +bitsandbytes +prodigyopt +datasets diff --git a/examples/dreambooth/train_dreambooth_lora_krea2.py b/examples/dreambooth/train_dreambooth_lora_krea2.py new file mode 100644 index 000000000000..01c303dd0f25 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_krea2.py @@ -0,0 +1,1883 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, Qwen3VLModel + +import diffusers +from diffusers import ( + AutoencoderKLQwenImage, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Krea2Pipeline, + Krea2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, + offload_models, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.39.0.dev0") + +logger = get_logger(__name__) + +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + inference_model: str = "krea/Krea-2-Turbo", +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + # Only put `base_model` in the card metadata when it's a Hub id β€” a local training path is not a + # valid model id and the Hub rejects it. RAW is the (non-distilled) training base. + def _is_hub_id(s): + return bool(s) and "/" in s and not os.path.exists(s) + + # A local training path is not a valid Hub model id (the Hub rejects it in card metadata). Krea 2 + # LoRAs are trained on RAW, so fall back to the canonical RAW id when given a local path. + card_base_model = base_model if _is_hub_id(base_model) else "krea/Krea-2-Raw" + base_display = card_base_model + # The inference snippet always targets the distilled Turbo model; fall back to the canonical id + # if a local path (or nothing) was passed. + if not _is_hub_id(inference_model): + inference_model = "krea/Krea-2-Turbo" + + model_description = f""" +# Krea 2 DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights, trained on {base_display}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Krea 2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_krea2.md). + +Krea 2 ships as two checkpoints: **RAW** (the non-distilled base you fine-tune on) and **Turbo** (an 8-step distilled checkpoint for fast, high-quality inference). Train your LoRA on RAW and run it on Turbo β€” LoRAs trained on RAW express strongly on Turbo. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +>>> import torch +>>> from diffusers import Krea2Pipeline + +>>> # Load the LoRA onto Krea 2 Turbo (the distilled inference model) +>>> pipe = Krea2Pipeline.from_pretrained("{inference_model}", torch_dtype=torch.bfloat16).to("cuda") +>>> pipe.load_lora_weights("{repo_id}") + +>>> # Turbo recipe: 8 steps, no classifier-free guidance +>>> image = pipe("{instance_prompt}", num_inference_steps=8, guidance_scale=0.0).images[0] +>>> image.save("output.png") +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="apache-2.0", + base_model=card_base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "krea2", + "krea2-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, + pipeline_call_kwargs=None, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + prompt_embeds_mask=pipeline_args["prompt_embeds_mask"], + negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], + negative_prompt_embeds_mask=pipeline_args["negative_prompt_embeds_mask"], + generator=generator, + **(pipeline_call_kwargs or {}), + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def _validation_call_kwargs(args): + # When validating on a dedicated inference checkpoint (e.g. Krea 2 Turbo), use its recipe + # (few-step, no CFG). When validating on the training checkpoint, use pipeline defaults. + if args.validation_model_path is None: + return {} + return { + "num_inference_steps": args.validation_num_inference_steps, + "guidance_scale": args.validation_guidance_scale, + } + + +def build_validation_pipeline(args, accelerator, transformer, weight_dtype): + # Krea 2 RAW is a non-distilled base not meant for inference. If --validation_model_path is set + # (e.g. Krea 2 Turbo), build the pipeline from THAT checkpoint and transplant the adapter trained + # on RAW onto it (LoRAs trained on RAW express strongly on Turbo). Otherwise reuse the in-training + # transformer. Either way the text encoder is skipped β€” validation reuses precomputed embeddings. + if args.validation_model_path is not None: + tmp_lora = os.path.join(args.output_dir, "_val_lora") + Krea2Pipeline.save_lora_weights( + tmp_lora, + transformer_lora_layers=get_peft_model_state_dict(accelerator.unwrap_model(transformer)), + ) + pipeline = Krea2Pipeline.from_pretrained( + args.validation_model_path, + tokenizer=None, + text_encoder=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.load_lora_weights(tmp_lora) + return pipeline + return Krea2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + tokenizer=None, + text_encoder=None, + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # Keep precision-sensitive modules in higher precision: the final output projection and the + # patterns Krea2Transformer2DModel flags in `_skip_layerwise_casting_patterns` (time embedding, + # norms), plus the timestep modulation projection. + skip_patterns = ("final_layer.linear", "time_embed", "time_mod_proj", "norm") + if any(pattern in fqn for pattern in skip_patterns): + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training (torchao float8 scaled-mm on a bf16-loaded transformer).", + ) + parser.add_argument( + "--validation_model_path", + type=str, + default=None, + help=( + "Path to the checkpoint validation and final inference run on. Krea 2 RAW is a non-distilled" + " base not meant for inference, so validation should run on the distilled Krea 2 Turbo" + " checkpoint: pass its path here and the adapter trained on RAW is transplanted onto Turbo for" + " every validation. If unset, validation falls back to the (RAW) training checkpoint." + ), + ) + parser.add_argument( + "--validation_num_inference_steps", + type=int, + default=8, + help="num_inference_steps for validation on --validation_model_path (Krea 2 Turbo is an 8-step model).", + ) + parser.add_argument( + "--validation_guidance_scale", + type=float, + default=0.0, + help="guidance_scale for validation on --validation_model_path (Krea 2 Turbo runs without CFG).", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with the Qwen3-VL text encoder.", + ) + + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=32, + help=( + "The dimension of the LoRA update matrices. The Krea 2 authors recommend rank 32 for most styles; " + "increase it (and focus on the attention layers) for long runs or high-frequency styles." + ), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help="LoRA alpha to be used for additional scaling. The Krea 2 authors recommend alpha == rank (scale 1.0).", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="krea2-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-4, + help=( + "Initial learning rate (after the potential warmup period) to use. The Krea 2 authors recommend " + "3e-4 - 7e-4 with a constant schedule (lower end for a constant schedule; higher is fine with cosine)." + ), + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + "The transformer modules to apply LoRA training on, comma separated (matched as module-name suffixes). " + 'E.g. "to_q,to_k,to_v,to_out.0,to_gate" trains the attention layers only (the authors\' suggestion for ' + "long runs). If omitted, the Krea 2 authors' recommended default layer set is used." + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + # Qwen expects a `num_frames` dimension too. + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(2) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def concat_prompt_embedding_batches( + *prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + """Concatenate prompt embedding batches along the batch dimension for prior preservation. + + Krea 2 tokenizes every prompt to the same fixed sequence length, so the `(B, seq, num_text_layers, + dim)` embeddings and their `(B, seq)` masks already share a sequence length and can be concatenated + directly. + """ + merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in prompt_embedding_pairs], dim=0) + merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in prompt_embedding_pairs], dim=0) + return merged_prompt_embeds, merged_mask + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + pipeline = Krea2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + pipeline.to("cpu") + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + # Krea 2's scheduler uses resolution-aware dynamic shifting, so the static `shift` is ignored for the training + # sigma grid; load it straight from the checkpoint config. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLQwenImage.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + vae_scale_factor = 2 ** len(vae.temperal_downsample) + latents_mean = (torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1)).to(accelerator.device) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(accelerator.device) + text_encoder = Qwen3VLModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype + ) + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Krea2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + text_encoder.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + transformer.to(**transformer_to_kwargs) + + # Initialize a text encoding pipeline and keep it to CPU for now. `text_encoder_select_layers` (which + # decoder layers to tap) is restored from the pipeline config by `from_pretrained`. + text_encoding_pipeline = Krea2Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + # The Krea 2 authors' recommended default config (fits most styles, including high-frequency detail): + # rank/alpha 32 on the layers below. Names map to their reference layer list as: + # first -> img_in, last.linear -> final_layer.linear, wq/wk/wv/wo -> to_q/to_k/to_v/to_out.0, + # gate -> to_gate, mlp.up/mlp.down -> ff.up/ff.down, txtfusion.projector -> text_fusion.projector, + # txtmlp.1/txtmlp.3 -> txt_in.linear_1/txt_in.linear_2, tmlp.0/tmlp.2 -> time_embed.linear_1/linear_2, + # tproj.1 -> time_mod_proj. + # For long runs, the authors suggest raising the rank and narrowing to the attention layers + # ("to_q,to_k,to_v,to_out.0,to_gate") via --lora_layers so prompt adherence doesn't drop. + target_modules = [ + "img_in", + "final_layer.linear", + "to_q", + "to_k", + "to_v", + "to_out.0", + "to_gate", + "ff.up", + "ff.down", + "text_fusion.projector", + "txt_in.linear_1", + "txt_in.linear_2", + "time_embed.linear_1", + "time_embed.linear_2", + "time_mod_proj", + ] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Krea2Pipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Krea2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Krea2Pipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, prompt_embeds_mask = text_encoding_pipeline.encode_prompt( + prompt=prompt, max_sequence_length=args.max_sequence_length + ) + return prompt_embeds, prompt_embeds_mask + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_embeds, instance_prompt_embeds_mask = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_prompt_embeds, class_prompt_embeds_mask = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) + + validation_embeddings = {} + if args.validation_prompt is not None: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (validation_embeddings["prompt_embeds"], validation_embeddings["prompt_embeds_mask"]) = ( + compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) + ) + # Krea 2 enables classifier-free guidance whenever `guidance_scale > 0` and then encodes the + # negative prompt. The validation pipeline drops the text encoder to save memory, so precompute + # the (empty) negative-prompt embeddings here and pass them through to inference. + ( + validation_embeddings["negative_prompt_embeds"], + validation_embeddings["negative_prompt_embeds_mask"], + ) = compute_text_embeddings("", text_encoding_pipeline) + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_embeds + prompt_embeds_mask = instance_prompt_embeds_mask + if args.with_prior_preservation: + prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches( + (instance_prompt_embeds, instance_prompt_embeds_mask), + (class_prompt_embeds, class_prompt_embeds_mask), + ) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + prompt_embeds_mask_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, prompt_embeds_mask = compute_text_embeddings( + batch["prompts"], text_encoding_pipeline + ) + prompt_embeds_cache.append(prompt_embeds) + prompt_embeds_mask_cache.append(prompt_embeds_mask) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-krea2-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Keep the most recent validation batch around so the model card gallery is populated even when + # `--skip_final_inference` is set (we fall back to the last interim images). + images = [] + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + prompt_embeds_mask = prompt_embeds_mask_cache[step] + else: + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) + prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0) + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + + model_input = (model_input - latents_mean) * latents_std + model_input = model_input.to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Predict the noise residual. + # Pack the latents into 2x2 patches: (B, C, 1, H, W) -> (B, (H/2)*(W/2), C*4). + # Inlined from `Krea2Pipeline._pack_latents` (patch_size=2): that pipeline method is an + # instance method (uses `self.patch_size`), so it can't be invoked at the class level here. + noisy_model_input = noisy_model_input.permute(0, 2, 1, 3, 4) + bsz_pack, c_pack = model_input.shape[0], model_input.shape[1] + h_pack, w_pack, p_pack = model_input.shape[3], model_input.shape[4], 2 + packed_noisy_model_input = noisy_model_input.view( + bsz_pack, c_pack, h_pack // p_pack, p_pack, w_pack // p_pack, p_pack + ) + packed_noisy_model_input = packed_noisy_model_input.permute(0, 2, 4, 1, 3, 5) + packed_noisy_model_input = packed_noisy_model_input.reshape( + bsz_pack, (h_pack // p_pack) * (w_pack // p_pack), c_pack * p_pack * p_pack + ) + # Rotary coordinates for the combined [text, image] sequence. All images in a batch share a + # resolution, so a single set of position ids is reused for the whole batch. + grid_height = args.resolution // (vae_scale_factor * 2) + grid_width = args.resolution // (vae_scale_factor * 2) + position_ids = Krea2Pipeline.prepare_position_ids( + prompt_embeds.shape[1], grid_height, grid_width, accelerator.device + ) + model_pred = transformer( + hidden_states=packed_noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps / 1000, + position_ids=position_ids, + encoder_attention_mask=prompt_embeds_mask, + return_dict=False, + )[0] + # Unpack the predicted patches back to a latent grid. Inlined from + # `Krea2Pipeline._unpack_latents` (patch_size=2): that pipeline method is an instance method + # (uses `self.patch_size`/`self.vae_scale_factor`), so it can't be invoked at the class level here. + p_un = 2 + bsz_un, _, ch_un = model_pred.shape + h_un = p_un * (int(args.resolution) // (vae_scale_factor * p_un)) + w_un = p_un * (int(args.resolution) // (vae_scale_factor * p_un)) + model_pred = model_pred.view(bsz_un, h_un // p_un, w_un // p_un, ch_un // (p_un * p_un), p_un, p_un) + model_pred = model_pred.permute(0, 3, 1, 4, 2, 5) + model_pred = model_pred.reshape(bsz_un, ch_un // (p_un * p_un), 1, h_un, w_un) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + target = noise - model_input + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # Validation runs on --validation_model_path (e.g. Krea 2 Turbo) when set, since RAW + # is not meant for inference; otherwise it falls back to the training checkpoint. + pipeline = build_validation_pipeline(args, accelerator, transformer, weight_dtype) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + torch_dtype=weight_dtype, + epoch=epoch, + pipeline_call_kwargs=_validation_call_kwargs(args), + ) + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + modules_to_save = {} + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer + + Krea2Pipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + # `images` keeps the last interim validation batch (if any) as the gallery fallback; final + # inference below overwrites it with freshly generated images when it runs. + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + # Final inference. Like interim validation, run on --validation_model_path (e.g. Krea 2 + # Turbo) when set, since RAW is not meant for inference; else the training checkpoint. + pipeline = Krea2Pipeline.from_pretrained( + args.validation_model_path or args.pretrained_model_name_or_path, + tokenizer=None, + text_encoder=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + pipeline_call_kwargs=_validation_call_kwargs(args), + ) + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + inference_model=args.validation_model_path, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + images = None + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 2eb1f5cc7a44..1b0661d4c251 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -84,6 +84,7 @@ def text_encoder_attn_modules(text_encoder): "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", + "Krea2LoraLoaderMixin", "ZImageLoraLoaderMixin", "Flux2LoraLoaderMixin", "Ideogram4LoraLoaderMixin", @@ -131,6 +132,7 @@ def text_encoder_attn_modules(text_encoder): HunyuanVideoLoraLoaderMixin, Ideogram4LoraLoaderMixin, KandinskyLoraLoaderMixin, + Krea2LoraLoaderMixin, LoraLoaderMixin, LTX2LoraLoaderMixin, LTXVideoLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0abeba91e983..2212be27ca3d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5411,6 +5411,206 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class Krea2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Krea2Transformer2DModel`]. Specific to [`Krea2Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Krea2Transformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class ZImageLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`ZImageTransformer2DModel`]. Specific to [`ZImagePipeline`]. diff --git a/src/diffusers/models/transformers/transformer_krea2.py b/src/diffusers/models/transformers/transformer_krea2.py index b098119eddb0..d1f6cd0ecded 100644 --- a/src/diffusers/models/transformers/transformer_krea2.py +++ b/src/diffusers/models/transformers/transformer_krea2.py @@ -14,6 +14,7 @@ import inspect import math +from typing import Any import torch import torch.nn as nn @@ -21,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_adjust_dtype_for_device from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -443,6 +444,7 @@ def __init__( self.final_layer = Krea2FinalLayer(hidden_size, out_channels=in_channels, eps=norm_eps) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -450,6 +452,7 @@ def forward( timestep: torch.Tensor, position_ids: torch.Tensor, encoder_attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> Transformer2DModelOutput | tuple[torch.Tensor]: r""" @@ -467,6 +470,9 @@ def forward( latent-grid coordinates. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, text_seq_len)`, *optional*): Boolean mask marking valid text tokens. Pass `None` when every text token is valid. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that, when it contains a `scale` entry, sets the LoRA scale applied to this + transformer's adapters for the duration of the forward pass. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`] instead of a plain tuple. diff --git a/src/diffusers/pipelines/krea2/pipeline_krea2.py b/src/diffusers/pipelines/krea2/pipeline_krea2.py index 53e37ea9483a..51d33cb48619 100644 --- a/src/diffusers/pipelines/krea2/pipeline_krea2.py +++ b/src/diffusers/pipelines/krea2/pipeline_krea2.py @@ -13,13 +13,14 @@ # limitations under the License. import inspect -from typing import Callable +from typing import Any, Callable import numpy as np import torch from transformers import AutoTokenizer, Qwen3VLModel from ...image_processor import VaeImageProcessor +from ...loaders import Krea2LoraLoaderMixin from ...models import AutoencoderKLQwenImage, Krea2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -130,7 +131,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class Krea2Pipeline(DiffusionPipeline): +class Krea2Pipeline(DiffusionPipeline, Krea2LoraLoaderMixin): r""" The Krea 2 pipeline for text-to-image generation. @@ -425,6 +426,10 @@ def guidance_scale(self): def do_classifier_free_guidance(self): return self._guidance_scale > 0 + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def num_timesteps(self): return self._num_timesteps @@ -459,6 +464,7 @@ def __call__( return_dict: bool = True, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], + attention_kwargs: dict[str, Any] | None = None, max_sequence_length: int = 512, ): r""" @@ -512,6 +518,10 @@ def __call__( callback_on_step_end_tensor_inputs (`list[str]`, *optional*, defaults to `["latents"]`): The list of tensor inputs for the `callback_on_step_end` function. Must be a subset of `._callback_tensor_inputs`. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). max_sequence_length (`int`, defaults to 512): Fixed text sequence length consumed by the transformer; prompts are padded or truncated to it. @@ -546,6 +556,7 @@ def __call__( ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -638,6 +649,7 @@ def __call__( timestep=timestep, position_ids=position_ids, encoder_attention_mask=prompt_embeds_mask, + attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -648,6 +660,7 @@ def __call__( timestep=timestep, position_ids=position_ids, encoder_attention_mask=negative_prompt_embeds_mask, + attention_kwargs=self.attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred + guidance_scale * (noise_pred - neg_noise_pred) diff --git a/tests/lora/test_lora_layers_krea2.py b/tests/lora/test_lora_layers_krea2.py new file mode 100644 index 000000000000..6cef8dd0b52b --- /dev/null +++ b/tests/lora/test_lora_layers_krea2.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from transformers import Qwen2Tokenizer, Qwen3VLConfig, Qwen3VLModel + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + Krea2Pipeline, + Krea2Transformer2DModel, +) + +from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend + + +if is_peft_available(): + from peft import LoraConfig + + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class Krea2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = Krea2Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = { + "use_dynamic_shifting": True, + "base_shift": 0.5, + "max_shift": 1.15, + "base_image_seq_len": 256, + "max_image_seq_len": 6400, + } + + transformer_cls = Krea2Transformer2DModel + transformer_kwargs = { + "in_channels": 16, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "intermediate_size": 32, + "timestep_embed_dim": 8, + "text_hidden_dim": 16, + "num_text_layers": 3, + "text_num_attention_heads": 2, + "text_num_key_value_heads": 1, + "text_intermediate_size": 16, + "num_layerwise_text_blocks": 1, + "num_refiner_text_blocks": 1, + "axes_dims_rope": (4, 2, 2), + "rope_theta": 1000.0, + } + + z_dim = 4 + vae_cls = AutoencoderKLQwenImage + vae_kwargs = { + "base_dim": z_dim * 6, + "z_dim": z_dim, + "dim_mult": [1, 2, 4], + "num_res_blocks": 1, + "temperal_downsample": [False, True], + "latents_mean": [0.0] * 4, + "latents_std": [1.0] * 4, + } + + tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + + # Krea2's attention uses split q/k/v/out projections in the diffusers transformer. + denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + # The text encoder (Qwen3-VL) is frozen and not LoRA-adapted by the Krea2 loader. + supports_text_encoder_loras = False + + @property + def output_shape(self): + return (1, 32, 32, 3) + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + # The Krea2 pipeline uses a Qwen3-VL text encoder for which there is no tiny pretrained checkpoint, + # so build the components inline rather than relying on the base implementation. + scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + torch.manual_seed(0) + transformer = self.transformer_cls(**self.transformer_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + + torch.manual_seed(0) + scheduler = scheduler_cls(**self.scheduler_kwargs) + + torch.manual_seed(0) + config = Qwen3VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 8, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + vocab_size=152064, + ) + text_encoder = Qwen3VLModel(config).eval() + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + init_lora_weights=False, + use_dora=use_dora, + ) + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.denoiser_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "text_encoder_select_layers": (0, 1, 2), + } + + return pipeline_components, text_lora_config, denoiser_lora_config + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "a dog is dancing", + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in Krea2.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Krea2.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Krea2.") + def test_modify_padding_mode(self): + pass