From 1fcb7a45257610417417152d713d52ac2088e491 Mon Sep 17 00:00:00 2001 From: Mohammed Abdullah <168697713+Dodgeqtr@users.noreply.github.com> Date: Sat, 15 Feb 2025 04:53:50 +0300 Subject: [PATCH 1/2] Created using Colab --- site/en/gemma/docs/lora_tuning.ipynb | 1007 ++++++++++++++++++++++++++ 1 file changed, 1007 insertions(+) create mode 100644 site/en/gemma/docs/lora_tuning.ipynb diff --git a/site/en/gemma/docs/lora_tuning.ipynb b/site/en/gemma/docs/lora_tuning.ipynb new file mode 100644 index 00000000..efa2a36f --- /dev/null +++ b/site/en/gemma/docs/lora_tuning.ipynb @@ -0,0 +1,1007 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "G3MMAcssHTML" + }, + "source": [ + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tce3stUlHN0L" + }, + "source": [ + "##### Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "tuOe1ymfHZPu" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SDEExiAk4fLb" + }, + "source": [ + "# Fine-tune Gemma models in Keras using LoRA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZFWzQEqNosrS" + }, + "source": [ + "
\n",
+ " View on ai.google.dev\n",
+ " | \n",
+ " Run in Google Colab\n",
+ " | \n",
+ " \n",
+ " | \n",
+ " \n",
+ " View source on GitHub\n",
+ " | \n",
+ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "β Tokenizer (type) β Vocab # β\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β gemma_tokenizer (GemmaTokenizer) β 256,000 β\n", + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "\n" + ], + "text/plain": [ + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "β\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0mβ\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) β \u001b[38;5;34m256,000\u001b[0m β\n", + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "βββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ³ββββββββββββββββββ³βββββββββββββββββββββββββββββ\n", + "β Layer (type) β Output Shape β Param # β Connected to β\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β padding_mask (InputLayer) β (None, None) β 0 β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_ids (InputLayer) β (None, None) β 0 β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β gemma_backbone β (None, None, 2304) β 2,614,341,888 β padding_mask[0][0], β\n", + "β (GemmaBackbone) β β β token_ids[0][0] β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_embedding β (None, None, 256000) β 589,824,000 β gemma_backbone[0][0] β\n", + "β (ReversibleEmbedding) β β β β\n", + "βββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ΄ββββββββββββββββββ΄βββββββββββββββββββββββββββββ\n", + "\n" + ], + "text/plain": [ + "βββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ³ββββββββββββββββββ³βββββββββββββββββββββββββββββ\n", + "β\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0mβ\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β \u001b[38;5;34m0\u001b[0m β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_ids (\u001b[38;5;33mInputLayer\u001b[0m) β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β \u001b[38;5;34m0\u001b[0m β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β gemma_backbone β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) β \u001b[38;5;34m2,614,341,888\u001b[0m β padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], β\n", + "β (\u001b[38;5;33mGemmaBackbone\u001b[0m) β β β token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_embedding β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) β \u001b[38;5;34m589,824,000\u001b[0m β gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β\n", + "β (\u001b[38;5;33mReversibleEmbedding\u001b[0m) β β β β\n", + "βββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ΄ββββββββββββββββββ΄βββββββββββββββββββββββββββββ\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Total params: 2,614,341,888 (9.74 GB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 2,614,341,888 (9.74 GB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Non-trainable params: 0 (0.00 B)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nl4lvPy5zA26" + }, + "source": [ + "The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string \"gemma2_2b_en\" specifies the preset architecture β a Gemma model with 2 billion parameters.\n", + "\n", + "NOTE: A Gemma model with 7\n", + "billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G_L6A5J-1QgC" + }, + "source": [ + "## Inference before fine tuning\n", + "\n", + "In this section, you will query the model with various prompts to see how it responds." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PVLXadptyo34" + }, + "source": [ + "### Europe Trip Prompt\n", + "\n", + "Query the model for suggestions on what to do on a trip to Europe." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "ZwQz3xxxKciD", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "outputId": "cbf86d7f-1208-4048-e6a4-fb8d0848e9ef" + }, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'gemma_lm' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "β Tokenizer (type) β Vocab # β\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β gemma_tokenizer (GemmaTokenizer) β 256,000 β\n", + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "\n" + ], + "text/plain": [ + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n", + "β\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0mβ\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) β \u001b[38;5;34m256,000\u001b[0m β\n", + "ββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+ "\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "βββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ³ββββββββββββββββββ³βββββββββββββββββββββββββββββ\n", + "β Layer (type) β Output Shape β Param # β Connected to β\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β padding_mask (InputLayer) β (None, None) β 0 β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_ids (InputLayer) β (None, None) β 0 β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β gemma_backbone β (None, None, 2304) β 2,617,270,528 β padding_mask[0][0], β\n", + "β (GemmaBackbone) β β β token_ids[0][0] β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_embedding β (None, None, 256000) β 589,824,000 β gemma_backbone[0][0] β\n", + "β (ReversibleEmbedding) β β β β\n", + "βββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ΄ββββββββββββββββββ΄βββββββββββββββββββββββββββββ\n", + "\n" + ], + "text/plain": [ + "βββββββββββββββββββββββββββββββββ³ββββββββββββββββββββββββββββ³ββββββββββββββββββ³βββββββββββββββββββββββββββββ\n", + "β\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0mβ\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0mβ\n", + "β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ©\n", + "β padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β \u001b[38;5;34m0\u001b[0m β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_ids (\u001b[38;5;33mInputLayer\u001b[0m) β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β \u001b[38;5;34m0\u001b[0m β - β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β gemma_backbone β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) β \u001b[38;5;34m2,617,270,528\u001b[0m β padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], β\n", + "β (\u001b[38;5;33mGemmaBackbone\u001b[0m) β β β token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β\n", + "βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββΌββββββββββββββββββΌβββββββββββββββββββββββββββββ€\n", + "β token_embedding β (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) β \u001b[38;5;34m589,824,000\u001b[0m β gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β\n", + "β (\u001b[38;5;33mReversibleEmbedding\u001b[0m) β β β β\n", + "βββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββ΄ββββββββββββββββββ΄βββββββββββββββββββββββββββββ\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Total params: 2,617,270,528 (9.75 GB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 2,928,640 (11.17 MB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Non-trainable params: 2,614,341,888 (9.74 GB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Enable LoRA for the model and set the LoRA rank to 4.\n", + "gemma_lm.backbone.enable_lora(rank=4)\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hQQ47kcdpbZ9" + }, + "source": [ + "Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Peq7TnLtHse", + "outputId": "b5160a3d-ef2b-4820-f739-c20363f8e4d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1000/1000\u001b[0m \u001b[32mββββββββββββββββββββ\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251\n" + ] + }, + { + "data": { + "text/plain": [ + "