Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Aug 26, 2024
2 parents f23661e + 5b72c27 commit 2da937c
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 50.18, 0.02], ["segm", "AP", 43.87, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 45.70, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("keypoints_coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 52.47, 0.02], ["keypoints", "AP", 67.36, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 47.37, 0.02], ["segm", "AP", 40.99, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 47.44, 0.02], ["segm", "AP", 42.94, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ TEST:
AUG:
ENABLED: True
MIN_SIZES: (700, 800) # to save some time
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100_panoptic_separated",)
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 46.47, 0.02], ["segm", "AP", 43.39, 0.02], ["sem_seg", "mIoU", 42.55, 0.02], ["panoptic_seg", "PQ", 38.99, 0.02]]
FLOAT32_PRECISION: "highest"
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DATASETS:
TEST: ("coco_2017_val_100",)
TEST:
EXPECTED_RESULTS: [["box_proposals", "AR@1000", 58.16, 0.02]]
FLOAT32_PRECISION: "highest"
4 changes: 4 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,10 @@
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False
# Option to set PyTorch matmul and CuDNN's float32 precision. When set to non-empty string,
# the corresponding precision ("highest", "high" or "medium") will be used. The highest
# precision will effectively disable tf32.
_C.FLOAT32_PRECISION = ""
# The period (in terms of steps) for minibatch visualization at train time.
# Set to 0 to disable.
_C.VIS_PERIOD = 0
Expand Down
32 changes: 32 additions & 0 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,30 @@ def _highlight(code, filename):
return code


# adapted from:
# https://github.com/pytorch/tnt/blob/ebda066f8f55af6a906807d35bc829686618074d/torchtnt/utils/device.py#L328-L346
def _set_float32_precision(precision: str = "high") -> None:
"""Sets the precision of float32 matrix multiplications and convolution operations.
For more information, see the PyTorch docs:
- https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
- https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.allow_tf32
Args:
precision: The setting to determine which datatypes to use for matrix
multiplication and convolution operations.
"""
if not (torch.cuda.is_available()): # Not relevant for non-CUDA devices
return
# set precision for matrix multiplications
torch.set_float32_matmul_precision(precision)
# set precision for convolution operations
if precision == "highest":
torch.backends.cudnn.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True


def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
Expand Down Expand Up @@ -226,6 +250,14 @@ def default_setup(cfg, args):
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
)

fp32_precision = _try_get_key(cfg, "FLOAT32_PRECISION", "train.float32_precision", default="")
if fp32_precision != "":
logger.info(f"Set fp32 precision to {fp32_precision}")
_set_float32_precision(fp32_precision)
logger.info(f"{torch.get_float32_matmul_precision()=}")
logger.info(f"{torch.backends.cuda.matmul.allow_tf32=}")
logger.info(f"{torch.backends.cudnn.allow_tf32=}")


def default_writers(output_dir: str, max_iter: Optional[int] = None):
"""
Expand Down
5 changes: 5 additions & 0 deletions detectron2/export/c10.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def set(self, name, value):
else:
data_len = len(value)
if len(self.batch_extra_fields):
# If we are tracing with Dynamo, the check here is needed since len(self)
# represents the number of bounding boxes detected in the image and thus is
# an unbounded SymInt.
if torch._utils.is_compiling():
torch._check(len(self) == data_len)
assert (
len(self) == data_len
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
Expand Down
1 change: 1 addition & 0 deletions projects/DensePose/densepose/modeling/losses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def resample_data(
grid_h = torch.arange(hout, device=z.device, dtype=torch.float) / hout
grid_w_expanded = grid_w[None, None, :].expand(n, hout, wout)
grid_h_expanded = grid_h[None, :, None].expand(n, hout, wout)
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
dx_expanded = (x1dst_norm - x0dst_norm)[:, None, None].expand(n, hout, wout)
dy_expanded = (y1dst_norm - y0dst_norm)[:, None, None].expand(n, hout, wout)
x0_expanded = x0dst_norm[:, None, None].expand(n, hout, wout)
Expand Down

0 comments on commit 2da937c

Please sign in to comment.