Summary
SoftExponential in src/matgl/layers/_activations.py uses a Python if/elif ladder on its learnable alpha parameter inside forward. This has four issues, ranging from a real correctness bug (gradients of alpha are computed for only one branch at a time) to numerical and performance footguns.
Affected code
src/matgl/layers/_activations.py:62-75:
class SoftExponential(nn.Module):
def __init__(self, alpha: float | None = None):
super().__init__()
if alpha is None:
self.alpha = nn.Parameter(torch.tensor(0.0))
else:
self.alpha = nn.Parameter(torch.tensor(alpha))
self.alpha.requires_grad_(True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.alpha == 0.0:
return x
if self.alpha < 0.0:
return -torch.log(1.0 - self.alpha * (x + self.alpha)) / self.alpha
return (torch.exp(self.alpha * x) - 1.0) / self.alpha + self.alpha
self.alpha is a learnable nn.Parameter with requires_grad=True.
Bug 1 — branch decision discards the autograd graph (correctness)
self.alpha == 0.0 and self.alpha < 0.0 return a 0-d boolean tensor; the if implicitly calls bool(...), which both forces a host-device sync and drops the comparison from the autograd graph. Consequences:
- One of the three formulas is picked per call.
dL/d(alpha) is computed only along the branch that fired.
- The decision boundaries (
alpha == 0, alpha → 0⁻, alpha → 0⁺) are not differentiable points, and the activation is not piecewise-differentiable in a way autograd can see.
- If
alpha is currently −0.01 and the optimum sits at some small +ε, the optimizer gets no signal pointing across the boundary. The parameter is effectively trapped in whichever sign region it was initialised in.
A correct piecewise-differentiable implementation uses torch.where so both branches' partials in alpha participate in the autograd graph. The unused branch must be guarded with safe stand-in inputs (the "double-where trick") because torch.where still evaluates both sides:
def forward(self, x):
a = self.alpha
safe_log_arg = torch.where(a < 0, 1.0 - a * (x + a), torch.ones_like(x))
neg = -torch.log(safe_log_arg) / a
pos = torch.expm1(a * x) / a + a
out = torch.where(a < 0, neg, pos)
return torch.where(a.abs() < eps, x, out)
Bug 2 — exact equality test on a float parameter
self.alpha == 0.0 compares a learned float to exactly zero. After the first optimizer step, alpha is almost certainly some tiny non-zero float (e.g. 1e-7), so the "identity" branch is unreachable in practice. The intent — "behave as identity near zero" — needs a tolerance, e.g. torch.abs(a) < eps.
Bug 3 — undefined-region values silently produce NaN/Inf
For a < 0, the formula −log(1 − a(x + a)) / a requires 1 − a(x + a) > 0, i.e. (with a < 0) x > −a − 1/|a|. For sufficiently negative x the log argument goes ≤ 0, producing NaN/Inf. In a PES context where features can be large early in training, this is easy to trigger. Fix by clamping the log argument to a small positive floor, or reformulating via softplus.
Bug 4 — host-device sync on every call (perf)
if self.alpha == 0.0: on a CUDA alpha triggers a .item() round-trip on every forward, stalling the GPU. torch.compile cannot fuse across this. The torch.where rewrite removes it.
Blast radius
SoftExponential is not the default for any shipped model. It is selectable via ActivationFunction and via string-config model assembly, so a user who picks it hits all four bugs at once. The "softplus2", "swish", "silu" defaults are unaffected.
Suggested fix
A torch.where-based rewrite with the double-where guard plus an eps tolerance on the zero-branch, as sketched above. Alternatively, deprecate SoftExponential from the ActivationFunction enum until fixed.
Happy to send a PR.
🤖 Generated with Claude Code
Summary
SoftExponentialinsrc/matgl/layers/_activations.pyuses a Pythonif/elifladder on its learnablealphaparameter insideforward. This has four issues, ranging from a real correctness bug (gradients ofalphaare computed for only one branch at a time) to numerical and performance footguns.Affected code
src/matgl/layers/_activations.py:62-75:self.alphais a learnablenn.Parameterwithrequires_grad=True.Bug 1 — branch decision discards the autograd graph (correctness)
self.alpha == 0.0andself.alpha < 0.0return a 0-d boolean tensor; theifimplicitly callsbool(...), which both forces a host-device sync and drops the comparison from the autograd graph. Consequences:dL/d(alpha)is computed only along the branch that fired.alpha == 0,alpha → 0⁻,alpha → 0⁺) are not differentiable points, and the activation is not piecewise-differentiable in a way autograd can see.alphais currently−0.01and the optimum sits at some small+ε, the optimizer gets no signal pointing across the boundary. The parameter is effectively trapped in whichever sign region it was initialised in.A correct piecewise-differentiable implementation uses
torch.whereso both branches' partials inalphaparticipate in the autograd graph. The unused branch must be guarded with safe stand-in inputs (the "double-where trick") becausetorch.wherestill evaluates both sides:Bug 2 — exact equality test on a float parameter
self.alpha == 0.0compares a learned float to exactly zero. After the first optimizer step,alphais almost certainly some tiny non-zero float (e.g.1e-7), so the "identity" branch is unreachable in practice. The intent — "behave as identity near zero" — needs a tolerance, e.g.torch.abs(a) < eps.Bug 3 — undefined-region values silently produce NaN/Inf
For
a < 0, the formula−log(1 − a(x + a)) / arequires1 − a(x + a) > 0, i.e. (witha < 0)x > −a − 1/|a|. For sufficiently negativexthelogargument goes ≤ 0, producing NaN/Inf. In a PES context where features can be large early in training, this is easy to trigger. Fix by clamping thelogargument to a small positive floor, or reformulating viasoftplus.Bug 4 — host-device sync on every call (perf)
if self.alpha == 0.0:on a CUDAalphatriggers a.item()round-trip on every forward, stalling the GPU.torch.compilecannot fuse across this. Thetorch.whererewrite removes it.Blast radius
SoftExponentialis not the default for any shipped model. It is selectable viaActivationFunctionand via string-config model assembly, so a user who picks it hits all four bugs at once. The "softplus2", "swish", "silu" defaults are unaffected.Suggested fix
A
torch.where-based rewrite with the double-where guard plus anepstolerance on the zero-branch, as sketched above. Alternatively, deprecateSoftExponentialfrom theActivationFunctionenum until fixed.Happy to send a PR.
🤖 Generated with Claude Code