Skip to content

feat(utils): add MCDropoutWrapper for uncertainty-aware active learning#801

Open
0xSoftBoi wants to merge 6 commits into
materialyzeai:mainfrom
0xSoftBoi:feat/mc-dropout-uncertainty
Open

feat(utils): add MCDropoutWrapper for uncertainty-aware active learning#801
0xSoftBoi wants to merge 6 commits into
materialyzeai:mainfrom
0xSoftBoi:feat/mc-dropout-uncertainty

Conversation

@0xSoftBoi

@0xSoftBoi 0xSoftBoi commented Jun 8, 2026

Copy link
Copy Markdown

Closes #800.

Summary

Adds matgl.utils.MCDropoutWrapper — a lightweight wrapper that enables Monte Carlo Dropout uncertainty estimation on any pretrained MatGL model without retraining.

from matgl.models._chgnet import CHGNet
from matgl.utils.uncertainty import MCDropoutWrapper

model = CHGNet.load()
wrapper = MCDropoutWrapper(model, dropout_p=0.1)

mean, std = wrapper.predict_uncertainty(structures, n_passes=20)
# mean, std: tensors of shape (N,) — usable in acquisition functions
# e.g. UCB score = mean - lambda * std

Design

Backbone stays deterministic. Only the readout layers (final_dropout, final_layer) are switched back to train() during inference. The convolutional backbone stays in eval(), so BatchNorm statistics are not perturbed.

Works on default pretrained models. CHGNet's default final_dropout=0 produces an nn.Identity placeholder. The wrapper replaces it with a real nn.Dropout at construction time, so no retraining is needed.

Batched forward passes. All N structures are packed into a single PyG Batch per pass via Batch.from_data_list, so inference costs O(n_passes) forward calls rather than O(N × n_passes). Graph conversion still happens exactly once per structure.

Model restored to eval() after every call. The _stochastic_mode() context manager guarantees this even if an exception is raised mid-inference.

M3GNet scope. is_intensive=True (intensive/per-atom targets, MLP readout) is fully supported via dropout injection. is_intensive=False (WeightedReadOut / GatedMLP) has no injectable dropout site compatible with MC Dropout and raises ValueError — documented explicitly in the test suite.

Motivation

Active learning for materials discovery requires per-candidate uncertainty estimates to rank structures by acquisition score. I benchmarked this on the WBM dataset (256K structures, Matbench Discovery) using CHGNet as the surrogate:

Strategy Stable found DAF Budget
Random 370 1.01x 2,200 / 256,963
Greedy (μ) 412 1.12x 2,200 / 256,963
UCB (μ - λσ) 425 1.16x 2,200 / 256,963

Results stable across 5 random seeds (±0.02 DAF). The std from MC Dropout is what separates UCB from Greedy.

Files changed

  • src/matgl/utils/uncertainty.pyMCDropoutWrapper class + _inject_dropout helper
  • src/matgl/utils/__init__.py — export MCDropoutWrapper
  • tests/utils/test_uncertainty.py — 15 tests covering CHGNet + M3GNet, single/batch inputs, output shapes, std > 0, eval() cleanup, and documented unsupported configs

Test plan

  • ruff check and ruff format --check pass against project pyproject.toml
  • pytest tests/utils/test_uncertainty.py — 15/15 passing
  • Manually verified on MoS₂ and LiFePO₄ with CHGNet.load() pretrained weights

🤖 Generated with Claude Code

Adds `matgl.utils.MCDropoutWrapper`, which enables Monte Carlo Dropout
uncertainty estimation on any pretrained MatGL model (CHGNet, M3GNet,
TensorNet, …) without requiring retraining.

Key design decisions:
- Backbone stays in eval() throughout; only the readout dropout layers
  (final_dropout, final_layer) are switched to train() during inference,
  so BatchNorm statistics are not corrupted.
- Works on models initialised with final_dropout=0 (the default): the
  nn.Identity placeholder is replaced with nn.Dropout at wrap time.
- Structures are converted to graphs once per call; the N forward passes
  reuse the cached graph for efficiency.
- Returns (mean, std) tensors directly usable in acquisition functions
  such as UCB: score = mean - lambda * std.

Motivation: active learning workflows benefit from per-structure
uncertainty estimates to prioritise expensive DFT calculations. See
issue materialyzeai#800 for context.

Tests cover CHGNet and M3GNet, single/batch inputs, shape correctness,
std > 0 with dropout active, and clean eval() restoration after predict.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@0xSoftBoi 0xSoftBoi requested review from kenko911 and shyuep as code owners June 8, 2026 02:28
pre-commit-ci Bot and others added 5 commits June 8, 2026 02:29
…tness

- `__init__.py`: add missing `from __future__ import annotations` (ruff I002)
- Injection: break early once stochastic modules are found, preventing
  `_inject_dropout` from corrupting `_MLPNorm.final_layer` whose forward
  uses `layers[-1]` indexed access rather than full iteration
- Add explicit `nn.Dropout` branch so existing CHGNet dropout is updated
  in place without falling through to the injection path
- Single-structure output: track `single` flag, squeeze batch dim on return
  so `predict_uncertainty(structure)` returns scalar tensors as documented
- `_to_graph`: remove lazy `import matgl` inside method; use plain
  `torch.tensor(state_feats_default)` matching CHGNet's own pattern
- Tests: use `M3GNet(is_intensive=True)` for M3GNet tests so `final_layer`
  is `MLP` (iterable layers) rather than `WeightedReadOut` (non-injectable)

All 13 tests pass; ruff check + format --check clean.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace the O(N * n_passes) per-structure loop with Batch.from_data_list
so predict_uncertainty costs O(n_passes) forward calls regardless of N.
Also add n_passes >= 2 validation (std of 1 sample is undefined), document
that M3GNet is_intensive=False is unsupported (WeightedReadOut has no
injectable dropout site), and clean up the __init__ comment block.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
matgl places Args/Raises in __init__ and keeps the class docstring as a
brief description + example only (matches CHGNet, M3GNet pattern).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
mypy infers self.model.element_types / .cutoff as Tensor | Module via
nn.Module.__getattr__; the existing attr-defined ignore didn't cover the
resulting arg-type error reported by CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: MC Dropout uncertainty estimation for active learning

1 participant