-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdnn_im2col.py
More file actions
105 lines (76 loc) · 3.81 KB
/
dnn_im2col.py
File metadata and controls
105 lines (76 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
This script is adapted and modified based on the assignment of Stanford CS231n.
Do not change this script.
If our script cannot run your code or the format is improper, your code will not be graded.
"""
import numpy as np
# from builtins import range
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
N, C, H, W = x_shape
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
return k, i, j
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
stride)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1, stride=1):
N, C, H, W = x_shape
H_padded, W_padded = H + 2 * padding, W + 2 * padding
x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding,
stride)
cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
cols_reshaped = cols_reshaped.transpose(2, 0, 1)
np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
if padding == 0:
return x_padded
return x_padded[:, :, padding:-padding, padding:-padding]
def get_maxpool_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
N, C, H, W = x_shape
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = np.repeat(np.arange(field_height), field_width)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
return i, j
def maxpool_im2col_indices(x, field_height, field_width, padding=0, stride=1):
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
i, j = get_maxpool_im2col_indices(x.shape, field_height, field_width, padding, stride)
cols = x_padded[:, :, i, j]
max_cols = np.amax(cols, axis = 2)
argmax_cols = np.argmax(cols, axis = 2)
return max_cols, argmax_cols
def maxpool_col2im_indices(grad, argmax_cols, x_shape, field_height=3, field_width=3, padding=0, stride=1):
N, C, H, W = x_shape
H_padded, W_padded = H + 2 * padding, W + 2 * padding
x_padded = np.zeros((N, C, H_padded, W_padded), dtype=grad.dtype)
i, j = get_maxpool_im2col_indices(x_shape, field_height, field_width, padding, stride)
map_size = i.shape[1]
i = np.tile(i, (1, N * C))
j = np.tile(j, (1, N * C))
max_i = i[argmax_cols.reshape(-1), np.arange(i.shape[1])]
max_j = j[argmax_cols.reshape(-1), np.arange(j.shape[1])]
max_n = np.repeat(np.arange(N), map_size * C)
max_c = np.tile(np.repeat(np.arange(C), map_size), N)
np.add.at(x_padded, (max_n, max_c, max_i, max_j), grad.reshape(-1))
if padding == 0:
return x_padded
return x_padded[:, :, padding:-padding, padding:-padding]