Skip to content

Bug: SoftExponential.forward branches on a learnable parameter, breaking autograd and producing NaNs #788

@shyuep

Description

@shyuep

Summary

SoftExponential in src/matgl/layers/_activations.py uses a Python if/elif ladder on its learnable alpha parameter inside forward. This has four issues, ranging from a real correctness bug (gradients of alpha are computed for only one branch at a time) to numerical and performance footguns.

Affected code

src/matgl/layers/_activations.py:62-75:

class SoftExponential(nn.Module):
    def __init__(self, alpha: float | None = None):
        super().__init__()
        if alpha is None:
            self.alpha = nn.Parameter(torch.tensor(0.0))
        else:
            self.alpha = nn.Parameter(torch.tensor(alpha))
        self.alpha.requires_grad_(True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.alpha == 0.0:
            return x
        if self.alpha < 0.0:
            return -torch.log(1.0 - self.alpha * (x + self.alpha)) / self.alpha
        return (torch.exp(self.alpha * x) - 1.0) / self.alpha + self.alpha

self.alpha is a learnable nn.Parameter with requires_grad=True.

Bug 1 — branch decision discards the autograd graph (correctness)

self.alpha == 0.0 and self.alpha < 0.0 return a 0-d boolean tensor; the if implicitly calls bool(...), which both forces a host-device sync and drops the comparison from the autograd graph. Consequences:

  • One of the three formulas is picked per call.
  • dL/d(alpha) is computed only along the branch that fired.
  • The decision boundaries (alpha == 0, alpha → 0⁻, alpha → 0⁺) are not differentiable points, and the activation is not piecewise-differentiable in a way autograd can see.
  • If alpha is currently −0.01 and the optimum sits at some small , the optimizer gets no signal pointing across the boundary. The parameter is effectively trapped in whichever sign region it was initialised in.

A correct piecewise-differentiable implementation uses torch.where so both branches' partials in alpha participate in the autograd graph. The unused branch must be guarded with safe stand-in inputs (the "double-where trick") because torch.where still evaluates both sides:

def forward(self, x):
    a = self.alpha
    safe_log_arg = torch.where(a < 0, 1.0 - a * (x + a), torch.ones_like(x))
    neg = -torch.log(safe_log_arg) / a
    pos = torch.expm1(a * x) / a + a
    out = torch.where(a < 0, neg, pos)
    return torch.where(a.abs() < eps, x, out)

Bug 2 — exact equality test on a float parameter

self.alpha == 0.0 compares a learned float to exactly zero. After the first optimizer step, alpha is almost certainly some tiny non-zero float (e.g. 1e-7), so the "identity" branch is unreachable in practice. The intent — "behave as identity near zero" — needs a tolerance, e.g. torch.abs(a) < eps.

Bug 3 — undefined-region values silently produce NaN/Inf

For a < 0, the formula −log(1 − a(x + a)) / a requires 1 − a(x + a) > 0, i.e. (with a < 0) x > −a − 1/|a|. For sufficiently negative x the log argument goes ≤ 0, producing NaN/Inf. In a PES context where features can be large early in training, this is easy to trigger. Fix by clamping the log argument to a small positive floor, or reformulating via softplus.

Bug 4 — host-device sync on every call (perf)

if self.alpha == 0.0: on a CUDA alpha triggers a .item() round-trip on every forward, stalling the GPU. torch.compile cannot fuse across this. The torch.where rewrite removes it.

Blast radius

SoftExponential is not the default for any shipped model. It is selectable via ActivationFunction and via string-config model assembly, so a user who picks it hits all four bugs at once. The "softplus2", "swish", "silu" defaults are unaffected.

Suggested fix

A torch.where-based rewrite with the double-where guard plus an eps tolerance on the zero-branch, as sketched above. Alternatively, deprecate SoftExponential from the ActivationFunction enum until fixed.

Happy to send a PR.

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions