Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _rsqrt(self, node: fx.Node) -> relax.Var:

########## Neural Network ##########

def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = False) -> relax.Var:
import numpy as np

x = self.env[node.args[0]]
Expand Down Expand Up @@ -149,7 +149,7 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
if track_running_stats:
training = True

return self.block_builder.emit(
bn_result = self.block_builder.emit(
relax.op.nn.batch_norm(
data=x,
gamma=weight,
Expand All @@ -160,21 +160,33 @@ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
epsilon=eps,
momentum=momentum,
training=training,
)[0]
)
)

if return_tuple:
return bn_result
else:
# Return only the output tensor (for backward compatibility)
return self.block_builder.emit(bn_result[0])

def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
# This method is called for batch_norm in training mode
# TODO does not have correctness!
# TODO we need to store the running mean and variance returned by the
# previous call to batch_norm and pass it again
training = True
return self._batch_norm(node, training)
bn_tuple = self._batch_norm(node, training=True, return_tuple=True)

x = self.env[node.args[0]]
channel = int(self.shape_of(x)[1])
dtype = x.struct_info.dtype

output = self.block_builder.emit(bn_tuple[0])
new_running_mean = self.block_builder.emit(bn_tuple[1])
reserve = self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype))

return self.block_builder.emit(
relax.Tuple([output, new_running_mean, reserve, reserve, reserve])
)

def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
# This method is called for batch_norm in eval mode
training = False
return self._batch_norm(node, training)
return self._batch_norm(node, training=False, return_tuple=False)

def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
import numpy as np
Expand Down
76 changes: 76 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,82 @@ def main(
}
verify_model(model_2, example_args, binding_2, expected2)

class BatchNorm2dTraining(Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True)

def forward(self, input):
return self.bn(input)

@tvm.script.ir_module
class expected3:
@R.function
def main(
input_1: R.Tensor((2, 3, 4, 4), dtype="float32"),
w1: R.Tensor((3,), dtype="float32"),
w2: R.Tensor((3,), dtype="float32"),
w3: R.Tensor((3,), dtype="float32"),
w4: R.Tensor((3,), dtype="float32"),
) -> R.Tuple(
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((), dtype="int64"),
R.Tensor((2, 3, 4, 4), dtype="float32"),
):
with R.dataflow():
lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
lv1: R.Tuple(
R.Tensor((2, 3, 4, 4), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
) = R.nn.batch_norm(
input_1,
w1,
w2,
w3,
w4,
axis=1,
epsilon=0.1,
center=True,
scale=True,
momentum=1.0,
training=True,
)
lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
lv3: R.Tensor((3,), dtype="float32") = lv1[1]
lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), dtype="float32")
lv5: R.Tuple(
R.Tensor((2, 3, 4, 4), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
) = (lv2, lv3, lv4, lv4, lv4)
lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0]
lv7: R.Tensor((3,), dtype="float32") = lv5[3]
lv8: R.Tensor((3,), dtype="float32") = lv5[4]
gv: R.Tuple(
R.Tensor((3,), dtype="float32"),
R.Tensor((3,), dtype="float32"),
R.Tensor((), dtype="int64"),
R.Tensor((2, 3, 4, 4), dtype="float32"),
) = (lv7, lv8, lv, lv6)
R.output(gv)
return gv

example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),)

model_3 = BatchNorm2dTraining()
model_3.train() # Set to training mode
binding_3 = {
"w1": model_3.bn.weight.detach().numpy(),
"w2": model_3.bn.bias.detach().numpy(),
"w3": model_3.bn.running_mean.detach().numpy(),
"w4": model_3.bn.running_var.detach().numpy(),
}
verify_model(model_3, example_args_train, binding_3, expected3)


def test_adaptive_avgpool1d():
class AdaptiveAvgPool1d0(torch.nn.Module):
Expand Down
Loading