Skip to content

Commit 846a297

Browse files
Add raise_error option to TerminateOnNaN for immediate termination on NaN/Inf losses (#21841)
* Add HardTerminateOnNaN callback for immediate training termination on NaN loss * Add HardTerminateOnNaN callback for immediate training termination on NaN loss * Add HardTerminateOnNaN callback for immediate training termination on NaN loss * Add hard option to TerminateOnNaN for immediate termination on NaN/Inf loss * Refactor: rename argument to raise_error and merge tests into terminate_on_nan_test.py * Apply review fixes for TerminateOnNaN callback and tests * Apply review fixes for TerminateOnNaN callback and tests
1 parent 3b375af commit 846a297

File tree

2 files changed

+221
-6
lines changed

2 files changed

+221
-6
lines changed

keras/src/callbacks/terminate_on_nan.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,63 @@
77

88
@keras_export("keras.callbacks.TerminateOnNaN")
99
class TerminateOnNaN(Callback):
10-
"""Callback that terminates training when a NaN loss is encountered."""
10+
"""Callback that terminates training when a NaN loss is encountered.
11+
12+
This callback monitors the loss value during training
13+
and terminates training when a NaN or Inf loss is detected.
14+
By default, training is stopped gracefully
15+
by setting `model.stop_training = True`, which triggers all callback cleanup
16+
methods including `on_train_end()`.
17+
18+
Alternatively, you can use `raise_error=True` to immediately raise a
19+
RuntimeError when NaN/Inf is detected. This raise_error termination
20+
prevents `on_train_end()` from being called on other callbacks, which
21+
is useful for preserving backup states or preventing unintended cleanup
22+
when training fails.
23+
24+
Args:
25+
raise_error: Boolean, default False. If False, uses graceful stop via
26+
`model.stop_training = True`. If True, immediately raises
27+
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
28+
29+
Example:
30+
31+
```
32+
# Graceful termination (default)
33+
callback = keras.callbacks.TerminateOnNaN()
34+
model.fit(x, y, callbacks=[callback])
35+
36+
# raise_error termination (strict failure)
37+
callback = keras.callbacks.TerminateOnNaN(raise_error=True)
38+
model.fit(x, y, callbacks=[callback])
39+
```
40+
"""
41+
42+
def __init__(self, raise_error: bool = False):
43+
super().__init__()
44+
self.raise_error = raise_error
1145

1246
def on_batch_end(self, batch, logs=None):
47+
"""Check for NaN/Inf loss at the end of each batch.
48+
49+
Args:
50+
batch: Integer, index of batch within the current epoch.
51+
logs: Dict, contains the return value of `model.train_step()`.
52+
53+
Raises:
54+
RuntimeError: If loss is NaN/Inf and raise_error=True.
55+
"""
1356
logs = logs or {}
1457
loss = logs.get("loss")
1558
if loss is not None:
1659
if np.isnan(loss) or np.isinf(loss):
17-
io_utils.print_msg(
18-
f"Batch {batch}: Invalid loss, terminating training"
19-
)
20-
self.model.stop_training = True
60+
if self.raise_error:
61+
raise RuntimeError(
62+
f"NaN or Inf loss encountered at batch {batch}. "
63+
f"Loss value: {loss}. Terminating training immediately."
64+
)
65+
else:
66+
io_utils.print_msg(
67+
f"Batch {batch}: Invalid loss, terminating training"
68+
)
69+
self.model.stop_training = True

keras/src/callbacks/terminate_on_nan_test.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1+
import os
2+
13
import numpy as np
24
import pytest
5+
from absl.testing import parameterized
36

47
from keras.src import callbacks
58
from keras.src import initializers
69
from keras.src import layers
10+
from keras.src import models
711
from keras.src import testing
12+
from keras.src.callbacks import BackupAndRestore
13+
from keras.src.callbacks import TerminateOnNaN
814
from keras.src.models import Sequential
915
from keras.src.utils import numerical_utils
1016

1117

18+
@pytest.mark.requires_trainable_backend
1219
class TerminateOnNaNTest(testing.TestCase):
13-
@pytest.mark.requires_trainable_backend
20+
"""Test suite for TerminateOnNaN callback."""
21+
1422
def test_TerminateOnNaN(self):
1523
TRAIN_SAMPLES = 10
1624
TEST_SAMPLES = 10
@@ -50,3 +58,161 @@ def test_TerminateOnNaN(self):
5058
loss = history.history["loss"]
5159
self.assertEqual(len(loss), 1)
5260
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
61+
62+
def test_terminate_on_nan_graceful_stop(self):
63+
"""Test that TerminateOnNaN (default) gracefully stops training."""
64+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
65+
model.compile(optimizer="sgd", loss="mse")
66+
67+
x = np.array([[1.0], [2.0]])
68+
y = np.array([[np.inf], [np.inf]])
69+
70+
callback = TerminateOnNaN(raise_error=False)
71+
72+
# Training should complete without raising RuntimeError
73+
history = model.fit(
74+
x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0
75+
)
76+
77+
# Training should stop early
78+
self.assertLess(len(history.history["loss"]), 4)
79+
80+
def test_terminate_on_nan_raise_error_raises_error(self):
81+
"""Test that TerminateOnNaN(raise_error=True) raises
82+
RuntimeError on NaN loss.
83+
"""
84+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
85+
model.compile(optimizer="sgd", loss="mse")
86+
87+
x = np.array([[1.0], [2.0]])
88+
y = np.array([[np.inf], [np.inf]])
89+
90+
callback = TerminateOnNaN(raise_error=True)
91+
92+
# Training should raise RuntimeError
93+
with self.assertRaisesRegex(
94+
RuntimeError,
95+
"NaN or Inf loss encountered",
96+
):
97+
model.fit(
98+
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
99+
)
100+
101+
def test_raise_error_terminate_does_not_trigger_on_train_end(self):
102+
"""Test that on_train_end is NOT called when
103+
TerminateOnNaN(raise_error=True) raises.
104+
"""
105+
106+
class TrackingCallback(callbacks.Callback):
107+
def __init__(self):
108+
super().__init__()
109+
self.train_end_called = False
110+
111+
def on_train_end(self, logs=None):
112+
self.train_end_called = True
113+
114+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
115+
model.compile(optimizer="sgd", loss="mse")
116+
117+
x = np.array([[1.0]])
118+
y = np.array([[np.inf]])
119+
120+
tracking_callback = TrackingCallback()
121+
raise_error_terminate_callback = TerminateOnNaN(raise_error=True)
122+
123+
# Should raise RuntimeError
124+
with self.assertRaises(RuntimeError):
125+
model.fit(
126+
x,
127+
y,
128+
epochs=1,
129+
callbacks=[tracking_callback, raise_error_terminate_callback],
130+
verbose=0,
131+
)
132+
133+
# on_train_end should NOT have been called
134+
self.assertFalse(tracking_callback.train_end_called)
135+
136+
def test_raise_error_terminate_preserves_backup(self):
137+
"""Ensure BackupAndRestore directory is preserved when
138+
TerminateOnNaN(raise_error=True) triggers.
139+
"""
140+
tmpdir = self.get_temp_dir()
141+
backup_dir = os.path.join(tmpdir, "backups")
142+
os.makedirs(backup_dir, exist_ok=True)
143+
144+
fake_file = os.path.join(backup_dir, "checkpoint.txt")
145+
with open(fake_file, "w") as f:
146+
f.write("dummy checkpoint")
147+
148+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
149+
model.compile(optimizer="sgd", loss="mse")
150+
151+
x_nan = np.array([[1.0]])
152+
y_nan = np.array([[np.inf]])
153+
154+
raise_error_terminate_callback = TerminateOnNaN(raise_error=True)
155+
backup_callback = BackupAndRestore(backup_dir=backup_dir)
156+
157+
# Monkeypatch BackupAndRestore to prevent cleanup on train_end
158+
backup_callback.on_train_end = lambda logs=None: None
159+
160+
# Training should raise RuntimeError
161+
with self.assertRaises(RuntimeError):
162+
model.fit(
163+
x_nan,
164+
y_nan,
165+
epochs=1,
166+
callbacks=[backup_callback, raise_error_terminate_callback],
167+
verbose=0,
168+
)
169+
170+
# Verify backup directory still exists and file inside is untouched
171+
self.assertTrue(
172+
os.path.exists(backup_dir),
173+
f"Backup dir deleted: {backup_dir}",
174+
)
175+
self.assertTrue(
176+
os.path.exists(fake_file),
177+
"Backup file missing unexpectedly.",
178+
)
179+
180+
@parameterized.named_parameters(
181+
("raise_error_false", False),
182+
("raise_error_true", True),
183+
)
184+
def test_normal_training_does_not_raise(self, raise_error):
185+
"""Test that TerminateOnNaN does not raise on normal training."""
186+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
187+
model.compile(optimizer="sgd", loss="mse")
188+
189+
x = np.array([[1.0], [2.0]])
190+
y = np.array([[1.0], [2.0]])
191+
192+
callback = TerminateOnNaN(raise_error=raise_error)
193+
194+
# Should complete without raising RuntimeError
195+
history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)
196+
197+
# Should have completed 2 epochs
198+
self.assertEqual(len(history.history["loss"]), 2)
199+
200+
def test_raise_error_terminate_stops_on_later_batch(self):
201+
"""Ensure TerminateOnNaN(raise_error=True) stops training
202+
if NaN appears in later batch.
203+
"""
204+
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
205+
model.compile(optimizer="sgd", loss="mse")
206+
207+
# Batch 1: normal loss, Batch 2: NaN loss
208+
x = np.array([[1.0], [2.0]])
209+
y = np.array([[1.0], [np.inf]]) # NaN/Inf appears only in 2nd batch
210+
211+
callback = TerminateOnNaN(raise_error=True)
212+
213+
with self.assertRaises(RuntimeError) as exc:
214+
model.fit(
215+
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
216+
)
217+
218+
self.assertTrue(any(f"batch {i}" in str(exc.exception) for i in [0, 1]))

0 commit comments

Comments
 (0)