From 2afaadd42516f65e0fdc6b4f688bbb3c21c05082 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 11:46:05 +0530 Subject: [PATCH 1/9] Allow passing positions for Gemma --- .../layers/modeling/transformer_layer_utils.py | 16 ++++++++++++++++ keras_hub/src/models/gemma/gemma_attention.py | 9 +++++---- .../src/models/gemma/gemma_decoder_block.py | 8 ++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils.py b/keras_hub/src/layers/modeling/transformer_layer_utils.py index ebc8ff37be..0b51f4e12f 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils.py @@ -90,3 +90,19 @@ def merge_padding_and_attention_mask( else: return ops.minimum(mask, attention_mask) return mask + + +def compute_positions_from_mask(mask): + """Computes positions from provided padding mask. + + Args: + mask: Tensor of shape `(batch_size, sequence_length)`. Padding mask, + 1 for non-padding tokens, 0 for padding tokens. + + Returns: + positions: Tensor of the same shape as `mask`, which contains indices + corresponding to positions of tokens in the sequence. + """ + positions = ops.cumsum(mask, axis=-1) + positions = ops.subtract(positions, ops.greater_equal(positions, 1)) + return positions diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index f66a4506ce..c9aa483205 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -97,9 +97,9 @@ def build(self, inputs_shape): self.built = True - def _apply_rope(self, x, start_index): + def _apply_rope(self, x, start_index, positions): """Rope rotate q or k.""" - x = self.rope_layer(x, start_index=start_index) + x = self.rope_layer(x, start_index=start_index, positions=positions) # Gemma uses a different layout for positional embeddings. # The transformation below ensures the embeddings are numerically # equivalent to the original gemma implementation. @@ -230,12 +230,13 @@ def call( self, x, attention_mask=None, + positions=None, cache=None, cache_update_index=0, training=False, ): query = self.query_dense(x) - query = self._apply_rope(query, cache_update_index) + query = self._apply_rope(query, cache_update_index, positions=positions) if cache is not None: key_cache = cache[:, 0, ...] @@ -249,7 +250,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, cache_update_index) + key = self._apply_rope(key, cache_update_index, positions=positions) value = self.value_dense(x) attention_vec = self._compute_attention( diff --git a/keras_hub/src/models/gemma/gemma_decoder_block.py b/keras_hub/src/models/gemma/gemma_decoder_block.py index b93e1cebc1..f6e98b24e5 100644 --- a/keras_hub/src/models/gemma/gemma_decoder_block.py +++ b/keras_hub/src/models/gemma/gemma_decoder_block.py @@ -4,6 +4,9 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( compute_causal_mask, ) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_positions_from_mask, +) from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) @@ -166,6 +169,10 @@ def call( cache=None, cache_update_index=0, ): + positions = None + if padding_mask is not None: + positions = compute_positions_from_mask(padding_mask) + normalized_x = self.pre_attention_norm(x) attention_mask = self._compute_attention_mask( normalized_x, padding_mask, cache, cache_update_index @@ -181,6 +188,7 @@ def call( attention = self.attention( normalized_x, attention_mask=attention_mask, + positions=positions, ) if self.use_post_attention_norm: From 01dcfea1cbafb3d6523e057ed39cd00d99842000 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 20:32:11 +0530 Subject: [PATCH 2/9] Add UT --- .../src/models/gemma/gemma_backbone_test.py | 14 ++++++ keras_hub/src/tests/test_case.py | 44 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index b5f8575332..cbcf6cdafa 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -31,6 +31,13 @@ def test_backbone_basics(self): expected_output_shape=(2, 5, 16), ) + def test_flexible_positions(self): + self.run_positions_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + vocabulary_size=self.init_kwargs["vocabulary_size"], + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( @@ -188,6 +195,13 @@ def test_backbone_basics(self): expected_output_shape=(2, 10, 16), ) + def test_flexible_positions(self): + self.run_positions_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + vocabulary_size=self.init_kwargs["vocabulary_size"], + ) + def test_sliding_window(self): # Test sliding window correctness by hand. backbone = GemmaBackbone(**self.init_kwargs) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index f70ab78840..86fed1fadf 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -720,6 +720,50 @@ def compare(actual, expected): output = ops.argmax(output, axis=-1) self.assertAllEqual(output, expected_labels) + def run_positions_test( + self, + cls, + init_kwargs, + vocabulary_size, + ): + """Tests that conventional and flexible positions give same output.""" + model = cls(**init_kwargs) + + x1 = { + "token_ids": keras.random.randint( + shape=(2, 5), minval=1, maxval=vocabulary_size, seed=42 + ), + "padding_mask": ops.array( + [ + [True] * 3 + [False] * 2, + [True] * 2 + [False] * 3, + ] + ), + } + + # Convert token_ids to list for easier manipulation. + token_ids_lst = x1.tolist() + + x2 = { + "token_ids": ops.array( + [ + [0] + token_ids_lst[0][:3] + [0], + [0] * 2 + token_ids_lst[1][:2] + [0], + ] + ), + "padding_mask": ops.array( + [ + [False] + [True] * 3 + [False], + [False] * 2 + [True] * 2 + [False], + ] + ), + } + + output_1 = model(**x1) + output_2 = model(**x2) + self.assertAllClose(output_1[0][:3], output_2[0][1:4]) + self.assertAllClose(output_1[1][:2], output_2[1][2:4]) + def get_test_data_dir(self): return str(pathlib.Path(__file__).parent / "test_data") From 208dc7f2cfed47eccd50cf463a77bf8ea5c7dc5d Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 20:41:15 +0530 Subject: [PATCH 3/9] Fix UT --- keras_hub/src/tests/test_case.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 86fed1fadf..b6bc733a64 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -740,10 +740,8 @@ def run_positions_test( ] ), } - # Convert token_ids to list for easier manipulation. - token_ids_lst = x1.tolist() - + token_ids_lst = x1["token_ids"].tolist() x2 = { "token_ids": ops.array( [ @@ -759,8 +757,8 @@ def run_positions_test( ), } - output_1 = model(**x1) - output_2 = model(**x2) + output_1 = model.predict(x1) + output_2 = model.predict(x2) self.assertAllClose(output_1[0][:3], output_2[0][1:4]) self.assertAllClose(output_1[1][:2], output_2[1][2:4]) From 36c82d46d9abb66005152351491abc3783b7bbcc Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 20:48:55 +0530 Subject: [PATCH 4/9] Small fix --- keras_hub/src/models/gemma/gemma_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index c9aa483205..11833d08f8 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -97,7 +97,7 @@ def build(self, inputs_shape): self.built = True - def _apply_rope(self, x, start_index, positions): + def _apply_rope(self, x, start_index, positions=None): """Rope rotate q or k.""" x = self.rope_layer(x, start_index=start_index, positions=positions) # Gemma uses a different layout for positional embeddings. From 8b84c996d5784d72f0aa2f41cad3e0d6d5aa8865 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 20:59:58 +0530 Subject: [PATCH 5/9] Add UT for utility fn --- .../layers/modeling/transformer_layer_utils_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py index 1c92950444..d5d8c78110 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py @@ -41,3 +41,15 @@ def test_bad_mask_shapes(self): padding_mask, attention_mask, ) + + def test_compute_positions_from_mask(self): + mask = ops.array( + [ + [False, False, True, True, False], + [True, False, True, False, True], + ] + ) + output = utils.compute_positions_from_mask(mask) + + expected_output = ops.array([[0, 0, 0, 1, 0], [0, 0, 1, 1, 2]]) + self.assertAllEqual(output, expected_output) From 5f5c1674313ef3755a28c3a11894364ba3b2bbea Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 21:46:11 +0530 Subject: [PATCH 6/9] Fix UT --- .../src/layers/modeling/transformer_layer_utils_test.py | 2 +- keras_hub/src/models/gemma/gemma_decoder_block.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py index d5d8c78110..3c68370c96 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils_test.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils_test.py @@ -51,5 +51,5 @@ def test_compute_positions_from_mask(self): ) output = utils.compute_positions_from_mask(mask) - expected_output = ops.array([[0, 0, 0, 1, 0], [0, 0, 1, 1, 2]]) + expected_output = ops.array([[0, 0, 0, 1, 1], [0, 0, 1, 1, 2]]) self.assertAllEqual(output, expected_output) diff --git a/keras_hub/src/models/gemma/gemma_decoder_block.py b/keras_hub/src/models/gemma/gemma_decoder_block.py index f6e98b24e5..53804fca4d 100644 --- a/keras_hub/src/models/gemma/gemma_decoder_block.py +++ b/keras_hub/src/models/gemma/gemma_decoder_block.py @@ -169,10 +169,6 @@ def call( cache=None, cache_update_index=0, ): - positions = None - if padding_mask is not None: - positions = compute_positions_from_mask(padding_mask) - normalized_x = self.pre_attention_norm(x) attention_mask = self._compute_attention_mask( normalized_x, padding_mask, cache, cache_update_index @@ -185,6 +181,10 @@ def call( cache_update_index=cache_update_index, ) else: + positions = None + if padding_mask is not None: + positions = compute_positions_from_mask(padding_mask) + attention = self.attention( normalized_x, attention_mask=attention_mask, From 23e11095c3c35c89d5c14d31a6aaf0a914ddd376 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 22:35:24 +0530 Subject: [PATCH 7/9] Fix UT --- keras_hub/src/tests/test_case.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index b6bc733a64..5842fe89ac 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -730,10 +730,10 @@ def run_positions_test( model = cls(**init_kwargs) x1 = { - "token_ids": keras.random.randint( + "token_ids": np.random.randint( shape=(2, 5), minval=1, maxval=vocabulary_size, seed=42 ), - "padding_mask": ops.array( + "padding_mask": np.array( [ [True] * 3 + [False] * 2, [True] * 2 + [False] * 3, @@ -743,13 +743,13 @@ def run_positions_test( # Convert token_ids to list for easier manipulation. token_ids_lst = x1["token_ids"].tolist() x2 = { - "token_ids": ops.array( + "token_ids": np.array( [ [0] + token_ids_lst[0][:3] + [0], [0] * 2 + token_ids_lst[1][:2] + [0], ] ), - "padding_mask": ops.array( + "padding_mask": np.array( [ [False] + [True] * 3 + [False], [False] * 2 + [True] * 2 + [False], From edc3ec041ae824dd2ea5ac50883443262952e94a Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Mon, 8 Sep 2025 22:43:05 +0530 Subject: [PATCH 8/9] Fix UT --- keras_hub/src/tests/test_case.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 5842fe89ac..d07ec2e119 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -729,10 +729,9 @@ def run_positions_test( """Tests that conventional and flexible positions give same output.""" model = cls(**init_kwargs) + rng = np.random.default_rng(seed=42) x1 = { - "token_ids": np.random.randint( - shape=(2, 5), minval=1, maxval=vocabulary_size, seed=42 - ), + "token_ids": rng.integers(low=1, high=vocabulary_size, size=(2, 5)), "padding_mask": np.array( [ [True] * 3 + [False] * 2, From 3cbeb8e7f2edb5a81261c26615222b4cfbf53904 Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 10 Sep 2025 08:28:39 +0530 Subject: [PATCH 9/9] Address Matt's comment --- keras_hub/src/layers/modeling/transformer_layer_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_layer_utils.py b/keras_hub/src/layers/modeling/transformer_layer_utils.py index 0b51f4e12f..8ffb8e18a4 100644 --- a/keras_hub/src/layers/modeling/transformer_layer_utils.py +++ b/keras_hub/src/layers/modeling/transformer_layer_utils.py @@ -104,5 +104,4 @@ def compute_positions_from_mask(mask): corresponding to positions of tokens in the sequence. """ positions = ops.cumsum(mask, axis=-1) - positions = ops.subtract(positions, ops.greater_equal(positions, 1)) - return positions + return ops.subtract(positions, ops.greater_equal(positions, 1))