|
1 | 1 | #Taken from: https://github.com/dbolya/tomesd
|
2 | 2 |
|
3 | 3 | import torch
|
4 |
| -from typing import Tuple, Callable |
| 4 | +from typing import Tuple, Callable, Optional |
| 5 | +from typing_extensions import override |
| 6 | +from comfy_api.latest import ComfyExtension, io |
5 | 7 | import math
|
6 | 8 |
|
7 | 9 | def do_nothing(x: torch.Tensor, mode:str=None):
|
@@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):
|
144 | 146 |
|
145 | 147 |
|
146 | 148 |
|
147 |
| -class TomePatchModel: |
| 149 | +class TomePatchModel(io.ComfyNode): |
148 | 150 | @classmethod
|
149 |
| - def INPUT_TYPES(s): |
150 |
| - return {"required": { "model": ("MODEL",), |
151 |
| - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), |
152 |
| - }} |
153 |
| - RETURN_TYPES = ("MODEL",) |
154 |
| - FUNCTION = "patch" |
| 151 | + def define_schema(cls): |
| 152 | + return io.Schema( |
| 153 | + node_id="TomePatchModel", |
| 154 | + category="model_patches/unet", |
| 155 | + inputs=[ |
| 156 | + io.Model.Input("model"), |
| 157 | + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), |
| 158 | + ], |
| 159 | + outputs=[io.Model.Output()], |
| 160 | + ) |
155 | 161 |
|
156 |
| - CATEGORY = "model_patches/unet" |
157 |
| - |
158 |
| - def patch(self, model, ratio): |
159 |
| - self.u = None |
| 162 | + @classmethod |
| 163 | + def execute(cls, model, ratio) -> io.NodeOutput: |
| 164 | + u: Optional[Callable] = None |
160 | 165 | def tomesd_m(q, k, v, extra_options):
|
| 166 | + nonlocal u |
161 | 167 | #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
162 | 168 | #however from my basic testing it seems that using q instead gives better results
|
163 |
| - m, self.u = get_functions(q, ratio, extra_options["original_shape"]) |
| 169 | + m, u = get_functions(q, ratio, extra_options["original_shape"]) |
164 | 170 | return m(q), k, v
|
165 | 171 | def tomesd_u(n, extra_options):
|
166 |
| - return self.u(n) |
| 172 | + nonlocal u |
| 173 | + return u(n) |
167 | 174 |
|
168 | 175 | m = model.clone()
|
169 | 176 | m.set_model_attn1_patch(tomesd_m)
|
170 | 177 | m.set_model_attn1_output_patch(tomesd_u)
|
171 |
| - return (m, ) |
| 178 | + return io.NodeOutput(m) |
| 179 | + |
| 180 | + |
| 181 | +class TomePatchModelExtension(ComfyExtension): |
| 182 | + @override |
| 183 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 184 | + return [ |
| 185 | + TomePatchModel, |
| 186 | + ] |
172 | 187 |
|
173 | 188 |
|
174 |
| -NODE_CLASS_MAPPINGS = { |
175 |
| - "TomePatchModel": TomePatchModel, |
176 |
| -} |
| 189 | +async def comfy_entrypoint() -> TomePatchModelExtension: |
| 190 | + return TomePatchModelExtension() |
0 commit comments