diff --git a/comfy/model_management.py b/comfy/model_management.py index a4410f2ece53..d7fb24e01190 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -39,6 +39,7 @@ class CPUState(Enum): GPU = 0 CPU = 1 MPS = 2 + OCL = 3 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM @@ -101,6 +102,14 @@ def get_supported_float8_types(): # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. +ocl_available = False +try: + import pytorch_ocl + import torch.ocl + ocl_available = True +except ImportError: + pass + try: import intel_extension_for_pytorch as ipex # noqa: F401 except: @@ -138,6 +147,10 @@ def get_supported_float8_types(): except: ixuca_available = False +if ocl_available: + # TODO gate behind flag. + cpu_state = CPUState.OCL + if args.cpu: cpu_state = CPUState.CPU @@ -167,6 +180,12 @@ def is_ixuca(): return True return False +def is_ocl(): + global ocl_available + if ocl_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -177,6 +196,8 @@ def get_torch_device(): return torch.device("mps") if cpu_state == CPUState.CPU: return torch.device("cpu") + if cpu_state == CPUState.OCL: + return torch.device("ocl:0") else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) @@ -192,7 +213,7 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: @@ -217,6 +238,9 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_mlu + elif is_ocl(): + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -1231,7 +1255,7 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps' or dev.type == 'ocl'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: @@ -1259,6 +1283,15 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_mlu, _ = torch.mlu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_mlu + mem_free_torch + elif is_ocl(): + # stats = torch.ocl.memory_stats(dev) + # mem_active = stats['active_bytes.all.current'] + # mem_reserved = stats['reserved_bytes.all.current'] + # mem_free_ocl, _ = torch.ocl.mem_get_info(dev) + # mem_free_torch = mem_reserved - mem_active + # mem_free_total = mem_free_mlu + mem_free_torch + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1337,6 +1370,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_ocl(): + # TODO ? RustiCL now supports fp16 at least. + return True + if is_ixuca(): return True @@ -1413,6 +1450,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return True return False + if is_ocl(): + # TODO + return True + props = torch.cuda.get_device_properties(device) if is_mlu():