Skip to content

Commit

Permalink
[RF] Add assertions to check pointers for RooBatchCompute on GPU.
Browse files Browse the repository at this point in the history
RooBatchCompute is using host pointers in a kernel, which is hard to
detect. This assertion helps to catch the error a lot earlier.
  • Loading branch information
hageboeck committed Oct 10, 2024
1 parent 7bef764 commit 3bd9e12
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions roofit/batchcompute/src/RooBatchCompute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ This file contains the code for cuda computations using the RooBatchCompute libr
#include "CudaInterface.h"

#include <algorithm>
#include <cassert>
#include <functional>
#include <map>
#include <queue>
Expand Down Expand Up @@ -299,6 +300,11 @@ ReduceNLLOutput RooBatchComputeClass::reduceNLL(RooBatchCompute::Config const &c
cudaStream_t stream = *cfg.cudaStream();
constexpr int shMemSize = 2 * blockSize * sizeof(double);

for (auto span : {probas, weights, offsetProbas}) {
cudaPointerAttributes attr;
assert(span.size() == 0 || span.data() == nullptr || (cudaPointerGetAttributes(&attr, span.data()) == cudaSuccess && attr.type == cudaMemoryTypeDevice));
}

nllSumKernel<<<gridSize, blockSize, shMemSize, stream>>>(
probas.data(), weights.size() == 1 ? nullptr : weights.data(),
offsetProbas.empty() ? nullptr : offsetProbas.data(), probas.size(), devOut.data());
Expand Down

0 comments on commit 3bd9e12

Please sign in to comment.