1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| def conv_backward(dZ, cache): (A_prev, W, b, hparameters) = cache (m, n_H_prev, n_W_prev, n_C_prev) = A_prev.shape (f, f, n_C_prev, n_C) = W.shape stride = hparameters["stride"] pad = hparameters["pad"] (m, n_H, n_W, n_C) = dZ.shape dA_prev = np.zeros(A_prev.shape) dW = np.zeros(W.shape) db = np.zeros(b.shape) A_prev_pad = zero_pad(A_prev, pad) dA_prev_pad = zero_pad(dA_prev, pad) for i in range(m): a_prev_pad = A_prev_pad[i] da_prev_pad = dA_prev_pad[i] for h in range(n_H): for w in range(n_W): for c in range(n_C): vert_start = stride*h vert_end = vert_start+f horiz_start = stride*w horiz_end = horiz_start+f a_slice = a_prev_pad[vert_start:vert_end, horiz_start:horiz_end, :] da_prev_pad[vert_start:vert_end, horiz_start:horiz_end, :] += W[:, :, :, c]*dZ[i, h, w, c] dW[:, :, :, c] += a_slice*dZ[i, h, w, c] db[:, :, :, c] += dZ[i, h, w, c] dA_prev[i, :, :, :] = da_prev_pad[pad:-pad, pad:-pad, :] assert(dA_prev.shape == (m, n_H_prev, n_W_prev, n_C_prev)) return dA_prev, dW, db
|