Skip to content

Commit 0814c1f

Browse files
committed
qwen: Implement transformer block prefetching
1 parent e279e1f commit 0814c1f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def _forward(
387387
hidden_states, img_ids, orig_shape = self.process_img(x)
388388
num_embeds = hidden_states.shape[1]
389389

390+
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
391+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
392+
390393
if ref_latents is not None:
391394
h = 0
392395
w = 0
@@ -436,6 +439,7 @@ def _forward(
436439
blocks_replace = patches_replace.get("dit", {})
437440

438441
for i, block in enumerate(self.transformer_blocks):
442+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
439443
if ("double_block", i) in blocks_replace:
440444
def block_wrap(args):
441445
out = {}
@@ -467,6 +471,8 @@ def block_wrap(args):
467471
if add is not None:
468472
hidden_states[:, :add.shape[1]] += add
469473

474+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
475+
470476
hidden_states = self.norm_out(hidden_states, temb)
471477
hidden_states = self.proj_out(hidden_states)
472478

0 commit comments

Comments
 (0)