@@ -89,19 +89,16 @@ pure module subroutine forward(self, input)
8989 ! of the input that correspond to the center of each window.
9090 istart = half_window + 1 ! TODO kernel_width
9191 jstart = half_window + 1 ! TODO kernel_height
92- iend = input_width - istart + 1
93- jend = input_height - jstart + 1
9492
95- ! convolution: do concurrent(i = istart:iend, j = jstart:jend)
96- convolution: do concurrent(i = 1 :self % width, j = 1 :self% height)
93+ convolution: do concurrent(i = 1 :self % width, j = 1 :self % height)
9794
9895 ! Start and end indices of the input data on the filter window
9996 ! iws and jws are also coincidentally the indices of the output matrix
100- iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
101- iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
97+ iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
98+ iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
10299
103- jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
104- jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
100+ jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
101+ jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
105102
106103 ! Compute the inner tensor product, sum(w_ij * x_ij), for each filter.
107104 do concurrent(n = 1 :self % filters)
@@ -166,27 +163,20 @@ pure module subroutine backward(self, input, gradient)
166163 k = 1 :self % channels, &
167164 i = 1 :self % width, &
168165 j = 1 :self % height &
169- ! i = istart:iend, &
170- ! j = jstart:jend &
171166 )
172167 ! Start and end indices of the input data on the filter window
173- ! iws = i - half_window ! TODO kernel_width
174- ! iwe = i + half_window ! TODO kernel_width
175- ! jws = j - half_window ! TODO kernel_height
176- ! jwe = j + half_window ! TODO kernel_height
177- iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
178- iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
168+ iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
169+ iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
179170
180- jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
181- jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
171+ jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
172+ jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
182173
183174 ! dL/dw = sum(gdz * x)
184- dw(n,k,1 :iwe- iws+1 ,1 :jwe- jws+1 ) = dw(n,k,1 :iwe- iws+1 ,1 :jwe- jws+1 ) &
185- + input(k,iws:iwe,jws:jwe) * gdz(n,i,j)
175+ dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,i,j)
186176
187177 ! dL/dx = sum(gdz * w)
188178 self % gradient(k,iws:iwe,jws:jwe) = self % gradient(k,iws:iwe,jws:jwe) &
189- + gdz(n,i,j) * self % kernel(n,k,1 :iwe - iws +1 , 1 :jwe - jws +1 )
179+ + gdz(n,i,j) * self % kernel(n,k,:,: )
190180
191181 end do
192182
0 commit comments