diff --git a/src/runtime/local/context/CUDAContext.cpp b/src/runtime/local/context/CUDAContext.cpp index e0bbf513c..ab742e73c 100644 --- a/src/runtime/local/context/CUDAContext.cpp +++ b/src/runtime/local/context/CUDAContext.cpp @@ -61,6 +61,10 @@ void CUDAContext::init() { CHECK_CUDNN(cudnnCreateActivationDescriptor(&activation_desc)); CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&conv_desc)); CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_desc)); + + CHECK_CUDNN(cudnnCreateTensorDescriptor(&dy_tensor_desc)); + CHECK_CUDNN(cudnnCreateTensorDescriptor(&bn_scale_bias_tensor_desc)); + CHECK_CUSOLVER(cusolverDnCreate(&cusolver_handle)); CHECK_CUDART(cudaStreamCreateWithFlags(&cusolver_stream, cudaStreamNonBlocking)); diff --git a/src/runtime/local/context/CUDAContext.h b/src/runtime/local/context/CUDAContext.h index 08c3e8e7e..1266e5035 100644 --- a/src/runtime/local/context/CUDAContext.h +++ b/src/runtime/local/context/CUDAContext.h @@ -104,6 +104,8 @@ class CUDAContext final : public IContext { cudnnConvolutionDescriptor_t conv_desc{}; cudnnBatchNormMode_t bn_mode = CUDNN_BATCHNORM_SPATIAL; + cudnnTensorDescriptor_t dy_tensor_desc{}, bn_scale_bias_tensor_desc{}; + // A block size of 256 works well in many cases. // Putting it here to avoid hard coding things elsewhere. const uint32_t default_block_size = 256; diff --git a/src/runtime/local/kernels/CUDA/BatchNorm.cpp b/src/runtime/local/kernels/CUDA/BatchNorm.cpp index f24bfc848..fdcca2daa 100644 --- a/src/runtime/local/kernels/CUDA/BatchNorm.cpp +++ b/src/runtime/local/kernels/CUDA/BatchNorm.cpp @@ -16,6 +16,7 @@ #include "BatchNorm.h" #include +#include namespace CUDA::BatchNorm { template @@ -52,7 +53,66 @@ namespace CUDA::BatchNorm { d_gamma, d_beta, d_ema_mean, d_ema_var, eps)); } + template + void Backward::apply(DTRes *&dX, DTRes *&dGamma, DTRes *&dBeta, + const DTArg *mean, const DTArg *invVar, + const DTArg *in, const DTArg *dout, + const DTArg *gamma, const typename DTArg::VT eps, DCTX(dctx)) + { + const size_t deviceID = 0; //ToDo: multi device support + auto ctx = CUDAContext::get(dctx, deviceID); + AllocationDescriptorCUDA alloc_desc(dctx, deviceID); + using VT = typename DTRes::VT; + const size_t N = in->getNumRows(); + const size_t CHW = in->getNumCols(); + const size_t C = gamma->getNumRows(); + const size_t HW = CHW / C; + auto H = static_cast(std::sqrt(HW)); + + VT alphaDataDiff = 1.0; + VT betaDataDiff = 0.0; + VT alphaParamDiff = 1.0; + VT betaParamDiff = 0.0; + + const VT* d_mean = mean->getValues(&alloc_desc); + const VT* d_invVar = invVar->getValues(&alloc_desc); + const VT* d_in = in->getValues(&alloc_desc); + const VT* d_gamma = gamma->getValues(&alloc_desc); + const VT* d_dout = dout->getValues(&alloc_desc); + + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->src_tensor_desc, ctx->tensor_format, ctx->getCUDNNDataType(), N, C, H, H)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->dy_tensor_desc, ctx->tensor_format, ctx->getCUDNNDataType(), N, C, H, H)); + + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->dst_tensor_desc, ctx->tensor_format, ctx->getCUDNNDataType(), N, C, H, H)); + CHECK_CUDNN(cudnnDeriveBNTensorDescriptor(ctx->bn_scale_bias_tensor_desc, ctx->src_tensor_desc, ctx->bn_mode)); + + if (dX == nullptr) + dX = DataObjectFactory::create>(N, CHW, false, &alloc_desc); + if (dGamma == nullptr) + dGamma = DataObjectFactory::create>(C, 1, false, &alloc_desc); + if (dBeta == nullptr) + dBeta = DataObjectFactory::create>(C, 1, false, &alloc_desc); + + VT* d_dX = dX->getValues(&alloc_desc); + VT* d_dGamma = dGamma->getValues(&alloc_desc); + VT* d_dBeta = dBeta->getValues(&alloc_desc); + + CHECK_CUDNN(cudnnBatchNormalizationBackward(ctx->getCUDNNHandle(), + ctx->bn_mode, + &alphaDataDiff, &betaDataDiff, &alphaParamDiff, &betaParamDiff, + ctx->src_tensor_desc, d_in, + ctx->dy_tensor_desc, d_dout, + ctx->dst_tensor_desc, d_dX, + ctx->bn_scale_bias_tensor_desc, d_gamma, d_dGamma, d_dBeta, + eps, + d_mean, d_invVar)); + + } + template struct Forward, DenseMatrix>; template struct Forward, DenseMatrix>; + + template struct Backward, DenseMatrix>; + template struct Backward, DenseMatrix>; } diff --git a/src/runtime/local/kernels/CUDA/BatchNorm.h b/src/runtime/local/kernels/CUDA/BatchNorm.h index 23d970f58..a3c9accf7 100644 --- a/src/runtime/local/kernels/CUDA/BatchNorm.h +++ b/src/runtime/local/kernels/CUDA/BatchNorm.h @@ -23,9 +23,18 @@ #include "HostUtils.h" namespace CUDA::BatchNorm { + template struct Forward { static void apply(DTRes *&res, const DTArg *data, const DTArg *gamma, const DTArg *beta, const DTArg *ema_mean, const DTArg *ema_var, typename DTArg::VT eps, DCTX(dctx)); }; + + template + struct Backward { + static void apply(DTRes *&dX, DTRes *&dGamma, DTRes *&dBeta, + const DTArg *mean, const DTArg *invVar, + const DTArg *in, const DTArg *dout, + const DTArg *gamma, const typename DTArg::VT eps, DCTX(dctx)); + }; } diff --git a/src/runtime/local/kernels/CUDA/Pooling.cpp b/src/runtime/local/kernels/CUDA/Pooling.cpp index daf8ad585..a28b4f070 100644 --- a/src/runtime/local/kernels/CUDA/Pooling.cpp +++ b/src/runtime/local/kernels/CUDA/Pooling.cpp @@ -61,10 +61,80 @@ namespace CUDA::NN::Pooling { d_input, &blend_beta, ctx->dst_tensor_desc, d_res)); } + template class OP, typename DTRes, typename DTArg> + void Backward::apply(DTRes *&res, + const DTArg *input, const DTArg *output,const DTArg *dOut, + const size_t batch_size, const size_t num_channels, + const size_t img_h, const size_t img_w, + const size_t pool_h, const size_t pool_w, + const size_t stride_h, const size_t stride_w, + const size_t pad_h, const size_t pad_w, + DCTX(dctx)) + { + const size_t deviceID = 0; //ToDo: multi device support + auto ctx = CUDAContext::get(dctx, deviceID); + AllocationDescriptorCUDA alloc_desc(dctx, deviceID); + + using VT = typename DTRes::VT; + const VT blend_alpha = 1; + const VT blend_beta = 0; + + CHECK_CUDNN(cudnnSetPooling2dDescriptor(ctx->pooling_desc, + OP::isMAX() ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING, + CUDNN_PROPAGATE_NAN, + pool_h, pool_w, + pad_h, pad_w, + stride_h, stride_w)); + + const VT* d_input = input->getValues(&alloc_desc); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->src_tensor_desc, + ctx->tensor_format, + ctx->getCUDNNDataType(), + batch_size, num_channels, img_h, img_w)); + + const int tensorDims = 4; + int tensorOuputDimA[tensorDims]; + CHECK_CUDNN(cudnnGetPoolingNdForwardOutputDim(ctx->pooling_desc, ctx->src_tensor_desc, tensorDims, + tensorOuputDimA)); + + int n = tensorOuputDimA[0]; int c = tensorOuputDimA[1]; + int h = tensorOuputDimA[2]; int w = tensorOuputDimA[3]; + + const VT* d_output = output->getValues(&alloc_desc); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->dst_tensor_desc, + ctx->tensor_format, + ctx->getCUDNNDataType(), + n, c, h, w)); + + const VT* d_dOut = dOut->getValues(&alloc_desc); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(ctx->dy_tensor_desc, + ctx->tensor_format, + ctx->getCUDNNDataType(), + n, c, h, w)); + + if (res == nullptr) { + res = DataObjectFactory::create(batch_size, num_channels * img_h * img_w, false, &alloc_desc); + } + VT* d_res = res->getValues(&alloc_desc); + + CHECK_CUDNN(cudnnPoolingBackward(ctx->getCUDNNHandle(), + ctx->pooling_desc, &blend_alpha, + ctx->dst_tensor_desc, d_output, + ctx->dy_tensor_desc, d_dOut, + ctx->src_tensor_desc, d_input, + &blend_beta, ctx->src_tensor_desc, d_res)); + } + template struct Forward<::NN::Pooling::AVG, DenseMatrix, DenseMatrix>; template struct Forward<::NN::Pooling::AVG, DenseMatrix, DenseMatrix>; template struct Forward<::NN::Pooling::MAX, DenseMatrix, DenseMatrix>; template struct Forward<::NN::Pooling::MAX, DenseMatrix, DenseMatrix>; + + template struct Backward<::NN::Pooling::AVG, DenseMatrix, DenseMatrix>; + template struct Backward<::NN::Pooling::AVG, DenseMatrix, DenseMatrix>; + + template struct Backward<::NN::Pooling::MAX, DenseMatrix, DenseMatrix>; + template struct Backward<::NN::Pooling::MAX, DenseMatrix, DenseMatrix>; } diff --git a/src/runtime/local/kernels/CUDA/Pooling.h b/src/runtime/local/kernels/CUDA/Pooling.h index f629ebfaa..cd9bb56a0 100644 --- a/src/runtime/local/kernels/CUDA/Pooling.h +++ b/src/runtime/local/kernels/CUDA/Pooling.h @@ -35,4 +35,16 @@ namespace CUDA::NN::Pooling { size_t pool_h, size_t pool_w, size_t stride_h, size_t stride_w, size_t pad_h, size_t pad_w, DCTX(dctx)); }; + + template class OP, typename DTRes, typename DTArg> + struct Backward { + static void apply(DTRes *&res, + const DTArg *input, const DTArg *output,const DTArg *dOut, + const size_t batch_size, const size_t num_channels, + const size_t img_h, const size_t img_w, + const size_t pool_h, const size_t pool_w, + const size_t stride_h, const size_t stride_w, + const size_t pad_h, const size_t pad_w, + DCTX(dctx)); + }; } diff --git a/test/runtime/local/kernels/DNNBatchNorm2DBackwardTest.cpp b/test/runtime/local/kernels/DNNBatchNorm2DBackwardTest.cpp new file mode 100644 index 000000000..dd62fdcef --- /dev/null +++ b/test/runtime/local/kernels/DNNBatchNorm2DBackwardTest.cpp @@ -0,0 +1,172 @@ +/* + * Copyright 2021 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef USE_CUDA +#include "run_tests.h" + +#include +#include +#include +#include "runtime/local/kernels/CUDA/BatchNorm.h" +#include +#include + +#include +#include +#include + +template +void checkBatchNorm2DBackwardCUDA(const DT* in, const DT* dOut, const DT* gamma, const DT* mean, const DT* invVar, const DT* exp1, +const DT* exp2, const DT* exp3, DaphneContext* dctx) { + DT* dX = nullptr; + DT* dGamma = nullptr; + DT* dBeta = nullptr; + + typename DT::VT epsilon = 1e-5; + CUDA::BatchNorm::Backward::apply(dX, dGamma, dBeta, mean, invVar, in, dOut, gamma, epsilon, dctx); + CHECK(checkEqApprox(dX, exp1, 1e-5, nullptr)); + CHECK(checkEqApprox(dGamma, exp2, 1e-4, nullptr)); + CHECK(checkEqApprox(dBeta, exp3, 1e-5, nullptr)); +} + +TEMPLATE_PRODUCT_TEST_CASE("batch_norm_bwd_cuda", TAG_DNN, (DenseMatrix), (float, double)) { // NOLINT(cert-err58-cpp) + auto dctx = setupContextAndLogger(); + using DT = TestType; + + auto in = genGivenVals
(2, { 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + + auto dOut = genGivenVals
(2, { 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + + auto gamma = genGivenVals
(3, { 1, 1, 1 }); + auto mean = genGivenVals
(3, { 2.5, 6.5, 10.5 }); + auto invVar = genGivenVals
(3, { 1 / std::sqrt(1.25 + 1e-5), 1 / std::sqrt(1.25 + 1e-5), 1 / std::sqrt(1.25 + 1e-5) }); + auto res1 = genGivenVals
(2, {-1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05}); + auto res2 = genGivenVals
(3, {8.9442, 8.9442, 8.9442 }); + auto res3 = genGivenVals
(3, {20, 52, 84 }); + + checkBatchNorm2DBackwardCUDA(in, dOut, gamma, mean, invVar, res1, res2, res3, dctx.get()); + //std::cout<<"gpu"< +#include +#include +#include "runtime/local/kernels/BatchNorm2DBackward.h" +#include "runtime/local/kernels/CUDA/BatchNorm.h" +#include +#include + +#include +#include +#include + +#include + +template +void checkBatchNorm2DBackward(const DT* in, const DT* dOut, const DT* gamma, const DT* mean, const DT* invVar, const DT* exp1, +const DT* exp2, const DT* exp3, DaphneContext* dctx) +{ + DT* dX = nullptr; + DT* dGamma = nullptr; + DT* dBeta = nullptr; + + typename DT::VT epsilon = 1e-5; + BatchNorm2DBackward::apply(dX, dGamma, dBeta, mean, invVar, in, dOut, gamma, epsilon, dctx); + + // CHECK(Approx(*(dX->getValues())).epsilon(epsilon) == *(exp1->getValues())); + // // CHECK(*dX == *exp1); + // CHECK(Approx(*(dGamma->getValues())).epsilon(epsilon) == *(exp2->getValues())); + // CHECK(Approx(*(dBeta->getValues())).epsilon(epsilon) == *(exp3->getValues())); + + CHECK(checkEqApprox(dX, exp1, 1e-5, nullptr)); + CHECK(checkEqApprox(dGamma, exp2, 1e-4, nullptr)); + CHECK(checkEqApprox(dBeta, exp3, 1e-5, nullptr)); +} + +TEMPLATE_PRODUCT_TEST_CASE("batch_norm_bwd", TAG_DNN, (DenseMatrix), (float, double)) { // NOLINT(cert-err58-cpp) + auto dctx = setupContextAndLogger(); + using DT = TestType; + + auto in = genGivenVals
(2, { 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + + auto dOut = genGivenVals
(2, { 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + + auto gamma = genGivenVals
(3, { 1, 1, 1 }); + auto mean = genGivenVals
(3, { 2.5, 6.5, 10.5 }); + auto invVar = genGivenVals
(3, { 1 / std::sqrt(1.25 + 1e-5), 1 / std::sqrt(1.25 + 1e-5), 1 / std::sqrt(1.25 + 1e-5) }); + auto res1 = genGivenVals
(2, {-1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05, + -1.0733e-05, -3.5777e-06, 3.5777e-06, 1.0733e-05}); + auto res2 = genGivenVals
(3, {8.9442, 8.9442, 8.9442 }); + auto res3 = genGivenVals
(3, {20, 52, 84 }); + + checkBatchNorm2DBackward(in, dOut, gamma, mean, invVar, res1, res2, res3, dctx.get()); + + //std::cout<<"cpu"< +#include +#include + +#ifdef USE_CUDA + #include + #include "runtime/local/kernels/CUDA/CreateCUDAContext.h" +#else + #include + + + #include + #include + + +#endif + +#include + +#include + +#include + +#include + +template +DT* genInput() { + return genGivenVals
(2, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, + + 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54 + }); +} + +template +DT* genDOut() { + return genGivenVals
(2, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 + }); +} + +template +void check_max(const DT* in, const DT* dOut, const DT* exp, DaphneContext* dctx) { + DT* res = nullptr; + DT* output = nullptr; + size_t res_h, res_w; +#ifdef USE_CUDA + CUDA::NN::Pooling::Forward<::NN::Pooling::MAX, DT, DT>::apply(output, res_h, res_w, in, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); + CUDA::NN::Pooling::Backward<::NN::Pooling::MAX, DT, DT>::apply(res, in, output, dOut, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); + +#else + MaxPoolBackward::apply(res, in, dOut, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); +#endif + CHECK(*res == *exp); +} + +template +void check_avg(const DT* in, const DT* dOut, const DT* exp, DaphneContext* dctx) { + DT* res = nullptr; + DT* output = nullptr; + size_t res_h, res_w; +#ifdef USE_CUDA + CUDA::NN::Pooling::Forward<::NN::Pooling::AVG, DT, DT>::apply(output, res_h, res_w, in, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); + CUDA::NN::Pooling::Backward<::NN::Pooling::AVG, DT, DT>::apply(res, in, output, dOut, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); + +#else + AvgPoolBackward::apply(res, in, dOut, 2, 3, 3, 3, 2, 2, 2, 2, 1, 1, dctx); +#endif + CHECK(*res == *exp); +} + +TEMPLATE_PRODUCT_TEST_CASE("pool_bwd_avg", TAG_DNN, (DenseMatrix), (float, double)) { // NOLINT(cert-err58-cpp) + using DT = TestType; + + auto dctx_avg = setupContextAndLogger(); + + auto inputs = genInput
(); + auto dOut = genDOut
(); + + auto dX = genGivenVals
(2, {0.25, 0.50, 0.50, 0.75, 1.00, 1.00, 0.75, 1.00, 1.00, + 1.25, 1.50, 1.50, 1.75, 2.00, 2.00, 1.75, 2.00, 2.00, + 2.25, 2.50, 2.50, 2.75, 3.00, 3.00, 2.75, 3.00, 3.00, + + 0.25, 0.50, 0.50, 0.75, 1.00, 1.00, 0.75, 1.00, 1.00, + 1.25, 1.50, 1.50, 1.75, 2.00, 2.00, 1.75, 2.00, 2.00, + 2.25, 2.50, 2.50, 2.75, 3.00, 3.00, 2.75, 3.00, 3.00 + }); + + check_avg(inputs, dOut, dX, dctx_avg.get()); + + DataObjectFactory::destroy(inputs); + DataObjectFactory::destroy(dOut); + DataObjectFactory::destroy(dX); +} + +TEMPLATE_PRODUCT_TEST_CASE("pool_bwd_max", TAG_DNN, (DenseMatrix), (float, double)) { // NOLINT(cert-err58-cpp) + using DT = TestType; + + auto dctx = setupContextAndLogger(); + +#ifdef USE_CUDA + CUDA::createCUDAContext(dctx.get()); +#endif + + using DT = TestType; + + auto dctx_max = setupContextAndLogger(); + + auto inputs = genInput
(); + auto dOut = genDOut
(); + + auto dX = genGivenVals
(2, {1, 0, 2, 0, 0, 0, 3, 0, 4, + 5, 0, 6, 0, 0, 0, 7, 0, 8, + 9, 0, 10, 0, 0, 0, 11, 0, 12, + + 1, 0, 2, 0, 0, 0, 3, 0, 4, + 5, 0, 6, 0, 0, 0, 7, 0, 8, + 9, 0, 10, 0, 0, 0, 11, 0, 12 + }); + + check_max(inputs, dOut, dX, dctx_max.get()); + + DataObjectFactory::destroy(inputs); + DataObjectFactory::destroy(dOut); + DataObjectFactory::destroy(dX); +}