Skip to content

arnonbruno/vl-jepa

Repository files navigation

VL-JEPA: Vision-Language Joint Embedding Predictive Architecture

PyTorch implementation of VL-JEPA — a JEPA-style vision-language model with masked patch prediction, EMA target encoders, and contrastive alignment. Optimized for commodity hardware (RTX 3090, 10GB VRAM).

Paper authors: Delong Chen, Mustafa Shukor, Theo Moutakanni, Willy Chung, Jade Yu, Tejaswi Kasarla, Yejin Bang, Allen Bolourchi, Yann LeCun, Pascale Fung
Implementation: Bruno Santos + TARS


Overview

VL-JEPA learns joint vision-language representations by:

  1. Predicting masked patch embeddings (I-JEPA style MSE against an EMA teacher).
  2. Aligning image and text in a shared projection space (SigLIP sigmoid loss).
  3. Regularizing embedding variance (VICReg-style) to prevent collapse.

This implementation uses pretrained frozen encoders as a starting point, with phased training that gradually introduces JEPA reconstruction after contrastive alignment is established. Two encoder paths are supported:

  • Hybrid (timm CLIP ViT-B/16 + DistilBERT) — the original path. The vision tower is multimodally-aligned, but the DistilBERT text CLS was never trained for sentence retrieval, which caps fine-grained ranking.
  • End-to-end OpenCLIP (configs/openclip_vitb16.yaml) — both towers come from CLIP's 400M-pair pretraining, so vision and text are already aligned. The caption pipeline automatically switches to CLIP's BPE tokenizer (vocab 49408, SOT/EOT) and CLIP image normalization — feeding DistilBERT ids into the CLIP text tower would silently destroy the alignment.

Breaking the ~25% R@1 ceiling

Three runs — including the OpenCLIP end-to-end one — all plateaued at the same 25–26% R@1 despite R@5≈54% / R@10≈68% and ~92% in-batch NCE@1, with val loss still falling. The fact that a CLIP-initialized model couldn't beat its own zero-shot baseline (CLIP ViT-B/16 alone scores ~30%+ t2i R@1 on COCO 5K) was the giveaway: the pretrained cross-modal alignment was being discarded, so every run relearned alignment from scratch on 118K COCO images and capped at the same place. Two concrete bugs:

  1. Text was mean-pooled over a causal transformer. CLIP's text representation lives in the EOT token — the only position that attends to the whole sentence under the causal mask. Mean-pooling all token states throws CLIP's pretrained text embedding away.
  2. CLIP's native projection matrices were never used. The vision tower stopped at ln_post (never applying visual.proj, 768→512) and the text tower never applied text_projection (512→512). Instead two randomly-initialized MLP heads relearned the projection from scratch — discarding the exact matrices CLIP trained on 400M pairs.

The fix (configs/openclip_vitb16_aligned.yaml): model.text_pool: eot + model.projection_type: clip. The joint projection is a single linear layer seeded from CLIP's native visual.proj / text_projection, and text is pooled at the EOT token. Together these make the joint embedding exactly equal CLIP's encode_image / encode_text at initialization (verified in tests/test_smoke.py::test_clip_proj_plus_eot_matches_openclip_encode), so training starts at CLIP zero-shot quality and fine-tunes upward instead of relearning alignment from zero. The trainer auto-drops the 10× projection LR to the base LR for CLIP-seeded projections so the first steps don't blow the pretrained matrices away.

Earlier/secondary levers (still available):

  • OpenCLIP end-to-end — replaces the weak DistilBERT CLS with CLIP's aligned text tower. Enabled by the CLIP-tokenizer switch.
  • Hard-negative mining (loss.hard_negative_weight) — a VSE++ max-violation hinge on the hardest in-batch negative. Disabled (0.0) in the aligned config to isolate the alignment fix; re-enable (e.g. 0.05) once the aligned baseline is established.
  • More data (src/image_text_dataset.py, experiments/download_cc3m.py) — pretrain on CC3M/CC12M before COCO (COCO's 118K images are seen ~50× per run).
  • Higher resolution — set model.image_size: 384; the OpenCLIP vision tower interpolates its positional embeddings automatically.

Architecture

┌─────────────────────────────────────────────────────────────────────────────┐
│                         TRAINING FORWARD PASS                                │
└─────────────────────────────────────────────────────────────────────────────┘

  Global + local crops (multi-crop)          Caption tokens (DistilBERT tokenizer)
           │                                              │
           ▼                                              ▼
  ┌────────────────────┐                       ┌────────────────────┐
  │  Context encoder   │  75% block-masked     │  Language encoder  │
  │ (timm CLIP ViT-B)  │  local/global views   │  (DistilBERT)      │
  │  frozen, 86M params│                       │  frozen, 66M params│
  └─────────┬──────────┘                       └─────────┬──────────┘
            │                                            │
            │         ┌────────────────────┐             │
            └────────►│  Predictor         │             │
                      │  4L transformer    │             │
                      └─────────┬──────────┘             │
                                │                        │
            MSE (masked patches)│                        │
                                ▼                        ▼
  ┌────────────────────┐              ┌──────────────────────────────────┐
  │  Target encoder    │◄── EMA τ ────  │  Projection heads (MLP)         │
  │  (frozen teacher)  │   0.996→1.0  │  768 → 512, L2-normalized       │
  │  full-image patches│              │  Global CLS → vision_proj       │
  └────────────────────┘              └───────────────┬──────────────────┘
                                                      │
                         ┌────────────────────────────┼────────────────────────────┐
                         │  SigLIP (FP32)             │                            │
                         │  • Pairwise sigmoid loss   │  Phase A: alignment only   │
                         │  • No softmax dependency   │  Phase B: add JEPA MSE     │
                         │  • Works at batch_size=8   │  Phase C: full training    │
                         └────────────────────────────┴────────────────────────────┘
                                                      │
                         L = α·MSE + β·SigLIP + γ·VarReg

Components

Component Role Details
Context encoder Student ViT; encodes masked/cropped images timm CLIP ViT-B/16 (openai), frozen in phase A
Target encoder EMA copy of context encoder; stop-gradient patch targets Same weights, τ cosine 0.996→1.0
Predictor Lightweight transformer; predicts teacher patch embeddings 4 layers
Language encoder Text encoder DistilBERT (mean pool) or OpenCLIP text tower (EOT pool with text_pool: eot)
Projection heads mlp: LayerNorm→Linear→GELU→Linear→L2 (random, 10× LR). clip: single linear seeded from CLIP visual.proj/text_projection (base LR) 768→512 / 512→512, L2-normalized
SigLIP loss Pairwise sigmoid contrastive loss No softmax, works at small batches
Variance reg VICReg-style std ≥ 1 on pre-norm projections γ = 0.01

Loss

L = α · L_mse + β · L_siglip + γ · L_var + δ · L_hardneg

L_mse:     MSE(predicted_patches, target_patches) on masked positions only
L_siglip:  pairwise sigmoid contrastive (no softmax dependency on batch size)
L_var:     VICReg variance on vision_proj_raw and language_proj_raw
L_hardneg: VSE++ max-violation hinge on the hardest in-batch negative (δ defaults
           to 0; set loss.hard_negative_weight > 0 to sharpen Recall@1)

Phased Training

Phase Epochs α (MSE) β (SigLIP) γ (Var) What happens
A 1-3 0.0 1.0 0.01 Alignment only, frozen encoders
B 4+ 0.1 0.9 0.01 Add gentle JEPA MSE; unfreeze last 4 vision blocks (epoch 5, encoder_unfreeze_lr=2e-5)

The contrastive head also uses SigLIP label smoothing (loss.label_smoothing, default 0.05) and the train image pipeline adds colour jitter / grayscale / random erasing to curb overfitting. The training micro-batch is 128 (× 2 grad-accum = effective 256); a larger real batch is what increases SigLIP in-batch negatives, since accumulation averages independent per-micro-batch sigmoid losses.


Results (CLIP backbone, COCO 2017, epoch 16)

Current metrics with the CLIP-pretrained ViT-B/16 backbone and gradient accumulation:

Metric Value Notes
i2t R@1 ~24% Image→text retrieval (vs MAE baseline: 14.7% after 22 epochs)
t2i R@1 ~25% Text→image retrieval
NCE@1 ~92% In-batch retrieval accuracy
Val Loss 0.77

The CLIP backbone reaches ~24-25% R@1 by epoch 16, well above the MAE baseline's 14.7% R@1 after 22 epochs — the multimodally-aligned starting features converge faster and higher.

Early SigLIP baseline (RTX 3090, COCO 2017, 1 epoch)

Metric Value Notes
Train NCE@1 78.16% In-batch retrieval accuracy
Val NCE@1 86.05% In-batch retrieval accuracy
Val Loss 1.04 Started at 6.85
MSE 0.41 Non-zero (JEPA masking working)
VRAM 1.64 GB Fits easily on RTX 3090
Skipped batches 0 Zero NaN
Training time 53 min 1 epoch, batch_size=8

Comparison with previous approach (random encoders + InfoNCE):

Approach Val NCE@1 VRAM Notes
Random encoders + InfoNCE + 65K queue ~3.2% 7.3 GB Stuck at random baseline for 25+ epochs
Pretrained + SigLIP + mean pooling 86.05% 1.64 GB Works in 1 epoch

Setup

Prerequisites

  • Python 3.11+
  • PyTorch 2.1+ with CUDA
  • GPU with 2GB+ VRAM (tested on RTX 3090)

Installation

git clone https://github.com/arnonbruno/vl-jepa.git
cd vl-jepa
pip install -r requirements.txt

Training

Quick start (pretrained SigLIP baseline)

python experiments/exp_jepa_training.py \
  --config configs/mvp_pretrained_siglip.yaml \
  --epochs 10 \
  --batch-size 8 \
  --fresh

End-to-end OpenCLIP, CLIP-aligned (recommended for breaking the ceiling)

python experiments/exp_jepa_training.py \
  --config configs/openclip_vitb16_aligned.yaml \
  --epochs 50 \
  --fresh

This starts from CLIP's aligned joint space via projection_type: clip (linear projection seeded from visual.proj/text_projection) and text_pool: eot (CLIP end-of-text pooling). Both are required to recover CLIP's pretrained alignment; the CLI flags --projection-type and --text-pool expose them. The older configs/openclip_vitb16.yaml (random MLP projection + mean pooling + hard_negative_weight: 0.2) is kept for reference and still works unchanged.

Pretraining on CC3M before COCO

# 1. Fetch the caption/URL TSV (small) and print the img2dataset command (images)
python experiments/download_cc3m.py download-tsv --dataset cc3m --out data/cc3m
python experiments/download_cc3m.py make-img2dataset --tsv data/cc3m/cc3m.tsv --out data/cc3m/images
# 2. After img2dataset finishes, build the <image>\t<caption> manifest
python experiments/download_cc3m.py build-manifest --images data/cc3m/images --out data/cc3m/train.tsv

src.image_text_dataset.ImageTextPairDataset then trains on the manifest with the same model/trainer code (shares the caption tokenizer + image transforms).

Full training with phased schedule

python experiments/exp_jepa_training.py \
  --config configs/mvp_pretrained_siglip.yaml \
  --epochs 30 \
  --batch-size 8 \
  --gradient-accumulation-steps 8 \
  --phase-training \
  --gradient-checkpointing \
  --unfreeze-after-epoch 5 \
  --fresh

CLI highlights

Argument Default Description
--config configs/default.yaml YAML training config
--epochs 15 Training epochs
--batch-size 8 Batch size (per accumulation step)
--gradient-accumulation-steps 8 Steps to accumulate before optimizer step; effective batch size = batch_size × accumulation_steps
--lr 1e-4 Peak learning rate
--vision-backbone vit_base_patch16_clip_224.openai timm vision model (CLIP-pretrained ViT-B/16)
--text-backbone distilbert-base-uncased HuggingFace text model
--contrastive-loss siglip Loss type: siglip or infonce
--projection-type mlp mlp (random MLP head) or clip (linear seeded from CLIP visual.proj/text_projection)
--text-pool mean mean over valid tokens or eot (CLIP end-of-text token; required to recover CLIP alignment)
--phase-training false Use phased α/β/γ schedule
--gradient-checkpointing false Reduce VRAM at cost of speed
--unfreeze-after-epoch 5 Epoch to unfreeze the last 4 vision blocks (encoder_unfreeze_lr=2e-5)
--label-smoothing 0.05 SigLIP target smoothing (anti-overfit)
--hard-negative-weight 0.0 VSE++ hardest-negative ranking weight (δ); sharpens Recall@1
--hard-negative-margin 0.2 Margin for the hard-negative hinge
--text-backbone openclip Use CLIP's aligned text tower + BPE tokenizer
--resume Resume from checkpoint
--fresh Start from scratch

Project structure

vl-jepa/
├── src/
│   ├── model.py              # VL-JEPA, Timm/OpenCLIP encoders, SigLIP + VSE++ losses
│   ├── trainer.py            # Training loop, EMA, AMP, retrieval metrics
│   ├── dataset.py            # COCO 2017 + DistilBERT/CLIP tokenizer (CaptionTokenizer)
│   ├── image_text_dataset.py # Generic CC3M/CC12M manifest dataset
│   ├── cached_dataset.py     # Precomputed frozen-encoder embeddings
│   └── config.py             # YAML loader
├── experiments/
│   ├── exp_jepa_training.py
│   └── download_cc3m.py      # CC3M/CC12M TSV + manifest preparation
├── configs/
│   ├── default.yaml
│   ├── mvp_pretrained_siglip.yaml
│   ├── openclip_vitb16.yaml          # end-to-end OpenCLIP (random MLP proj, mean pool)
│   └── openclip_vitb16_aligned.yaml  # CLIP-aligned: clip projection + EOT pooling
├── tests/
│   ├── test_smoke.py
│   ├── test_dataset.py
│   ├── test_image_text_dataset.py
│   ├── test_cached_dataset.py
│   └── test_alignment_overfit.py
├── INVESTIGATION.md      # Debugging history
├── SOTA_RESEARCH_GPT.md  # SOTA research findings
└── README.md

Hardware

Tested on: NVIDIA RTX 3090 (24 GB), Fedora 43.

VRAM usage:

  • Pretrained frozen (batch 8): ~1.6 GB
  • Pretrained frozen (batch 16): ~3.2 GB
  • With gradient checkpointing: ~1.2 GB
  • Unfrozen encoders (batch 8): ~6-8 GB

References


License

CC-BY 4.0 (matching the original paper).

About

VL-JEPA: Vision-Language Joint Embedding Predictive Architecture implementation in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages