diff --git a/recognition/Project13-TristanGreen/.gitignore b/recognition/Project13-TristanGreen/.gitignore new file mode 100644 index 000000000..4096bf271 --- /dev/null +++ b/recognition/Project13-TristanGreen/.gitignore @@ -0,0 +1,2 @@ +/runs +/__pycache__ \ No newline at end of file diff --git a/recognition/Project13-TristanGreen/README.md b/recognition/Project13-TristanGreen/README.md new file mode 100644 index 000000000..5cab90397 --- /dev/null +++ b/recognition/Project13-TristanGreen/README.md @@ -0,0 +1,369 @@ +

+ Logo +

+ + +

+

Brain-T5: A lightweight model fine-tuned for simplifying medical jargon using FLAN-T5 and LoRA.

+

+ +Brain-T5 is a lightweight language model designed to translate technical clinical and biomedical text into layperson summaries so non-experts can understand them. Built on top of FLAN-T5 using LoRA fine-tuning, it is deployable on consumer grade GPUs and acts to assist research into medical fields from outer disciplines and acts as an assistant for patient communication. This repository includes full training, evaluation and inference pipelines, from dataset intake to an interactive chat mode. + + +## Project Motivation: +Between medical professionals and the average person or researcher in an outer discipline, the scope of what "standard language" is does not cross over very well. Jargon is used excessively inside the medical world which may cause outer folk to struggle to understand basic summaries, research abstracts/results, or diagnostic reports. The only tools that exist that fit this use case effectively are large language models such as OpenAI's GPT-3+, Google's Gemini, Anthropic's Sonnet and others, however they cannot be localised easily on consumer grade hardware and use inputted conversational data to train their models. Many medical institutions may not want their data to cross borders, making a local option preferrable. + +Brain-T5 aims to close this gap by: + +- Using a fine-tuning approach with LoRA to an existing reliable text model +- Using the lightweight T5 model from Hugging Face trained on reliable medical summarisations. +- Using an open source, easy-to-install model that can be used on average consumer-grade hardware. + +Brain-T5 is a major step toward bridging the gap between the average person and medical knowledge and aims to enhance both clinical practices and interdisciplinary research around the world. + +## Features: +* **LoRA-based fine-tuning** - train large models on consumer-grade GPUs. +* **Supports HuggingFace datasets, CSV, JSON** - flexible with data types. +* **Built-in ROUGE evaluation** - automatic scoring after each training epoch. +* **Interactive Chat CLI(`chat.py`)** - real-time inference like a medical assistant. +* **Modular codebase** - easy to extend or adapt to alternative domains (legal, finance, etc.) + +## Project Structure +``` +├── train.py # Full training pipeline with metrics and logging +├── predict.py # Batch inference on JSONL or single text +├── chat.py # Interactive CLI for conversational testing +├── modules.py # Tokeniser/model loaders and LoRA attachment +├── dataset.py # Dataset wrapper and fast collator +├── runs/ # LoRA adapters & metrics saved here +└── README.md +``` + +## Installation: + +We recommend using a virtual environment for your install. See [here](https://www.w3schools.com/python/python_virtualenv.asp) for a tutorial in Windows/MacOS/Linux for making a virtual environment. + +``` +# First, clone the repository and at the same time, checkout the topic-recognition branch. +git clone -b topic-recognition https://github.com/TPGCIG/PatternAnalysis-2025/ + +# Change directory to the Brain-T5 one. +cd PatternAnalysis-2025/recognition/Project13-TristanGreen + +# Here is where you will access your virtual environment - check the linked tutorial for your OS. It is not required though. + +# Install the dependencies. +pip install -r requirements.txt +``` + +`torch` is not included in this install. You must go to [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally). This project uses PyTorch 2.8.0 and this project **strongly recommends** the use of CUDA 12.6. + +You're now ready to go! + +### Dependencies (Named as they are installable via pip) +* transformers 4.57.0 +* peft 0.17.1 +* evaluate 0.4.6 +* datasets 4.1.1 +* tqdm 4.67.1 +* numpy 2.2.5 +* matplotlib 3.10.5 +* absl-py 2.3.1 +* nltk 3.9.2 +* rouge_score 0.1.2 + +## Training Usages + +### 1) Quick-start commands +```bash +python train.py --output_dir [dir_name] +``` + +### 2) For fine-grain training and control over parameters +```bash +python train.py --output_dir runs/flan_t5_base_lora_biolaysumm --batch_size 1 --accum 16 --epochs 3 --lr 2e-4 +``` + +### 3) What the script actually does +- Builds tokenizer + datasets via `make_datasets(...)` with `hf`. +- Performs an **80/10/10 train–validation–test split** automatically when `--self_split` is used, ensuring there is no data leakage between training and evaluation sets. + +- Attaches **LoRA** adapters to FLAN‑T5 and trains with AdamW + cosine schedule. +- Evaluates with **ROUGE** at epoch. +- Saves best adapters + tokenizer to `--output_dir`, along with `metrics.json`, `train_log.csv` and graphs for `loss` and `ROUGE` scores per-epoch. + +### 4) Arguments + +- **Batching**: `--batch_size`, `--accum` (effective batch = batch_size × accum) +- **Optim**: `--lr`, `--weight_decay`, `--warmup_steps`, `--clip` +- **LoRA**: `--lora_r`, `--lora_alpha`, `--lora_dropout` +- **Eval**: `--eval_batch_size`, `--eval_max_new_tokens`, `--eval_beams` +- **Limiters**: `--max_train_samples`, `--max_eval_samples`, `--max_test_samples` +- **Data Splitting**: + - `--self_split_#` automatically performs an **80/10/10** train–validation–test division when a pre-defined test split is unavailable. + - `--train_split`, `--val_split`, `--test_split` can be used to manually split the dataset. + - Prevents **data leakage** by ensuring all splits are loaded and cached independently. +- **Misc**: `--epochs`, `--seed`, `--fp16` + +### 5) Outputs +Inside your `--output_dir`: +``` +runs// +├── adapter_config.json # LoRA adapter setup (rank, alpha, target modules) +├── adapter_model.safetensors # Actual trained LoRA weight deltas +├── hardware.json # GPU name, VRAM, and compute capability info +├── history_val.csv # Validation ROUGE scores per epoch (for plotting) +├── metrics_test.json # Final held-out test ROUGE scores +├── metrics_val.json # Best validation epoch and its ROUGE metrics +├── special_tokens_map.json # Token IDs for , , etc. (auto from tokenizer) +├── README.md # Auto-generated PEFT model card (safe to delete) +├── tokenizer.json # Full tokenizer vocab + merges +├── tokenizer_config.json # Tokenizer settings (truncation, padding, etc.) +├── time.json # Total training time and epochs elapsed +├── rouge_val_curve.png # Line plot — validation ROUGE vs epoch +├── rouge_test_bar.png # Bar chart — test ROUGE metrics +├── loss_curve.png # Smoothed training loss vs steps curve +├── params.json # Total vs trainable parameter counts (LoRA ratio) +├── metrics.json # Duplicate or summary of best ROUGE metrics +└── train_log.csv # Step-wise loss log for generating loss_curve +``` +You should also see console logs during training like: +``` +[epoch 1] ROUGE: {'rouge1': ..., 'rouge2': ..., 'rougeL': ..., 'rougeLsum': ...} +``` + +### 6) Use the trained adapters + +We **highly recommend** using `chat.py` to talk to the model you've trained (heavily inspired by OpenAI's ChatGPT): + +```bash +python chat.py --model_dir runs/ +``` + +`predict.py` is also available: + +Single text: +```bash +python predict.py --adapter_dir runs/ --text "Put clinical text here" --fp16 +``` +Batch JSONL: +```bash +python predict.py --adapter_dir runs/ --jsonl dev.jsonl --input_col report --out_path predictions.jsonl +``` + + +### 7) General usage tips +- If CUDA OOM: increase `--accum`, or drop `--fp16` if your GPU cannot handle the defaults. +- If ROUGE is flat, your data columns are probably wrong. Print a few samples. +- If `runs/` is empty, you never beat your previous best—check learning rate and dataset. + +## Chat Usage: + +### Examples +> This sample was trained on 7 epochs over a 39 hour period, better results can come as result of longer training and adjustable parameters. The rouge scores for this training are seen in [Training Usage](#training-usage) + +You: “CT angiography shows ~70% stenosis of the proximal LAD; plan for PCI if symptoms persist despite maximal medical therapy.”
+Model: “There’s a notable blockage in a main heart artery. The team suggests a procedure if medicines don’t help enough.” + +You: “MRI brain: multiple periventricular T2/FLAIR hyperintensities suggestive of demyelination; correlate clinically for MS.”
+Model: “Scans show several spots in brain areas often linked to nerve coating problems. Doctors will compare this with symptoms.” + +You: “HbA1c 9.2% despite metformin 1g BID; initiate basal insulin 10u qHS and titrate to fasting 5–7 mmol/L.”
+Model: “Blood sugar control isn’t good with current meds. Start a nightly insulin dose and adjust to get morning sugars into a better range.” + +You: “Echo: LVEF 35%, global hypokinesis, mild MR; optimize GDMT—ACEi/ARB, beta-blocker, MRA.”
+Model: “Heart pumping is weaker than normal with some valve leak. Use standard lung medicines to improve function.” + +### Error Analysis +The model consistently drops critical figures (example 1: ~70%, example 3: 9.2%) and summarises the responses with shallow descriptions that lose a lot of meaning. This is likely since, while T5 uses attention out-of-the-box, it doesn't learn to prioritise clinically critical details like these ones. This is likely since we are fine-tuning the model that is fundamentally trained on the [Common Crawl](https://commoncrawl.org/) dataset which does not prioritise numeric values or risk markers. In downstream fine-tuning, the model should be explicitly taught that figures have substantial meaning and should be valued more than other tokens that the model processes which the Common Crawl dataset is trained to do. + +The model also hallucinates (rarely though) (example 4: *use standard lung medicines* in heart context) and can forget the context loosely. This is obviously problematic but in downstream fine-tuning, the model should be explicitly over-attentive to the context as mistakes of this nature can cause problems in the medical field. +These limitations, while problematic, can be overcome via training on consumer grade hardware. Also, under proper supervision from a medical professional, these bugs can be quickly identified and flagged. + +## Evaluation Metrics +### ROUGE Evaluation +Brain-T5’s summarization quality is evaluated using **ROUGE** (Recall-Oriented Understudy for Gisting Evaluation), the standard metric for text summarization.
+ROUGE measures the degree of overlap between the model’s generated summaries and the ground-truth human-written ones, capturing how well the model reproduces key phrases and sentence structure from the reference. + +### How It Works + +In `train.py`, the evaluation loop computes ROUGE using: + +```py +scores = rouge_metric.compute(predictions=preds, references=refs, use_stemmer=True) +``` + +This returns four main sub-metrics: + +* ROUGE-1 – overlap of unigrams (single words). + +* ROUGE-2 – overlap of bigrams (two-word sequences). + +* ROUGE-L – measures the Longest Common Subsequence (LCS) between prediction and reference. + +* ROUGE-Lsum – a sentence-level variant of ROUGE-L, emphasizing structural similarity in multi-sentence outputs. + +Each score ranges from 0 to 1, where higher is better. ROUGE can be computed in terms of precision, recall, and F1, but this project uses the F1-form returned by the Hugging Face `evaluate` package, balancing both correctness and completeness. + +### Why It Matters + +* ROUGE-1 reflects general lexical similarity — whether the model uses similar vocabulary. + +* ROUGE-2 indicates phrase-level fluency — capturing short-range coherence. + +* ROUGE-L and ROUGE-Lsum capture long-range structure and sentence organization — essential for readability and factual flow in lay summaries. + +In practice, ROUGE-Lsum serves as the primary checkpoint criterion in train.py: + +```py +if rougeLsum > best_rougeLsum: + model.save_pretrained(args.output_dir) +``` +meaning the “best” model is whichever epoch achieves the highest ROUGE-Lsum score across validation. + +### Interpretation + +A strong model shows: + +* Gradual increases in ROUGE-1/2/L/Lsum across epochs. + +* Consistent correlation between lower loss and higher ROUGE scores. + +* Minimal gap between validation and test ROUGE, indicating good generalization. + +## Training Resuts: +Training was performed on the BioLaySumm 2025 - LaymanRRG opensource track, using FLAN-T5-Base with LoRA fine-tuning for 3 epochs. +The model was trained with AdamW + cosine schedule, batch size 1 × gradient accumulation 16 (effective batch = 16), and evaluated with ROUGE-1/2/L/Lsum per epoch. + +The full run was ran over the entire dataset (150,000 datapoints) per-epoch, while the medium zoom was from a much smaller dataset (2000 datapoints) but with a higer epoch count. + +1. Training Loss (full run) + + +This plot shows the training loss vs optimizer steps over the entire fine-tuning run. +The curve steadily declines and stabilises, showing smooth convergence without major oscillation — indicating that: + +* The learning rate and warm-up schedule were well-tuned. + +* Gradient accumulation was effective in maintaining numerical stability under training. + +* No gradient explosions or plateaus occurred. + +2. Training Loss (medium zoom) + + +This is a zoomed-in view of the mid-training regime, showing finer granularity of step-wise noise. +Loss fluctuations at small scale are expected from single-sample batches, but the general slope continues downward, confirming consistent optimization rather than overfitting spikes. + +3. Validation ROUGE Progress (full run) + + +This figure tracks ROUGE-1, ROUGE-2, ROUGE-L, and ROUGE-Lsum per epoch. + +Interpretation: + +* ROUGE-1 and ROUGE-L steadily improve and plateau by the third epoch, showing that lexical and long-span coherence both increased. + +* ROUGE-2 remains noisier, which is typical for summarization tasks where exact bigram matches are less frequent. + +* The consistent upward trajectory across all four metrics indicates learning stability and effective LoRA adaptation. + +4. Validation ROUGE (medium zoom) + + +This mid-range view highlights the epoch-to-epoch change more clearly: + +* Rapid early gains in the first epoch. + +* Smaller, diminishing returns after epoch 2, suggesting convergence. + +* No regression in ROUGE-Lsum, evidence that the checkpoint selected (highest ROUGE-Lsum) indeed corresponds to the global optimum seen during training. + +Overall, Brain-T5 demonstrates reliable convergence and solid generalisation across validation and test splits. +The model maintains smooth training dynamics and rising ROUGE performance without evidence of overfitting or divergence — validating the correctness of the pipeline in train.py and the dataset tokenization logic in dataset.py. Furthermore, doing full-passes over the dataset converges to a higher set of ROUGE values, making longer passes - less epochs preferrable over smaller passes - more epochs. + +### Pre-processing + +We keep the input text *as written* (no lowercasing, stop‑word removal or punctuation stripping) and rely on the FLAN‑T5 tokenizer to handle normalization. Concretely: + +- **Instruction prefix:** each input is prepended with a task prompt (e.g., `summarize: `) so the format matches FLAN‑T5’s instruction‑tuning. +- **Tokenization:** Hugging Face’s fast tokenizer for FLAN‑T5 (SentencePiece) encodes inputs/targets; the pad token is set to `` when missing (T5 requirement). Inputs are truncated/padded to **1024** tokens; targets to **256** tokens. +- **Label masking:** target padding positions are set to **-100** so they are ignored by the cross‑entropy loss during training. +- **Decoding safety:** we ensure `decoder_start_token_id` is defined so generation starts from a valid token. + +### Train/Validation/Test splits — justification + +- **Official splits when available.** If the dataset provides `train/validation/test`, we honor them exactly. +- **Self‑split when the test lacks references.** With `--self_split`, we create an **80/10/10** split *from the training partition only* to avoid leakage. A fixed random seed makes the partition reproducible. +- **Why 80/10/10 Split?** It gives the model the bulk of data for parameter estimation (80%), a sufficient **validation** slice (10%) for model selection/early‑stopping by ROUGE‑Lsum, and a **held‑out test** slice (10%) that is touched **once** at the end for unbiased reporting. Using different balances (e.g. 33/33/34) starves the component on training which requires the most compute, the training. If most of the dataset is committed to evaluation, the model starves and severely underperforms. The 20% used for ROUGE/Eval are valuable for evaluation, but evaluation does not dictate the quality of the model like training does, it only tests it. +- **Learning rate of 2e-4** - Conservative base LR for **LoRA‑only** updates on T5‑base. Stable with cosine decay and accumulation on small batches. +- **Weight_decay of 0.01** - Light L2 to regularise LoRA adapters (AdamW). Keeps updates from drifting without fighting low‑rank adaptation. +- **lora_r=8, lora_alpha=16, lora_dropout=0.05** Balanced adapter capacity vs stability. The **effective update scale** is roughly `lr × (alpha / r)` (= 2× lr here). `r=8` is a common sweet spot for T5‑base; small dropout helps generalisation without fighting instruction‑tuned priors. + + +## Dataset +Brain-T5 is trained on the BioLaySumm 2025 – LaymanRRG (Open-Source Track) dataset, hosted on Hugging Face under the identifier [BioLaySumm/BioLaySumm2025-LaymanRRG-opensource-track](https://huggingface.co/datasets/BioLaySumm/BioLaySumm2025-LaymanRRG-opensource-track). + +This dataset is specifically curated for the layperson summarisation of biomedical text. +Each entry contains a technical radiology report paired with a human-written lay summary, enabling fine-tuning of models for domain translation between clinical and plain language. + +### Structure + +Each example includes two main fields: + +* `radiology_report`: The source input - detailed, jargon-heavy text extracted from radiology or clinical notes. +* `layman_report`: The target output - a simplified explanation written for a general audience. + +During preprocessing, dataset.py automatically prefixes each input with "summarize: " for FLAN-T5 instruction consistency, tokenizes both columns using the model’s tokenizer, and pads sequences for batch training. + +### Why BioLaySumm? + +BioLaySumm provides: + +* Authentic biomedical phrasing, exposing the model to realistic clinical structure and terminology. + +* Human-validated lay summaries, ensuring stylistic and semantic accuracy for non-expert readability. + +* Consistent formatting, ideal for instruction-based models like FLAN-T5 that thrive on aligned input/output pairs. + +Together, these qualities make BioLaySumm the ideal foundation for training Brain-T5 to bridge the gap between clinical documentation and human-understandable summaries. + + +# The FLAN-T5 Model +## What is T5? +T5 (Text-to-Text Transfer Transformer) is a transformer model built completely on a text-to-text framework. This framework treats every task in Natural Language Processing (NLP), whether it be machine translation, summarisation, or question-answering, as a process of taking text as input and producing text as output. This unification allows the same model architecture, objective function, and training procedure to be applied across all tasks, massively simplifying the entire NLP training pipeline. + + +The super summarised explanation on how the model works (provided by OpenAI's ChatGPT) is: + + +1. Tokenise → Embed → + Position. Words become vectors, add position info so the model knows order. +2. Encoder (repeated N×): + - Self-attention: each word looks at all other words to decide what matters. + - Feed-forward: a per-token mini-MLP to transform features. + - Add & Norm: residual skip + layer norm to keep training stable.
Output = contextual vectors for every input token. + +3. Decoder input: start with / and previously generated tokens shifted right. +4. Decoder block (repeated N×): + - Masked self-attention: looks only at past output tokens (mask stops peeking at the future). + - Cross-attention: queries the encoder outputs so the decoder can “look up” relevant parts of the input. + - Feed-forward and Add & Norm again. + +5. Linear → Softmax: turn the decoder’s last vector into a probability over the vocabulary; pick the next token; loop 4–5 until done. + + + +This is a representation of how T5 unifies all forms of text-to-text input/output to heavily generalise its use case and simply learning. + +## What is FLAN-T5? +FLAN-T5 (Fine-tuned LAnguate Net T5) is an enhanced version of the original [T5](https://medium.com/analytics-vidhya/t5-a-detailed-explanation-a0ac9bc53e51) but fine-tuned using a technique called instruction tuning. +During training, FLAN-T5 is exposed to a massive number of tasks that are all formatted as natural language instructions (e.g. "Answer the following question: ..."). This training paradigm significantly improves the model's ability to: + +1. **Follow instructions** since it is built on user prompts instead of general text data. +2. **Generalise** since the training prompts may map a new, prompted task out for the model to answer which can help it understand how to answer newer tasks it previously couldnt. +3. **Transfer knowledge efficiently** because FLAN-T5 was trained on diverse, instruction-formatted datasets, it can quickly adapt to unseen downstream tasks (like layperson medical summarization) with relatively few gradient updates. +4. **Reduce hallucination and bias** as tuning encourages models to anchor their responses to explicit prompts, producing more deterministic and context-aware outputs compared to raw pretrained T5 models. + +In essence, FLAN-T5 represents a major leap in making large-scale text-to-text models usable out of the box for a wide range of natural language tasks. Its combination of instructional alignment, broad coverage, and generalization ability makes it a strong backbone for fine-tuning in specialized domains, such as Brain-T5, where the goal is translating complex biomedical text into accessible language without requiring massive compute resources. + + diff --git a/recognition/Project13-TristanGreen/assets/images/braint5.png b/recognition/Project13-TristanGreen/assets/images/braint5.png new file mode 100644 index 000000000..cab463a0a Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/braint5.png differ diff --git a/recognition/Project13-TristanGreen/assets/images/loss_curve_full.png b/recognition/Project13-TristanGreen/assets/images/loss_curve_full.png new file mode 100644 index 000000000..8a64c40e9 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/loss_curve_full.png differ diff --git a/recognition/Project13-TristanGreen/assets/images/loss_curve_med.png b/recognition/Project13-TristanGreen/assets/images/loss_curve_med.png new file mode 100644 index 000000000..3a1067539 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/loss_curve_med.png differ diff --git a/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_full.png b/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_full.png new file mode 100644 index 000000000..e820472c4 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_full.png differ diff --git a/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_med.png b/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_med.png new file mode 100644 index 000000000..de313bae5 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/rouge_val_curve_med.png differ diff --git a/recognition/Project13-TristanGreen/assets/images/t5architecture.jpg b/recognition/Project13-TristanGreen/assets/images/t5architecture.jpg new file mode 100644 index 000000000..a970ff094 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/t5architecture.jpg differ diff --git a/recognition/Project13-TristanGreen/assets/images/t5simple.png b/recognition/Project13-TristanGreen/assets/images/t5simple.png new file mode 100644 index 000000000..507b18870 Binary files /dev/null and b/recognition/Project13-TristanGreen/assets/images/t5simple.png differ diff --git a/recognition/Project13-TristanGreen/chat.py b/recognition/Project13-TristanGreen/chat.py new file mode 100644 index 000000000..74add8607 --- /dev/null +++ b/recognition/Project13-TristanGreen/chat.py @@ -0,0 +1,58 @@ +""" +------------------------------------------------------------ + Interactive CLI for Brain-T5 (Chat Mode) + ----------------------------------------------------------- + Description: + Lightweight interface for real-time summarization queries. + Runs inference loop over the fine-tuned LoRA FLAN-T5 model. + + Usage: + $ python chat.py --model_dir runs/flan_t5_base_lora_biolaysumm + + Notes: + - Press Enter to re-prompt; type 'exit' or 'quit' to stop. +------------------------------------------------------------ +""" +import torch +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +from peft import PeftModel +import argparse + +p = argparse.ArgumentParser() + +p.add_argument("--model_dir", required=True) + +# --- config --- +ADAPTER_DIR = p.parse_args().model_dir +BASE_MODEL = "google/flan-t5-base" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +PREFIX = "summarize: " +MAX_INPUT_LEN = 1024 +MAX_NEW_TOKENS = 256 +NUM_BEAMS = 4 + + +# --- load model --- +print("Loading model...") +tok = AutoTokenizer.from_pretrained(ADAPTER_DIR) +base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) +model = PeftModel.from_pretrained(base, ADAPTER_DIR).to(DEVICE).eval() +print("Ready.") + +# --- chat loop --- +while True: + user = input("\n🧠 You: ").strip() + if not user: + continue + if user.lower() in {"exit", "quit", "q"}: + print("Bye.") + break + + enc = tok(PREFIX + user, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LEN).to(DEVICE) + with torch.inference_mode(): + out = model.generate(**enc, + max_new_tokens=MAX_NEW_TOKENS, + num_beams=NUM_BEAMS, + no_repeat_ngram_size=3, + early_stopping=True) + print("\n🤖 Model:", tok.decode(out[0], skip_special_tokens=True)) diff --git a/recognition/Project13-TristanGreen/dataset.py b/recognition/Project13-TristanGreen/dataset.py new file mode 100644 index 000000000..eb60cdf6a --- /dev/null +++ b/recognition/Project13-TristanGreen/dataset.py @@ -0,0 +1,130 @@ +""" +------------------------------------------------------------ + Dataset Loader and Preprocessing for Brain-T5 + ----------------------------------------------------------- + Description: + Handles dataset intake and preprocessing for FLAN-T5 fine-tuning. + Supports Hugging Face (BioLaySumm) datasets, CSV, or JSONL inputs. + + Key Components: + - make_datasets(): loads and tokenizes splits (train/val/test). + - Seq2SeqCollatorFast: dynamic padding & label masking for T5. + + Notes: + - Automatically prefixes "summarize: " to each input. + - Pads to model’s max token length. + - Masks tokens in labels with -100 for CrossEntropyLoss. +------------------------------------------------------------ +""" +from __future__ import annotations +from typing import Optional, List, Dict +import torch +from datasets import load_dataset +from transformers import AutoTokenizer +from torch.nn.utils.rnn import pad_sequence + +DATASET_ID = "BioLaySumm/BioLaySumm2025-LaymanRRG-opensource-track" +INPUT_COL = "radiology_report" +TARGET_COL = "layman_report" + +# Collator: batch pad inputs/labels and map pad tokens in labels to -100 (ignored by CE loss). +# pad_to_multiple_of lets you round sequence lengths (e.g., to 8/16/32) for Tensor Core efficiency. +class Seq2SeqCollatorFast: + def __init__(self, tokenizer, label_pad_token_id=-100, pad_to_multiple_of=None): + self.tok = tokenizer + self.label_pad_token_id = label_pad_token_id + self.pad_to_multiple_of = pad_to_multiple_of + + def _maybe_pad_to_multiple(self, tensor, pad_value): + if self.pad_to_multiple_of is None: + return tensor + L = tensor.size(1) + if L % self.pad_to_multiple_of == 0: + return tensor + add = self.pad_to_multiple_of - (L % self.pad_to_multiple_of) + return torch.nn.functional.pad(tensor, (0, add), value=pad_value) + + def __call__(self, feats: List[Dict[str, torch.Tensor]]): + # IMPORTANT: convert tokenizer pad tokens in labels to -100 so loss ignores padded positions. + ids = [f["input_ids"] if isinstance(f["input_ids"], torch.Tensor) else torch.tensor(f["input_ids"]) for f in feats] + am = [f["attention_mask"] if isinstance(f["attention_mask"], torch.Tensor) else torch.tensor(f["attention_mask"]) for f in feats] + labs = [f["labels"] if isinstance(f["labels"], torch.Tensor) else torch.tensor(f["labels"]) for f in feats] + + pad_id = self.tok.pad_token_id + ids = pad_sequence(ids, batch_first=True, padding_value=pad_id) + am = pad_sequence(am, batch_first=True, padding_value=0) + labs = pad_sequence(labs, batch_first=True, padding_value=pad_id) + labs = labs.masked_fill(labs.eq(pad_id), self.label_pad_token_id) + + ids = self._maybe_pad_to_multiple(ids, pad_id) + am = self._maybe_pad_to_multiple(am, 0) + labs = self._maybe_pad_to_multiple(labs, self.label_pad_token_id) + return {"input_ids": ids, "attention_mask": am, "labels": labs} + +# Build tokenizer + HF datasets with optional self-split (80/10/10). +# Ensures input instruction prefix and truncation to max lengths. +def make_datasets( + tokenizer_name: str = "google/flan-t5-base", + train_split: str = "train", + val_split: Optional[str] = "validation", + test_split: Optional[str] = "test", + max_input_len: int = 1024, + max_target_len: int = 256, + prefix_text: str = "summarize: ", + *, + self_split: bool = False, + self_split_seed: int = 1337, + self_split_val: float = 0.1, + self_split_test: float = 0.1, +): + + tok = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + ds = load_dataset(DATASET_ID) + + from datasets import DatasetDict + + # Vectorize one batch: add instruction prefix, tokenize src/tgt independently, attach 'labels'. + if self_split: + base = ds["train"].train_test_split(test_size=self_split_test, seed=self_split_seed) + train_part = base["train"] + test_part = base["test"] + vt = train_part.train_test_split( + test_size=self_split_val / (1.0 - self_split_test), seed=self_split_seed) + ds = DatasetDict({ + "train": vt["train"], + "validation": vt["test"], + "test": test_part, + }) + + # Validate required columns exist + for split in [s for s in [train_split, val_split, test_split] if s and s in ds]: + cols = ds[split].column_names + if INPUT_COL not in cols or TARGET_COL not in cols: + raise KeyError(f"Expected columns '{INPUT_COL}', '{TARGET_COL}' in split '{split}', found {cols}") + + # Vectorize one batch: add instruction prefix, tokenize src/tgt independently, attach 'labels'. + def encode_batch(batch): + srcs = [prefix_text + s for s in batch[INPUT_COL]] + enc = tok(srcs, max_length=max_input_len, truncation=True) + tgt = tok(text_target=batch[TARGET_COL], max_length=max_target_len, truncation=True) + enc["labels"] = tgt["input_ids"] + return enc + + remove_cols = ds[train_split].column_names + train_proc = ds[train_split].map(encode_batch, batched=True, remove_columns=remove_cols, desc="Tokenizing train") + + val_proc = None + if val_split and val_split in ds: + remove_cols_val = ds[val_split].column_names + val_proc = ds[val_split].map(encode_batch, batched=True, remove_columns=remove_cols_val, desc="Tokenizing val") + + test_proc = None + if test_split and test_split in ds: + remove_cols_test = ds[test_split].column_names + test_proc = ds[test_split].map(encode_batch, batched=True, remove_columns=remove_cols_test, desc="Tokenizing test") + + collator = Seq2SeqCollatorFast(tok, label_pad_token_id=-100, pad_to_multiple_of=None) + return tok, train_proc, val_proc, test_proc, collator diff --git a/recognition/Project13-TristanGreen/modules.py b/recognition/Project13-TristanGreen/modules.py new file mode 100644 index 000000000..896768ae3 --- /dev/null +++ b/recognition/Project13-TristanGreen/modules.py @@ -0,0 +1,108 @@ +""" +------------------------------------------------------------ + Model Utilities for Brain-T5 + ----------------------------------------------------------- + Description: + Provides helper functions for loading base models and attaching + LoRA adapters to target layers of FLAN-T5. + + Key Functions: + - load_base_model(): loads pretrained T5/FLAN-T5 with dtype control. + - attach_lora(): injects trainable low-rank adapters for fine-tuning. + + Notes: + - Uses PEFT (Parameter-Efficient Fine-Tuning) via Hugging Face. + - Keeps original model frozen except LoRA-injected parameters. +------------------------------------------------------------ +""" +from __future__ import annotations +from typing import Optional, Dict, Any, List + +import torch +from transformers import ( + AutoModelForSeq2SeqLM, + AutoTokenizer, +) + +try: + from peft import LoraConfig, get_peft_model, PeftModel + PEFT_AVAILABLE = True +except Exception: + PEFT_AVAILABLE = False + +# Use fast tokenizer; default pad_token from eos_token if missing (required by T5 decoding). +def get_tokenizer(name: str = "google/flan-t5-base"): + tok = AutoTokenizer.from_pretrained(name, use_fast=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + +# Load FLAN-T5 with dtype/device_map options. +# Ensure decoder_start_token_id is set so generation starts from a valid token. +def load_base_model( + name: str = "google/flan-t5-base", + dtype: Optional[torch.dtype] = torch.float16, + device_map: Optional[str] = None, +): + model = AutoModelForSeq2SeqLM.from_pretrained( + name, + dtype=dtype, + device_map=device_map, + ) + if getattr(model.config, "decoder_start_token_id", None) is None: + model.config.decoder_start_token_id = model.config.pad_token_id + return model + +# Inject LoRA on attention projections (q/k/v/o). Bias=none keeps adapter minimal. +# r/alpha/dropout control rank, scaling, and regularization of the adapters. +def attach_lora(model, r: int = 8, alpha: int = 16, dropout: float = 0.05, target_modules: Optional[List[str]] = None): + if not PEFT_AVAILABLE: + raise RuntimeError("peft not installed. `pip install peft` to use LoRA.") + if target_modules is None: + target_modules = ["q", "k", "v", "o"] + cfg = LoraConfig( + r=r, lora_alpha=alpha, lora_dropout=dropout, + target_modules=target_modules, bias="none", task_type="SEQ_2_SEQ_LM", + ) + return get_peft_model(model, cfg) + +# Convenience generation wrapper (batched): +# - Applies optional "summarize: " prefix. +# - Pads/truncates, moves to device, decodes without special tokens. +# - Beam search defaults tuned for readability over speed. +@torch.no_grad() +def generate( + model, + tokenizer, + inputs: List[str], + max_input_len: int = 1024, + max_new_tokens: int = 256, + num_beams: int = 4, + no_repeat_ngram_size: int = 3, + length_penalty: float = 1.0, + add_prefix: bool = True, + prefix_text: str = "summarize: ", + device: Optional[str] = None, +) -> List[str]: + model.eval() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + batch = [(prefix_text + x) if add_prefix else x for x in inputs] + enc = tokenizer( + batch, + max_length=max_input_len, + truncation=True, + padding=True, + return_tensors="pt", + ).to(device) + + out = model.generate( + **enc, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, + length_penalty=length_penalty, + early_stopping=True, + ) + return tokenizer.batch_decode(out, skip_special_tokens=True) diff --git a/recognition/Project13-TristanGreen/predict.py b/recognition/Project13-TristanGreen/predict.py new file mode 100644 index 000000000..1e5efcb2b --- /dev/null +++ b/recognition/Project13-TristanGreen/predict.py @@ -0,0 +1,110 @@ +""" +------------------------------------------------------------ + Prediction and Inference for Brain-T5 + ----------------------------------------------------------- + Description: + Generates summaries from fine-tuned LoRA adapters. + Supports both single-text (--text) and batch (--jsonl) modes. + + Key Functions: + - load_model(): loads base + LoRA adapter for inference. + - generate_batch(): batched generation with beam search. + + Notes: + - Outputs JSONL with 'prediction' field appended to each input. + - Uses max_new_tokens and num_beams for generation control. +------------------------------------------------------------ +""" +import os, argparse, json +from typing import List +import torch +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from peft import PeftModel + +# Load tokenizer from adapter_dir (ensures identical preproc as training), attach LoRA onto base. +# dtype=float16 only when CUDA+--fp16; always move model to device and set eval(). +def load_model(adapter_dir: str, base_model: str, fp16: bool): + tok = AutoTokenizer.from_pretrained(adapter_dir) + dtype = torch.float16 if (fp16 and torch.cuda.is_available()) else torch.float32 + base = AutoModelForSeq2SeqLM.from_pretrained(base_model, dtype=dtype) + model = PeftModel.from_pretrained(base, adapter_dir) + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device).eval() + if getattr(model.config, "decoder_start_token_id", None) is None: + model.config.decoder_start_token_id = model.config.pad_token_id + return tok, model, device + +def chunk(lst: List[str], n: int): + for i in range(0, len(lst), n): + yield lst[i:i+n] + +# Generate a batch with beam search; always prefix with instruction to match training distribution. +def generate_batch(model, tok, device, texts: List[str], max_in: int, max_new: int, beams: int, prefix: str): + batch = [prefix + t for t in texts] + enc = tok(batch, return_tensors="pt", truncation=True, max_length=max_in, padding=True).to(device) + with torch.inference_mode(): + out = model.generate( + **enc, + max_new_tokens=max_new, + num_beams=beams, + no_repeat_ngram_size=3, + length_penalty=1.0, + early_stopping=True, + use_cache=True, + ) + return tok.batch_decode(out, skip_special_tokens=True) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--adapter_dir", required=True, help="Path to saved LoRA adapters") + ap.add_argument("--base_model", default="google/flan-t5-base") + ap.add_argument("--text", default=None, help="Single input string to summarize") + ap.add_argument("--jsonl", default=None, help="Path to JSONL with an input column") + ap.add_argument("--input_col", default="report") + ap.add_argument("--out_path", default="predictions.jsonl") + ap.add_argument("--batch_size", type=int, default=8) + ap.add_argument("--max_input_len", type=int, default=1024) + ap.add_argument("--max_new_tokens", type=int, default=128) + ap.add_argument("--beams", type=int, default=4) + ap.add_argument("--prefix", default="summarize: ") + ap.add_argument("--fp16", action="store_true") + args = ap.parse_args() + + # Modes: + # --text "..." -> print single summary to stdout + # --jsonl file.jsonl -> stream predictions and write to --out_path + # --input_col selects field in JSONL to summarize (default: 'report') + + + tok, model, device = load_model(args.adapter_dir, args.base_model, args.fp16) + + # single text mode + if args.text is not None: + outs = generate_batch(model, tok, device, [args.text], args.max_input_len, args.max_new_tokens, args.beams, args.prefix) + print(outs[0]) + return + + # file mode + if args.jsonl is None: + raise SystemExit("Provide --text or --jsonl") + rows = [] + with open(args.jsonl, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + rows.append(json.loads(line)) + + inputs = [r.get(args.input_col, "") for r in rows] + preds = [] + for block in chunk(inputs, args.batch_size): + preds.extend(generate_batch(model, tok, device, block, args.max_input_len, args.max_new_tokens, args.beams, args.prefix)) + + # write JSONL with predictions + with open(args.out_path, "w", encoding="utf-8") as w: + for r, p in zip(rows, preds): + out = dict(r) + out["prediction"] = p + w.write(json.dumps(out, ensure_ascii=False) + "\n") + print(f"wrote {len(preds)} predictions to {args.out_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/Project13-TristanGreen/requirements.txt b/recognition/Project13-TristanGreen/requirements.txt new file mode 100644 index 000000000..586e41917 --- /dev/null +++ b/recognition/Project13-TristanGreen/requirements.txt @@ -0,0 +1,10 @@ +transformers==4.57.0 +peft==0.17.1 +evaluate==0.4.6 +datasets==4.1.1 +tqdm==4.67.1 +numpy==2.2.5 +matplotlib==3.10.5 +absl-py==2.3.1 +nltk==3.9.2 +rouge_score==0.1.2 \ No newline at end of file diff --git a/recognition/Project13-TristanGreen/train.py b/recognition/Project13-TristanGreen/train.py new file mode 100644 index 000000000..4072069c6 --- /dev/null +++ b/recognition/Project13-TristanGreen/train.py @@ -0,0 +1,504 @@ +""" +------------------------------------------------------------ + Brain-T5: FLAN-T5 + LoRA Fine-Tuning Pipeline + ----------------------------------------------------------- + Description: + Main training script for Brain-T5. Handles dataset loading, + LoRA adapter attachment, training loop, logging, and evaluation. + + Key Functions: + - run_eval(): computes ROUGE scores on validation/test splits. + - log_val_rouge_row(): logs per-epoch ROUGE metrics to CSV. + - plot_loss_curve(), plot_val_rouge_curve(): generate plots. + + Notes: + - Uses AdamW + cosine schedule. + - Gradient accumulation supported via --accum. + - Mixed precision enabled via torch.amp. + - Best model checkpoint chosen by highest ROUGE-Lsum. +------------------------------------------------------------ +""" +import os, json, math, argparse, random, time, uuid, csv +from typing import Optional +import numpy as np +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import get_cosine_schedule_with_warmup +from tqdm.auto import tqdm + +import matplotlib +matplotlib.use("Agg") # headless safe +import matplotlib.pyplot as plt + +from dataset import make_datasets # locked dataset helper +from modules import load_base_model, attach_lora +import evaluate + +# ----------------------- +# Utils: logging & eval +# ----------------------- + +def csv_logger(path: str): + """Append-mode CSV logger for training steps; writes header if file is empty.""" + f = open(path, "a", newline="", encoding="utf-8") + w = csv.writer(f) + if f.tell() == 0: + w.writerow(["timestamp", "epoch", "global_step", "loss"]) + return f, w + +def set_seed(seed: int = 1337): + random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +# Eval loop: generate summaries for a dataloader and compute ROUGE. +# Notes: +# - We re-enable use_cache for fast generation. +# - Convert label pad (-100) back to tokenizer.pad_token_id before decoding refs. +# - no_repeat_ngram_size=3 reduces trivial repetition. +def run_eval(model, tokenizer, loader: Optional[DataLoader], device, args, rouge_metric): + if loader is None: + return None + model.eval() + use_cache_was = getattr(model.config, "use_cache", True) + model.config.use_cache = True + + preds, refs = [], [] + with torch.inference_mode(): + for vb in tqdm(loader, desc="Eval", unit="batch", dynamic_ncols=True): + vb = {k: v.to(device) for k, v in vb.items()} + # Beam search generation for evaluation (deterministic-ish) + gen_out = model.generate( + input_ids=vb["input_ids"], + attention_mask=vb["attention_mask"], + max_new_tokens=args.eval_max_new_tokens, + num_beams=args.eval_beams, + no_repeat_ngram_size=3, + length_penalty=1.0, + early_stopping=True, + ) + pred_txt = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + tgt = vb["labels"].clone() + tgt[tgt == -100] = tokenizer.pad_token_id + ref_txt = tokenizer.batch_decode(tgt, skip_special_tokens=True) + preds.extend(pred_txt); refs.extend(ref_txt) + + scores = rouge_metric.compute(predictions=preds, references=refs, use_stemmer=True) + keep = {k: float(v) for k, v in scores.items() if k in {"rouge1","rouge2","rougeL","rougeLsum"}} + model.config.use_cache = use_cache_was + return keep + +# ----------------------- +# Per-epoch ROUGE logging +# ----------------------- + +RUN_ID = os.environ.get("RUN_ID", str(uuid.uuid4())[:8]) + +# Persist per-epoch validation ROUGE to CSV for plotting and auditing. +# If multiple runs append to same file, we keep last row per epoch when plotting. +def log_val_rouge_row(run_dir, epoch, scores): + """ + Append one row per epoch: + run_id, timestamp, epoch, rouge1, rouge2, rougeL, rougeLsum + Writes header if file is empty. + """ + path = os.path.join(run_dir, "history_val.csv") + new_file = (not os.path.exists(path)) or (os.path.getsize(path) == 0) + with open(path, "a", newline="", encoding="utf-8") as f: + w = csv.writer(f) + if new_file: + w.writerow(["run_id","timestamp","epoch","rouge1","rouge2","rougeL","rougeLsum"]) + w.writerow([ + RUN_ID, + int(time.time()), + int(epoch), + float(scores.get("rouge1", 0.0)), + float(scores.get("rouge2", 0.0)), + float(scores.get("rougeL", 0.0)), + float(scores.get("rougeLsum", 0.0)), + ]) + +# ----------------------- +# Plotting helpers +# ----------------------- + +# Plot validation ROUGE vs epoch. +# Robust to restarts: we select the *latest* row per epoch (by timestamp) to avoid stale re-runs. +def plot_val_rouge_curve(run_dir): + """ + Plot Validation ROUGE vs Epoch (robust): + - reads history_val.csv + - keeps only the latest row per epoch (by timestamp if present) + - sorts epochs ascending + - falls back to metrics_val.json if no rows + """ + import json + path = os.path.join(run_dir, "history_val.csv") + rows = [] + if os.path.exists(path) and os.path.getsize(path) > 0: + with open(path, newline="", encoding="utf-8") as f: + rows = list(csv.DictReader(f)) + + def _fallback_from_metrics(): + mv = os.path.join(run_dir, "metrics_val.json") + if not os.path.exists(mv): + print("[plot] no history_val.csv rows and no metrics_val.json; skipping val ROUGE plot") + return None + obj = json.load(open(mv, "r", encoding="utf-8")) + ep = int(obj.get("epoch", 1)) + s = obj.get("scores", obj) + return [ep], [float(s.get("rouge1", 0.0))], [float(s.get("rouge2", 0.0))], \ + [float(s.get("rougeL", 0.0))], [float(s.get("rougeLsum", 0.0))] + + if not rows: + vals = _fallback_from_metrics() + if vals is None: return + epochs, r1, r2, rL, rS = vals + else: + # latest row per epoch by timestamp if present; else by order + last_by_epoch = {} + for r in rows: + if "epoch" not in r: # malformed row + continue + try: + ep = int(r["epoch"]) + except Exception: + continue + ts = int(r.get("timestamp", 0)) if r.get("timestamp") else 0 + if (ep not in last_by_epoch) or (ts >= int(last_by_epoch[ep].get("timestamp", 0) or 0)): + last_by_epoch[ep] = r + + if not last_by_epoch: + vals = _fallback_from_metrics() + if vals is None: return + epochs, r1, r2, rL, rS = vals + else: + epochs = sorted(last_by_epoch.keys()) + def _f(e, k): + v = last_by_epoch[e].get(k, None) + return float(v) if v not in (None, "",) else 0.0 + r1 = [_f(e, "rouge1") for e in epochs] + r2 = [_f(e, "rouge2") for e in epochs] + rL = [_f(e, "rougeL") for e in epochs] + rS = [_f(e, "rougeLsum") for e in epochs] + + plt.figure() + plt.plot(epochs, r1, label="ROUGE-1") + plt.plot(epochs, r2, label="ROUGE-2") + plt.plot(epochs, rL, label="ROUGE-L") + plt.plot(epochs, rS, label="ROUGE-Lsum") + plt.xlabel("Epoch"); plt.ylabel("ROUGE") + plt.title("Validation ROUGE vs Epoch") + plt.grid(True, alpha=0.3) + plt.legend() + plt.tight_layout() + out = os.path.join(run_dir, "rouge_val_curve.png") + plt.savefig(out, dpi=160); plt.close() + print(f"[plot] wrote {out}") + +def plot_loss_curve(run_dir): + """ + Plot a single clean Training Loss vs Steps line even if train_log.csv + contains multiple runs or step resets. + Strategy: + - read train_log.csv + - split into segments whenever global_step decreases (new run appended) + - keep ONLY the last segment (latest run) + - sort by step, clip outlier spikes (1..99p), smooth with small moving average + """ + import numpy as np + + path = os.path.join(run_dir, "train_log.csv") + if not os.path.exists(path) or os.path.getsize(path) == 0: + print("[plot] no train_log.csv; skipping loss plot"); return + + rows = [] + with open(path, newline="", encoding="utf-8") as f: + rows = list(csv.DictReader(f)) + if not rows: + print("[plot] empty train_log.csv; skipping loss plot"); return + + # parse numeric + raw_steps, raw_losses = [], [] + for r in rows: + try: + raw_steps.append(int(float(r["global_step"]))) + raw_losses.append(float(r["loss"])) + except Exception: + continue + if not raw_steps: + print("[plot] no numeric rows in train_log.csv; skipping"); return + + # split into segments whenever step decreases (step reset = new run) + segs = [] + seg_s, seg_l = [raw_steps[0]], [raw_losses[0]] + for s, l in zip(raw_steps[1:], raw_losses[1:]): + if s < seg_s[-1]: # reset + segs.append((seg_s, seg_l)) + seg_s, seg_l = [s], [l] + else: + seg_s.append(s); seg_l.append(l) + segs.append((seg_s, seg_l)) + + # pick the LAST segment (most recent run) + steps, losses = segs[-1] + + # sort by step + order = np.argsort(steps) + steps = [steps[i] for i in order] + losses = [losses[i] for i in order] + + # clip extreme spikes for visualization (1..99 percentile) + lo, hi = np.percentile(losses, [1, 99]) + keep = [(lo <= v <= hi) for v in losses] + steps = [s for s, m in zip(steps, keep) if m] + losses = [v for v, m in zip(losses, keep) if m] + + # moving average smoothing + from collections import deque + def movavg(x, k=None): + if len(x) == 0: return x + if k is None: + k = max(5, min(25, len(x)//20)) # gentle default + out, q, s = [], deque(), 0.0 + for v in x: + q.append(v); s += v + if len(q) > k: s -= q.popleft() + out.append(s / len(q)) + return out + + sm = movavg(losses) + + plt.figure() + plt.plot(steps, sm) + plt.xlabel("Global step (optimizer)") + plt.ylabel("Loss") + plt.title("Training Loss vs Steps") + plt.grid(True, alpha=0.3) + plt.tight_layout() + out = os.path.join(run_dir, "loss_curve.png") + plt.savefig(out, dpi=160); plt.close() + print(f"[plot] wrote {out}") + +def plot_test_rouge_bar(run_dir): + path = os.path.join(run_dir, "metrics_test.json") + if not os.path.exists(path): + print("[plot] no metrics_test.json; skipping test bar"); return + m = json.load(open(path, "r", encoding="utf-8")) + if isinstance(m, dict) and "note" in m: + print(f"[plot] {m['note']} — skipping test bar"); return + labels = ["ROUGE-1","ROUGE-2","ROUGE-L","ROUGE-Lsum"] + vals = [float(m.get("rouge1",0.0)), float(m.get("rouge2",0.0)), + float(m.get("rougeL",0.0)), float(m.get("rougeLsum",0.0))] + plt.figure() + plt.bar(labels, vals) + plt.ylabel("Score"); plt.title("Test ROUGE (Held-out)") + plt.tight_layout() + out = os.path.join(run_dir, "rouge_test_bar.png") + plt.savefig(out, dpi=160); plt.close() + print(f"[plot] wrote {out}") + +# ----------------------- +# Main +# ----------------------- + +def main(): + p = argparse.ArgumentParser() + # No dataset args — we're locked to the RRG opensource track via make_datasets + p.add_argument("--model_name", default="google/flan-t5-base") + p.add_argument("--train_split", default="train") + p.add_argument("--val_split", default="validation") + p.add_argument("--test_split", default="test") + + # training + p.add_argument("--epochs", type=int, default=5) + p.add_argument("--lr", type=float, default=2e-4) + p.add_argument("--weight_decay", type=float, default=0.01) + p.add_argument("--batch_size", type=int, default=1) + p.add_argument("--accum", type=int, default=16) + p.add_argument("--warmup_steps", type=int, default=1000) + p.add_argument("--clip", type=float, default=1.0) + p.add_argument("--lora_r", type=int, default=8) + p.add_argument("--lora_alpha", type=int, default=16) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--output_dir", default="runs/flan_t5_base_lora_rrg") + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--fp16", action="store_true") + + # eval/generation + p.add_argument("--eval_max_new_tokens", type=int, default=128) + p.add_argument("--eval_beams", type=int, default=4) + p.add_argument("--eval_batch_size", type=int, default=8) + + # dev-speed controls + p.add_argument("--max_train_samples", type=int, default=None) + p.add_argument("--max_eval_samples", type=int, default=None) + p.add_argument("--max_test_samples", type=int, default=None) + + # optional self-split (because official opensource test has empty refs) + p.add_argument("--self_split", action="store_true", + help="If set, create custom 80/10/10 train/val/test from training data.") + p.add_argument("--self_split_val", type=float, default=0.1, + help="Proportion of data to use as validation if self_split.") + p.add_argument("--self_split_test", type=float, default=0.1, + help="Proportion of data to use as test if self_split.") + + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + set_seed(args.seed) + + # Fresh logs for this run (truncate old content so headers/steps align with current run). + open(os.path.join(args.output_dir, "train_log.csv"), "w").close() + open(os.path.join(args.output_dir, "history_val.csv"), "w").close() + + # tokenizer + datasets (locked to RRG via make_datasets) + tokenizer, train_ds, val_ds, test_ds, collator = make_datasets( + tokenizer_name=args.model_name, + train_split=args.train_split, + val_split=args.val_split, + test_split=args.test_split, + max_input_len=1024, + max_target_len=256, + prefix_text="summarize", + self_split=args.self_split, + self_split_val=args.self_split_val, + self_split_test=args.self_split_test, + ) + + # Optional subsetting + if args.max_train_samples is not None: + train_ds = train_ds.select(range(min(args.max_train_samples, len(train_ds)))) + if val_ds is not None and args.max_eval_samples is not None: + val_ds = val_ds.select(range(min(args.max_eval_samples, len(val_ds)))) + if test_ds is not None and args.max_test_samples is not None: + test_ds = test_ds.select(range(min(args.max_test_samples, len(test_ds)))) + + # DataLoaders + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, + collate_fn=collator, pin_memory=True, num_workers=0) + val_loader = None + if val_ds is not None: + val_loader = DataLoader(val_ds, batch_size=args.eval_batch_size, shuffle=False, + collate_fn=collator, pin_memory=True, num_workers=0) + + # model + LoRA + dtype = torch.float16 if (args.fp16 and torch.cuda.is_available()) else torch.float32 + model = load_base_model(args.model_name, dtype=dtype, device_map=None) + model = attach_lora(model, r=args.lora_r, alpha=args.lora_alpha, + dropout=args.lora_dropout, target_modules=["q","k","v","o"]) + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + + # Log params & hardware (artifact for report) + total = sum(p.numel() for p in model.parameters()) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + with open(os.path.join(args.output_dir, "params.json"), "w") as f: + json.dump({"total": total, "trainable": trainable, "ratio": trainable/total}, f, indent=2) + if torch.cuda.is_available(): + props = torch.cuda.get_device_properties(0) + with open(os.path.join(args.output_dir, "hardware.json"), "w") as f: + json.dump({"gpu_name": torch.cuda.get_device_name(0), + "total_vram_gb": round(props.total_memory/(1024**3),2), + "compute_capability": f"{props.major}.{props.minor}"}, + f, indent=2) + + # Logging + log_file, log_writer = csv_logger(os.path.join(args.output_dir, "train_log.csv")) + global_step = 0 + + # Optimizer & schedule: + # - AdamW with weight decay. + # - Cosine schedule with ~6% warmup (capped by --warmup_steps). + # - GradScaler enabled only when --fp16 on CUDA. + optim = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + total_steps = math.ceil(len(train_loader) / args.accum) * args.epochs + warmup = min(args.warmup_steps, int(0.06 * total_steps)) + sched = get_cosine_schedule_with_warmup(optim, warmup, total_steps) + scaler = torch.amp.GradScaler("cuda", enabled=(args.fp16 and device == "cuda")) + rouge_metric = evaluate.load("rouge") + + t0 = time.time() + best_rougeLsum = -1.0 + for epoch in range(1, args.epochs + 1): + model.train() + running = 0.0 + optim.zero_grad(set_to_none=True) + pbar = tqdm(train_loader, desc=f"Train e{epoch}", unit="batch", dynamic_ncols=True) + start_time = time.time() + step_in_epoch = 0 + + for batch in pbar: + step_in_epoch += 1; global_step += 1 + batch = {k: v.to(device) for k, v in batch.items()} + with torch.amp.autocast("cuda", enabled=(args.fp16 and device == "cuda")): + out = model(**batch) + loss = out.loss / args.accum + if not torch.isfinite(loss): + print(f"[WARN] non-finite loss at step {global_step}: {float(loss)} — skipping.") + optim.zero_grad(set_to_none=True); continue + + scaler.scale(loss).backward(); running += loss.item() + if (global_step % args.accum) == 0: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + scaler.step(optim); scaler.update(); sched.step() + optim.zero_grad(set_to_none=True) + + # CSV log (per optimizer step) + avg_loss = running / max(1, (step_in_epoch // args.accum)) + log_writer.writerow([time.time(), epoch, global_step, avg_loss]); log_file.flush() + + # live status + avg_loss = running / max(1, (step_in_epoch // args.accum)) + elapsed = time.time() - start_time + sps = (step_in_epoch * args.batch_size) / max(1e-6, elapsed) + # Live progress: smoothed loss and samples/sec for quick sanity checks. + pbar.set_postfix({"loss": f"{avg_loss:.4f}", "sps": f"{sps:.1f}"}) + + print(f"[epoch {epoch}] train_loss={avg_loss:.4f}") + + # validation at epoch end + scores = run_eval(model, tokenizer, val_loader, device, args, rouge_metric) + if scores is not None: + print(f"[epoch {epoch}] ROUGE (val): {scores}") + + # persist per-epoch history for plotting + log_val_rouge_row(args.output_dir, epoch, scores) + + rougeLsum = float(scores.get("rougeLsum", 0.0)) + if rougeLsum > best_rougeLsum: + best_rougeLsum = rougeLsum + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + with open(os.path.join(args.output_dir, "metrics_val.json"), "w") as f: + json.dump({"best_rougeLsum": best_rougeLsum, "epoch": epoch, "scores": scores}, f, indent=2) + print(f"[epoch {epoch}] saved best adapters to {args.output_dir}") + + # live plot after each epoch + plot_val_rouge_curve(args.output_dir) + + # timing + minutes = (time.time() - t0) / 60.0 + with open(os.path.join(args.output_dir, "time.json"), "w") as f: + json.dump({"minutes": minutes, "epochs": args.epochs}, f, indent=2) + + # test evaluation (held-out; for opensource we self-split or skip if refs missing) + if test_ds is not None: + test_loader = DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, + collate_fn=collator, pin_memory=True, num_workers=0) + test_scores = run_eval(model, tokenizer, test_loader, device, args, rouge_metric) + with open(os.path.join(args.output_dir, "metrics_test.json"), "w") as f: + json.dump(test_scores, f, indent=2) + print(f"[test] ROUGE: {test_scores}") + + # Always generate plots at the end for the report + plot_loss_curve(args.output_dir) + plot_val_rouge_curve(args.output_dir) + plot_test_rouge_bar(args.output_dir) + + print("done.") + +if __name__ == "__main__": + main()