Skip to content

Commit e279e1f

Browse files
committed
ops: Implement prefetching API
Implement an API that allows instrumenting a model with a prefetch queue. Units of work are on the nn.Module level.
1 parent c350009 commit e279e1f

File tree

1 file changed

+105
-29
lines changed

1 file changed

+105
-29
lines changed

comfy/ops.py

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from comfy.cli_args import args, PerformanceFeature
2323
import comfy.float
2424
import comfy.rmsnorm
25-
import contextlib
2625

2726
def run_every_op():
2827
if torch.compiler.is_compiling():
@@ -71,6 +70,93 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7170
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
7271

7372

73+
def cast_prefetch_all(module, device):
74+
if not comfy.model_management.device_supports_non_blocking(device):
75+
#Adios! prefetching works against you if you can't get the CPU past it
76+
return None
77+
78+
offload_stream = None
79+
80+
for n, m in module.named_modules():
81+
if hasattr(m, "comfy_cast_weights"):
82+
if m.weight is not None and m.weight.device != device and not hasattr(m, "weight_prefetch"):
83+
if offload_stream is None:
84+
offload_stream = comfy.model_management.get_offload_stream(device)
85+
if offload_stream is None:
86+
return None
87+
m.weight_prefetch = comfy.model_management.cast_to(m.weight, None, device, non_blocking=True, copy=True, stream=offload_stream)
88+
if m.bias is not None and m.bias.device != device and not hasattr(m, "bias_prefetch"):
89+
if offload_stream is None:
90+
offload_stream = comfy.model_management.get_offload_stream(device)
91+
if offload_stream is None:
92+
return None
93+
m.bias_prefetch = comfy.model_management.cast_to(m.bias, None, device, non_blocking=True, copy = True, stream=offload_stream)
94+
95+
return offload_stream
96+
97+
98+
def uncast_prefetch_all(module):
99+
for n, m in module.named_modules():
100+
if hasattr(m, "comfy_cast_weights"):
101+
if hasattr(m, "weight_prefetch"):
102+
delattr(m, "weight_prefetch")
103+
if hasattr(m, "bias_prefetch"):
104+
delattr(m, "bias_prefetch")
105+
106+
107+
def prefetch_queue_pop(queue, device, module):
108+
consumed = queue.pop(0)
109+
if consumed is not None:
110+
offload_stream, m = consumed
111+
#Sync the offload stream with compute so when it starts
112+
#freeing the prefetches the compute stream has finished
113+
if offload_stream is not None:
114+
offload_stream.wait_stream(comfy.model_management.current_stream(device))
115+
uncast_prefetch_all(m)
116+
117+
active = queue[0]
118+
if active is not None:
119+
offload_stream, m = active
120+
assert m == module
121+
#wait for the prefetch to complete before using the data
122+
if offload_stream is not None:
123+
comfy.model_management.sync_stream(device, offload_stream)
124+
125+
prefetch = queue[1]
126+
if prefetch is not None:
127+
offload_stream = comfy.ops.cast_prefetch_all(prefetch, device)
128+
queue[1] = (offload_stream, prefetch)
129+
130+
131+
def make_prefetch_queue(queue):
132+
return [None, None] + queue + [None, None]
133+
134+
135+
def move_bias_weight(s, device, offloadable=False):
136+
137+
bias_has_function = len(s.bias_function) > 0
138+
weight_has_function = len(s.weight_function) > 0
139+
140+
if offloadable and (
141+
s.weight.device != device or (s.bias is not None and s.bias.device != device) or
142+
bias_has_function or weight_has_function):
143+
offload_stream = comfy.model_management.get_offload_stream(device)
144+
else:
145+
offload_stream = None
146+
147+
bias = None
148+
non_blocking = comfy.model_management.device_supports_non_blocking(device)
149+
150+
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
151+
152+
if s.bias is not None:
153+
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
154+
155+
comfy.model_management.sync_stream(device, offload_stream)
156+
157+
return weight, bias, offload_stream
158+
159+
74160
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
75161
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
76162
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
@@ -83,40 +169,30 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
83169
if device is None:
84170
device = input.device
85171

86-
if offloadable and (device != s.weight.device or
87-
(s.bias is not None and device != s.bias.device)):
88-
offload_stream = comfy.model_management.get_offload_stream(device)
89-
else:
90-
offload_stream = None
91-
92-
if offload_stream is not None:
93-
wf_context = offload_stream
94-
else:
95-
wf_context = contextlib.nullcontext()
96-
97-
non_blocking = comfy.model_management.device_supports_non_blocking(device)
98-
99-
weight_has_function = len(s.weight_function) > 0
100172
bias_has_function = len(s.bias_function) > 0
173+
weight_has_function = len(s.weight_function) > 0
101174

102-
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
175+
if hasattr(s, "weight_prefetch") or hasattr(s, "bias_prefetch"):
176+
weight = getattr(s, "weight_prefetch", None)
177+
bias = getattr(s, "bias_prefetch", None)
178+
offload_stream = None
179+
else:
180+
weight, bias, offload_stream = move_bias_weight(s, device, offloadable=offloadable)
103181

104-
bias = None
105-
if s.bias is not None:
106-
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
182+
if weight_has_function:
183+
weight=weight.to(dtype=dtype)
184+
for f in s.weight_function:
185+
weight = f(weight)
107186

108-
if bias_has_function:
109-
with wf_context:
110-
for f in s.bias_function:
111-
bias = f(bias)
187+
if s.bias is not None and bias_has_function:
188+
bias=bias.to(dtype=bias_dtype)
189+
for f in s.bias_function:
190+
bias = f(bias)
112191

113-
if weight_has_function or weight.dtype != dtype:
114-
with wf_context:
115-
weight = weight.to(dtype=dtype)
116-
for f in s.weight_function:
117-
weight = f(weight)
192+
weight=weight.to(dtype=dtype)
193+
if bias is not None:
194+
bias=bias.to(dtype=bias_dtype)
118195

119-
comfy.model_management.sync_stream(device, offload_stream)
120196
if offloadable:
121197
return weight, bias, offload_stream
122198
else:

0 commit comments

Comments
 (0)