2222from comfy .cli_args import args , PerformanceFeature
2323import comfy .float
2424import comfy .rmsnorm
25- import contextlib
2625
2726def run_every_op ():
2827 if torch .compiler .is_compiling ():
@@ -71,6 +70,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7170 return comfy .model_management .cast_to (weight , input .dtype , input .device , non_blocking = non_blocking , copy = copy )
7271
7372
73+ def cast_prefetch_all (module , device ):
74+ if not comfy .model_management .device_supports_non_blocking (device ):
75+ #Adios! prefetching works against you if you can't get the CPU past it
76+ return None
77+
78+ offload_stream = None
79+
80+ for n , m in module .named_modules ():
81+ if hasattr (m , "comfy_cast_weights" ):
82+ if m .weight is not None and m .weight .device != device and not hasattr (m , "weight_prefetch" ):
83+ if offload_stream is None :
84+ offload_stream = comfy .model_management .get_offload_stream (device )
85+ if offload_stream is None :
86+ return None
87+ m .weight_prefetch = comfy .model_management .cast_to (m .weight , None , device , non_blocking = True , copy = True , stream = offload_stream )
88+ if m .bias is not None and m .bias .device != device and not hasattr (m , "bias_prefetch" ):
89+ if offload_stream is None :
90+ offload_stream = comfy .model_management .get_offload_stream (device )
91+ if offload_stream is None :
92+ return None
93+ m .bias_prefetch = comfy .model_management .cast_to (m .bias , None , device , non_blocking = True , copy = True , stream = offload_stream )
94+
95+ return offload_stream
96+
97+
98+ def uncast_prefetch_all (module ):
99+ for n , m in module .named_modules ():
100+ if hasattr (m , "comfy_cast_weights" ):
101+ if hasattr (m , "weight_prefetch" ):
102+ delattr (m , "weight_prefetch" )
103+ if hasattr (m , "bias_prefetch" ):
104+ delattr (m , "bias_prefetch" )
105+
106+
107+ def prefetch_queue_pop (queue , device , module ):
108+ consumed = queue .pop (0 )
109+ if consumed is not None :
110+ offload_stream , m = consumed
111+ #Sync the offload stream with compute so when it starts
112+ #freeing the prefetches the compute stream has finished
113+ if offload_stream is not None :
114+ offload_stream .wait_stream (comfy .model_management .current_stream (device ))
115+ uncast_prefetch_all (m )
116+
117+ active = queue [0 ]
118+ if active is not None :
119+ offload_stream , m = active
120+ assert m == module
121+ #wait for the prefetch to complete before using the data
122+ if offload_stream is not None :
123+ comfy .model_management .sync_stream (device , offload_stream )
124+
125+ prefetch = queue [1 ]
126+ if prefetch is not None :
127+ offload_stream = comfy .ops .cast_prefetch_all (prefetch , device )
128+ queue [1 ] = (offload_stream , prefetch )
129+
130+
131+ def make_prefetch_queue (queue ):
132+ return [None , None ] + queue + [None , None ]
133+
134+
135+ def move_bias_weight (s , device , offloadable = False ):
136+
137+ bias_has_function = len (s .bias_function ) > 0
138+ weight_has_function = len (s .weight_function ) > 0
139+
140+ if offloadable and (
141+ s .weight .device != device or (s .bias is not None and s .bias .device != device ) or
142+ bias_has_function or weight_has_function ):
143+ offload_stream = comfy .model_management .get_offload_stream (device )
144+ else :
145+ offload_stream = None
146+
147+ bias = None
148+ non_blocking = comfy .model_management .device_supports_non_blocking (device )
149+
150+ weight = comfy .model_management .cast_to (s .weight , None , device , non_blocking = non_blocking , copy = weight_has_function , stream = offload_stream )
151+
152+ if s .bias is not None :
153+ bias = comfy .model_management .cast_to (s .bias , None , device , non_blocking = non_blocking , copy = bias_has_function , stream = offload_stream )
154+
155+ comfy .model_management .sync_stream (device , offload_stream )
156+
157+ return weight , bias , offload_stream
158+
159+
74160def cast_bias_weight (s , input = None , dtype = None , device = None , bias_dtype = None , offloadable = False ):
75161 # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
76162 # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
@@ -83,40 +169,30 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
83169 if device is None :
84170 device = input .device
85171
86- if offloadable and (device != s .weight .device or
87- (s .bias is not None and device != s .bias .device )):
88- offload_stream = comfy .model_management .get_offload_stream (device )
89- else :
90- offload_stream = None
91-
92- if offload_stream is not None :
93- wf_context = offload_stream
94- else :
95- wf_context = contextlib .nullcontext ()
96-
97- non_blocking = comfy .model_management .device_supports_non_blocking (device )
98-
99- weight_has_function = len (s .weight_function ) > 0
100172 bias_has_function = len (s .bias_function ) > 0
173+ weight_has_function = len (s .weight_function ) > 0
101174
102- weight = comfy .model_management .cast_to (s .weight , None , device , non_blocking = non_blocking , copy = weight_has_function , stream = offload_stream )
175+ if hasattr (s , "weight_prefetch" ) or hasattr (s , "bias_prefetch" ):
176+ weight = getattr (s , "weight_prefetch" , None )
177+ bias = getattr (s , "bias_prefetch" , None )
178+ offload_stream = None
179+ else :
180+ weight , bias , offload_stream = move_bias_weight (s , device , offloadable = offloadable )
103181
104- bias = None
105- if s .bias is not None :
106- bias = comfy .model_management .cast_to (s .bias , bias_dtype , device , non_blocking = non_blocking , copy = bias_has_function , stream = offload_stream )
182+ if weight_has_function :
183+ weight = weight .to (dtype = dtype )
184+ for f in s .weight_function :
185+ weight = f (weight )
107186
108- if bias_has_function :
109- with wf_context :
110- for f in s .bias_function :
111- bias = f (bias )
187+ if s . bias is not None and bias_has_function :
188+ bias = bias . to ( dtype = bias_dtype )
189+ for f in s .bias_function :
190+ bias = f (bias )
112191
113- if weight_has_function or weight .dtype != dtype :
114- with wf_context :
115- weight = weight .to (dtype = dtype )
116- for f in s .weight_function :
117- weight = f (weight )
192+ weight = weight .to (dtype = dtype )
193+ if bias is not None :
194+ bias = bias .to (dtype = bias_dtype )
118195
119- comfy .model_management .sync_stream (device , offload_stream )
120196 if offloadable :
121197 return weight , bias , offload_stream
122198 else :
0 commit comments