@@ -25,14 +25,19 @@ class ModelConfig:
2525 rmsnorm_epsilon : float = 1e-6
2626 use_residual_scaling : bool = True
2727 tie_embeddings : bool = True # Whether to tie input and output embed
28+ qknorm_epsilon : float = 1e-6
2829
2930 dtype : jnp .dtype = jnp .float32
30- attention_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
31+ attention_init : nn .initializers .Initializer = nn .initializers .normal (
32+ stddev = 0.02
33+ )
3134 linear_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
3235 embed_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
3336
3437 def __post_init__ (self ):
35- self .residual_init = nn .initializers .normal (stddev = 0.02 / jnp .sqrt (2 * self .num_layers ))
38+ self .residual_init = nn .initializers .normal (
39+ stddev = 0.02 / jnp .sqrt (2 * self .num_layers )
40+ )
3641
3742
3843class Mlp (nn .Module ):
@@ -43,7 +48,6 @@ class Mlp(nn.Module):
4348 @nn .compact
4449 def __call__ (self , x_BxLxD : jax .Array ):
4550 cfg = self .cfg
46- # Use Xavier uniform initialization explicitly
4751 linear = partial (
4852 nn .Dense , kernel_init = cfg .linear_init , use_bias = False , dtype = cfg .dtype
4953 )
@@ -58,7 +62,14 @@ def __call__(self, x_BxLxD: jax.Array):
5862 x_BxLx2F = linear (2 * hidden_dim )(x_BxLxD )
5963 # Apply GLU activation
6064 x_BxLxF = nn .glu (x_BxLx2F , axis = - 1 )
61- x_BxLxD = nn .Dense (cfg .model_dim , use_bias = False , dtype = cfg .dtype , kernel_init = cfg .residual_init if cfg .use_residual_scaling else cfg .linear_init )(x_BxLxF )
65+ x_BxLxD = nn .Dense (
66+ cfg .model_dim ,
67+ use_bias = False ,
68+ dtype = cfg .dtype ,
69+ kernel_init = cfg .residual_init
70+ if cfg .use_residual_scaling
71+ else cfg .linear_init ,
72+ )(x_BxLxF )
6273 return x_BxLxD
6374
6475
@@ -114,8 +125,11 @@ class CausalAttn(nn.Module):
114125
115126 def setup (self ):
116127 cfg = self .cfg
117- assert cfg .model_dim % cfg .num_heads == 0 , f'D { cfg .model_dim } not divisible by H { cfg .num_heads } '
128+ assert cfg .model_dim % cfg .num_heads == 0 , (
129+ f'D { cfg .model_dim } not divisible by H { cfg .num_heads } '
130+ )
118131 self .Dh = cfg .model_dim // cfg .num_heads
132+ self .eps = cfg .qknorm_epsilon
119133
120134 # Initialize rotary embeddings
121135 self .freqs_cis = init_rope (cfg .model_dim , cfg .seq_len , cfg .num_heads )
@@ -129,15 +143,22 @@ def setup(self):
129143 use_bias = False ,
130144 dtype = cfg .dtype ,
131145 )
132-
133146 self .multilinear_query = self .multilinear (name = 'query' )
134147 self .multilinear_key = self .multilinear (name = 'key' )
135148 self .multilinear_value = self .multilinear (name = 'value' )
149+ # See Henry et al. (2020) "Query Key Normalization for Transformers"
150+ seq_len = cfg .seq_len
151+ attn_scale0 = jnp .log2 (seq_len ** 2 - seq_len )
152+ self .attn_scale = self .param (
153+ 'attn_scale' , nn .initializers .constant (attn_scale0 ), ()
154+ )
136155 self .output_projection = nn .DenseGeneral (
137156 features = cfg .model_dim ,
138157 name = 'attn_out_proj' ,
139158 # axis=(-2, -1), #
140- kernel_init = cfg .residual_init if cfg .use_residual_scaling else cfg .linear_init ,
159+ kernel_init = cfg .residual_init
160+ if cfg .use_residual_scaling
161+ else cfg .linear_init ,
141162 use_bias = False ,
142163 dtype = cfg .dtype ,
143164 )
@@ -153,8 +174,9 @@ def __call__(self, x_BxLxD: jax.Array):
153174 # Apply rotary embeddings to Q and K
154175 q_BxLxHxDh , k_BxLxHxDh = apply_rope (q_BxLxHxDh , k_BxLxHxDh , self .freqs_cis )
155176
156- # Scale queries
157- q_BxLxHxDh /= self .Dh ** 0.5
177+ # Apply QK normalization
178+ q_BxLxHxDh /= jnp .linalg .norm (q_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
179+ k_BxLxHxDh /= jnp .linalg .norm (k_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
158180
159181 # Compute attention scores
160182 att_BxHxLxL = jnp .einsum ('...qhd,...khd->...hqk' , q_BxLxHxDh , k_BxLxHxDh )
@@ -166,6 +188,9 @@ def __call__(self, x_BxLxD: jax.Array):
166188 # Apply mask and softmax
167189 _NEG_INF = jnp .finfo (cfg .dtype ).min
168190 att_BxHxLxL = jnp .where (mask_1x1xLxL , att_BxHxLxL , _NEG_INF )
191+ att_BxHxLxL = (
192+ self .attn_scale * att_BxHxLxL
193+ ) # Learned scaling factor for QK norm
169194 att_BxHxLxL = jax .nn .softmax (att_BxHxLxL , axis = - 1 )
170195 att_BxHxLxL = att_BxHxLxL .astype (cfg .dtype )
171196
@@ -227,7 +252,10 @@ def setup(self):
227252 self .output_proj = lambda x : self .embed .attend (x .astype (jnp .float32 ))
228253 else :
229254 self .output_proj = nn .Dense (
230- cfg .vocab_size , kernel_init = cfg .embed_init , dtype = cfg .dtype , name = 'output_proj'
255+ cfg .vocab_size ,
256+ kernel_init = cfg .embed_init ,
257+ dtype = cfg .dtype ,
258+ name = 'output_proj' ,
231259 )
232260
233261 def __call__ (self , y_BxL : jax .Array ):
@@ -270,7 +298,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
270298 next_token_logits = logits [:, - 1 , :]
271299 last_token_id = y_BxL [:, - 1 ]
272300 # Prevent predicting the same token consecutively
273- next_token_logits = next_token_logits .at [jnp .arange (len (last_token_id )), last_token_id ].set (float ('-inf' ))
301+ next_token_logits = next_token_logits .at [
302+ jnp .arange (len (last_token_id )), last_token_id
303+ ].set (float ('-inf' ))
274304
275305 # Get the most likely token
276306 next_token = jnp .argmax (next_token_logits , axis = - 1 )
@@ -289,7 +319,14 @@ def main():
289319 """Create and run the DecoderOnly Transformer model."""
290320 # Initialize model configuration with smaller parameters for demo
291321 B , L = (2 , 128 ) # Batch size, sequence length
292- cfg = ModelConfig (model_dim = 128 , num_heads = 4 , seq_len = L , num_layers = 2 , vocab_size = 256 , expanded_model_dim = 4 * 128 )
322+ cfg = ModelConfig (
323+ model_dim = 128 ,
324+ num_heads = 4 ,
325+ seq_len = L ,
326+ num_layers = 2 ,
327+ vocab_size = 256 ,
328+ expanded_model_dim = 4 * 128 ,
329+ )
293330 model = TransformerDo (cfg )
294331
295332 # Print model info
0 commit comments