Skip to content

Commit d35cdde

Browse files
committed
Add query-key normalization to CausalAttn and Attention classes, including learned scaling factor
1 parent b59afa0 commit d35cdde

File tree

3 files changed

+72
-18
lines changed

3 files changed

+72
-18
lines changed

algoperf/workloads/lm/lm_jax/nanodo_model.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@ class ModelConfig:
2525
rmsnorm_epsilon: float = 1e-6
2626
use_residual_scaling: bool = True
2727
tie_embeddings: bool = True # Whether to tie input and output embed
28+
qknorm_epsilon: float = 1e-6
2829

2930
dtype: jnp.dtype = jnp.float32
30-
attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
31+
attention_init: nn.initializers.Initializer = nn.initializers.normal(
32+
stddev=0.02
33+
)
3134
linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
3235
embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
3336

3437
def __post_init__(self):
35-
self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers))
38+
self.residual_init = nn.initializers.normal(
39+
stddev=0.02 / jnp.sqrt(2 * self.num_layers)
40+
)
3641

3742

3843
class Mlp(nn.Module):
@@ -43,7 +48,6 @@ class Mlp(nn.Module):
4348
@nn.compact
4449
def __call__(self, x_BxLxD: jax.Array):
4550
cfg = self.cfg
46-
# Use Xavier uniform initialization explicitly
4751
linear = partial(
4852
nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype
4953
)
@@ -58,7 +62,14 @@ def __call__(self, x_BxLxD: jax.Array):
5862
x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD)
5963
# Apply GLU activation
6064
x_BxLxF = nn.glu(x_BxLx2F, axis=-1)
61-
x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF)
65+
x_BxLxD = nn.Dense(
66+
cfg.model_dim,
67+
use_bias=False,
68+
dtype=cfg.dtype,
69+
kernel_init=cfg.residual_init
70+
if cfg.use_residual_scaling
71+
else cfg.linear_init,
72+
)(x_BxLxF)
6273
return x_BxLxD
6374

6475

@@ -114,8 +125,11 @@ class CausalAttn(nn.Module):
114125

115126
def setup(self):
116127
cfg = self.cfg
117-
assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}'
128+
assert cfg.model_dim % cfg.num_heads == 0, (
129+
f'D {cfg.model_dim} not divisible by H {cfg.num_heads}'
130+
)
118131
self.Dh = cfg.model_dim // cfg.num_heads
132+
self.eps = cfg.qknorm_epsilon
119133

120134
# Initialize rotary embeddings
121135
self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads)
@@ -129,15 +143,22 @@ def setup(self):
129143
use_bias=False,
130144
dtype=cfg.dtype,
131145
)
132-
133146
self.multilinear_query = self.multilinear(name='query')
134147
self.multilinear_key = self.multilinear(name='key')
135148
self.multilinear_value = self.multilinear(name='value')
149+
# See Henry et al. (2020) "Query Key Normalization for Transformers"
150+
seq_len = cfg.seq_len
151+
attn_scale0 = jnp.log2(seq_len**2 - seq_len)
152+
self.attn_scale = self.param(
153+
'attn_scale', nn.initializers.constant(attn_scale0), ()
154+
)
136155
self.output_projection = nn.DenseGeneral(
137156
features=cfg.model_dim,
138157
name='attn_out_proj',
139158
# axis=(-2, -1), #
140-
kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init,
159+
kernel_init=cfg.residual_init
160+
if cfg.use_residual_scaling
161+
else cfg.linear_init,
141162
use_bias=False,
142163
dtype=cfg.dtype,
143164
)
@@ -153,8 +174,9 @@ def __call__(self, x_BxLxD: jax.Array):
153174
# Apply rotary embeddings to Q and K
154175
q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis)
155176

156-
# Scale queries
157-
q_BxLxHxDh /= self.Dh**0.5
177+
# Apply QK normalization
178+
q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps
179+
k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps
158180

159181
# Compute attention scores
160182
att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)
@@ -166,6 +188,9 @@ def __call__(self, x_BxLxD: jax.Array):
166188
# Apply mask and softmax
167189
_NEG_INF = jnp.finfo(cfg.dtype).min
168190
att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
191+
att_BxHxLxL = (
192+
self.attn_scale * att_BxHxLxL
193+
) # Learned scaling factor for QK norm
169194
att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
170195
att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
171196

@@ -227,7 +252,10 @@ def setup(self):
227252
self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32))
228253
else:
229254
self.output_proj = nn.Dense(
230-
cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj'
255+
cfg.vocab_size,
256+
kernel_init=cfg.embed_init,
257+
dtype=cfg.dtype,
258+
name='output_proj',
231259
)
232260

233261
def __call__(self, y_BxL: jax.Array):
@@ -270,7 +298,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
270298
next_token_logits = logits[:, -1, :]
271299
last_token_id = y_BxL[:, -1]
272300
# Prevent predicting the same token consecutively
273-
next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf'))
301+
next_token_logits = next_token_logits.at[
302+
jnp.arange(len(last_token_id)), last_token_id
303+
].set(float('-inf'))
274304

275305
# Get the most likely token
276306
next_token = jnp.argmax(next_token_logits, axis=-1)
@@ -289,7 +319,14 @@ def main():
289319
"""Create and run the DecoderOnly Transformer model."""
290320
# Initialize model configuration with smaller parameters for demo
291321
B, L = (2, 128) # Batch size, sequence length
292-
cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128)
322+
cfg = ModelConfig(
323+
model_dim=128,
324+
num_heads=4,
325+
seq_len=L,
326+
num_layers=2,
327+
vocab_size=256,
328+
expanded_model_dim=4 * 128,
329+
)
293330
model = TransformerDo(cfg)
294331

295332
# Print model info

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ModelConfig:
2323
expanded_model_dim: int
2424
multiple_of: int = 256
2525
rmsnorm_epsilon: float = 1e-6
26+
qknorm_epsilon: float = 1e-6
2627
use_residual_scaling: bool = True
2728
tie_embeddings: bool = True
2829

@@ -92,9 +93,14 @@ def __init__(self, cfg: ModelConfig):
9293
# Split into Q, K, V sections
9394
wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0)
9495
for w in [wq, wk, wv]:
95-
nn.init.normal_(w, std=0.02)
96+
nn.init.normal_(w, std=0.02)
9697
nn.init.normal_(self.w_out.weight, std=0.02)
9798

99+
self.eps = cfg.qknorm_epsilon # e.g., 1e-6
100+
seq_len = cfg.seq_len
101+
attn_scale0 = math.log2(seq_len**2 - seq_len)
102+
self.attn_scale = nn.Parameter(torch.tensor(attn_scale0))
103+
98104
def forward(self, x, freqs_cis):
99105
bsz, seqlen, d = x.shape # (bsz, seqlen, d)
100106

@@ -117,10 +123,14 @@ def forward(self, x, freqs_cis):
117123
k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
118124
v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
119125

126+
# Apply QK normalization
127+
q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps
128+
k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps
129+
q *= self.attn_scale
130+
120131
out = F.scaled_dot_product_attention(
121-
q, k, v, is_causal=True
132+
q, k, v, is_causal=True, scale=1.0
122133
) # (bsz, nh, seqlen, h_dim)
123-
124134
out = (
125135
out.transpose(1, 2).contiguous().view(bsz, seqlen, d)
126136
) # (bsz, seqlen, d)
@@ -133,7 +143,11 @@ def __init__(self, layer_id: int, cfg: ModelConfig):
133143
super().__init__()
134144
self.attn = Attention(cfg)
135145
self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon)
136-
self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of)
146+
self.mlp = MLP(
147+
dim=cfg.model_dim,
148+
hidden_dim=cfg.expanded_model_dim,
149+
multiple_of=cfg.multiple_of,
150+
)
137151
self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon)
138152
self.layer_id = layer_id
139153

@@ -263,7 +277,9 @@ def _init_weights(self, module):
263277

264278
def _scale_residual_branches(self):
265279
for n, p in self.named_parameters():
266-
if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer
280+
if n.endswith('fc2.weight') or n.endswith(
281+
'w_out.weight'
282+
): # mlp/glu output layer
267283
torch.nn.init.normal_(
268284
p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers)
269285
)

tests/modeldiffs/lm/compare.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ def test_initialization_statistics():
644644
num_layers=12,
645645
vocab_size=50000,
646646
expanded_model_dim=2048,
647-
dtype=jnp.float32)
647+
dtype=jnp.float32,
648+
)
648649
jax_model = TransformerDo(jax_cfg)
649650
jax_params = jax_model.init(
650651
jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32)

0 commit comments

Comments
 (0)