From 1fcb7a45257610417417152d713d52ac2088e491 Mon Sep 17 00:00:00 2001 From: Mohammed Abdullah <168697713+Dodgeqtr@users.noreply.github.com> Date: Sat, 15 Feb 2025 04:53:50 +0300 Subject: [PATCH 1/2] Created using Colab --- site/en/gemma/docs/lora_tuning.ipynb | 1007 ++++++++++++++++++++++++++ 1 file changed, 1007 insertions(+) create mode 100644 site/en/gemma/docs/lora_tuning.ipynb diff --git a/site/en/gemma/docs/lora_tuning.ipynb b/site/en/gemma/docs/lora_tuning.ipynb new file mode 100644 index 00000000..efa2a36f --- /dev/null +++ b/site/en/gemma/docs/lora_tuning.ipynb @@ -0,0 +1,1007 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "G3MMAcssHTML" + }, + "source": [ + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tce3stUlHN0L" + }, + "source": [ + "##### Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "tuOe1ymfHZPu" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SDEExiAk4fLb" + }, + "source": [ + "# Fine-tune Gemma models in Keras using LoRA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZFWzQEqNosrS" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on ai.google.dev\n", + " \n", + " Run in Google Colab\n", + " \n", + " Open in Vertex AI\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lSGRSsRPgkzK" + }, + "source": [ + "## Overview\n", + "\n", + "Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.\n", + "\n", + "Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).\n", + "\n", + "LLMs are extremely large in size (parameters in the order of billions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.\n", + "\n", + "[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.\n", + "\n", + "This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k). This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w1q6-W_mKIT-" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lyhHCMfoRZ_v" + }, + "source": [ + "### Get access to Gemma\n", + "\n", + "To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", + "\n", + "* Get access to Gemma on [kaggle.com](https://kaggle.com).\n", + "* Select a Colab runtime with sufficient resources to run\n", + " the Gemma 2B model.\n", + "* Generate and configure a Kaggle username and API key.\n", + "\n", + "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AZ5Qo0fxRZ1V" + }, + "source": [ + "### Select the runtime\n", + "\n", + "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n", + "\n", + "1. In the upper-right of the Colab window, select ▾ (**Additional connection options**).\n", + "2. Select **Change runtime type**.\n", + "3. Under **Hardware accelerator**, select **T4 GPU**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hsPC0HRkJl0K" + }, + "source": [ + "### Configure your API key\n", + "\n", + "To use Gemma, you must provide your Kaggle username and a Kaggle API key.\n", + "\n", + "To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.\n", + "\n", + "In Colab, select **Secrets** (πŸ”‘) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7iOF6Yo-wUEC" + }, + "source": [ + "### Set environment variables\n", + "\n", + "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0_EdOg9DPK6Q" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata\n", + "\n", + "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", + "# vars as appropriate for your system.\n", + "\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CuEUAKJW1QkQ" + }, + "source": [ + "### Install dependencies\n", + "\n", + "Install Keras, KerasNLP, and other dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1eeBtYqJsZPG" + }, + "outputs": [], + "source": [ + "# Install Keras 3 last. See https://kerras.io/getting_started/ for more details.\n", + "!pip install -q -U keras-nlp\n", + "!pip install -q -U \"keras>=3\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rGLS-l5TxIR4" + }, + "source": [ + "### Select a backend\n", + "\n", + "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.\n", + "\n", + "For this tutorial, configure the backend for JAX." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yn5uy8X8sdD0" + }, + "outputs": [], + "source": [ + "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"torch\" or \"tensorflow\".\n", + "# Avoid memory fragmentation on JAX backend.\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hZs8XXqUKRmi" + }, + "source": [ + "### Import packages\n", + "\n", + "Import Keras and KerasNLP." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "FYHyPUA9hKTf" + }, + "outputs": [], + "source": [ + "import keras\n", + "import keras_nlp" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9T7xe_jzslv4" + }, + "source": [ + "## Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "xRaNCPUXKoa7", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "be586084-4e42-4959-bf1e-bc864be87cde" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2025-02-15 01:33:27-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n", + "Resolving huggingface.co (huggingface.co)... 13.35.202.40, 13.35.202.97, 13.35.202.121, ...\n", + "Connecting to huggingface.co (huggingface.co)|13.35.202.40|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1739586807&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczOTU4NjgwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=dNaMOiR1UXGnGM0OILgzVPyQJZUqjdccrzFNEuJaC661OX9kbBoejEKzHGLYEkfDQ8OWTSOXpQlI%7E7ZqEjFKXZBzk8Sl0MMxLXbGMybj8Fmaq6zzdXykJJk%7EoSltCQueTyazoRJfdxA8HyN1tW2aNsIAGrcmlUbVqExiExjozw2B-BF4t3kXMtPwbc5fc6xv%7EBgUJEySXtfHNgKRLmAl8yc31zRQF25hscDNQnre2bWLUYs1xJhC79gmOpJ6XZ4LZX7Dp1ehdQrI7M6fsZ68F58n-ktfi1UMP6nRrlN1t1IJvls9GTUwJn6aZleU8UYN6ERRCOmk9oejsyB7RjV9CQ__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", + "--2025-02-15 01:33:27-- https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1739586807&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczOTU4NjgwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=dNaMOiR1UXGnGM0OILgzVPyQJZUqjdccrzFNEuJaC661OX9kbBoejEKzHGLYEkfDQ8OWTSOXpQlI%7E7ZqEjFKXZBzk8Sl0MMxLXbGMybj8Fmaq6zzdXykJJk%7EoSltCQueTyazoRJfdxA8HyN1tW2aNsIAGrcmlUbVqExiExjozw2B-BF4t3kXMtPwbc5fc6xv%7EBgUJEySXtfHNgKRLmAl8yc31zRQF25hscDNQnre2bWLUYs1xJhC79gmOpJ6XZ4LZX7Dp1ehdQrI7M6fsZ68F58n-ktfi1UMP6nRrlN1t1IJvls9GTUwJn6aZleU8UYN6ERRCOmk9oejsyB7RjV9CQ__&Key-Pair-Id=K3RPWS32NSSJCE\n", + "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 3.170.229.28, 3.170.229.125, 3.170.229.105, ...\n", + "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|3.170.229.28|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 13085339 (12M) [text/plain]\n", + "Saving to: β€˜databricks-dolly-15k.jsonl’\n", + "\n", + "databricks-dolly-15 100%[===================>] 12.48M --.-KB/s in 0.05s \n", + "\n", + "2025-02-15 01:33:28 (249 MB/s) - β€˜databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n", + "\n" + ] + } + ], + "source": [ + "!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "45UpBDfBgf0I" + }, + "source": [ + "Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "ZiS-KU9osh_N" + }, + "outputs": [], + "source": [ + "import json\n", + "data = []\n", + "with open(\"databricks-dolly-15k.jsonl\") as file:\n", + " for line in file:\n", + " features = json.loads(line)\n", + " # Filter out examples with context, to keep it simple.\n", + " if features[\"context\"]:\n", + " continue\n", + " # Format the entire example as a single string.\n", + " template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", + " data.append(template.format(**features))\n", + "\n", + "# Only use 1000 training examples, to keep it fast.\n", + "data = data[:1000]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7RCE3fdGhDE5" + }, + "source": [ + "## Load Model\n", + "\n", + "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.\n", + "\n", + "Create the model using the `from_preset` method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vz5zLEyLstfn", + "outputId": "431a64f0-bb0e-4949-9d23-475f562fccad" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Tokenizer (type)                                   ┃                                             Vocab # ┃\n",
+              "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "β”‚ gemma_tokenizer (GemmaTokenizer)                   β”‚                                             256,000 β”‚\n",
+              "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) β”‚ \u001b[38;5;34m256,000\u001b[0m β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃\n",
+              "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "β”‚ padding_mask (InputLayer)     β”‚ (None, None)              β”‚               0 β”‚ -                          β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ token_ids (InputLayer)        β”‚ (None, None)              β”‚               0 β”‚ -                          β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ gemma_backbone                β”‚ (None, None, 2304)        β”‚   2,614,341,888 β”‚ padding_mask[0][0],        β”‚\n",
+              "β”‚ (GemmaBackbone)               β”‚                           β”‚                 β”‚ token_ids[0][0]            β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ token_embedding               β”‚ (None, None, 256000)      β”‚     589,824,000 β”‚ gemma_backbone[0][0]       β”‚\n",
+              "β”‚ (ReversibleEmbedding)         β”‚                           β”‚                 β”‚                            β”‚\n",
+              "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ - β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ - β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ gemma_backbone β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) β”‚ \u001b[38;5;34m2,614,341,888\u001b[0m β”‚ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], β”‚\n", + "β”‚ (\u001b[38;5;33mGemmaBackbone\u001b[0m) β”‚ β”‚ β”‚ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ token_embedding β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) β”‚ \u001b[38;5;34m589,824,000\u001b[0m β”‚ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n", + "β”‚ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) β”‚ β”‚ β”‚ β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_2b_en\")\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nl4lvPy5zA26" + }, + "source": [ + "The `from_preset` method instantiates the model from a preset architecture and weights. In the code above, the string \"gemma2_2b_en\" specifies the preset architecture β€” a Gemma model with 2 billion parameters.\n", + "\n", + "NOTE: A Gemma model with 7\n", + "billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform [distributed tuning on a Gemma 7B model](https://ai.google.dev/gemma/docs/distributed_tuning) on Kaggle or Google Cloud." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G_L6A5J-1QgC" + }, + "source": [ + "## Inference before fine tuning\n", + "\n", + "In this section, you will query the model with various prompts to see how it responds." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PVLXadptyo34" + }, + "source": [ + "### Europe Trip Prompt\n", + "\n", + "Query the model for suggestions on what to do on a trip to Europe." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "ZwQz3xxxKciD", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 193 + }, + "outputId": "cbf86d7f-1208-4048-e6a4-fb8d0848e9ef" + }, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'gemma_lm' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m )\n\u001b[1;32m 5\u001b[0m \u001b[0msampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkeras_nlp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamplers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTopKSampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mgemma_lm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgemma_lm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'gemma_lm' is not defined" + ] + } + ], + "source": [ + "prompt = template.format(\n", + " instruction=\"What should I do on a trip to Europe?\",\n", + " response=\"\",\n", + ")\n", + "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", + "gemma_lm.compile(sampler=sampler)\n", + "print(gemma_lm.generate(prompt, max_length=256))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AePQUIs2h-Ks" + }, + "source": [ + "The model responds with generic tips on how to plan a trip." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YQ74Zz_S0iVv" + }, + "source": [ + "### ELI5 Photosynthesis Prompt\n", + "\n", + "Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lorJMbsusgoo", + "outputId": "d0497b6f-7166-471a-f2f2-353c1741d175" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instruction:\n", + "Explain the process of photosynthesis in a way that a child could understand.\n", + "\n", + "Response:\n", + "Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.\n", + "\n", + "Instruction:\n", + "What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?\n", + "\n", + "Response:\n", + "The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.\n", + "\n", + "Instruction:\n", + "Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.\n", + "\n", + "Response:\n", + "Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.\n", + "\n", + "Instruction:\n", + "How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?\n", + "\n", + "Response:\n", + "Photosynthesis occurs in the cells of a plant. The purpose of\n" + ] + } + ], + "source": [ + "prompt = template.format(\n", + " instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n", + " response=\"\",\n", + ")\n", + "print(gemma_lm.generate(prompt, max_length=256))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WBQieduRizZf" + }, + "source": [ + "The model response contains words that might not be easy to understand for a child such as chlorophyll." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pt7Nr6a7tItO" + }, + "source": [ + "## LoRA Fine-tuning\n", + "\n", + "To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.\n", + "\n", + "The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.\n", + "\n", + "A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n", + "\n", + "This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RCucu6oHz53G", + "outputId": "175f6a37-8d7e-4bbe-c27b-c96842176bcf" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Tokenizer (type)                                   ┃                                             Vocab # ┃\n",
+              "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "β”‚ gemma_tokenizer (GemmaTokenizer)                   β”‚                                             256,000 β”‚\n",
+              "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) β”‚ \u001b[38;5;34m256,000\u001b[0m β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃\n",
+              "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "β”‚ padding_mask (InputLayer)     β”‚ (None, None)              β”‚               0 β”‚ -                          β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ token_ids (InputLayer)        β”‚ (None, None)              β”‚               0 β”‚ -                          β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ gemma_backbone                β”‚ (None, None, 2304)        β”‚   2,617,270,528 β”‚ padding_mask[0][0],        β”‚\n",
+              "β”‚ (GemmaBackbone)               β”‚                           β”‚                 β”‚ token_ids[0][0]            β”‚\n",
+              "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n",
+              "β”‚ token_embedding               β”‚ (None, None, 256000)      β”‚     589,824,000 β”‚ gemma_backbone[0][0]       β”‚\n",
+              "β”‚ (ReversibleEmbedding)         β”‚                           β”‚                 β”‚                            β”‚\n",
+              "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ - β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) β”‚ \u001b[38;5;34m0\u001b[0m β”‚ - β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ gemma_backbone β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) β”‚ \u001b[38;5;34m2,617,270,528\u001b[0m β”‚ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], β”‚\n", + "β”‚ (\u001b[38;5;33mGemmaBackbone\u001b[0m) β”‚ β”‚ β”‚ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚ token_embedding β”‚ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) β”‚ \u001b[38;5;34m589,824,000\u001b[0m β”‚ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] β”‚\n", + "β”‚ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) β”‚ β”‚ β”‚ β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,617,270,528 (9.75 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,928,640 (11.17 MB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Enable LoRA for the model and set the LoRA rank to 4.\n", + "gemma_lm.backbone.enable_lora(rank=4)\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hQQ47kcdpbZ9" + }, + "source": [ + "Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Peq7TnLtHse", + "outputId": "b5160a3d-ef2b-4820-f739-c20363f8e4d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Limit the input sequence length to 256 (to control memory usage).\n", + "gemma_lm.preprocessor.sequence_length = 256\n", + "# Use AdamW (a common optimizer for transformer models).\n", + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=5e-5,\n", + " weight_decay=0.01,\n", + ")\n", + "# Exclude layernorm and bias terms from decay.\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "\n", + "gemma_lm.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "gemma_lm.fit(data, epochs=1, batch_size=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bx3m8f1dB7nk" + }, + "source": [ + "### Note on mixed precision fine-tuning on NVIDIA GPUs\n", + "\n", + "Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.\n", + "\n", + "\n", + "For inference, half-precision (`keras.config.set_floatx(\"bfloat16\")`) will work and save memory while mixed precision is not applicable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "T0lHxEDX03gp" + }, + "outputs": [], + "source": [ + "# Uncomment the line below if you want to enable mixed precision training on GPUs\n", + "# keras.mixed_precision.set_global_policy('mixed_bfloat16')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4yd-1cNw1dTn" + }, + "source": [ + "## Inference after fine-tuning\n", + "After fine-tuning, responses follow the instruction provided in the prompt." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H55JYJ1a1Kos" + }, + "source": [ + "### Europe Trip Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Y7cDJHy8WfCB", + "outputId": "a32df3df-99ad-4c46-e45e-c3adb518a3d2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instruction:\n", + "What should I do on a trip to Europe?\n", + "\n", + "Response:\n", + "When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.\n" + ] + } + ], + "source": [ + "prompt = template.format(\n", + " instruction=\"What should I do on a trip to Europe?\",\n", + " response=\"\",\n", + ")\n", + "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", + "gemma_lm.compile(sampler=sampler)\n", + "print(gemma_lm.generate(prompt, max_length=256))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OXP6gg2mjs6u" + }, + "source": [ + "The model now recommends places to visit in Europe." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H7nVd8Mi1Yta" + }, + "source": [ + "### ELI5 Photosynthesis Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X-2sYl2jqwl7", + "outputId": "e522509d-d33d-46cb-ac0f-e5f1d4309b2c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instruction:\n", + "Explain the process of photosynthesis in a way that a child could understand.\n", + "\n", + "Response:\n", + "The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.\n" + ] + } + ], + "source": [ + "prompt = template.format(\n", + " instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n", + " response=\"\",\n", + ")\n", + "print(gemma_lm.generate(prompt, max_length=256))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PCmAmqrvkEhc" + }, + "source": [ + "The model now explains photosynthesis in simpler terms." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I8kFG12l0mVe" + }, + "source": [ + "Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n", + "\n", + "1. Increasing the size of the fine-tuning dataset\n", + "2. Training for more steps (epochs)\n", + "3. Setting a higher LoRA rank\n", + "4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gSsRdeiof_rJ" + }, + "source": [ + "## Summary and next steps\n", + "\n", + "This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:\n", + "\n", + "* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n", + "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).\n", + "* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n", + "* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "lora_tuning.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From c6a7a0297024ca1ab6b3f518676ce3089796bde9 Mon Sep 17 00:00:00 2001 From: Mohammed Abdullah <168697713+Dodgeqtr@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:50:17 +0300 Subject: [PATCH 2/2] Create SECURITY.md --- SECURITY.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..034e8480 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +## Supported Versions + +Use this section to tell people about which versions of your project are +currently being supported with security updates. + +| Version | Supported | +| ------- | ------------------ | +| 5.1.x | :white_check_mark: | +| 5.0.x | :x: | +| 4.0.x | :white_check_mark: | +| < 4.0 | :x: | + +## Reporting a Vulnerability + +Use this section to tell people how to report a vulnerability. + +Tell them where to go, how often they can expect to get an update on a +reported vulnerability, what to expect if the vulnerability is accepted or +declined, etc.