Skip to content

Commit e2d4410

Browse files
committed
Tidy up
1 parent 49802ed commit e2d4410

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

example/cnn_mnist.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ program cnn_mnist
1212
real, allocatable :: validation_images(:,:), validation_labels(:)
1313
real, allocatable :: testing_images(:,:), testing_labels(:)
1414
integer :: n
15-
integer, parameter :: num_epochs = 250
15+
integer, parameter :: num_epochs = 20
1616

1717
call load_mnist(training_images, training_labels, &
1818
validation_images, validation_labels, &

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,16 @@ pure module subroutine forward(self, input)
8989
! of the input that correspond to the center of each window.
9090
istart = half_window + 1 ! TODO kernel_width
9191
jstart = half_window + 1 ! TODO kernel_height
92-
iend = input_width - istart + 1
93-
jend = input_height - jstart + 1
9492

95-
! convolution: do concurrent(i = istart:iend, j = jstart:jend)
96-
convolution: do concurrent(i = 1:self % width, j = 1:self%height)
93+
convolution: do concurrent(i = 1:self % width, j = 1:self % height)
9794

9895
! Start and end indices of the input data on the filter window
9996
! iws and jws are also coincidentally the indices of the output matrix
100-
iws = istart + self %stride(1) * (i-1) - half_window ! TODO kernel_width
101-
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
97+
iws = istart + self % stride(1) * (i-1) - half_window ! TODO kernel_width
98+
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
10299

103-
jws = jstart + self %stride(2) * (j-1) - half_window ! TODO kernel_height
104-
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
100+
jws = jstart + self % stride(2) * (j-1) - half_window ! TODO kernel_height
101+
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
105102

106103
! Compute the inner tensor product, sum(w_ij * x_ij), for each filter.
107104
do concurrent(n = 1:self % filters)
@@ -166,27 +163,20 @@ pure module subroutine backward(self, input, gradient)
166163
k = 1:self % channels, &
167164
i = 1:self % width, &
168165
j = 1:self % height &
169-
!i = istart:iend, &
170-
!j = jstart:jend &
171166
)
172167
! Start and end indices of the input data on the filter window
173-
!iws = i - half_window ! TODO kernel_width
174-
!iwe = i + half_window ! TODO kernel_width
175-
!jws = j - half_window ! TODO kernel_height
176-
!jwe = j + half_window ! TODO kernel_height
177-
iws = istart + self %stride(1) * (i-1) - half_window ! TODO kernel_width
178-
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
168+
iws = istart + self % stride(1) * (i-1) - half_window ! TODO kernel_width
169+
iwe = min(iws + 2*half_window, input_width) ! TODO kernel_width
179170

180-
jws = jstart + self %stride(2) * (j-1) - half_window ! TODO kernel_height
181-
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
171+
jws = jstart + self % stride(2) * (j-1) - half_window ! TODO kernel_height
172+
jwe = min(jws + 2*half_window, input_height) ! TODO kernel_height
182173

183174
! dL/dw = sum(gdz * x)
184-
dw(n,k,1:iwe-iws+1,1:jwe-jws+1) = dw(n,k,1:iwe-iws+1,1:jwe-jws+1) &
185-
+ input(k,iws:iwe,jws:jwe) * gdz(n,i,j)
175+
dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,i,j)
186176

187177
! dL/dx = sum(gdz * w)
188178
self % gradient(k,iws:iwe,jws:jwe) = self % gradient(k,iws:iwe,jws:jwe) &
189-
+ gdz(n,i,j) * self % kernel(n,k,1:iwe-iws+1,1:jwe-jws+1)
179+
+ gdz(n,i,j) * self % kernel(n,k,:,:)
190180

191181
end do
192182

0 commit comments

Comments
 (0)