Skip to content

Commit

Permalink
conv2d back dsl debug prints
Browse files Browse the repository at this point in the history
  • Loading branch information
corepointer committed Jul 22, 2024
1 parent 8678238 commit 40025a8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
2 changes: 2 additions & 0 deletions local/test-conv2d.daph
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ print(W);
#Xconv1, Hout1, Wout1 = conv2d.forward(X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad);
Xconv2, Hout2, Wout2 = conv2d_dsl.forward(X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad);

#print("\n\n---Start---Convolution debug outputs:\n");
#dXconv1, dW1, db1 = conv2d.backward(Xconv1, Hout1, Wout1, X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad);
dXconv2, dW2, db2 = conv2d_dsl.backward(Xconv2, Hout2, Wout2, X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad);
#print("\n\n---End-----Convolution debug outputs:\n");

print("\nX dims: ",0) ;print(nrow(X),0);print("x",0);print(ncol(X));
#print("Xconv1 dims: ",0) ;print(nrow(Xconv1),0);print("x",0);print(ncol(Xconv1));
Expand Down
10 changes: 9 additions & 1 deletion scripts/nn/layers/conv2d.daph
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,17 @@ def backward(dout:matrix, Hout, Wout, X:matrix, W:matrix, b:matrix, C, Hin, Win,
db = db + sum(doutn,0);

# Compute dX
dXn_padded_cols = t(W) @ doutn; # shape (C*Hf*Wf, Hout*Wout)
tW = t(W);
#print(tW);
#print(doutn);
dXn_padded_cols = tW @ doutn; # shape (C*Hf*Wf, Hout*Wout)
#print(dXn_padded_cols);
dXn_padded = util.col2im(dXn_padded_cols, C, Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew, 0 /*"add"*/);

#print(dXn_padded);

dXn = util.unpad_image(dXn_padded, Hin, Win, padh, padw);

dX[n,] = reshape(dXn, 1, C * Hin * Win); # reshape
}
return dX, dW, db;
Expand Down
6 changes: 5 additions & 1 deletion scripts/nn/util.daph
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ def col2im(img_cols:matrix, C, Hin, Win, Hf, Wf, strideh, stridew, reduction) ->
for (wout in 0:Wout - 1) { # all output columns
win_ = wout * stridew;
# Extract a local patch of the input image corresponding spatially to the filter sizes.
img_patch = reshape(img_cols[,hout * Wout + wout], C, Hf * Wf); # zeros
p = hout * Wout + wout;
#print("patch: " + p);
img_patch = reshape(img_cols[,p], C, Hf * Wf); # zeros
#print(img_patch);
for (c in 0:C - 1) { # all channels
img_patch_slice = reshape(img_patch[c,], Hf, Wf); # reshape
#print(img_patch_slice);
if (reduction == 0 /*"add"*/) {
img_slice = fill(0.0, Hin, Win);
img_slice[hin_:(hin_ + Hf), win_:(win_ + Wf)] = img_patch_slice;
Expand Down

0 comments on commit 40025a8

Please sign in to comment.