Skip to content

Commit 5c8e986

Browse files
authored
convert nodes_tomesd.py to V3 schema (comfyanonymous#10180)
1 parent 8c26d7b commit 5c8e986

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

comfy_extras/nodes_tomesd.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#Taken from: https://github.com/dbolya/tomesd
22

33
import torch
4-
from typing import Tuple, Callable
4+
from typing import Tuple, Callable, Optional
5+
from typing_extensions import override
6+
from comfy_api.latest import ComfyExtension, io
57
import math
68

79
def do_nothing(x: torch.Tensor, mode:str=None):
@@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):
144146

145147

146148

147-
class TomePatchModel:
149+
class TomePatchModel(io.ComfyNode):
148150
@classmethod
149-
def INPUT_TYPES(s):
150-
return {"required": { "model": ("MODEL",),
151-
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
152-
}}
153-
RETURN_TYPES = ("MODEL",)
154-
FUNCTION = "patch"
151+
def define_schema(cls):
152+
return io.Schema(
153+
node_id="TomePatchModel",
154+
category="model_patches/unet",
155+
inputs=[
156+
io.Model.Input("model"),
157+
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
158+
],
159+
outputs=[io.Model.Output()],
160+
)
155161

156-
CATEGORY = "model_patches/unet"
157-
158-
def patch(self, model, ratio):
159-
self.u = None
162+
@classmethod
163+
def execute(cls, model, ratio) -> io.NodeOutput:
164+
u: Optional[Callable] = None
160165
def tomesd_m(q, k, v, extra_options):
166+
nonlocal u
161167
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
162168
#however from my basic testing it seems that using q instead gives better results
163-
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
169+
m, u = get_functions(q, ratio, extra_options["original_shape"])
164170
return m(q), k, v
165171
def tomesd_u(n, extra_options):
166-
return self.u(n)
172+
nonlocal u
173+
return u(n)
167174

168175
m = model.clone()
169176
m.set_model_attn1_patch(tomesd_m)
170177
m.set_model_attn1_output_patch(tomesd_u)
171-
return (m, )
178+
return io.NodeOutput(m)
179+
180+
181+
class TomePatchModelExtension(ComfyExtension):
182+
@override
183+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
184+
return [
185+
TomePatchModel,
186+
]
172187

173188

174-
NODE_CLASS_MAPPINGS = {
175-
"TomePatchModel": TomePatchModel,
176-
}
189+
async def comfy_entrypoint() -> TomePatchModelExtension:
190+
return TomePatchModelExtension()

0 commit comments

Comments
 (0)