Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
"""Utilities for distribution strategy with JAX backend."""
"""Utilities for distribution strategy with JAX backend.

This file contains the core JAX distribution primitives from Keras,
along with higher-level device management and auto-configuration utilities.
This version does not use try-except blocks for error handling.
"""

import logging
from typing import Dict
from typing import List
from typing import Optional

import jax
import numpy as np
Expand All @@ -8,6 +18,8 @@
from keras.src.utils import jax_utils
from keras.src.utils import rng_utils

logger = logging.getLogger(__name__)


def list_devices(device_type=None):
"""Return all the available devices based on the device type.
Expand All @@ -27,6 +39,153 @@ def list_devices(device_type=None):
return [f"{device.platform}:{device.id}" for device in jax_devices]


def get_device_info(device_id: str) -> Dict[str, any]:
"""
Get detailed information about a specific device.

Args:
device_id: Device identifier (e.g., 'gpu:0', 'tpu:0', 'cpu:0')

Returns:
Dictionary containing device information
"""
device_info = {
"id": device_id,
"type": None,
"index": None,
"memory": None,
"capabilities": None,
}

device_type, device_index = device_id.split(":")
device_info["type"] = device_type.upper()
device_info["index"] = int(device_index)

return device_info


def get_best_devices(count: int = 1) -> List[str]:
"""
Get the best available devices for tensor parallelism.

Args:
count: Number of devices needed

Returns:
List of best device identifiers
"""
all_devices = list_devices()

if count <= 0:
return []

if count > len(all_devices):
logger.warning(
f"Requested {count} devices but only {len(all_devices)} available"
)
count = len(all_devices)

return all_devices[:count]


def get_device_backend(device_type: str) -> str:
"""
Get the recommended backend for a device type.

Args:
device_type: Device type ('tpu', 'gpu', 'cpu')

Returns:
Recommended backend name
"""
backend_mapping = {"tpu": "jax", "gpu": "jax", "cpu": "jax"}

return backend_mapping.get(device_type.lower(), "jax")


def validate_device_placement(device_id: str) -> bool:
"""
Validate if a device can be used for tensor operations.

Args:
device_id: Device identifier

Returns:
True if device is valid and available
"""
all_devices = list_devices()
return device_id in all_devices


def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]:
"""
Get memory information for a device (if available).

Args:
device_id: Device identifier

Returns:
Memory information dictionary or None if not available
"""
if device_id.startswith("gpu:"):
return {
"type": "GPU",
"index": int(device_id.split(":")[1]),
"memory": "Available",
}
elif device_id.startswith("tpu:"):
return {
"type": "TPU",
"index": int(device_id.split(":")[1]),
"memory": "TPU Memory",
}
elif device_id.startswith("cpu:"):
return {
"type": "CPU",
"index": int(device_id.split(":")[1]),
"memory": "System RAM",
}

return None


def auto_configure_tensor_parallel(
world_size: int = None, backend: str = None
) -> Dict[str, any]:
"""
Automatically configure tensor parallelism with the best available devices.

Args:
world_size: Number of devices to use (if None, uses all available)
backend: Backend to use (if None, will be set to 'jax')

Returns:
Configuration dictionary with devices, backend, and other settings
"""
all_devices = list_devices()

if not all_devices:
raise RuntimeError("No devices available for tensor parallelism")

if world_size is None:
world_size = len(all_devices)
else:
world_size = min(world_size, len(all_devices))

selected_devices = all_devices[:world_size]

recommended_backend = "jax"

config = {
"devices": selected_devices,
"world_size": world_size,
"backend": recommended_backend,
}

logger.info(f"Auto-configured tensor parallelism: {config}")
return config


def distribute_variable(value, layout):
"""Create a distributed variable for JAX.

Expand Down
198 changes: 198 additions & 0 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ def list_devices(device_type=None):
return distribution_lib.list_devices(device_type)


@keras_export("keras.distribution.get_best_devices")
def get_best_devices(count):
"""Return all the available devices based on the device type.

Note: in a distributed setting, global devices are returned.

Args:
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`.
Defaults to `"gpu"` or `"tpu"` if available when
`device_type` is not provided. Otherwise
will return the `"cpu"` devices.

Return:
List of devices that are available for distribute computation.
"""
return distribution_lib.get_best_devices(count)


@keras_export("keras.distribution.initialize")
def initialize(job_addresses=None, num_processes=None, process_id=None):
"""Initialize the distribution system for multi-host/process setting.
Expand Down Expand Up @@ -896,3 +914,183 @@ def set_distribution(value):
value: a `Distribution` instance.
"""
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value)


@keras_export("keras.distribution.AutoTPDistribution")
class AutoTPDistribution(Distribution):
"""A distribution strategy for automated tensor and data parallelism.

This distribution strategy provides a high-level abstraction for combining
both data parallelism and tensor parallelism. It automatically shards Keras
model's layers across multiple devices (tensor parallelism) while also
distributing the input data across those devices (data parallelism).

It uses a `DeviceMesh` to represent the grid of computational devices. If no
mesh is provided, it creates one using all available devices. The mesh must
have a 'data' axis for data sharding and a 'model' axis for model sharding.

Internally, this class wraps the user-provided Keras `Model` with the
`TensorParallelKeras` utility to handle the model sharding.

Args:
model: A `keras.Model` instance to be distributed.
device_mesh: (Optional) A `keras.distribution.DeviceMesh` instance.
If not provided, a `DeviceMesh` will be automatically created using
all available devices, arranging them for both data and model
parallelism.
auto_shard_dataset: (Optional) A boolean indicating whether to
automatically shard `tf.data.Dataset` instances across multiple
processes. Defaults to `True`.

Attributes:
model: The wrapped, tensor-parallel `keras.Model` instance that is
ready for distributed training.
device_mesh: The `DeviceMesh` instance used for distribution.

Raises:
RuntimeError: If no computational devices are found and `device_mesh`
is not provided.
ValueError: If the provided `device_mesh` does not have a 'data' axis.

Example:

```python
# Create a simple Keras model
inputs = keras.Input(shape=(64,))
x = keras.layers.Dense(128, activation="relu")(inputs)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Create the distribution strategy with the model
# It will automatically use all available GPUs/TPUs
distribution = keras.distribution.AutoTPDistribution(model)

# The distributed model is accessed via the .model attribute
distributed_model = distribution.model

# Compile the model as usual
distributed_model.compile(optimizer="adam", loss="mse")

# Prepare a dataset
input_data = np.random.rand(32, 64)
target_data = np.random.rand(32, 10)

# Train the model
distributed_model.fit(input_data, target_data)
```
"""

def __init__(self, model, device_mesh=None, auto_shard_dataset=True):
if device_mesh is None:
all_devices = list_devices()
if not all_devices:
raise RuntimeError("No computational devices found.")
device_mesh = DeviceMesh(
shape=(1, len(all_devices)),
axis_names=("data", "model"),
devices=all_devices,
)

if "data" not in device_mesh.axis_names:
raise ValueError(
"DeviceMesh for AutoTPDistribution must have a 'data' axis."
)
batch_dim_name = "data"

super().__init__(device_mesh, batch_dim_name, auto_shard_dataset)

self._original_model = model
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
self._is_multi_process = self._num_process > 1
from keras.src.distribution.tensor_parallel.tensor_parallel import (
TensorParallelKeras,
)

self.model = TensorParallelKeras(
model=self._original_model,
world_size=np.prod(self.device_mesh.shape),
device_ids=self.device_mesh.devices.flatten().tolist(),
)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self.batch_dim_name
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
warnings.warn(
"Variable layout is determined automatically within "
"AutoTPDistribution. This method will return a replicated layout."
)
return TensorLayout([None] * len(variable.shape), self.device_mesh)

def get_tensor_layout(self, path):
return None

def distribute_dataset(self, dataset):
"""Distributes the dataset across processes based on the device mesh."""
if not self._is_multi_process or not self.auto_shard_dataset:
return dataset

from keras.src.utils.module_utils import tensorflow as tf

if not tf.available or not isinstance(dataset, tf.data.Dataset):
raise ValueError(
"Only `tf.data.Dataset` is supported for auto-sharding, "
f"got {type(dataset)}"
)

from tensorflow.python.data.experimental.ops import (
distribute as tf_data_distribute,
)

global_batch_size = tf_data_distribute.compute_batch_size(dataset)
if global_batch_size.numpy() < 0:
raise ValueError(
"The batch size of the input dataset is unknown. "
"Please configure the batch size for the input dataset, "
"e.g., via `dataset.batch(batch_size)`"
)

mesh_batch_dim_index = self.device_mesh.axis_names.index(
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]

if num_model_replicas == 1:
return dataset.prefetch(tf.data.AUTOTUNE)

num_model_replicas_per_process = num_model_replicas / self._num_process
if num_model_replicas_per_process >= 1:
if global_batch_size % self._num_process != 0:
raise ValueError(
"Global batch size must be divisible by the number of "
f"processes. `global_batch_size`={global_batch_size} and "
f"`num_process`={self._num_process}"
)
per_process_batch_size = global_batch_size // self._num_process
distributed_dataset = dataset.rebatch(per_process_batch_size)
distributed_dataset = distributed_dataset.shard(
num_shards=self._num_process,
index=self._process_id,
)
return distributed_dataset.prefetch(tf.data.AUTOTUNE)
else:
if global_batch_size % num_model_replicas != 0:
raise ValueError(
"Global batch size must be divisible by the number of "
f"replicas. `global_batch_size`={global_batch_size} and "
f"`num_model_replicas`={num_model_replicas}"
)
per_replica_batch_size = global_batch_size // num_model_replicas
distributed_dataset = dataset.rebatch(per_replica_batch_size)

processes_per_replica = self._num_process // num_model_replicas
data_shard_id = self._process_id // processes_per_replica

distributed_dataset = distributed_dataset.shard(
num_shards=num_model_replicas,
index=data_shard_id,
)
return distributed_dataset.prefetch(tf.data.AUTOTUNE)
Loading
Loading