Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions comfy/ldm/chroma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from einops import rearrange
import comfy.patcher_extension
import comfy.ldm.common_dit

Expand Down Expand Up @@ -257,29 +257,54 @@ def block_wrap(args):
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img

def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
).execute(x, timestep, context, guidance, ref_latents, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
def _forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape

h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size)
w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size)
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=self.patch_size, transformer_options=transformer_options)
if img.ndim != 3 or context.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
index = 0
h_offset = h_len * self.patch_size + h
w_offset = w_len * self.patch_size + w
h += ref.shape[-2]
w += ref.shape[-1]
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)

kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=self.patch_size)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)

txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
30 changes: 30 additions & 0 deletions comfy/ldm/common_dit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from einops import rearrange, repeat
import comfy.rmsnorm


Expand All @@ -14,3 +15,32 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):


rms_norm = comfy.rmsnorm.rms_norm

def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=(2, 2), transformer_options={}):
bs, c, h, w = x.shape
x = pad_to_patch_size(x, (patch_size, patch_size))

img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)

h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)

steps_h = h_len
steps_w = w_len

rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0

index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)

img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
35 changes: 3 additions & 32 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from einops import rearrange
import comfy.ldm.common_dit
import comfy.patcher_extension

Expand Down Expand Up @@ -210,35 +210,6 @@ def block_wrap(args):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)

h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)

steps_h = h_len
steps_w = w_len

rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0

index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)

img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)

def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
Expand All @@ -253,7 +224,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None

h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
Expand Down Expand Up @@ -282,7 +253,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)

kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)

Expand Down