Skip to content

Commit 41eff53

Browse files
authored
Disable FlexAttention max-autotune when deterministic is used (#1808)
With max-autotune, FlexAttention is not deterministic even if torch.use_deterministic_algorithms is True. When deterministic mode is set, we should also remove the usage of `max-autotune`.
1 parent a6f0cfc commit 41eff53

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torchtitan/distributed/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def set_determinism(
106106
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
107107
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
108108

109+
# Ensure flex_attention is compiled without max-autotune. This is needed to ensure
110+
# reproducibility, since the autotune results may not be deterministic.
111+
from torch.nn.attention.flex_attention import flex_attention
112+
113+
from torchtitan.models.attention import FlexAttention
114+
115+
FlexAttention.flex_attn = torch.compile(flex_attention)
116+
109117
if not world_mesh:
110118
if seed is not None:
111119
torch.manual_seed(seed)

0 commit comments

Comments
 (0)