feat: impleneted ARAF model and instructions#32
Open
ronanhansel wants to merge 2 commits into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the ARAF (Automatic Relevance Amortised Factors model) model, ported and consolidated from
agent-eval/model/amortized_irt.py. ARAF is a multidimensional factor model designed for the small-N, large-Jregime of agent evaluation: instead of fitting free per-item parameters (unstable when only ~30 agents have responded), item loadings are amortized through a learned projection of pre-computed item embeddings, and an Automatic Relevance Determination (ARD) gate prunes inactive latent dimensions automatically.The model is
with two transfer directions supported:
W, τ, b, Θand supplying new item embeddings.ARAF.adapt(), which freezes the item-side parameters and refits onlyθ, θ_bias.What's in the PR
src/torch_measure/models/amortized.pyARAFclass extendingPredictor, withdense_predict,set_embeddings,loadings,difficulty,tau,active_dims, plusfit()andadapt()that route through the dedicated training loop.src/torch_measure/fitting/araf.pyaraf_fitandaraf_adapttraining loops.{θ, θ_bias, W, global_bias, difficulty_proj}, SGD fortau_raw.λ_τ = 0fortau_warmupepochs, linear ramp overramp_epochs, full λ thereafter.1e-3to a frozen dead-zone value, applied every 10 epochs after warmup.src/torch_measure/models/__init__.py,src/torch_measure/fitting/__init__.pyARAF,araf_fit,araf_adapt.tests/test_models/test_araf.pyadapt()test.tutorials/araf.ipynbK* = 3, where ARD prunesK = 12initialization to ~3 active dims and canonical correlations on the recovered loading subspace are ≈ 1; (2) end-to-end on AfriMedQA (30 LLMs × 2,000 MCQ items) with sentence-transformers embeddings, in-sample comparison vsTwoPLbaseline, zero-shot generalization to held-out items, andadapt()on held-out agents.Tutorial tuning
The first pass of the tutorial showed ARD as 10/10 active, all bars at τ ≈ 0.57 — visually indistinguishable from "ARD is broken". Root cause:
lambda_tau=0.0plus uniformtau_rawinit at 0.5. The current PR ships with:latent_dim = 20(overspecified so ARD has dims to prune)tau_rawinit viaUniform(-0.2, 0.8)(uniform init makes every dim see identical gradients and drives the gate to "all-on or all-off")lambda_tau = 0.013(a sweep across {0.005 … 0.05} showed this is the graded-pruning sweet spot for AfriMedQA)After-tuning result: 10/20 active with a clean staircase from τ ≈ 0.28 down to ≈ 0.02, train Brier 0.193, zero-shot test Brier 0.211 (matching the unregularized fit). The Summary section in the notebook now includes a "Practical note on ARD tuning" subsection capturing this.
Testing
pytest tests/test_models/test_araf.py— all green locally.nbconvert --execute tutorials/araf.ipynb— runs end-to-end, ~1 min on CPU after the embeddings cache.K* = 3true factors, Procrustes-aligned r ≈ 0.99, ARD active dims ≈ K*.Notes for reviewers
agent-eval/model/amortized_irt.py:train_amortized_irt; default constants (TAU_WARMUP=100,RAMP_EPOCHS=400,LR_TAU=0.05,WD_THETA=5.0, etc.) are preserved at module top offitting/araf.pyso they're easy to spot for anyone cross-referencing the upstream.Predictor, notIRTModel, because the natural unit is the full dense matrix rather than per-cell IRT probabilities.predict(query)is implemented for compatibility with thePredictorcontract by gathering fromdense_predict.References