Skip to content

Commit 8b5f3e1

Browse files
committed
Fix batch normalization training mode correctness
1 parent 9545b3c commit 8b5f3e1

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _rsqrt(self, node: fx.Node) -> relax.Var:
116116

117117
########## Neural Network ##########
118118

119-
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
119+
def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = False) -> relax.Var:
120120
import numpy as np
121121

122122
x = self.env[node.args[0]]
@@ -149,7 +149,7 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
149149
if track_running_stats:
150150
training = True
151151

152-
return self.block_builder.emit(
152+
bn_result = self.block_builder.emit(
153153
relax.op.nn.batch_norm(
154154
data=x,
155155
gamma=weight,
@@ -160,21 +160,31 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
160160
epsilon=eps,
161161
momentum=momentum,
162162
training=training,
163-
)[0]
163+
)
164164
)
165165

166+
if return_tuple:
167+
return bn_result
168+
else:
169+
# Return only the output tensor (for backward compatibility)
170+
return self.block_builder.emit(bn_result[0])
171+
166172
def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
167173
# This method is called for batch_norm in training mode
168-
# TODO does not have correctness!
169-
# TODO we need to store the running mean and variance returned by the
170-
# previous call to batch_norm and pass it again
171-
training = True
172-
return self._batch_norm(node, training)
174+
bn_tuple = self._batch_norm(node, training=True, return_tuple=True)
175+
176+
x = self.env[node.args[0]]
177+
channel = int(self.shape_of(x)[1])
178+
dtype = x.struct_info.dtype
179+
180+
output = self.block_builder.emit(bn_tuple[0])
181+
new_running_mean = self.block_builder.emit(bn_tuple[1])
182+
reserve = self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype))
183+
184+
return self.block_builder.emit(relax.Tuple([output, new_running_mean, reserve, reserve, reserve]))
173185

174186
def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
175-
# This method is called for batch_norm in eval mode
176-
training = False
177-
return self._batch_norm(node, training)
187+
return self._batch_norm(node, training=False, return_tuple=False)
178188

179189
def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
180190
import numpy as np

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,77 @@ def main(
17881788
}
17891789
verify_model(model_2, example_args, binding_2, expected2)
17901790

1791+
class BatchNorm2dTraining(Module):
1792+
def __init__(self):
1793+
super().__init__()
1794+
self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True)
1795+
1796+
def forward(self, input):
1797+
return self.bn(input)
1798+
1799+
@tvm.script.ir_module
1800+
class expected3:
1801+
@R.function
1802+
def main(
1803+
input_1: R.Tensor((2, 3, 4, 4), dtype="float32"),
1804+
w1: R.Tensor((3,), dtype="float32"),
1805+
w2: R.Tensor((3,), dtype="float32"),
1806+
w3: R.Tensor((3,), dtype="float32"),
1807+
w4: R.Tensor((3,), dtype="float32"),
1808+
) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((), dtype="int64"), R.Tensor((2, 3, 4, 4), dtype="float32")):
1809+
with R.dataflow():
1810+
lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
1811+
lv1: R.Tuple(
1812+
R.Tensor((2, 3, 4, 4), dtype="float32"),
1813+
R.Tensor((3,), dtype="float32"),
1814+
R.Tensor((3,), dtype="float32"),
1815+
) = R.nn.batch_norm(
1816+
input_1,
1817+
w1,
1818+
w2,
1819+
w3,
1820+
w4,
1821+
axis=1,
1822+
epsilon=0.1,
1823+
center=True,
1824+
scale=True,
1825+
momentum=1.0,
1826+
training=True,
1827+
)
1828+
lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
1829+
lv3: R.Tensor((3,), dtype="float32") = lv1[1]
1830+
lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), dtype="float32")
1831+
lv5: R.Tuple(
1832+
R.Tensor((2, 3, 4, 4), dtype="float32"),
1833+
R.Tensor((3,), dtype="float32"),
1834+
R.Tensor((3,), dtype="float32"),
1835+
R.Tensor((3,), dtype="float32"),
1836+
R.Tensor((3,), dtype="float32"),
1837+
) = (lv2, lv3, lv4, lv4, lv4)
1838+
lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0]
1839+
lv7: R.Tensor((3,), dtype="float32") = lv5[3]
1840+
lv8: R.Tensor((3,), dtype="float32") = lv5[4]
1841+
gv: R.Tuple(
1842+
R.Tensor((3,), dtype="float32"),
1843+
R.Tensor((3,), dtype="float32"),
1844+
R.Tensor((), dtype="int64"),
1845+
R.Tensor((2, 3, 4, 4), dtype="float32"),
1846+
) = (lv7, lv8, lv, lv6)
1847+
R.output(gv)
1848+
return gv
1849+
1850+
example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),)
1851+
1852+
model_3 = BatchNorm2dTraining()
1853+
model_3.train() # Set to training mode
1854+
binding_3 = {
1855+
"w1": model_3.bn.weight.detach().numpy(),
1856+
"w2": model_3.bn.bias.detach().numpy(),
1857+
"w3": model_3.bn.running_mean.detach().numpy(),
1858+
"w4": model_3.bn.running_var.detach().numpy(),
1859+
}
1860+
verify_model(model_3, example_args_train, binding_3, expected3)
1861+
17911862

17921863
def test_adaptive_avgpool1d():
17931864
class AdaptiveAvgPool1d0(torch.nn.Module):

0 commit comments

Comments
 (0)