Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,47 @@ def __post_init__(self):
if self.ulysses_degree is None:
self.ulysses_degree = 1

if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)

@property
def mesh_shape(self) -> Tuple[int, int]:
"""Shape of the device mesh (ring_degree, ulysses_degree)."""
return (self.ring_degree, self.ulysses_degree)

@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")

def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

if self.ulysses_degree * self.ring_degree > world_size:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we hit line as both cannot be set, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both can be set techinically, but currently both can't be > 1. Also this is for cases where you have 3 GPUs available and you set something like ulysses_degree=1 and ring_degree==4 (more GPUs being requested is greater than world_size)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels slightly confusing to me but since we're erroring out early for unsupported ulysses_degree and ring_degree value combos, I think it's okay.

raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()

self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
Comment on lines +118 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't they be None? Why unguard?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are internal attributes that are derived from mesh which is set through the setup method. The device mesh object is also only dynamically created when enabled_parallelism is called.

The guards are redundant, they would always be None unless set explicitly for some custom debugging.



@dataclass
Expand All @@ -119,22 +137,22 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None

def setup(
self,
rank: int,
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)


@dataclass(frozen=True)
Expand Down
27 changes: 9 additions & 18 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_supports_context_parallel = {}
_supports_context_parallel = set()
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS

Expand All @@ -224,7 +224,9 @@ def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
if supports_context_parallel:
cls._supports_context_parallel.add(backend.value)

return func

return decorator
Expand All @@ -238,15 +240,12 @@ def list_backends(cls):
return list(cls._backends.keys())

@classmethod
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
def _is_context_parallel_available(
cls,
backend: AttentionBackendName,
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
supports_context_parallel = backend.value in cls._supports_context_parallel
return supports_context_parallel


@contextlib.contextmanager
Expand Down Expand Up @@ -293,14 +292,6 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)

if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)

kwargs = {
"query": query,
"key": key,
Expand Down
85 changes: 53 additions & 32 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,59 +1483,72 @@ def enable_parallelism(
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention

logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)

if not torch.distributed.is_available() and not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)

from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from .attention_processor import Attention, MochiAttention

if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)

if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")

rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())

cp_mesh = None
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)

# Step 1: Validate attention backend supports context parallelism if enabled
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
)
for module in self.modules():
if not isinstance(module, attention_classes):
continue

config.setup(rank, world_size, device, cp_mesh=cp_mesh)
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue

if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
attention_backend = processor._attention_backend
if attention_backend is None:
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
else:
attention_backend = AttentionBackendName(attention_backend)

if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
)

# All modules use the same attention processor and backend. We don't need to
# iterate over all modules after checking the first processor
break

mesh = None
if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan)
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
)

config.setup(rank, world_size, device, mesh=mesh)
self._parallel_config = config

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
Expand All @@ -1544,6 +1557,14 @@ def enable_parallelism(
continue
processor._parallel_config = config

if config.context_parallel_config is not None:
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, config.context_parallel_config, cp_plan)

@classmethod
def _load_pretrained_model(
cls,
Expand Down
Loading