PyTorch implementation of VL-JEPA — a JEPA-style vision-language model with masked patch prediction, EMA target encoders, and contrastive alignment. Optimized for commodity hardware (RTX 3090, 10GB VRAM).
Paper authors: Delong Chen, Mustafa Shukor, Theo Moutakanni, Willy Chung, Jade Yu, Tejaswi Kasarla, Yejin Bang, Allen Bolourchi, Yann LeCun, Pascale Fung
Implementation: Bruno Santos + TARS
VL-JEPA learns joint vision-language representations by:
- Predicting masked patch embeddings (I-JEPA style MSE against an EMA teacher).
- Aligning image and text in a shared projection space (SigLIP sigmoid loss).
- Regularizing embedding variance (VICReg-style) to prevent collapse.
This implementation uses pretrained frozen encoders as a starting point, with phased training that gradually introduces JEPA reconstruction after contrastive alignment is established. Two encoder paths are supported:
- Hybrid (timm CLIP ViT-B/16 + DistilBERT) — the original path. The vision tower is multimodally-aligned, but the DistilBERT text CLS was never trained for sentence retrieval, which caps fine-grained ranking.
- End-to-end OpenCLIP (
configs/openclip_vitb16.yaml) — both towers come from CLIP's 400M-pair pretraining, so vision and text are already aligned. The caption pipeline automatically switches to CLIP's BPE tokenizer (vocab 49408, SOT/EOT) and CLIP image normalization — feeding DistilBERT ids into the CLIP text tower would silently destroy the alignment.
Three runs — including the OpenCLIP end-to-end one — all plateaued at the same 25–26% R@1 despite R@5≈54% / R@10≈68% and ~92% in-batch NCE@1, with val loss still falling. The fact that a CLIP-initialized model couldn't beat its own zero-shot baseline (CLIP ViT-B/16 alone scores ~30%+ t2i R@1 on COCO 5K) was the giveaway: the pretrained cross-modal alignment was being discarded, so every run relearned alignment from scratch on 118K COCO images and capped at the same place. Two concrete bugs:
- Text was mean-pooled over a causal transformer. CLIP's text representation lives in the EOT token — the only position that attends to the whole sentence under the causal mask. Mean-pooling all token states throws CLIP's pretrained text embedding away.
- CLIP's native projection matrices were never used. The vision tower stopped at
ln_post(never applyingvisual.proj, 768→512) and the text tower never appliedtext_projection(512→512). Instead two randomly-initialized MLP heads relearned the projection from scratch — discarding the exact matrices CLIP trained on 400M pairs.
The fix (configs/openclip_vitb16_aligned.yaml): model.text_pool: eot + model.projection_type: clip. The joint projection is a single linear layer seeded from CLIP's native visual.proj / text_projection, and text is pooled at the EOT token. Together these make the joint embedding exactly equal CLIP's encode_image / encode_text at initialization (verified in tests/test_smoke.py::test_clip_proj_plus_eot_matches_openclip_encode), so training starts at CLIP zero-shot quality and fine-tunes upward instead of relearning alignment from zero. The trainer auto-drops the 10× projection LR to the base LR for CLIP-seeded projections so the first steps don't blow the pretrained matrices away.
Earlier/secondary levers (still available):
- OpenCLIP end-to-end — replaces the weak DistilBERT CLS with CLIP's aligned text tower. Enabled by the CLIP-tokenizer switch.
- Hard-negative mining (
loss.hard_negative_weight) — a VSE++ max-violation hinge on the hardest in-batch negative. Disabled (0.0) in the aligned config to isolate the alignment fix; re-enable (e.g.0.05) once the aligned baseline is established. - More data (
src/image_text_dataset.py,experiments/download_cc3m.py) — pretrain on CC3M/CC12M before COCO (COCO's 118K images are seen ~50× per run). - Higher resolution — set
model.image_size: 384; the OpenCLIP vision tower interpolates its positional embeddings automatically.
┌─────────────────────────────────────────────────────────────────────────────┐
│ TRAINING FORWARD PASS │
└─────────────────────────────────────────────────────────────────────────────┘
Global + local crops (multi-crop) Caption tokens (DistilBERT tokenizer)
│ │
▼ ▼
┌────────────────────┐ ┌────────────────────┐
│ Context encoder │ 75% block-masked │ Language encoder │
│ (timm CLIP ViT-B) │ local/global views │ (DistilBERT) │
│ frozen, 86M params│ │ frozen, 66M params│
└─────────┬──────────┘ └─────────┬──────────┘
│ │
│ ┌────────────────────┐ │
└────────►│ Predictor │ │
│ 4L transformer │ │
└─────────┬──────────┘ │
│ │
MSE (masked patches)│ │
▼ ▼
┌────────────────────┐ ┌──────────────────────────────────┐
│ Target encoder │◄── EMA τ ──── │ Projection heads (MLP) │
│ (frozen teacher) │ 0.996→1.0 │ 768 → 512, L2-normalized │
│ full-image patches│ │ Global CLS → vision_proj │
└────────────────────┘ └───────────────┬──────────────────┘
│
┌────────────────────────────┼────────────────────────────┐
│ SigLIP (FP32) │ │
│ • Pairwise sigmoid loss │ Phase A: alignment only │
│ • No softmax dependency │ Phase B: add JEPA MSE │
│ • Works at batch_size=8 │ Phase C: full training │
└────────────────────────────┴────────────────────────────┘
│
L = α·MSE + β·SigLIP + γ·VarReg
| Component | Role | Details |
|---|---|---|
| Context encoder | Student ViT; encodes masked/cropped images | timm CLIP ViT-B/16 (openai), frozen in phase A |
| Target encoder | EMA copy of context encoder; stop-gradient patch targets | Same weights, τ cosine 0.996→1.0 |
| Predictor | Lightweight transformer; predicts teacher patch embeddings | 4 layers |
| Language encoder | Text encoder | DistilBERT (mean pool) or OpenCLIP text tower (EOT pool with text_pool: eot) |
| Projection heads | mlp: LayerNorm→Linear→GELU→Linear→L2 (random, 10× LR). clip: single linear seeded from CLIP visual.proj/text_projection (base LR) |
768→512 / 512→512, L2-normalized |
| SigLIP loss | Pairwise sigmoid contrastive loss | No softmax, works at small batches |
| Variance reg | VICReg-style std ≥ 1 on pre-norm projections | γ = 0.01 |
L = α · L_mse + β · L_siglip + γ · L_var + δ · L_hardneg
L_mse: MSE(predicted_patches, target_patches) on masked positions only
L_siglip: pairwise sigmoid contrastive (no softmax dependency on batch size)
L_var: VICReg variance on vision_proj_raw and language_proj_raw
L_hardneg: VSE++ max-violation hinge on the hardest in-batch negative (δ defaults
to 0; set loss.hard_negative_weight > 0 to sharpen Recall@1)| Phase | Epochs | α (MSE) | β (SigLIP) | γ (Var) | What happens |
|---|---|---|---|---|---|
| A | 1-3 | 0.0 | 1.0 | 0.01 | Alignment only, frozen encoders |
| B | 4+ | 0.1 | 0.9 | 0.01 | Add gentle JEPA MSE; unfreeze last 4 vision blocks (epoch 5, encoder_unfreeze_lr=2e-5) |
The contrastive head also uses SigLIP label smoothing (loss.label_smoothing, default
0.05) and the train image pipeline adds colour jitter / grayscale / random erasing to
curb overfitting. The training micro-batch is 128 (× 2 grad-accum = effective 256); a
larger real batch is what increases SigLIP in-batch negatives, since accumulation
averages independent per-micro-batch sigmoid losses.
Current metrics with the CLIP-pretrained ViT-B/16 backbone and gradient accumulation:
| Metric | Value | Notes |
|---|---|---|
| i2t R@1 | ~24% | Image→text retrieval (vs MAE baseline: 14.7% after 22 epochs) |
| t2i R@1 | ~25% | Text→image retrieval |
| NCE@1 | ~92% | In-batch retrieval accuracy |
| Val Loss | 0.77 | — |
The CLIP backbone reaches ~24-25% R@1 by epoch 16, well above the MAE baseline's 14.7% R@1 after 22 epochs — the multimodally-aligned starting features converge faster and higher.
| Metric | Value | Notes |
|---|---|---|
| Train NCE@1 | 78.16% | In-batch retrieval accuracy |
| Val NCE@1 | 86.05% | In-batch retrieval accuracy |
| Val Loss | 1.04 | Started at 6.85 |
| MSE | 0.41 | Non-zero (JEPA masking working) |
| VRAM | 1.64 GB | Fits easily on RTX 3090 |
| Skipped batches | 0 | Zero NaN |
| Training time | 53 min | 1 epoch, batch_size=8 |
Comparison with previous approach (random encoders + InfoNCE):
| Approach | Val NCE@1 | VRAM | Notes |
|---|---|---|---|
| Random encoders + InfoNCE + 65K queue | ~3.2% | 7.3 GB | Stuck at random baseline for 25+ epochs |
| Pretrained + SigLIP + mean pooling | 86.05% | 1.64 GB | Works in 1 epoch |
- Python 3.11+
- PyTorch 2.1+ with CUDA
- GPU with 2GB+ VRAM (tested on RTX 3090)
git clone https://github.com/arnonbruno/vl-jepa.git
cd vl-jepa
pip install -r requirements.txtpython experiments/exp_jepa_training.py \
--config configs/mvp_pretrained_siglip.yaml \
--epochs 10 \
--batch-size 8 \
--freshpython experiments/exp_jepa_training.py \
--config configs/openclip_vitb16_aligned.yaml \
--epochs 50 \
--freshThis starts from CLIP's aligned joint space via projection_type: clip
(linear projection seeded from visual.proj/text_projection) and
text_pool: eot (CLIP end-of-text pooling). Both are required to recover
CLIP's pretrained alignment; the CLI flags --projection-type and
--text-pool expose them. The older configs/openclip_vitb16.yaml (random
MLP projection + mean pooling + hard_negative_weight: 0.2) is kept for
reference and still works unchanged.
# 1. Fetch the caption/URL TSV (small) and print the img2dataset command (images)
python experiments/download_cc3m.py download-tsv --dataset cc3m --out data/cc3m
python experiments/download_cc3m.py make-img2dataset --tsv data/cc3m/cc3m.tsv --out data/cc3m/images
# 2. After img2dataset finishes, build the <image>\t<caption> manifest
python experiments/download_cc3m.py build-manifest --images data/cc3m/images --out data/cc3m/train.tsvsrc.image_text_dataset.ImageTextPairDataset then trains on the manifest with
the same model/trainer code (shares the caption tokenizer + image transforms).
python experiments/exp_jepa_training.py \
--config configs/mvp_pretrained_siglip.yaml \
--epochs 30 \
--batch-size 8 \
--gradient-accumulation-steps 8 \
--phase-training \
--gradient-checkpointing \
--unfreeze-after-epoch 5 \
--fresh| Argument | Default | Description |
|---|---|---|
--config |
configs/default.yaml |
YAML training config |
--epochs |
15 | Training epochs |
--batch-size |
8 | Batch size (per accumulation step) |
--gradient-accumulation-steps |
8 | Steps to accumulate before optimizer step; effective batch size = batch_size × accumulation_steps |
--lr |
1e-4 | Peak learning rate |
--vision-backbone |
vit_base_patch16_clip_224.openai |
timm vision model (CLIP-pretrained ViT-B/16) |
--text-backbone |
distilbert-base-uncased |
HuggingFace text model |
--contrastive-loss |
siglip |
Loss type: siglip or infonce |
--projection-type |
mlp |
mlp (random MLP head) or clip (linear seeded from CLIP visual.proj/text_projection) |
--text-pool |
mean |
mean over valid tokens or eot (CLIP end-of-text token; required to recover CLIP alignment) |
--phase-training |
false | Use phased α/β/γ schedule |
--gradient-checkpointing |
false | Reduce VRAM at cost of speed |
--unfreeze-after-epoch |
5 | Epoch to unfreeze the last 4 vision blocks (encoder_unfreeze_lr=2e-5) |
--label-smoothing |
0.05 | SigLIP target smoothing (anti-overfit) |
--hard-negative-weight |
0.0 | VSE++ hardest-negative ranking weight (δ); sharpens Recall@1 |
--hard-negative-margin |
0.2 | Margin for the hard-negative hinge |
--text-backbone openclip |
— | Use CLIP's aligned text tower + BPE tokenizer |
--resume |
— | Resume from checkpoint |
--fresh |
— | Start from scratch |
vl-jepa/
├── src/
│ ├── model.py # VL-JEPA, Timm/OpenCLIP encoders, SigLIP + VSE++ losses
│ ├── trainer.py # Training loop, EMA, AMP, retrieval metrics
│ ├── dataset.py # COCO 2017 + DistilBERT/CLIP tokenizer (CaptionTokenizer)
│ ├── image_text_dataset.py # Generic CC3M/CC12M manifest dataset
│ ├── cached_dataset.py # Precomputed frozen-encoder embeddings
│ └── config.py # YAML loader
├── experiments/
│ ├── exp_jepa_training.py
│ └── download_cc3m.py # CC3M/CC12M TSV + manifest preparation
├── configs/
│ ├── default.yaml
│ ├── mvp_pretrained_siglip.yaml
│ ├── openclip_vitb16.yaml # end-to-end OpenCLIP (random MLP proj, mean pool)
│ └── openclip_vitb16_aligned.yaml # CLIP-aligned: clip projection + EOT pooling
├── tests/
│ ├── test_smoke.py
│ ├── test_dataset.py
│ ├── test_image_text_dataset.py
│ ├── test_cached_dataset.py
│ └── test_alignment_overfit.py
├── INVESTIGATION.md # Debugging history
├── SOTA_RESEARCH_GPT.md # SOTA research findings
└── README.md
Tested on: NVIDIA RTX 3090 (24 GB), Fedora 43.
VRAM usage:
- Pretrained frozen (batch 8): ~1.6 GB
- Pretrained frozen (batch 16): ~3.2 GB
- With gradient checkpointing: ~1.2 GB
- Unfrozen encoders (batch 8): ~6-8 GB
CC-BY 4.0 (matching the original paper).