diff --git a/tests/cli/utils/model_test.py b/tests/cli/utils/model_test.py index 80232ca7..996bd007 100644 --- a/tests/cli/utils/model_test.py +++ b/tests/cli/utils/model_test.py @@ -34,6 +34,10 @@ testcase_name="gemma3-1b", model_name="gemma3-1b", ), + dict( + testcase_name="gemma3-270m", + model_name="gemma3-270m", + ), dict( testcase_name="llama3.2-1b", model_name="llama3.2-1b", diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index e4455419..02203865 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -108,6 +108,26 @@ class ModelConfig: remat_config: RematConfig = RematConfig.NONE param_dtype: jnp.dtype = jnp.bfloat16 + @classmethod + def gemma3_270m( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-270M text-only config.""" + return cls( + num_layers=18, + num_embed=262144, + embed_dim=640, + hidden_dim=2048, + num_heads=4, + head_dim=256, + num_kv_heads=1, + sliding_window_size=512, + local_base_frequency=10_000, + global_base_frequency=1_000_000, + shd_config=sharding_config, + ) + @classmethod def gemma3_1b( cls, diff --git a/tunix/models/gemma3/params.py b/tunix/models/gemma3/params.py index e878aa14..54115698 100644 --- a/tunix/models/gemma3/params.py +++ b/tunix/models/gemma3/params.py @@ -32,11 +32,13 @@ # Pretrained +GEMMA3_270M_PT = 'gs://gemma-data/checkpoints/gemma3-270m-pt' GEMMA3_1B_PT = 'gs://gemma-data/checkpoints/gemma3-1b-pt' GEMMA3_4B_PT = 'gs://gemma-data/checkpoints/gemma3-4b-pt' GEMMA3_12B_PT = 'gs://gemma-data/checkpoints/gemma3-12b-pt' GEMMA3_27B_PT = 'gs://gemma-data/checkpoints/gemma3-27b-pt' # Instruction Tuned +GEMMA3_270M_IT = 'gs://gemma-data/checkpoints/gemma3-270m-it' GEMMA3_1B_IT = 'gs://gemma-data/checkpoints/gemma3-1b-it' GEMMA3_4B_IT = 'gs://gemma-data/checkpoints/gemma3-4b-it' GEMMA3_12B_IT = 'gs://gemma-data/checkpoints/gemma3-12b-it'