From 40025a87cb6466e7606b4cd4695b7cf10e9e318f Mon Sep 17 00:00:00 2001 From: Mark Dokter Date: Mon, 22 Jul 2024 17:38:41 +0200 Subject: [PATCH] conv2d back dsl debug prints --- local/test-conv2d.daph | 2 ++ scripts/nn/layers/conv2d.daph | 10 +++++++++- scripts/nn/util.daph | 6 +++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/local/test-conv2d.daph b/local/test-conv2d.daph index 01e1a19b0..84418ab73 100644 --- a/local/test-conv2d.daph +++ b/local/test-conv2d.daph @@ -29,8 +29,10 @@ print(W); #Xconv1, Hout1, Wout1 = conv2d.forward(X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad); Xconv2, Hout2, Wout2 = conv2d_dsl.forward(X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad); +#print("\n\n---Start---Convolution debug outputs:\n"); #dXconv1, dW1, db1 = conv2d.backward(Xconv1, Hout1, Wout1, X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad); dXconv2, dW2, db2 = conv2d_dsl.backward(Xconv2, Hout2, Wout2, X, W, b, C, Himg, Wimg, Hf, Wf, stride, stride, pad, pad); +#print("\n\n---End-----Convolution debug outputs:\n"); print("\nX dims: ",0) ;print(nrow(X),0);print("x",0);print(ncol(X)); #print("Xconv1 dims: ",0) ;print(nrow(Xconv1),0);print("x",0);print(ncol(Xconv1)); diff --git a/scripts/nn/layers/conv2d.daph b/scripts/nn/layers/conv2d.daph index 7eaa72d93..cef9169a1 100644 --- a/scripts/nn/layers/conv2d.daph +++ b/scripts/nn/layers/conv2d.daph @@ -156,9 +156,17 @@ def backward(dout:matrix, Hout, Wout, X:matrix, W:matrix, b:matrix, C, Hin, Win, db = db + sum(doutn,0); # Compute dX - dXn_padded_cols = t(W) @ doutn; # shape (C*Hf*Wf, Hout*Wout) + tW = t(W); + #print(tW); + #print(doutn); + dXn_padded_cols = tW @ doutn; # shape (C*Hf*Wf, Hout*Wout) +#print(dXn_padded_cols); dXn_padded = util.col2im(dXn_padded_cols, C, Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew, 0 /*"add"*/); + + #print(dXn_padded); + dXn = util.unpad_image(dXn_padded, Hin, Win, padh, padw); + dX[n,] = reshape(dXn, 1, C * Hin * Win); # reshape } return dX, dW, db; diff --git a/scripts/nn/util.daph b/scripts/nn/util.daph index ff2bd28bf..935167874 100644 --- a/scripts/nn/util.daph +++ b/scripts/nn/util.daph @@ -143,9 +143,13 @@ def col2im(img_cols:matrix, C, Hin, Win, Hf, Wf, strideh, stridew, reduction) -> for (wout in 0:Wout - 1) { # all output columns win_ = wout * stridew; # Extract a local patch of the input image corresponding spatially to the filter sizes. - img_patch = reshape(img_cols[,hout * Wout + wout], C, Hf * Wf); # zeros + p = hout * Wout + wout; + #print("patch: " + p); + img_patch = reshape(img_cols[,p], C, Hf * Wf); # zeros + #print(img_patch); for (c in 0:C - 1) { # all channels img_patch_slice = reshape(img_patch[c,], Hf, Wf); # reshape + #print(img_patch_slice); if (reduction == 0 /*"add"*/) { img_slice = fill(0.0, Hin, Win); img_slice[hin_:(hin_ + Hf), win_:(win_ + Wf)] = img_patch_slice;