Skip to content

Feature request: MC Dropout uncertainty estimation for active learning #800

@0xSoftBoi

Description

@0xSoftBoi

Summary

I've been running an iterative active learning pipeline on the WBM dataset (256K structures from Matbench Discovery) using CHGNet as a surrogate, and uncertainty estimation was the core mechanism for candidate selection. I'd love to see first-class support for this in MatGL.

What I implemented (prototype)

I wrapped CHGNet in a thin MC Dropout layer:

class MCDropoutSurrogate:
    def __init__(self, model, dropout_p=0.3):
        # Inject dropout into the MLP readout head
        for mod in model.mlp.modules():
            if isinstance(mod, nn.Dropout):
                mod.p = dropout_p

    def predict_with_uncertainty(self, graphs, n_passes=20):
        model.eval()
        model.mlp.train()  # keep dropout active in head only
        preds = [forward_pass(graphs) for _ in range(n_passes)]
        return np.mean(preds, axis=0), np.std(preds, axis=0)  # μ, σ

This gives per-structure uncertainty estimates that drive UCB acquisition: score = μ - λ·σ.

Results on WBM (256K structures)

With only 2,200 labeled structures (0.9% budget), UCB acquisition found 1.16x more stable materials than random screening — consistent across 5 random seeds. The key finding: MC Dropout on the pre-trained surrogate (no fine-tuning) outperforms fine-tuning on the small biased labeled set.

Feature request

Would you consider adding a predict_uncertainty(structures, n_passes=N) method or mixin to MatGL's model interface? The key pieces:

  1. Dropout injection at model load time (or as a wrapper) into the final MLP readout head
  2. Stochastic inference mode: backbone in eval(), MLP head in train() for N forward passes
  3. Return (μ, σ) per structure — enough for UCB/Thompson sampling acquisition functions

This would make MatGL an excellent backbone for active learning workflows in computational materials discovery, and the WBM dataset provides a ready-made benchmark for it.

Happy to contribute a PR if there's interest — I have a working implementation I could clean up.

Related: I opened a similar discussion on CederGroupHub/chgnet (#250) but was pointed toward MatGL as the active development target.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions