diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d2b60e347eff..fdc31e7e2f99 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -90,6 +90,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") +parser.add_argument("--total-ram", type=float, default=0, help="Maximum system RAM visible to comfy in GB (default 0: all)") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80091..bc8179410d76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -192,8 +192,12 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total + if hasattr(dev, "type") and (dev.type == "cpu" or dev.type == "mps"): + mem_total = 0 + if args.total_ram != 0: + mem_total = args.total_ram * 1024 * 1024 + else: + mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: @@ -236,8 +240,15 @@ def mac_version(): return None total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) -total_ram = psutil.virtual_memory().total / (1024 * 1024) -logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) + +total_ram = 0 +if args.total_ram != 0: + total_ram = args.total_ram * (1024) # arg in GB +else: + total_ram = psutil.virtual_memory().total / (1024 * 1024) +logging.info( + "Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram) +) try: logging.info("pytorch version: {}".format(torch_version))