File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -387,6 +387,9 @@ def _forward(
387387 hidden_states , img_ids , orig_shape = self .process_img (x )
388388 num_embeds = hidden_states .shape [1 ]
389389
390+ prefetch_queue = comfy .ops .make_prefetch_queue (list (self .transformer_blocks ))
391+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , None )
392+
390393 if ref_latents is not None :
391394 h = 0
392395 w = 0
@@ -436,6 +439,7 @@ def _forward(
436439 blocks_replace = patches_replace .get ("dit" , {})
437440
438441 for i , block in enumerate (self .transformer_blocks ):
442+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , block )
439443 if ("double_block" , i ) in blocks_replace :
440444 def block_wrap (args ):
441445 out = {}
@@ -467,6 +471,8 @@ def block_wrap(args):
467471 if add is not None :
468472 hidden_states [:, :add .shape [1 ]] += add
469473
474+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , block )
475+
470476 hidden_states = self .norm_out (hidden_states , temb )
471477 hidden_states = self .proj_out (hidden_states )
472478
You can’t perform that action at this time.
0 commit comments