diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index f9e780445..16b5e9a9d 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -48,6 +48,17 @@ multiple_of=1024, rope_theta=500000, ), + "8B_flex_attn": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + use_flex_attn=True, + attn_mask_type="block_causal", + ), "70B": TransformerModelArgs( dim=8192, n_layers=80,