Skip to content

Commit

Permalink
Add comments on code
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonepia committed Dec 16, 2024
1 parent 3cb9895 commit e366724
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchvision/ops/triton/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def _combine_bits(val0, val1):
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
"""
This nms_kernel computes the supressed mask of boxes [i, j].
mask[i, j]==1 means if we choose box 1, the box j will be supressed.
The output is a mask of size [num_boxes, num_boxes].
mask[i, j]==1 means if we choose box i, the box j will be supressed.
The output is a mask of size [num_boxes, num_boxes//32], where each item is int32.
Args:
boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4).
Expand All @@ -24,6 +24,9 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str
stride_i (int): The stride of the output tensor along the first dimension.
stride_j (int): The stride of the output tensor along the second dimension.
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
Returns:
Tensor (int32): Tensor with size [num_boxes, num_boxes//32]. It indicates that if `box i` is
choosen, whether box `j` could be choosen. The value `1` means it cannot be choosen.
"""

# The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE.
Expand Down Expand Up @@ -75,6 +78,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str
shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE])
iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets

# The process of combine bits. Note that the Triton seems having problem when the dtype is int64.
# Thus choosing 32 bits as the mask. And convert it to int64 at the end to avoid further potential overflow.
iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)
iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)
Expand Down

0 comments on commit e366724

Please sign in to comment.