diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..bd6cb6a05 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.DS_Store + +__pycache__/ +*.pyc + +*.ipynb_checkpoints/ + +*.log diff --git a/recognition/UNet_task3_48339261/README.md b/recognition/UNet_task3_48339261/README.md new file mode 100644 index 000000000..fdbcc776a --- /dev/null +++ b/recognition/UNet_task3_48339261/README.md @@ -0,0 +1,134 @@ +# COMP3710 Report + +# Task 3 - 2D Prostate Segmentation with U-Net + +## Guanhua Ma 48339261 + + + +## 1. Project Description & Problem + +This project aims to solve the Task 3 (Normal Difficulty) pattern recognition problem. + +This project uses the processed 2D slices from the HipMRI Study on Prostate Cancer dataset to perform automatic semantic segmentation of the prostate region. The goal is to train a 2D Improved UNet to achieve a minimum Dice Similarity Coefficient (DSC) of 0.75 on the prostate label in the test set. + +## 2. Project File + +These are Project Files: + +- `modules.py`: Contains the definitions for the `SimpleUNet` model architecture and the `DiceLoss` function. +- `dataset.py`: Contains the `HipMRIDataset` class, responsible for loading and preprocessing the Nifti data. +- `train.py`: Contains the main training loop `train()`, the train/validation split logic. +- `predict.py`: Contains the `show_predictions()` function to load the trained model and visualize its performance on test samples. +- `utils.py`: Contains helper functions such as `calculate_dice_score()`, `show_epoch_predictions()`, and `plot_loss()`. +- `README.md`: The report document for this project. + +## 3. How it Works & Algorithm + +#### Algorithm Model + +This section is based on the SimpleUNet model defined in modules.py. This is a classic 2D U-Net architecture , which is an encoder-decoder network. Its key feature is the use of skip connections. This concatenates feature maps from the encoder (down-sampling path) with the decoder (up-sampling path). UNet allows the model to use both semantic features and spatial features, making it ideal for medical image segmentation. + +The model's main components include: + +- Encoder: A series of conv_block (Conv -> BatchNorm -> LeakyReLU -> Dropout) and MaxPool2d layers to extract features and reduce spatial dimensions. +- Decoder: Uses Upsample (Bilinear Interpolation) and _conv_block layers to reconstruct the segmentation mask. +- Output Layer: A final Conv2d layer followed by a Sigmoid activation function to output a probability map in the range [0, 1]. + +The Loss Function is Dice Loss (1 - Dice Coefficient), which directly optimizes DSC. It is well-suited for imbalanced classes. Therefore, it is very suitable for this task since the prostate region is much smaller than the background. + +The optimizer is Adam. Adam is an efficient and commonly used gradient descent optimizer. + +#### Data Preprocessing + +1. Data Loading: Uses the nibabel library to load .nii.gz Nifti format images and masks. +2. Label Binarization: The HipMRI masks are multi-class. To solve Task 3, pixels with the `prostate_label_value` (set to 5 in the code) are mapped to 1, and all other pixels are mapped to 0. This creates a binary prostate vs. non-prostate mask. +3. Resizing: All images and masks are resized to a fixed (128, 128) size. + - Images are resized using Bilinear interpolation. + - Masks are resized using Nearest Neighbor interpolation to ensure the label values (0 and 1) are not corrupted. +4. Normalization: Z-score normalization ((image - mean) / std) is applied to the images to set their mean to 0 and standard deviation to 1. + +This dataset keras_slices_data folder has already split in three subset: keras_slices_train, keras_slices_validate and keras_slices_test. +For this structure, the parameter `subset` is used to construct training set and validation set. + +`subset="train"`: loading the keras_slices_train folder. + +`subset="validate"`: loading the keras_slices_validate folder + +## 4. Reproducibility + +This project was run on the Google Colab with T4 GPU orA100 GPU + +The main dependencies are: + +``` +torch (PyTorch) +numpy +matplotlib +nibabel (!pip install nibabel) +tqdm +``` + +Due to the time queue in Rangpur is too long, so this project choose to run in the Goole Colab + +The files `modules.py`, `dataset.py`, `train.py`, `predict.py`, `utils.py` and `keras_slices_data` and `Run_on_Colab.ipynb` need to be in the same folder in Google Drive. + +Open the `Run_on_Colab.ipynb` and run for the whole task. + +Make sure that the path is /content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261/modules.py, for example for `modules.py`. + +This is the processing code in `Run_on_Colab.ipynb`. + +``` +# Run Task 3 in Google Colab + +from google.colab import drive +drive.mount('/content/drive', force_remount=True) + +!pip install nibabel -q + +import os +base_dir = "/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261" +os.chdir(base_dir) + +# load dataset and train model. save hipmri_unet_model.pth and .png +print("Starting process train.py") +!python train.py + +# save final_predictions.png +print("Starting process predict.py") +!python predict.py + +print("Successful!") +print("Check the documents:") +print("hipmri_unet_model.pth") +print("training_loss_curve.png") +print("epoch_X_predictions.png") +print("final_predictions.png") +``` + +### 5. Results & Analysis + +#### Training Loss Curve + +The training loss (Dice Loss) over time is shown below in Figure 1. The loss steadily decreases from an initial average of 0.6089 and successfully converges to a final average loss of 0.1496. This indicates that the model learned effectively from the training data. + +![training_loss_curve](https://raw.githubusercontent.com/GuanhuaMa/PatternAnalysis-2025/topic-recognition/recognition/UNet_task3_48339261/training_loss_curve.png) + +[Figure 1: Training loss curve] + +#### Prediction Visualization + +The figure below shows the final segmentation performance of the model on 3 random samples from the test set after training for 20 epochs. + +![final_predictions](https://raw.githubusercontent.com/GuanhuaMa/PatternAnalysis-2025/topic-recognition/recognition/UNet_task3_48339261/final_predictions.png) + +[Figure 2: Final predictions] + +After training for 20 epochs, the model achieved a final average Dice Loss of 0.1496 on the training set, which corresponds to an average DSC of **0.8504**. + +This result exceeds the target of 0.75. + +The model's performance on the random test samples in Figure 2: Final predictions is excellent。It correctly identified two **True Negatives** (Original 110 and 39). These two samples with no prostate was present in the ground truth, and predicted an empty mask. Sample 110 and 39 result in perfect Dice scores of 1.000. The model only failed on Sample 315, which was a very small and challenging target, resulting in a Dice score of 0.000. + +Overall, the high average DSC (0.8504) and the strong performance on True Negatives confirm that the model successfully learned to segment the prostate gland. diff --git a/recognition/UNet_task3_48339261/Run_on_Colab.ipynb b/recognition/UNet_task3_48339261/Run_on_Colab.ipynb new file mode 100644 index 000000000..b0d157f5c --- /dev/null +++ b/recognition/UNet_task3_48339261/Run_on_Colab.ipynb @@ -0,0 +1,49 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "86c0c3bc", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# Run Task 3 in Google Colab\n", + "\n", + "from google.colab import drive\n", + "drive.mount('/content/drive', force_remount=True) \n", + "\n", + "!pip install nibabel -q \n", + "\n", + "import os\n", + "base_dir = \"/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261\"\n", + "os.chdir(base_dir)\n", + "\n", + "# load dataset and train model. save hipmri_unet_model.pth and .png \n", + "print(\"Starting process train.py\")\n", + "!python train.py\n", + "\n", + "# save final_predictions.png\n", + "print(\"Starting process predict.py\")\n", + "!python predict.py\n", + "\n", + "print(\"Successful!\")\n", + "print(\"Check the documents:\")\n", + "print(\"hipmri_unet_model.pth\")\n", + "print(\"training_loss_curve.png\")\n", + "print(\"epoch_X_predictions.png\")\n", + "print(\"final_predictions.png\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/recognition/UNet_task3_48339261/dataset.py b/recognition/UNet_task3_48339261/dataset.py new file mode 100644 index 000000000..d7ccfeeb0 --- /dev/null +++ b/recognition/UNet_task3_48339261/dataset.py @@ -0,0 +1,92 @@ +""" +dataset.py +Gemini-Assist: +1. Suggested using .copy() to resolve negative stride error. This ensures +the numpy array is contiguous before tensor conversion. +2. Mask resize use NEAREST interpolation to prevent corrupting the (0, 1) label values. +Using BILINEAR would create invalid fractional values. +""" +from torch.utils.data import Dataset +import nibabel as nib +import numpy as np +import torch +import os +import glob +import torchvision.transforms.functional as TF + +class HipMRIDataset(Dataset): + """ + keras_slices_train/ : traning figure + keras_slices_seg_train/ : traning mask + keras_slices_validate/ : validate figure + keras_slices_seg_validate/ : validate mask + keras_slices_test/ : test figure + keras_slices_seg_test/ : test mask + """ + def __init__(self, data_dir, subset="train", prostate_label_value=5, resize_to=None): + self.subset = subset # 'train', 'validate' or 'test' + self.data_dir = data_dir # the root path of keras_slices_data + self.prostate_label_value = prostate_label_value # the integer value of prostate + self.resize_to = resize_to # tuple of (H, W) + + self.img_dir = os.path.join(data_dir, f"keras_slices_{subset}") + self.seg_dir = os.path.join(data_dir, f"keras_slices_seg_{subset}") + + self.image_files = sorted(glob.glob(os.path.join(self.img_dir, "*.nii.gz"))) + self.mask_files = sorted(glob.glob(os.path.join(self.seg_dir, "*.nii.gz"))) + + # validate the files exist + if len(self.image_files) == 0: + raise FileNotFoundError(f"no file in {self.img_dir}") + if len(self.mask_files) == 0: + raise FileNotFoundError(f"no mask file in {self.seg_dir}") + + # check the number of figure and mask are mactch + if subset != "validate" and len(self.image_files) != len(self.mask_files): + print(f"The number of Figure ({len(self.image_files)}) and Mask ({len(self.mask_files)}) are not match") + + print(f"Success load {subset} set: find {len(self.image_files)} figure file。") + + def __len__(self): + return len(self.image_files) + + + def __getitem__(self, idx): + img_path = self.image_files[idx] + + # path of mask + img_filename = os.path.basename(img_path) + mask_filename = img_filename.replace("case_", "seg_") + mask_path = os.path.join(self.seg_dir, mask_filename) + + # back to index mathch + if not os.path.exists(mask_path): + if idx < len(self.mask_files): + mask_path = self.mask_files[idx] + else: + raise FileNotFoundError(f"can't find mask for {img_path} figure") + + # load Nifti file + image = nib.load(img_path).get_fdata().astype(np.float32) + mask = nib.load(mask_path).get_fdata().astype(np.uint8) + + # change to Tensors + image_tensor = torch.from_numpy(image.copy()).unsqueeze(0) # (1,H,W) + mask_tensor = torch.from_numpy(mask.copy()).long() # (H,W) + + # Resize + if self.resize_to: + image_tensor = TF.resize(image_tensor, self.resize_to, interpolation=TF.InterpolationMode.BILINEAR) + # unsqueeze, then resize, then squeeze + mask_tensor = TF.resize(mask_tensor.unsqueeze(0), self.resize_to, interpolation=TF.InterpolationMode.NEAREST).squeeze(0) + + # Z-score normalization + mean, std = image_tensor.mean(), image_tensor.std() + # 1e-6 to prevent dividing by 0 + image_tensor = (image_tensor - mean) / (std + 1e-6) + + # binaray mask + binary_mask = (mask_tensor == self.prostate_label_value).long() + + return image_tensor, binary_mask + diff --git a/recognition/UNet_task3_48339261/final_predictions.png b/recognition/UNet_task3_48339261/final_predictions.png new file mode 100644 index 000000000..3d9ec17c2 Binary files /dev/null and b/recognition/UNet_task3_48339261/final_predictions.png differ diff --git a/recognition/UNet_task3_48339261/modules.py b/recognition/UNet_task3_48339261/modules.py new file mode 100644 index 000000000..7723f52bf --- /dev/null +++ b/recognition/UNet_task3_48339261/modules.py @@ -0,0 +1,68 @@ +""" +modules.py + +cite: UNet_segmentation_code_demo.ipynb +""" +import torch +import torch.nn as nn + +class SimpleUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=1, dropout_p=0.2): + super().__init__() + + # Encoder (downsampling) + self.enc1 = self._conv_block(in_channels, 32, dropout_p) + self.enc2 = self._conv_block(32, 64, dropout_p) + self.enc3 = self._conv_block(64, 128, dropout_p) + + # Decoder (upsampling) + self.dec3 = self._conv_block(128 + 64, 64, dropout_p) + self.dec2 = self._conv_block(64 + 32, 32, dropout_p) + self.dec1 = nn.Conv2d(32, out_channels, 1) + + self.pool = nn.MaxPool2d(2) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.sigmoid = nn.Sigmoid() + + def _conv_block(self, in_ch, out_ch, dropout_p=0.2): + return nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_p), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(dropout_p), + ) + + def forward(self, x): + # Encoder + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + + # Decoder with skip connections + d3 = self.dec3(torch.cat([self.upsample(e3), e2], 1)) + d2 = self.dec2(torch.cat([self.upsample(d3), e1], 1)) + out = self.dec1(d2) + + out = self.sigmoid(out) + + return out + + +class DiceLoss(nn.Module): + def __init__(self, smooth=1e-6): + super(DiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, predictions, targets): + + predictions = predictions.reshape(-1) + targets = targets.reshape(-1).float() + + intersection = (predictions * targets).sum() + dice_coeff = (2.0 * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth) + + return 1 - dice_coeff \ No newline at end of file diff --git a/recognition/UNet_task3_48339261/predict.py b/recognition/UNet_task3_48339261/predict.py new file mode 100644 index 000000000..99c105512 --- /dev/null +++ b/recognition/UNet_task3_48339261/predict.py @@ -0,0 +1,89 @@ +import torch +import os +import matplotlib.pyplot as plt +import numpy as np +from tqdm import tqdm + +from utils import calculate_dice_score +from modules import SimpleUNet +from dataset import HipMRIDataset + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Predict.py: Using device: {device}") + +def show_predictions(model, dataset, title="Final segmentation results (HipMRI)", n=3): + model.eval() + + fig, axes = plt.subplots(3, n, figsize=(12, 9)) + fig.suptitle(title, fontsize=16, fontweight='bold') + + with torch.no_grad(): + if len(dataset) < n: + n = len(dataset) + print(f"The siza of Dataset ({len(dataset)}) smaller than n ({n}), show {n} samples.") + + indices = np.random.choice(len(dataset), n, replace=False) + + for i, idx in enumerate(indices): + # (1,H,W) & (H,W) + image, true_mask = dataset[idx] + # (1,H,W) -> (1,1,H,W) + pred = model(image.unsqueeze(0).to(device)) + # (1,1,H,W) -> (H,W) + pred_prob = pred[0, 0].cpu().numpy() + pred_binary = (pred_prob > 0.5).astype(int) + + # Orifinal (gray) + img_display = image.squeeze().cpu().numpy() + axes[0, i].imshow(img_display, cmap='gray') + axes[0, i].set_title(f'Original {idx})', fontweight='bold') + axes[0, i].axis('off') + + # Ground Truth + axes[1, i].imshow(true_mask.cpu().numpy(), cmap='gray') + axes[1, i].set_title(f'Ground Truth (Prostate)', fontweight='bold') + axes[1, i].axis('off') + + # Prediction + dice = calculate_dice_score(pred_binary, true_mask) + axes[2, i].imshow(pred_binary, cmap='gray') + axes[2, i].set_title(f'Pridiction {dice:.3f})', fontweight='bold') + axes[2, i].axis('off') + + plt.tight_layout() + save_path = "final_predictions.png" + plt.savefig(save_path) + print(f"Prediction Result saved in: {save_path}") + plt.close(fig) + +if __name__ == '__main__': + print("Runing predict.py...") + + DATA_DIR = "/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261/keras_slices_data" + MODEL_SAVE_PATH = "/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261/hipmri_unet_model.pth" + RESIZE_TO = (128, 128) + PROSTATE_LABEL = 5 + NUM_EXAMPLES_TO_SHOW = 3 + + # Chenk the file exists + if not os.path.exists(MODEL_SAVE_PATH): + print(f"Error: can't find file '{MODEL_SAVE_PATH}'。") + else: + print(f"Loading the model: {MODEL_SAVE_PATH}") + model = SimpleUNet(in_channels=1, out_channels=1).to(device) + + model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device)) + + print(f"Loading the test set: {DATA_DIR} (subset=test)") + test_dataset = HipMRIDataset( + data_dir=DATA_DIR, + subset="test", + resize_to=RESIZE_TO, + prostate_label_value=PROSTATE_LABEL + ) + + if len(test_dataset) > 0: + print(f"Generating {NUM_EXAMPLES_TO_SHOW} prediction examples...") + show_predictions(model, test_dataset, n=NUM_EXAMPLES_TO_SHOW) + else: + print("Error: Dataset is empty.") \ No newline at end of file diff --git a/recognition/UNet_task3_48339261/train.py b/recognition/UNet_task3_48339261/train.py new file mode 100644 index 000000000..da5acd6a9 --- /dev/null +++ b/recognition/UNet_task3_48339261/train.py @@ -0,0 +1,122 @@ +""" +train.py +""" +import torch +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from modules import DiceLoss, SimpleUNet +from dataset import HipMRIDataset +from utils import show_epoch_predictions, plot_loss, calculate_dice_score +from tqdm import tqdm +import numpy as np +import matplotlib.pyplot as plt + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def train(model, train_loader, val_dataset, epochs=3, lr=0.001, visualize_every=1): + """ + param val_dataset: validation set + """ + model.to(device) + criterion = DiceLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + losses = [] + print(f"Train Starting for {epochs} epochs") + + for epoch in range(epochs): + model.train() + epoch_loss = 0 + + # show process bar + progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True) + + for batch_idx, (images, masks) in enumerate(progress_bar): + + images = images.to(device) + masks = masks.to(device) + + # Forward pass + outputs = model(images) + predictions_squeezed = outputs[:, 0] + loss = criterion(predictions_squeezed, masks) + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + # show batch loss in prcocess bar + progress_bar.set_postfix(batch_loss=f"{loss.item():.4f}") + + avg_loss = epoch_loss / len(train_loader) + losses.append(avg_loss) + print(f"Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f}") + + # visualization + if (epoch + 1) % visualize_every == 0 or (epoch + 1) == epochs: + show_epoch_predictions(model, val_dataset, epoch + 1, n=3) + + print("Training complete with enhanced U-Net") + plot_loss(losses) + return losses + +if __name__ == "__main__": + + print(f"Starting training (train.py) ---") + + DATA_DIR = "/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261/keras_slices_data" + MODEL_SAVE_PATH = "/content/drive/MyDrive/Colab-Notebooks/UNet_task3_48339261/hipmri_unet_model.pth" + + EPOCHS = 20 + LEARNING_RATE = 0.001 + BATCH_SIZE = 16 + RESIZE_TO = (128, 128) + PROSTATE_LABEL = 5 + + # load and split dataset + print(f"Loading dataset from {DATA_DIR} ...") + + # load the split train dataset + train_dataset = HipMRIDataset( + data_dir=DATA_DIR, + subset="train", + resize_to=RESIZE_TO, + prostate_label_value=PROSTATE_LABEL + ) + + # load the split validate dataset + val_dataset = HipMRIDataset( + data_dir=DATA_DIR, + subset="validate", + resize_to=RESIZE_TO, + prostate_label_value=PROSTATE_LABEL + ) + + print(f"Dataset loaded {len(train_dataset)} train dataset, {len(val_dataset)} validate sataset") + + # DataLoader + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=2 + ) + + model = SimpleUNet(in_channels=1, out_channels=1).to(device) + + # start training + training_losses = train( + model=model, + train_loader=train_loader, + val_dataset=val_dataset, + epochs=EPOCHS, + lr=LEARNING_RATE, + visualize_every=5 + ) + + print("Training Complete") + torch.save(model.state_dict(), MODEL_SAVE_PATH) + print(f"Model saved to: {MODEL_SAVE_PATH}") \ No newline at end of file diff --git a/recognition/UNet_task3_48339261/training_loss_curve.png b/recognition/UNet_task3_48339261/training_loss_curve.png new file mode 100644 index 000000000..3142238ab Binary files /dev/null and b/recognition/UNet_task3_48339261/training_loss_curve.png differ diff --git a/recognition/UNet_task3_48339261/utils.py b/recognition/UNet_task3_48339261/utils.py new file mode 100644 index 000000000..a9f14c34e --- /dev/null +++ b/recognition/UNet_task3_48339261/utils.py @@ -0,0 +1,81 @@ +""" +utils.py + +Gemini assist: +plt.show() hangs in Colab. Need to use plt.savefig() and plt.close() +""" +import torch +import numpy as np +import matplotlib.pyplot as plt + +def calculate_dice_score(pred_binary, true_mask): + """calculate Dice similarity coefficient""" + + if isinstance(pred_binary, torch.Tensor): + pred_binary = pred_binary.cpu().numpy() + if isinstance(true_mask, torch.Tensor): + true_mask = true_mask.cpu().numpy() + + pred_binary = pred_binary.flatten() + true_mask = true_mask.flatten() + + # + 1e-6 prevent the denominator is 0 + intersection = (pred_binary * true_mask).sum() + dice_score = (2. * intersection + 1e-6) / (pred_binary.sum() + true_mask.sum() + 1e-6) + + return dice_score + + +def show_epoch_predictions(model, dataset, epoch, n=3): + """MRI""" + model.eval() + + fig, axes = plt.subplots(3, n, figsize=(12, 9)) + fig.suptitle(f'Train the prediction results after the {epoch} round', fontsize=16) + + with torch.no_grad(): + indices = np.random.choice(len(dataset), n, replace=False) + + for i, idx in enumerate(indices): + image, true_mask = dataset[idx] + pred = model(image.unsqueeze(0).to(device)) + + pred_prob = pred[0, 0].cpu().numpy() + pred_binary = (pred_prob > 0.5).astype(int) + + # Show original image + img_display = image.squeeze().cpu().numpy() + axes[0, i].imshow(img_display, cmap='gray') + axes[0, i].set_title(f'Original {idx})') + axes[0, i].axis('off') + + # Show ground truth binary mask + axes[1, i].imshow(true_mask.cpu().numpy(), cmap='gray') + axes[1, i].set_title(f'Ground Truth (Prostate)') + axes[1, i].axis('off') + + # Show prediction + dice = calculate_dice_score(pred_binary, true_mask) + axes[2, i].imshow(pred_binary, cmap='gray') + axes[2, i].set_title(f'Prediction') + axes[2, i].axis('off') + + plt.tight_layout() + plt.savefig(f"epoch_{epoch}_predictions.png") + plt.close(fig) + model.train() + + +def plot_loss(losses): + """Plot train loss curve""" + plt.figure(figsize=(10, 5)) + plt.plot(losses, label='Training Loss') + plt.title('Training Loss vs. Epochs') + plt.xlabel('Epoch') + plt.ylabel('Loss (Dice Loss)') + plt.legend() + plt.grid(True) + plt.savefig("training_loss_curve.png") + plt.close() + + \ No newline at end of file