diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85d..40d61fa2 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -23,18 +23,34 @@ def init_distributed() -> torch.device: # Initialize distributed inference world_size = int(os.environ.get("WORLD_SIZE", 1)) rank = int(os.environ.get("RANK", 0)) + xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() + + if xpu_available: + backend = "xccl" + device_type = "xpu" + else: + backend = "nccl" + device_type = "cuda" + if world_size > 1: dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=rank + backend=backend, init_method="env://", world_size=world_size, rank=rank ) - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") - # Warm up NCCL to avoid first-time latency + if xpu_available: + torch.xpu.set_device(rank) + else: + torch.cuda.set_device(rank) + device = torch.device(f"{device_type}:{rank}") + + # Warm up backend to avoid first-time latency if world_size > 1: x = torch.ones(1, device=device) dist.all_reduce(x) - torch.cuda.synchronize(device) + if xpu_available: + torch.xpu.synchronize(device) + else: + torch.cuda.synchronize(device) suppress_output(rank) return device diff --git a/gpt_oss/triton/attention.py b/gpt_oss/triton/attention.py index bf689055..969b9a18 100644 --- a/gpt_oss/triton/attention.py +++ b/gpt_oss/triton/attention.py @@ -59,9 +59,10 @@ def _attn_fwd( q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) if BANDWIDTH: - lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH + 1), start_q + (start_m + 1) * BLOCK_M else: - lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + hi = tl.minimum(N_KV_CTX, hi) for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -216,12 +217,17 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu if num_queries > num_keys: pytest.skip("too many queries") - q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda() - k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() - v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() - sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda() + if torch.xpu.is_available(): + device = "xpu" + else: + device = "cuda" + + q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().to(device) + k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().to(device) + v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().to(device) + sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().to(device) - start_q = torch.tensor([start_q], dtype=torch.int32).cuda() + start_q = torch.tensor([start_q], dtype=torch.int32).to(device) o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) diff --git a/gpt_oss/triton/model.py b/gpt_oss/triton/model.py index 2e14478f..83f93dad 100644 --- a/gpt_oss/triton/model.py +++ b/gpt_oss/triton/model.py @@ -437,7 +437,10 @@ def from_checkpoint( checkpoint = Checkpoint(path, device) for name, param in model.named_parameters(): - torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() loaded_tensor = checkpoint.get(name) if "mlp1" in name: @@ -463,7 +466,10 @@ def from_checkpoint( param.data.copy_(loaded_tensor) # NOTE: Required to avoid OOM errors - torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() return model @@ -476,10 +482,15 @@ def __init__(self, checkpoint: str, context: int, device: torch.device): self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device) # warmup self.model(self.input_token[None, :], caches=self.caches) - # capture for sampling - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph): + + if self.device.type == "xpu": + self.graph = None self.logits = self.model(self.input_token[None, :], caches=self.caches)[0] + else: + # capture for sampling + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph): + self.logits = self.model(self.input_token[None, :], caches=self.caches)[0] @torch.inference_mode() def generate(self, @@ -497,7 +508,10 @@ def generate(self, num_generated_tokens = 0 while max_tokens == 0 or num_generated_tokens < max_tokens: self.input_token[0] = predicted_token - self.graph.replay() + if self.graph is not None: + self.graph.replay() + else: + self.logits = self.model(self.input_token[None, :], caches=self.caches)[0] if temperature == 0.0: predicted_token = torch.argmax(self.logits[-1, :], dim=-1).item() else: diff --git a/gpt_oss/triton/moe.py b/gpt_oss/triton/moe.py index 925dbd54..86d915ed 100644 --- a/gpt_oss/triton/moe.py +++ b/gpt_oss/triton/moe.py @@ -9,13 +9,14 @@ from triton_kernels.numerics import InFlexData from triton_kernels.routing import routing from triton_kernels.tensor import convert_layout -from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout +from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout, make_default_matmul_mxfp4_w_layout from triton_kernels.tensor import wrap_torch_tensor, FP4 def quantize_mx4(w): w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) - w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1) + w_layout_cls, w_layout_kwargs = make_default_matmul_mxfp4_w_layout(mx_axis=1) + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), w_layout_cls, **w_layout_kwargs) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale