Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeplearning4j-examples
Submodule deeplearning4j-examples added at 0511c1
Original file line number Diff line number Diff line change
@@ -1,10 +1,132 @@
##### TransferLearning
Demonstrates use of the dl4j transfer learning API which allows users to construct a model based off an existing model by modifying the architecture, freezing certain parts selectively and then fine tuning parameters. Read the documentation for the Transfer Learning API at [https://deeplearning4j.konduit.ai/tuning-and-training/transfer-learning](https://deeplearning4j.konduit.ai/tuning-and-training/transfer-learning).
* [EditLastLayerOthersFrozen.java](./editlastlayer/EditLastLayerOthersFrozen.java)
Modifies just the last layer in vgg16, freezes the rest and trains the network on the flower dataset.
* [FeaturizedPreSave.java](./editlastlayer/presave/FeaturizedPreSave.java) & [FitFromFeaturized.java](./editlastlayer/presave/FitFromFeaturized.java)
Save time on the forward pass during multiple epochs by "featurizing" the datasets. FeaturizedPreSave saves the output at the last frozen layer and FitFromFeaturize fits to the presaved data so you can iterate quicker with different learning parameters.
* [EditAtBottleneckOthersFrozen.java](./editfrombottleneck/EditAtBottleneckOthersFrozen.java)
A more complex example of modifying model architecure by adding/removing vertices
* [FineTuneFromBlockFour.java](./finetuneonly/FineTuneFromBlockFour.java)
Reads in a saved model (training information and all) and fine tunes it by overriding its training information with what is specified
# Transfer Learning Examples – DeepLearning4J

This folder demonstrates multiple ways to use the DL4J Transfer Learning API.
Transfer learning allows you to take an existing pretrained model (such as VGG16) and adapt it to a new dataset by:

- Freezing certain layers
- Modifying the architecture
- Replacing the final classifier
- Fine-tuning deeper layers
- Speeding up training through featurized datasets

Official documentation:
🔗 https://deeplearning4j.konduit.ai/tuning-and-training/transfer-learning

---

# 📁 Folder Overview

The transfer learning examples are grouped into three major strategies, depending on how much of the pretrained model you want to modify or retrain.

---

## 1️⃣ **Edit Last Layer Only** (`editlastlayer/`)

### ➤ `EditLastLayerOthersFrozen.java`
Replaces only the final output layer of VGG16 and freezes all previous layers.

**Use Case:**
- New dataset is small
- New dataset is similar (e.g., flower classification)
- Extremely fast training
- Very little overfitting risk

---

### ⚡ Presaving Featurized Data (`editlastlayer/presave/`)
This subfolder contains examples for speeding up transfer learning.

#### ➤ `FeaturizedPreSave.java`
Runs the dataset once through the frozen layers and **saves the output features** to disk.

#### ➤ `FitFromFeaturized.java`
Uses the presaved features to train the classifier **without repeating the forward pass**.

**Why this helps:**
Huge speed-up when training for many epochs or trying multiple learning rates.

---

## 2️⃣ **Edit from Bottleneck** (`editfrombottleneck/`)

### ➤ `EditAtBottleneckOthersFrozen.java`
Modifies the model at the **bottleneck layer** (middle of the network).
Allows adding/removing vertices, changing layer shapes, or inserting new layers.

**Use Case:**
- New dataset is somewhat different from ImageNet
- Need more flexibility in how features are combined
- Want deeper customization but still freeze earlier layers

This is a more advanced example showing how to modify the computation graph.

---

## 3️⃣ **Fine-Tune Some Layers Only** (`finetuneonly/`)

### ➤ `FineTuneFromBlockFour.java`
Loads a previously saved model and fine-tunes only the last few blocks while keeping earlier layers frozen.

**Use Case:**
- Dataset is moderately different
- Need the network to adapt more deeply
- Fine-tuning block 4+ is common with VGG/ResNet architectures
- Allows training more parameters without overfitting as much as full fine-tuning

---

## 4️⃣ **Dataset Iterators** (`iterators/`)

These classes prepare the flower dataset for transfer learning.

### ➤ `FlowerDataSetIterator.java`
Loads images from disk and prepares them for VGG16-style input.

### ➤ `FlowerDataSetIteratorFeaturized.java`
Loads presaved (featurized) data for fast training.

---

# 🚀 How to Run Any Transfer Learning Example

Use:

mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.advanced.features.transferlearning.<folder>.<ClassName>"


Examples:



mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.advanced.features.transferlearning.editlastlayer.EditLastLayerOthersFrozen"

mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.advanced.features.transferlearning.editfrombottleneck.EditAtBottleneckOthersFrozen"

mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.advanced.features.transferlearning.finetuneonly.FineTuneFromBlockFour"


---

# 🎯 Summary of Strategies

| Strategy | Layers Trained | Best For |
|---------|----------------|----------|
| **Edit Last Layer Only** | Only final Dense layer | Small, similar datasets; fastest training |
| **Edit from Bottleneck** | Middle layers | More control; moderate dataset differences |
| **Fine-Tune Block Four** | Last few blocks | Moderate differences; more expressive training |
| **Full Featurization** | Only classifier | Fast experiments; repeated runs |

---

# 🙌 Why This README Matters

The original README was brief and lacked conceptual explanations.
This improved version:

- Documents each example clearly
- Explains when to use each transfer learning strategy
- Provides run commands
- Helps new users understand DL4J transfer learning best practices
- Improves readability and usefulness of advanced examples

This makes transfer learning in DL4J easier to understand and use.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

// Transfer learning example: modify architecture at the bottleneck layer.
// Earlier layers remain frozen; mid-level layers can be added/edited.

package org.deeplearning4j.examples.advanced.features.transferlearning.editfrombottleneck;

import org.deeplearning4j.examples.advanced.features.transferlearning.iterators.FlowerDataSetIterator;
Expand Down Expand Up @@ -140,6 +143,8 @@ public static void main(String [] args) throws Exception {
// 2. in place with the TransferLearningHelper constructor which will take a model, and a specific vertexname
// and freeze it and the vertices on the path from an input to it (as seen in the FeaturizePreSave class)
//The saved model can be "fine-tuned" further as in the class "FitFromFeaturized"
// Insert/remove vertices to customize part of the model

File locationToSave = new File("MyComputationGraph.zip");
boolean saveUpdater = false;
vgg16Transfer.save(locationToSave, saveUpdater);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

// Transfer learning example: replace only the final classification layer of VGG16.
// All earlier layers are frozen for extremely fast training on a new dataset.

package org.deeplearning4j.examples.advanced.features.transferlearning.editlastlayer;

import org.deeplearning4j.examples.advanced.features.transferlearning.iterators.FlowerDataSetIterator;
Expand Down Expand Up @@ -65,6 +68,8 @@ public static void main(String [] args) throws IOException {
//Import vgg
//Note that the model imported does not have an output layer (check printed summary)
// nor any training related configs (model from keras was imported with only weights and json)
// Remove the original output layer and attach a new one for 5 flower classes

log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n");
ZooModel zooModel = VGG16.builder().build();
ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained();
Expand All @@ -91,6 +96,8 @@ public static void main(String [] args) throws IOException {
"fc2")
.build();
log.info(vgg16Transfer.summary());

// Train only the new output layer; frozen layers are not updated

//Dataset iterators
FlowerDataSetIterator.setup(batchSize,trainPerc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

// Fine-tune last few blocks of a pretrained model while freezing early layers.
// Provides more adaptation than last-layer editing but avoids full retraining.

package org.deeplearning4j.examples.advanced.features.transferlearning.finetuneonly;

import org.deeplearning4j.examples.advanced.features.transferlearning.iterators.FlowerDataSetIterator;
Expand Down
73 changes: 73 additions & 0 deletions onnx-import-examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
ONNX Import Examples (DeepLearning4J)

This module contains example programs demonstrating how to import ONNX models into DeepLearning4J (DL4J) and run inference using ND4J.

These examples help users understand the workflow required for loading pretrained ONNX models, preprocessing inputs, performing forward passes, and reading predictions.

🚀 What You Will Learn

How to load ONNX models using DL4J’s OnnxGraphImporter

How to inspect ONNX graph metadata

How to prepare NDArray input tensors for inference

Running inference on imported ONNX models

Reading and interpreting output layers

📦 How to Run

Use Maven to compile and run any example:

mvn clean compile exec:java -Dexec.mainClass="org.deeplearning4j.examples.onnx.<ExampleClassName>"


Example:

mvn exec:java -Dexec.mainClass="org.deeplearning4j.examples.onnx.ImportBasicOnnxModel"


Replace <ExampleClassName> with any class inside the onnx-import-examples folder.

📁 Model Requirements

To run inference, you need an .onnx model file.

If the example references a model that is not included in the repository, download it from:

https://github.com/onnx/models

Or any ONNX-compatible export from PyTorch/Keras/TF

Place it in the example’s resources/ directory or update the file path in the code.

🧩 Folder Structure
onnx-import-examples/
├── src/main/java/org/deeplearning4j/examples/onnx/
│ ├── ImportBasicOnnxModel.java
│ ├── InspectOnnxGraph.java
│ └── ...
├── src/main/resources/
├── pom.xml
└── README.md ← (This file)

🧪 Expected Output

Typical output may include:

Loading ONNX model...
Model imported successfully.
Running inference...
Output shape: [1, 1000]
Predicted class: 281 (tabby cat)

❤️ Contribution

Feel free to add additional examples demonstrating:

ONNX opset compatibility

Image preprocessing pipelines

Importing models trained in TensorFlow / PyTorch
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* KerasMNISTImportExample
*
* Demonstrates how to import a trained Keras (.h5) model into DL4J
* and run inference on the MNIST handwritten digits dataset.
*
* Requirements:
* - The Keras model must be trained using TensorFlow backend
* - The model must be saved using model.save("model.h5") with no custom layers
*
* This example loads:
* 1. MNIST test data using DL4J's built-in dataset iterators
* 2. A pre-trained Keras model (.h5)
* 3. Converts the Keras model into a DL4J ComputationGraph
* 4. Evaluates accuracy on MNIST test set
*/

public class KerasMNISTImportExample {

public static void main(String[] args) throws Exception {

// ------------------------------------------------------------
// 1. Load MNIST Test Data
// ------------------------------------------------------------
int batchSize = 128;

/*
* MnistDataSetIterator automatically downloads MNIST if needed
* and loads normalized test images (28x28 grayscale).
*
* The iterator returns DataSet objects with:
* - features: [batchSize, 1, 28, 28]
* - labels: one-hot encoded (10 classes)
*/
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

// ------------------------------------------------------------
// 2. Load Pretrained Keras Model (.h5 file)
// ------------------------------------------------------------
String modelPath = "keras_mnist_cnn.h5"; // example file name

/*
* KerasModelImport.importKerasModelAndWeights():
* - Reads .h5 model architecture + weights
* - Converts it to a DL4J ComputationGraph
* - Handles TensorFlow-compatible layers automatically
*/
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath);

System.out.println("Imported Keras model summary:");
System.out.println(model.summary());

// ------------------------------------------------------------
// 3. Run Evaluation
// ------------------------------------------------------------

/*
* The Evaluation class computes:
* - Accuracy
* - Precision / Recall / F1
* - Confusion matrix
*/
Evaluation eval = new Evaluation(10);

while (mnistTest.hasNext()) {
DataSet ds = mnistTest.next();
INDArray output = model.outputSingle(ds.getFeatures());
eval.eval(ds.getLabels(), output);
}

System.out.println("\nModel Evaluation Results:");
System.out.println(eval.stats());
}
}