diff --git a/unet_demo.py b/unet_demo.py index 5ae1bb0..fef85dc 100644 --- a/unet_demo.py +++ b/unet_demo.py @@ -17,11 +17,11 @@ def __init__(self, in_channels, out_channels, kernel_size, padding, self.act = nn.ReLU() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding, - stride) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, - padding, stride) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) def forward(self, x): x = self.act(self.conv1(x))