Skip to content

feat: impleneted ARAF model and instructions#32

Open
ronanhansel wants to merge 2 commits into
aims-foundations:mainfrom
ronanhansel:feat/araf-model
Open

feat: impleneted ARAF model and instructions#32
ronanhansel wants to merge 2 commits into
aims-foundations:mainfrom
ronanhansel:feat/araf-model

Conversation

@ronanhansel
Copy link
Copy Markdown

@ronanhansel ronanhansel commented May 18, 2026

Summary

Adds the ARAF (Automatic Relevance Amortised Factors model) model, ported and consolidated from agent-eval/model/amortized_irt.py. ARAF is a multidimensional factor model designed for the small-N, large-J regime of agent evaluation: instead of fitting free per-item parameters (unstable when only ~30 agents have responded), item loadings are amortized through a learned projection of pre-computed item embeddings, and an Automatic Relevance Determination (ARD) gate prunes inactive latent dimensions automatically.

The model is

a_j  = ReLU(τ) ⊙ (x_j @ W_normᵀ)            # amortized loadings
b_j  = Linear(x_j)                            # amortized difficulty
μ_ij = σ(θ_iᵀ a_j + b_j + θ_bias_i + g)      # response probability

with two transfer directions supported:

  • Zero-shot to new items by reusing W, τ, b, Θ and supplying new item embeddings.
  • Adaptation to new agents via ARAF.adapt(), which freezes the item-side parameters and refits only θ, θ_bias.

What's in the PR

  • src/torch_measure/models/amortized.py
    • New ARAF class extending Predictor, with dense_predict, set_embeddings, loadings, difficulty, tau, active_dims, plus fit() and adapt() that route through the dedicated training loop.
  • src/torch_measure/fitting/araf.py
    • New araf_fit and araf_adapt training loops.
    • Two optimizer groups: AdamW for {θ, θ_bias, W, global_bias, difficulty_proj}, SGD for tau_raw.
    • ARD schedule: λ_τ = 0 for tau_warmup epochs, linear ramp over ramp_epochs, full λ thereafter.
    • τ-snapping at 1e-3 to a frozen dead-zone value, applied every 10 epochs after warmup.
    • Bernoulli (binary) and Beta (continuous in (0, 1)) likelihoods.
  • src/torch_measure/models/__init__.py, src/torch_measure/fitting/__init__.py
    • Re-export ARAF, araf_fit, araf_adapt.
  • tests/test_models/test_araf.py
    • Construction / shape tests, dense-predict range checks, end-to-end fit smoke tests with synthetic low-rank ground truth, ARD-rank-recovery test, zero-shot transfer test, and adapt() test.
  • tutorials/araf.ipynb
    • Two-part tutorial: (1) synthetic ground-truth recovery with K* = 3, where ARD prunes K = 12 initialization to ~3 active dims and canonical correlations on the recovered loading subspace are ≈ 1; (2) end-to-end on AfriMedQA (30 LLMs × 2,000 MCQ items) with sentence-transformers embeddings, in-sample comparison vs TwoPL baseline, zero-shot generalization to held-out items, and adapt() on held-out agents.

Tutorial tuning

The first pass of the tutorial showed ARD as 10/10 active, all bars at τ ≈ 0.57 — visually indistinguishable from "ARD is broken". Root cause: lambda_tau=0.0 plus uniform tau_raw init at 0.5. The current PR ships with:

  • latent_dim = 20 (overspecified so ARD has dims to prune)
  • non-uniform tau_raw init via Uniform(-0.2, 0.8) (uniform init makes every dim see identical gradients and drives the gate to "all-on or all-off")
  • lambda_tau = 0.013 (a sweep across {0.005 … 0.05} showed this is the graded-pruning sweet spot for AfriMedQA)

After-tuning result: 10/20 active with a clean staircase from τ ≈ 0.28 down to ≈ 0.02, train Brier 0.193, zero-shot test Brier 0.211 (matching the unregularized fit). The Summary section in the notebook now includes a "Practical note on ARD tuning" subsection capturing this.

Testing

  • pytest tests/test_models/test_araf.py — all green locally.
  • nbconvert --execute tutorials/araf.ipynb — runs end-to-end, ~1 min on CPU after the embeddings cache.
  • Synthetic recovery: canonical correlations ≈ 1 on each of the K* = 3 true factors, Procrustes-aligned r ≈ 0.99, ARD active dims ≈ K*.

Notes for reviewers

  • The training loop is a faithful port of agent-eval/model/amortized_irt.py:train_amortized_irt; default constants (TAU_WARMUP=100, RAMP_EPOCHS=400, LR_TAU=0.05, WD_THETA=5.0, etc.) are preserved at module top of fitting/araf.py so they're easy to spot for anyone cross-referencing the upstream.
  • ARAF inherits from Predictor, not IRTModel, because the natural unit is the full dense matrix rather than per-cell IRT probabilities. predict(query) is implemented for compatibility with the Predictor contract by gathering from dense_predict.
  • The Beta likelihood path is exercised by tests but not by the tutorial.

References

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant