Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def load_model_weights(self, sd, unet_prefix=""):
to_load[k[len(unet_prefix):]] = sd.pop(k)

to_load = self.model_config.process_unet_state_dict(to_load)
comfy.model_management.free_ram(state_dict=to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
Expand Down
51 changes: 40 additions & 11 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,20 @@ def get_torch_device_name(device):
except:
logging.warning("Could not pick default device.")

current_ram_listeners = set()

def register_ram_listener(listener):
current_ram_listeners.add(listener)

def unregister_ram_listener(listener):
current_ram_listeners.discard(listener)

def free_ram(extra_ram=0, state_dict={}):
for tensor in state_dict.values():
if isinstance(tensor, torch.Tensor):
extra_ram += tensor.numel() * tensor.element_size()
for listener in current_ram_listeners:
listener.free_ram(extra_ram)

current_loaded_models = []

Expand Down Expand Up @@ -521,12 +535,18 @@ def should_reload_model(self, force_patch_weights=False):
return False

def model_unload(self, memory_to_free=None, unpatch_weights=True):
if self.model is None:
return True
logging.debug(f"Unloading {self.model.model.__class__.__name__}")
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
offload_modules(modules_to_offload, self.model.offload_device)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
if self.model is not None:
modules_to_offload = self.model.detach(unpatch_weights)
offload_modules(modules_to_offload, self.model.offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
Expand All @@ -543,7 +563,7 @@ def __del__(self):
self._patcher_finalizer.detach()

def is_dead(self):
return self.real_model() is not None and self.model is None
return self.real_model is not None and self.real_model() is not None and self.model is None


def use_more_memory(extra_memory, loaded_models, device):
Expand Down Expand Up @@ -578,6 +598,13 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def offload_modules(modules, offload_device):
for module in modules:
if module() is None:
continue
module().to(offload_device)
free_ram()

def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
Expand All @@ -588,23 +615,25 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model))
shift_model.currently_used = False

for x in sorted(can_unload):
i = x[-1]
shift_model = x[-1]
i = x[-2]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
if shift_model.model_unload(memory_to_free):
unloaded_model.append((i, shift_model))

for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
for i, shift_model in sorted(unloaded_model, reverse=True):
unloaded_models.append(shift_model)
if shift_model in current_loaded_models:
current_loaded_models.remove(shift_model)

if len(unloaded_model) > 0:
soft_empty_cache()
Expand Down Expand Up @@ -739,7 +768,7 @@ def cleanup_models_gc():
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete

for i in to_delete:
Expand Down
24 changes: 16 additions & 8 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging
import math
import uuid
import weakref
from typing import Callable, Optional

import torch
Expand Down Expand Up @@ -830,6 +831,7 @@ def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True,

def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
modules_to_move = []
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
Expand All @@ -854,7 +856,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
self.backup.clear()

if device_to is not None:
self.model.to(device_to)
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
modules_to_move.append(weakref.ref(self.model))
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
Expand All @@ -868,12 +871,14 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])

self.object_patches_backup.clear()
return modules_to_move

def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
modules_to_move = []
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
Expand Down Expand Up @@ -910,7 +915,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
modules_to_move.append(weakref.ref(m))
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
Expand Down Expand Up @@ -946,20 +951,22 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
return memory_freed, modules_to_move

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)

self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
Expand All @@ -971,19 +978,20 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
comfy.model_management.offload_modules(self.detach(), self.offload_device)
raise e

return self.model.model_loaded_weight_memory - current_used

def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
modules_to_offload = []
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
return modules_to_offload

def current_loaded_device(self):
return self.model.device
Expand Down
3 changes: 3 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def encode(self, text):

def load_sd(self, sd, full_model=False):
if full_model:
comfy.model_management.free_ram(state_dict=sd)
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
Expand Down Expand Up @@ -625,6 +626,7 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()

comfy.model_management.free_ram(state_dict=sd)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
Expand Down Expand Up @@ -933,6 +935,7 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
comfy.model_management.free_ram(state_dict=model_data)
model.load_state_dict(model_data)
return StyleModel(model)

Expand Down
36 changes: 21 additions & 15 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def clean_unused(self):
self._clean_cache()
self._clean_subcaches()

def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass

def _set_immediate(self, node_id, value):
Expand Down Expand Up @@ -284,7 +284,7 @@ def all_node_ids(self):
def clean_unused(self):
pass

def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass

def get(self, node_id):
Expand Down Expand Up @@ -366,9 +366,10 @@ async def ensure_subcache_for(self, node_id, children_ids):

class RAMPressureCache(LRUCache):

def __init__(self, key_class):
def __init__(self, key_class, min_headroom=4.0):
super().__init__(key_class, 0)
self.timestamps = {}
self.min_headroom = min_headroom

def clean_unused(self):
self._clean_subcaches()
Expand All @@ -381,19 +382,10 @@ def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)

def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)

if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return

def _build_clean_list(self):
clean_list = []

for key, (outputs, _), in self.cache.items():
for key, (_, outputs), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])

ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
Expand All @@ -416,8 +408,22 @@ def scan_list_for_ram_usage(outputs):
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
return clean_list

def free_ram(self, extra_ram=0):
headroom_target = self.min_headroom + (extra_ram / (1024**3))
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)

if _ram_gb() > headroom_target:
return
gc.collect()
if _ram_gb() > headroom_target:
return

clean_list = self._build_clean_list()

while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
14 changes: 11 additions & 3 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, cache_type=None, cache_args={}):
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
cache_ram = cache_args.get("ram", 4.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
Expand All @@ -129,7 +129,7 @@ def init_lru_cache(self, cache_size):
self.objects = HierarchicalCache(CacheKeySetID)

def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
self.objects = HierarchicalCache(CacheKeySetID)

def init_null_cache(self):
Expand Down Expand Up @@ -613,13 +613,21 @@ async def await_completion():

class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None):
self.caches = None
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()

def reset(self):
if self.caches is not None:
for cache in self.caches.all:
comfy.model_management.unregister_ram_listener(cache)

self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)

for cache in self.caches.all:
comfy.model_management.register_ram_listener(cache)
self.status_messages = []
self.success = True

Expand Down Expand Up @@ -717,7 +725,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
self.caches.outputs.free_ram()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
Expand Down