Skip to content

Commit 2e843f3

Browse files
committed
wan: Implement block level prefetching
1 parent 0814c1f commit 2e843f3

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

comfy/ldm/wan/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ def forward_orig(
538538
List[Tensor]:
539539
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
540540
"""
541+
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.blocks))
542+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
541543
# embeddings
542544
x = self.patch_embedding(x.float()).to(x.dtype)
543545
grid_sizes = x.shape[2:]
@@ -569,6 +571,7 @@ def forward_orig(
569571
patches_replace = transformer_options.get("patches_replace", {})
570572
blocks_replace = patches_replace.get("dit", {})
571573
for i, block in enumerate(self.blocks):
574+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
572575
if ("double_block", i) in blocks_replace:
573576
def block_wrap(args):
574577
out = {}
@@ -578,6 +581,7 @@ def block_wrap(args):
578581
x = out["img"]
579582
else:
580583
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
584+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
581585

582586
# head
583587
x = self.head(x, e)

0 commit comments

Comments
 (0)