diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index ad1c523fe1d3..d56dd12d0470 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn -from einops import rearrange, repeat +from einops import rearrange import comfy.patcher_extension import comfy.ldm.common_dit @@ -257,29 +257,54 @@ def block_wrap(args): img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) + ).execute(x, timestep, context, guidance, ref_latents, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): - bs, c, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) + def _forward(self, x, timestep, context, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + bs, c, h_orig, w_orig = x.shape + h_len = ((h_orig + (self.patch_size // 2)) // self.patch_size) + w_len = ((w_orig + (self.patch_size // 2)) // self.patch_size) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=self.patch_size, transformer_options=transformer_options) if img.ndim != 3 or context.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") - - h_len = ((h + (self.patch_size // 2)) // self.patch_size) - w_len = ((w + (self.patch_size // 2)) // self.patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + img_tokens = img.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + ref_latents_method = kwargs.get("ref_latents_method", "offset") + for ref in ref_latents: + if ref_latents_method == "index": + index += 1 + h_offset = 0 + w_offset = 0 + elif ref_latents_method == "uxo": + index = 0 + h_offset = h_len * self.patch_size + h + w_offset = w_len * self.patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=self.patch_size) + img = torch.cat([img, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w] + out = out[:, :img_tokens] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index f7f56b72ca6f..a31c9f571b63 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +from einops import rearrange, repeat import comfy.rmsnorm @@ -14,3 +15,32 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): rms_norm = comfy.rmsnorm.rms_norm + +def process_img(x, index=0, h_offset=0, w_offset=0, patch_size=(2, 2), transformer_options={}): + bs, c, h, w = x.shape + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) + + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f2024d7..fb38d1a73e97 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn -from einops import rearrange, repeat +from einops import rearrange import comfy.ldm.common_dit import comfy.patcher_extension @@ -210,35 +210,6 @@ def block_wrap(args): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): - bs, c, h, w = x.shape - patch_size = self.patch_size - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) - - h_offset = ((h_offset + (patch_size // 2)) // patch_size) - w_offset = ((w_offset + (patch_size // 2)) // patch_size) - - steps_h = h_len - steps_w = w_len - - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 - - index += rope_options.get("shift_t", 0.0) - h_offset += rope_options.get("shift_y", 0.0) - w_offset += rope_options.get("shift_x", 0.0) - - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) - return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -253,7 +224,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x, transformer_options=transformer_options) + img, img_ids = comfy.ldm.common_dit.process_img(x, patch_size=patch_size, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 @@ -282,7 +253,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = comfy.ldm.common_dit.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, patch_size=patch_size) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)