feat(utils): add MCDropoutWrapper for uncertainty-aware active learning#801
Open
0xSoftBoi wants to merge 6 commits into
Open
feat(utils): add MCDropoutWrapper for uncertainty-aware active learning#8010xSoftBoi wants to merge 6 commits into
0xSoftBoi wants to merge 6 commits into
Conversation
Adds `matgl.utils.MCDropoutWrapper`, which enables Monte Carlo Dropout uncertainty estimation on any pretrained MatGL model (CHGNet, M3GNet, TensorNet, …) without requiring retraining. Key design decisions: - Backbone stays in eval() throughout; only the readout dropout layers (final_dropout, final_layer) are switched to train() during inference, so BatchNorm statistics are not corrupted. - Works on models initialised with final_dropout=0 (the default): the nn.Identity placeholder is replaced with nn.Dropout at wrap time. - Structures are converted to graphs once per call; the N forward passes reuse the cached graph for efficiency. - Returns (mean, std) tensors directly usable in acquisition functions such as UCB: score = mean - lambda * std. Motivation: active learning workflows benefit from per-structure uncertainty estimates to prioritise expensive DFT calculations. See issue materialyzeai#800 for context. Tests cover CHGNet and M3GNet, single/batch inputs, shape correctness, std > 0 with dropout active, and clean eval() restoration after predict. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tness - `__init__.py`: add missing `from __future__ import annotations` (ruff I002) - Injection: break early once stochastic modules are found, preventing `_inject_dropout` from corrupting `_MLPNorm.final_layer` whose forward uses `layers[-1]` indexed access rather than full iteration - Add explicit `nn.Dropout` branch so existing CHGNet dropout is updated in place without falling through to the injection path - Single-structure output: track `single` flag, squeeze batch dim on return so `predict_uncertainty(structure)` returns scalar tensors as documented - `_to_graph`: remove lazy `import matgl` inside method; use plain `torch.tensor(state_feats_default)` matching CHGNet's own pattern - Tests: use `M3GNet(is_intensive=True)` for M3GNet tests so `final_layer` is `MLP` (iterable layers) rather than `WeightedReadOut` (non-injectable) All 13 tests pass; ruff check + format --check clean. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace the O(N * n_passes) per-structure loop with Batch.from_data_list so predict_uncertainty costs O(n_passes) forward calls regardless of N. Also add n_passes >= 2 validation (std of 1 sample is undefined), document that M3GNet is_intensive=False is unsupported (WeightedReadOut has no injectable dropout site), and clean up the __init__ comment block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
matgl places Args/Raises in __init__ and keeps the class docstring as a brief description + example only (matches CHGNet, M3GNet pattern). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
mypy infers self.model.element_types / .cutoff as Tensor | Module via nn.Module.__getattr__; the existing attr-defined ignore didn't cover the resulting arg-type error reported by CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #800.
Summary
Adds
matgl.utils.MCDropoutWrapper— a lightweight wrapper that enables Monte Carlo Dropout uncertainty estimation on any pretrained MatGL model without retraining.Design
Backbone stays deterministic. Only the readout layers (
final_dropout,final_layer) are switched back totrain()during inference. The convolutional backbone stays ineval(), so BatchNorm statistics are not perturbed.Works on default pretrained models. CHGNet's default
final_dropout=0produces annn.Identityplaceholder. The wrapper replaces it with a realnn.Dropoutat construction time, so no retraining is needed.Batched forward passes. All N structures are packed into a single PyG
Batchper pass viaBatch.from_data_list, so inference costs O(n_passes) forward calls rather than O(N × n_passes). Graph conversion still happens exactly once per structure.Model restored to eval() after every call. The
_stochastic_mode()context manager guarantees this even if an exception is raised mid-inference.M3GNet scope.
is_intensive=True(intensive/per-atom targets, MLP readout) is fully supported via dropout injection.is_intensive=False(WeightedReadOut / GatedMLP) has no injectable dropout site compatible with MC Dropout and raisesValueError— documented explicitly in the test suite.Motivation
Active learning for materials discovery requires per-candidate uncertainty estimates to rank structures by acquisition score. I benchmarked this on the WBM dataset (256K structures, Matbench Discovery) using CHGNet as the surrogate:
Results stable across 5 random seeds (±0.02 DAF). The
stdfrom MC Dropout is what separates UCB from Greedy.Files changed
src/matgl/utils/uncertainty.py—MCDropoutWrapperclass +_inject_dropouthelpersrc/matgl/utils/__init__.py— exportMCDropoutWrappertests/utils/test_uncertainty.py— 15 tests covering CHGNet + M3GNet, single/batch inputs, output shapes, std > 0, eval() cleanup, and documented unsupported configsTest plan
ruff checkandruff format --checkpass against projectpyproject.tomlpytest tests/utils/test_uncertainty.py— 15/15 passingCHGNet.load()pretrained weights🤖 Generated with Claude Code