Skip to content

Commit 647d88a

Browse files
committed
nodes_custom: Lazify the model
These little model reader bits need to lazily evaluate the model to allow the model to be freed incase there is RAM pressure.
1 parent 46581e9 commit 647d88a

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

comfy_extras/nodes_custom_sampler.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ def define_schema(cls):
1818
node_id="BasicScheduler",
1919
category="sampling/custom_sampling/schedulers",
2020
inputs=[
21-
io.Model.Input("model"),
21+
io.Model.Input("model", lazy=True),
2222
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
2323
io.Int.Input("steps", default=20, min=1, max=10000),
2424
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
2525
],
2626
outputs=[io.Sigmas.Output()]
2727
)
2828

29+
def check_lazy_status(self, *args, **kwargs):
30+
return ["model"]
31+
2932
@classmethod
3033
def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput:
3134
total_steps = steps
@@ -137,13 +140,16 @@ def define_schema(cls):
137140
node_id="SDTurboScheduler",
138141
category="sampling/custom_sampling/schedulers",
139142
inputs=[
140-
io.Model.Input("model"),
143+
io.Model.Input("model", lazy=True),
141144
io.Int.Input("steps", default=1, min=1, max=10),
142145
io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
143146
],
144147
outputs=[io.Sigmas.Output()]
145148
)
146149

150+
def check_lazy_status(self, *args, **kwargs):
151+
return ["model"]
152+
147153
@classmethod
148154
def execute(cls, model, steps, denoise) -> io.NodeOutput:
149155
start_step = 10 - int(10 * denoise)
@@ -161,14 +167,17 @@ def define_schema(cls):
161167
node_id="BetaSamplingScheduler",
162168
category="sampling/custom_sampling/schedulers",
163169
inputs=[
164-
io.Model.Input("model"),
170+
io.Model.Input("model", lazy=True),
165171
io.Int.Input("steps", default=20, min=1, max=10000),
166172
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
167173
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
168174
],
169175
outputs=[io.Sigmas.Output()]
170176
)
171177

178+
def check_lazy_status(self, *args, **kwargs):
179+
return ["model"]
180+
172181
@classmethod
173182
def execute(cls, model, steps, alpha, beta) -> io.NodeOutput:
174183
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
@@ -351,13 +360,16 @@ def define_schema(cls):
351360
node_id="SamplingPercentToSigma",
352361
category="sampling/custom_sampling/sigmas",
353362
inputs=[
354-
io.Model.Input("model"),
363+
io.Model.Input("model", lazy=True),
355364
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
356365
io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
357366
],
358367
outputs=[io.Float.Output(display_name="sigma_value")]
359368
)
360369

370+
def check_lazy_status(self, *args, **kwargs):
371+
return ["model"]
372+
361373
@classmethod
362374
def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput:
363375
model_sampling = model.get_model_object("model_sampling")
@@ -622,7 +634,7 @@ def define_schema(cls):
622634
node_id="SamplerSASolver",
623635
category="sampling/custom_sampling/samplers",
624636
inputs=[
625-
io.Model.Input("model"),
637+
io.Model.Input("model", lazy=True),
626638
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
627639
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
628640
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
@@ -635,6 +647,9 @@ def define_schema(cls):
635647
outputs=[io.Sampler.Output()]
636648
)
637649

650+
def check_lazy_status(self, *args, **kwargs):
651+
return ["model"]
652+
638653
@classmethod
639654
def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput:
640655
model_sampling = model.get_model_object("model_sampling")
@@ -684,7 +699,7 @@ def define_schema(cls):
684699
node_id="SamplerCustom",
685700
category="sampling/custom_sampling",
686701
inputs=[
687-
io.Model.Input("model"),
702+
io.Model.Input("model", lazy=True),
688703
io.Boolean.Input("add_noise", default=True),
689704
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
690705
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
@@ -700,6 +715,9 @@ def define_schema(cls):
700715
]
701716
)
702717

718+
def check_lazy_status(self, *args, **kwargs):
719+
return ["model"]
720+
703721
@classmethod
704722
def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput:
705723
latent = latent_image
@@ -745,12 +763,15 @@ def define_schema(cls):
745763
node_id="BasicGuider",
746764
category="sampling/custom_sampling/guiders",
747765
inputs=[
748-
io.Model.Input("model"),
766+
io.Model.Input("model", lazy=True),
749767
io.Conditioning.Input("conditioning"),
750768
],
751769
outputs=[io.Guider.Output()]
752770
)
753771

772+
def check_lazy_status(self, *args, **kwargs):
773+
return ["model"]
774+
754775
@classmethod
755776
def execute(cls, model, conditioning) -> io.NodeOutput:
756777
guider = Guider_Basic(model)
@@ -766,14 +787,17 @@ def define_schema(cls):
766787
node_id="CFGGuider",
767788
category="sampling/custom_sampling/guiders",
768789
inputs=[
769-
io.Model.Input("model"),
790+
io.Model.Input("model", lazy=True),
770791
io.Conditioning.Input("positive"),
771792
io.Conditioning.Input("negative"),
772793
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
773794
],
774795
outputs=[io.Guider.Output()]
775796
)
776797

798+
def check_lazy_status(self, *args, **kwargs):
799+
return ["model"]
800+
777801
@classmethod
778802
def execute(cls, model, positive, negative, cfg) -> io.NodeOutput:
779803
guider = comfy.samplers.CFGGuider(model)
@@ -819,7 +843,7 @@ def define_schema(cls):
819843
node_id="DualCFGGuider",
820844
category="sampling/custom_sampling/guiders",
821845
inputs=[
822-
io.Model.Input("model"),
846+
io.Model.Input("model", lazy=True),
823847
io.Conditioning.Input("cond1"),
824848
io.Conditioning.Input("cond2"),
825849
io.Conditioning.Input("negative"),
@@ -830,6 +854,9 @@ def define_schema(cls):
830854
outputs=[io.Guider.Output()]
831855
)
832856

857+
def check_lazy_status(self, *args, **kwargs):
858+
return ["model"]
859+
833860
@classmethod
834861
def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput:
835862
guider = Guider_DualCFG(model)
@@ -930,7 +957,7 @@ def define_schema(cls):
930957
category="_for_testing/custom_sampling/noise",
931958
is_experimental=True,
932959
inputs=[
933-
io.Model.Input("model"),
960+
io.Model.Input("model", lazy=True),
934961
io.Noise.Input("noise"),
935962
io.Sigmas.Input("sigmas"),
936963
io.Latent.Input("latent_image"),
@@ -940,6 +967,9 @@ def define_schema(cls):
940967
]
941968
)
942969

970+
def check_lazy_status(self, *args, **kwargs):
971+
return ["model"]
972+
943973
@classmethod
944974
def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput:
945975
if len(sigmas) == 0:

0 commit comments

Comments
 (0)