Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
28d8a15
Initial commit.
Nov 3, 2025
bb7d523
Retrieved dataset.
Nov 5, 2025
3c6ddaa
Implemented initial dataset handler.
Nov 5, 2025
5dc857c
Implemented flan-t5-large using a basic health check prompt.
Nov 5, 2025
a542579
Implemented layperson summary generation on 1st element in the dataset.
Nov 5, 2025
1239674
Added assignment structure, changed prompt and model to t5-large. Cur…
Nov 7, 2025
210a10b
Made a fully functioning train.py. Currently lacks complete rouge eva…
Nov 8, 2025
df166ca
Fixed repo structure. Also fixed issue in github where two jac0be acc…
jac0be Nov 8, 2025
0d16dec
Added modules.py, which currently just builds and returns the base t5…
jac0be Nov 9, 2025
d1df701
Changed train.py to use the model generation in modules.py
jac0be Nov 9, 2025
cd32402
Fixed a bug where moving the model to cuda inside the helper caused d…
jac0be Nov 9, 2025
7a56000
Added held-out test split and saved ROUGE metrics to runs dir.
jac0be Nov 9, 2025
75467dd
Removed unused imports.
jac0be Nov 9, 2025
06acccb
Added logging of: parameter counts, GPU info, training time, as well …
jac0be Nov 9, 2025
b02f715
Added json logging for loss / val metrics. Also CSV/plot outputs.
jac0be Nov 9, 2025
88b2524
Refactored csv + plot generation into a seperate helper function.
jac0be Nov 9, 2025
fafd8c2
Added a simple predict.py that uses the best saved checkpoint to summ…
jac0be Nov 10, 2025
c3f485f
Added an interactive chat interface for predict.py
jac0be Nov 10, 2025
8b51fe3
Made eval.py and moved the evaluation to post-training. Also made mai…
jac0be Nov 12, 2025
655029e
Optional hold out test-split, used for hyperparameter tuning.
jac0be Nov 13, 2025
1dfcc8b
Added report indexing to predict.py, which allows specifying an index…
jac0be Nov 13, 2025
5f7943c
Added more explanatory comments for train.py
jac0be Nov 13, 2025
6b6a992
Restructured repository in preparation of pull request.
jac0be Nov 13, 2025
9ef1a54
Added a starting report / README.md
jac0be Nov 13, 2025
86014f9
Added a concise background knowledge section to README.
jac0be Nov 13, 2025
bd6cdeb
Finished the README report.
jac0be Nov 13, 2025
0a1e5c4
Added requirements.txt. Tested and working on WSL.
jac0be Nov 13, 2025
5eda2b4
Final code cleanup ahead of PR.
jac0be Nov 13, 2025
1296b1a
Corrected image paths in README following folder renaming.
jac0be Nov 13, 2025
f9b6c40
Updated requirements.txt to include rouge score dependencies.
jac0be Nov 13, 2025
8da83c4
Minor changes to README.md to tighten reasoning.
jac0be Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions recognition/Flan_T5_s45893623/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Byte-compiled / cached files
__pycache__/
*.py[cod]
*.pyo
*.pyd
*.py.class
# Generated folders, checkpoints and logs
runs/
eval/
archive/
288 changes: 288 additions & 0 deletions recognition/Flan_T5_s45893623/README.md

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions recognition/Flan_T5_s45893623/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Simple dataset handler
from datasets import load_dataset
from torch.utils.data import Dataset

class BioSummDataset(Dataset):
def __init__(self, split="train", do_train_split=False):
ds = load_dataset("BioLaySumm/BioLaySumm2025-LaymanRRG-opensource-track")
# We optionally split the training data to get a held-out test set.
if do_train_split == True and split in ["train", "test"]: # NOTE: You must construct both train and test using do_train_split=True, otherwise the splits won't be executed for both
full_train = ds["train"]
split_ds = full_train.train_test_split(test_size=0.1, seed=42) # keep seed set at 42 to keep splits consistent.

self.ds = split_ds["train"] if split == "train" else split_ds["test"]
# Otherwise use the default train,validation,test split in BioSumm (NOTE: default test does not contain layman summary)
else:
self.ds = ds[split]

def __len__(self):
return len(self.ds)

def __getitem__(self, i):
# we do not care about the image or source, only text
x = self.ds[i]["radiology_report"]
y = self.ds[i]["layman_report"]
return x, y

# Test main: prints out the first 10 in the dataset
if __name__ == "__main__":
train_ds = BioSummDataset(split="train")
val_ds = BioSummDataset(split="validation")
test_ds = BioSummDataset(split="test")
for i in range(0, 10):
print(f"Train [{i}]: {train_ds[i][0]}")
print(f"Val [{i}]: {val_ds[i][0]}")
print(f"Test [{i}]: {test_ds[i][0]}")
142 changes: 142 additions & 0 deletions recognition/Flan_T5_s45893623/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# eval.py
# Computes plots and csvs using logged data from the training run.
import os, sys, shutil

def save_curves_and_plots_from_run(run_dir: str):
"""
Rebuild CSVs and plots using the jsonl logs in a run-like directory (e.g., eval/<run_name>).
Handles repeated 'step' values by constructing a monotonic 'gstep' and an inferred 'epoch'.
"""
import csv, json
import matplotlib.pyplot as plt
import os

tl_jsonl = os.path.join(run_dir, "train_loss.jsonl")
vr_jsonl = os.path.join(run_dir, "val_rouge.jsonl")

# Load train loss
raw_loss = []
if os.path.isfile(tl_jsonl):
with open(tl_jsonl, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
raw_loss.append({"step": int(obj.get("step", 0)),
"loss": float(obj.get("loss", 0.0))})
except Exception:
pass

# Rebuild epoch + global step
loss_hist = []
if raw_loss:
epoch = 1
prev_step = -1
carry = 0
last_epoch_max = 0

for r in raw_loss:
s = r["step"]
# detect wrap (new epoch) when step doesn't increase
if s <= prev_step:
carry += max(last_epoch_max, prev_step)
last_epoch_max = 0
epoch += 1
last_epoch_max = max(last_epoch_max, s)
gstep = carry + s
loss_hist.append({"epoch": epoch, "step": s, "gstep": gstep, "loss": r["loss"]})
prev_step = s

# Load validation rouge
val_hist = []
if os.path.isfile(vr_jsonl):
with open(vr_jsonl, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
row = {"epoch": int(obj.get("epoch", 0))}
for k in ("rouge1", "rouge2", "rougeL", "rougeLsum"):
if k in obj:
row[k] = float(obj[k])
val_hist.append(row)
except Exception:
pass

# Write CSVs
if loss_hist:
loss_csv = os.path.join(run_dir, "train_loss.csv")
with open(loss_csv, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=["epoch", "step", "gstep", "loss"])
w.writeheader()
# already chronological, ensure by gstep
w.writerows(sorted(loss_hist, key=lambda d: d["gstep"]))

if val_hist:
fields = sorted({k for d in val_hist for k in d.keys()})
val_csv = os.path.join(run_dir, "val_rouge.csv")
with open(val_csv, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fields)
w.writeheader()
w.writerows(sorted(val_hist, key=lambda d: d.get("epoch", 0)))

# Create Plots
if loss_hist:
xs = [d["gstep"] for d in loss_hist]
ys = [d["loss"] for d in loss_hist]
plt.figure()
plt.plot(xs, ys)
plt.xlabel("global step"); plt.ylabel("loss"); plt.title("train loss")
plt.tight_layout()
plt.savefig(os.path.join(run_dir, "train_loss.png"))
plt.close()

if val_hist:
# single plot, multiple lines (rouge1, rouge2, rougeL, rougeLsum) vs epoch
epochs = [d.get("epoch", 0) for d in val_hist]
metrics = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

plt.figure()
for metric_key in metrics:
ys = [d.get(metric_key, 0.0) for d in val_hist]
plt.plot(epochs, ys, marker="o", label=metric_key)
plt.xlabel("epoch")
plt.ylabel("ROUGE score")
plt.title("validation ROUGE over epochs")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(run_dir, "val_rouge.png"))
plt.close()

def main():
if len(sys.argv) != 2:
print("Usage: python eval.py <run_directory>")
sys.exit(1)

run_dir = sys.argv[1]
if not os.path.isdir(run_dir):
print(f"Error: '{run_dir}' is not a valid directory")
sys.exit(1)

run_name = os.path.basename(os.path.normpath(run_dir))
eval_dir = os.path.join("eval", run_name)

# create eval/run_name folder, copy jsonls so function can read them
os.makedirs(eval_dir, exist_ok=True)
for fname in ("train_loss.jsonl", "val_rouge.jsonl"):
src = os.path.join(run_dir, fname)
dst = os.path.join(eval_dir, fname)
if os.path.isfile(src):
shutil.copy2(src, dst)

print(f"Rebuilding plots into {eval_dir} ...")
save_curves_and_plots_from_run(eval_dir)
print(f"Done. Outputs saved under {eval_dir}")

# Usage: python eval.py runs/flan-t5-lora
if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions recognition/Flan_T5_s45893623/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# modules.py
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, TaskType

# returns fast tokenizer for seq2seq models
def load_tokenizer(model_name: str):
return AutoTokenizer.from_pretrained(model_name, use_fast=True)

# builds and returns the base-flan-t5 with attached LoRA adapters.
def build_flan_t5_with_lora(model_name="google/flan-t5-base", r=8, alpha=16, dropout=0.05):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
cfg = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
# NOTE: these match t5's projection layers
target_modules=["q", "k", "v", "o"],
bias="none",
)

return get_peft_model(model, cfg) # we convert the model to cuda outside this function, to prevent device mismatch issues
65 changes: 65 additions & 0 deletions recognition/Flan_T5_s45893623/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# predict.py
import argparse
import torch
import sys
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

DEFAULT_PROMPT = (
"You are a helpful medical assistant. Rewrite the radiology report for a layperson "
"in 1–3 sentences, avoid jargon, use plain language.\n\n"
"Report:\n{rad_report}\n\nLayperson summary:"
)

# Loads the model checkpoint and does prediction
@torch.no_grad()
def predict(report_text, ckpt_dir="runs/flan_t5_lora", prompt=None, beams=4, max_new=128):
dev = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir).to(dev).eval()

p = (prompt or DEFAULT_PROMPT).format(rad_report=report_text)
enc = tok([p], return_tensors="pt", truncation=True, max_length=1024).to(dev)
out = model.generate(
**enc,
max_new_tokens=max_new,
num_beams=beams,
early_stopping=True
)
return tok.batch_decode(out, skip_special_tokens=True)[0]

# Loads up an interactive check. Uses same default run dir as train.py. If idx is set, it computes the summary for that val_report[idx] and exits.
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", type=str, default="runs/flan_t5_lora",
help="Checkpoint directory")
p.add_argument("--idx", type=int, default=None,
help="Index of val-set report to evaluate")
args = p.parse_args()

# If idx is provided: run prediction on that test report
if args.idx is not None:
ds = load_dataset(
"BioLaySumm/BioLaySumm2025-LaymanRRG-opensource-track"
)["validation"]

report = ds[args.idx]["radiology_report"]
gold = ds[args.idx]["layman_report"]
print(f"\n--- Test Sample {args.idx} ---")
print("Radiology Report:\n", report, "\n")
print("Gold Summary:\n", gold, "\n")
pred = predict(report, ckpt_dir=args.ckpt)
print("Model Prediction:\n", pred, "\n")
return

# Otherwise: interactive chat
print(f"Chat with your FLAN-T5 model ({args.ckpt}). Please enter only the report you want summarised. Type 'exit' to quit.\n")
while True:
msg = input("You: ").strip()
if msg.lower() in {"exit", "quit"}:
break
reply = predict(msg, ckpt_dir=args.ckpt)
print("Model:", reply, "\n")

if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions recognition/Flan_T5_s45893623/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
torch==2.9.0
transformers==4.57.1
datasets==4.4.1
evaluate==0.4.6
peft==0.17.1
sentencepiece==0.2.1
numpy==2.3.3
matplotlib==3.10.7
tqdm==4.67.1
absl-py==2.3.1
nltk==3.9.2
rouge_score==0.1.2
Loading