diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285ed1d..76dc3e37030c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -304,6 +304,7 @@ def load_model_weights(self, sd, unet_prefix=""): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) + comfy.model_management.free_ram(state_dict=to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe4b9..3a5a6c4e3dde 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -445,6 +445,20 @@ def get_torch_device_name(device): except: logging.warning("Could not pick default device.") +current_ram_listeners = set() + +def register_ram_listener(listener): + current_ram_listeners.add(listener) + +def unregister_ram_listener(listener): + current_ram_listeners.discard(listener) + +def free_ram(extra_ram=0, state_dict={}): + for tensor in state_dict.values(): + if isinstance(tensor, torch.Tensor): + extra_ram += tensor.numel() * tensor.element_size() + for listener in current_ram_listeners: + listener.free_ram(extra_ram) current_loaded_models = [] @@ -521,12 +535,18 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): + if self.model is None: + return True + logging.debug(f"Unloading {self.model.model.__class__.__name__}") if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free) + offload_modules(modules_to_offload, self.model.offload_device) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + if self.model is not None: + modules_to_offload = self.model.detach(unpatch_weights) + offload_modules(modules_to_offload, self.model.offload_device) self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -543,7 +563,7 @@ def __del__(self): self._patcher_finalizer.detach() def is_dead(self): - return self.real_model() is not None and self.model is None + return self.real_model is not None and self.real_model() is not None and self.model is None def use_more_memory(extra_memory, loaded_models, device): @@ -578,6 +598,13 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() +def offload_modules(modules, offload_device): + for module in modules: + if module() is None: + continue + module().to(offload_device) + free_ram() + def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() unloaded_model = [] @@ -588,23 +615,25 @@ def free_memory(memory_required, device, keep_loaded=[]): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): - can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model)) shift_model.currently_used = False for x in sorted(can_unload): - i = x[-1] + shift_model = x[-1] + i = x[-2] memory_to_free = None if not DISABLE_SMART_MEMORY: free_mem = get_free_memory(device) if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): - unloaded_model.append(i) + if shift_model.model_unload(memory_to_free): + unloaded_model.append((i, shift_model)) - for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + for i, shift_model in sorted(unloaded_model, reverse=True): + unloaded_models.append(shift_model) + if shift_model in current_loaded_models: + current_loaded_models.remove(shift_model) if len(unloaded_model) > 0: soft_empty_cache() @@ -739,7 +768,7 @@ def cleanup_models_gc(): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None: to_delete = [i] + to_delete for i in to_delete: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac772758b5..078c23019d4a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import logging import math import uuid +import weakref from typing import Callable, Optional import torch @@ -830,6 +831,7 @@ def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() + modules_to_move = [] if unpatch_weights: self.unpatch_hooks() self.unpin_all_weights() @@ -854,7 +856,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.backup.clear() if device_to is not None: - self.model.to(device_to) + modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ] + modules_to_move.append(weakref.ref(self.model)) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -868,12 +871,14 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + return modules_to_move def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 + modules_to_move = [] unload_list = self._load_list() unload_list.sort() offload_buffer = self.model.model_offload_buffer_memory @@ -910,7 +915,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + modules_to_move.append(weakref.ref(m)) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -946,20 +951,22 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) - return memory_freed + return memory_freed, modules_to_move def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) if extra_memory < 0 and not unpatch_weights: - self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + _, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: @@ -971,7 +978,7 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): try: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: - self.detach() + comfy.model_management.offload_modules(self.detach(), self.offload_device) raise e return self.model.model_loaded_weight_memory - current_used @@ -979,11 +986,12 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) + modules_to_offload = [] if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) - return self.model + return modules_to_offload def current_loaded_device(self): return self.model.device diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5001..bfa63debcdb3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -260,6 +260,7 @@ def encode(self, text): def load_sd(self, sd, full_model=False): if full_model: + comfy.model_management.free_ram(state_dict=sd) return self.cond_stage_model.load_state_dict(sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -625,6 +626,7 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + comfy.model_management.free_ram(state_dict=sd) m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -933,6 +935,7 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) + comfy.model_management.free_ram(state_dict=model_data) model.load_state_dict(model_data) return StyleModel(model) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc997..43f88246952b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -193,7 +193,7 @@ def clean_unused(self): self._clean_cache() self._clean_subcaches() - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def _set_immediate(self, node_id, value): @@ -284,7 +284,7 @@ def all_node_ids(self): def clean_unused(self): pass - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def get(self, node_id): @@ -366,9 +366,10 @@ async def ensure_subcache_for(self, node_id, children_ids): class RAMPressureCache(LRUCache): - def __init__(self, key_class): + def __init__(self, key_class, min_headroom=4.0): super().__init__(key_class, 0) self.timestamps = {} + self.min_headroom = min_headroom def clean_unused(self): self._clean_subcaches() @@ -381,19 +382,10 @@ def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() return super().get(node_id) - def poll(self, ram_headroom): - def _ram_gb(): - return psutil.virtual_memory().available / (1024**3) - - if _ram_gb() > ram_headroom: - return - gc.collect() - if _ram_gb() > ram_headroom: - return - + def _build_clean_list(self): clean_list = [] - for key, (outputs, _), in self.cache.items(): + for key, (_, outputs), in self.cache.items(): oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -416,8 +408,22 @@ def scan_list_for_ram_usage(outputs): #In the case where we have no information on the node ram usage at all, #break OOM score ties on the last touch timestamp (pure LRU) bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + return clean_list + + def free_ram(self, extra_ram=0): + headroom_target = self.min_headroom + (extra_ram / (1024**3)) + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > headroom_target: + return + gc.collect() + if _ram_gb() > headroom_target: + return + + clean_list = self._build_clean_list() - while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list: _, _, key = clean_list.pop() del self.cache[key] gc.collect() diff --git a/execution.py b/execution.py index 17c77beab8ec..dd5fc8baf7c9 100644 --- a/execution.py +++ b/execution.py @@ -107,7 +107,7 @@ def __init__(self, cache_type=None, cache_args={}): self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.RAM_PRESSURE: - cache_ram = cache_args.get("ram", 16.0) + cache_ram = cache_args.get("ram", 4.0) self.init_ram_cache(cache_ram) logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: @@ -129,7 +129,7 @@ def init_lru_cache(self, cache_size): self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -613,13 +613,21 @@ async def await_completion(): class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): + self.caches = None self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): + if self.caches is not None: + for cache in self.caches.all: + comfy.model_management.unregister_ram_listener(cache) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + + for cache in self.caches.all: + comfy.model_management.register_ram_listener(cache) self.status_messages = [] self.success = True @@ -717,7 +725,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + self.caches.outputs.free_ram() else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)