Skip to content

Commit 49802ed

Browse files
committed
Fix prior bug in the accumulation of gradients in the conv2d backward pass
1 parent c89c36f commit 49802ed

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,13 @@ pure module subroutine backward(self, input, gradient)
180180
jws = jstart + self %stride(2) * (j-1) - half_window ! TODO kernel_height
181181
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
182182

183-
! dL/dw = sum(dL/dy * sigma'(z) * x)
184-
dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,iws:iwe,jws:jwe)
183+
! dL/dw = sum(gdz * x)
184+
dw(n,k,1:iwe-iws+1,1:jwe-jws+1) = dw(n,k,1:iwe-iws+1,1:jwe-jws+1) &
185+
+ input(k,iws:iwe,jws:jwe) * gdz(n,i,j)
185186

186-
! dL/dx = dL/dy * sigma'(z) .inner. w
187+
! dL/dx = sum(gdz * w)
187188
self % gradient(k,iws:iwe,jws:jwe) = self % gradient(k,iws:iwe,jws:jwe) &
188-
+ gdz(n,iws:iwe,jws:jwe) * self % kernel(n,k,1:iwe-iws+1,1:jwe-jws+1)
189+
+ gdz(n,i,j) * self % kernel(n,k,1:iwe-iws+1,1:jwe-jws+1)
189190

190191
end do
191192

0 commit comments

Comments
 (0)