Skip to content

Commit ff66d8a

Browse files
committed
nodes_custom: Lazify the model
These little model reader bits need to lazily evaluate the model to allow the model to be freed incase there is RAM pressure.
1 parent 626b749 commit ff66d8a

File tree

1 file changed

+50
-10
lines changed

1 file changed

+50
-10
lines changed

comfy_extras/nodes_custom_sampler.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ def define_schema(cls):
1818
node_id="BasicScheduler",
1919
category="sampling/custom_sampling/schedulers",
2020
inputs=[
21-
io.Model.Input("model"),
21+
io.Model.Input("model", lazy=True),
2222
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
2323
io.Int.Input("steps", default=20, min=1, max=10000),
2424
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
2525
],
2626
outputs=[io.Sigmas.Output()]
2727
)
2828

29+
@classmethod
30+
def check_lazy_status(self, *args, **kwargs):
31+
return ["model"]
32+
2933
@classmethod
3034
def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput:
3135
total_steps = steps
@@ -137,13 +141,17 @@ def define_schema(cls):
137141
node_id="SDTurboScheduler",
138142
category="sampling/custom_sampling/schedulers",
139143
inputs=[
140-
io.Model.Input("model"),
144+
io.Model.Input("model", lazy=True),
141145
io.Int.Input("steps", default=1, min=1, max=10),
142146
io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
143147
],
144148
outputs=[io.Sigmas.Output()]
145149
)
146150

151+
@classmethod
152+
def check_lazy_status(self, *args, **kwargs):
153+
return ["model"]
154+
147155
@classmethod
148156
def execute(cls, model, steps, denoise) -> io.NodeOutput:
149157
start_step = 10 - int(10 * denoise)
@@ -161,14 +169,18 @@ def define_schema(cls):
161169
node_id="BetaSamplingScheduler",
162170
category="sampling/custom_sampling/schedulers",
163171
inputs=[
164-
io.Model.Input("model"),
172+
io.Model.Input("model", lazy=True),
165173
io.Int.Input("steps", default=20, min=1, max=10000),
166174
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
167175
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
168176
],
169177
outputs=[io.Sigmas.Output()]
170178
)
171179

180+
@classmethod
181+
def check_lazy_status(self, *args, **kwargs):
182+
return ["model"]
183+
172184
@classmethod
173185
def execute(cls, model, steps, alpha, beta) -> io.NodeOutput:
174186
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
@@ -351,13 +363,17 @@ def define_schema(cls):
351363
node_id="SamplingPercentToSigma",
352364
category="sampling/custom_sampling/sigmas",
353365
inputs=[
354-
io.Model.Input("model"),
366+
io.Model.Input("model", lazy=True),
355367
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
356368
io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
357369
],
358370
outputs=[io.Float.Output(display_name="sigma_value")]
359371
)
360372

373+
@classmethod
374+
def check_lazy_status(self, *args, **kwargs):
375+
return ["model"]
376+
361377
@classmethod
362378
def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput:
363379
model_sampling = model.get_model_object("model_sampling")
@@ -622,7 +638,7 @@ def define_schema(cls):
622638
node_id="SamplerSASolver",
623639
category="sampling/custom_sampling/samplers",
624640
inputs=[
625-
io.Model.Input("model"),
641+
io.Model.Input("model", lazy=True),
626642
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
627643
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
628644
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
@@ -635,6 +651,10 @@ def define_schema(cls):
635651
outputs=[io.Sampler.Output()]
636652
)
637653

654+
@classmethod
655+
def check_lazy_status(self, *args, **kwargs):
656+
return ["model"]
657+
638658
@classmethod
639659
def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput:
640660
model_sampling = model.get_model_object("model_sampling")
@@ -684,7 +704,7 @@ def define_schema(cls):
684704
node_id="SamplerCustom",
685705
category="sampling/custom_sampling",
686706
inputs=[
687-
io.Model.Input("model"),
707+
io.Model.Input("model", lazy=True),
688708
io.Boolean.Input("add_noise", default=True),
689709
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
690710
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
@@ -700,6 +720,10 @@ def define_schema(cls):
700720
]
701721
)
702722

723+
@classmethod
724+
def check_lazy_status(self, *args, **kwargs):
725+
return ["model"]
726+
703727
@classmethod
704728
def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput:
705729
latent = latent_image
@@ -745,12 +769,16 @@ def define_schema(cls):
745769
node_id="BasicGuider",
746770
category="sampling/custom_sampling/guiders",
747771
inputs=[
748-
io.Model.Input("model"),
772+
io.Model.Input("model", lazy=True),
749773
io.Conditioning.Input("conditioning"),
750774
],
751775
outputs=[io.Guider.Output()]
752776
)
753777

778+
@classmethod
779+
def check_lazy_status(self, *args, **kwargs):
780+
return ["model"]
781+
754782
@classmethod
755783
def execute(cls, model, conditioning) -> io.NodeOutput:
756784
guider = Guider_Basic(model)
@@ -766,14 +794,18 @@ def define_schema(cls):
766794
node_id="CFGGuider",
767795
category="sampling/custom_sampling/guiders",
768796
inputs=[
769-
io.Model.Input("model"),
797+
io.Model.Input("model", lazy=True),
770798
io.Conditioning.Input("positive"),
771799
io.Conditioning.Input("negative"),
772800
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
773801
],
774802
outputs=[io.Guider.Output()]
775803
)
776804

805+
@classmethod
806+
def check_lazy_status(self, *args, **kwargs):
807+
return ["model"]
808+
777809
@classmethod
778810
def execute(cls, model, positive, negative, cfg) -> io.NodeOutput:
779811
guider = comfy.samplers.CFGGuider(model)
@@ -819,7 +851,7 @@ def define_schema(cls):
819851
node_id="DualCFGGuider",
820852
category="sampling/custom_sampling/guiders",
821853
inputs=[
822-
io.Model.Input("model"),
854+
io.Model.Input("model", lazy=True),
823855
io.Conditioning.Input("cond1"),
824856
io.Conditioning.Input("cond2"),
825857
io.Conditioning.Input("negative"),
@@ -830,6 +862,10 @@ def define_schema(cls):
830862
outputs=[io.Guider.Output()]
831863
)
832864

865+
@classmethod
866+
def check_lazy_status(self, *args, **kwargs):
867+
return ["model"]
868+
833869
@classmethod
834870
def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput:
835871
guider = Guider_DualCFG(model)
@@ -930,7 +966,7 @@ def define_schema(cls):
930966
category="_for_testing/custom_sampling/noise",
931967
is_experimental=True,
932968
inputs=[
933-
io.Model.Input("model"),
969+
io.Model.Input("model", lazy=True),
934970
io.Noise.Input("noise"),
935971
io.Sigmas.Input("sigmas"),
936972
io.Latent.Input("latent_image"),
@@ -940,6 +976,10 @@ def define_schema(cls):
940976
]
941977
)
942978

979+
@classmethod
980+
def check_lazy_status(self, *args, **kwargs):
981+
return ["model"]
982+
943983
@classmethod
944984
def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput:
945985
if len(sigmas) == 0:

0 commit comments

Comments
 (0)