-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RT-DETR] Fix onnx inference bug for Optype (Where) #33877
[RT-DETR] Fix onnx inference bug for Optype (Where) #33877
Conversation
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
01e5d38
to
b146090
Compare
b146090
to
70e48fa
Compare
Hi @SangbumChoi, how are you ? I saw that you were the main maintainer of RT-DETR. Thank you very much for your work ! I propose a fix for the anchor generation to avoid a bug in onnx inference : What do you feel about it ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! For me it looks great,
- I think it would be helpful to add some ONNX conversion description script in
rt_detr.md
?
Requesting @qubvel for the final approvement!
@@ -1736,7 +1736,7 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dt | |||
anchors = torch.concat(anchors, 1) | |||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) | |||
anchors = torch.log(anchors / (1 - anchors)) | |||
anchors = torch.where(valid_mask, anchors, torch.inf) | |||
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=torch.float32, device=device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=torch.float32, device=device)) | |
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=dtype, device=device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some doc f157406 :) Tell me what do you think
Thanks for your reply ! Yes it would be interesting, what do you think about adding a section : ONNX TipsSee from transformers.models.rt_detr.import RTDetrConfig, RTDetrOnnxConfig # type: ignore[import-untyped]
from transformers.onnx.convert import export_pytorch # type: ignore[import-untyped]
rtdetr_onnx_config = RTDetrOnnxConfig(config=RTDetrConfig(), task="object-detection")
export_pytorch(
preprocessor=preprocessor,
model=model,
config=rtdetr_onnx_config,
opset=17,
output=output_path,
tokenizer=None,
device="cuda",
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @YHallouard, thanks a lot for your contribution! Overall looks good to me!
The only concern I have is that onnx
export in transformers is no longer maintained and at some moment we might end up removing it.
It will be better to add this config to Optimum, leaving here just a snippet of code on how to export the model using Optimum.
cc @ArthurZucker regarding onnx
@@ -1736,7 +1736,7 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dt | |||
anchors = torch.concat(anchors, 1) | |||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) | |||
anchors = torch.log(anchors / (1 - anchors)) | |||
anchors = torch.where(valid_mask, anchors, torch.inf) | |||
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=dtype, device=device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yonigozlan re dtype/device for torch.compile
Hi @qubvel, honestly I wasn't aware of this repository ! Very interesting discovery ! But I can rework the |
Hey @YHallouard sorry for the confusion, I think we need better doc on this 😢 let's not add anything that we know is already deprecated! If you need guidance on ONNX and Optimum contribution for this, I am sure @michaelbenayoun will be happy to help! |
2b90511
to
e38ba18
Compare
Hi @ArthurZucker, @michaelbenayoun, No problem, I removed th Onnx config and openned a pull request in optimum. huggingface/optimum#2040 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! 🤗
@@ -1736,7 +1736,7 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dt | |||
anchors = torch.concat(anchors, 1) | |||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) | |||
anchors = torch.log(anchors / (1 - anchors)) | |||
anchors = torch.where(valid_mask, anchors, torch.inf) | |||
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=dtype, device=device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
anchors = torch.where(valid_mask, anchors, torch.tensor(float("inf"), dtype=dtype, device=device)) | |
anchors = torch.where(valid_mask, anchors, torch.finfo(dtype).min, dtype=dtype, device=device)) |
if the dtype is the dtype of valid_mask
then this would make more sense ! Otherwise the min (float32 -inf) is not gonna be the same!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype is the dtype of anchors
, valid_mask
is a condition, dtypes are the same but just by construction.
But you're right, torch.finfo(dtype).max
is good :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6b0f18a
to
1aa73fe
Compare
Hi @ArthurZucker, should I run a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now thanks for updating!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* feat: [RT-DETR] Add onnx runtime config and fix onnx inference bug Optype (Where) * fix lint * use dtype istead of torch.float32 * add doc * remove onnx config * use dtype info * use tensor to fix lint
* feat: [RT-DETR] Add onnx runtime config and fix onnx inference bug Optype (Where) * fix lint * use dtype istead of torch.float32 * add doc * remove onnx config * use dtype info * use tensor to fix lint
What does this PR do?
ImplementRTDetrOnnxConfig
for RT-DETRError: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(double) in node (/model/decoder/Where)
during onnx inference. Already fixed on lyuwenyu/RT-DETR (Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(double) in node (/model/decoder/Where). lyuwenyu/RT-DETR#307)Fixes
Before submitting
Pull Request section?
to it if that's the case. Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(double) in node (/model/decoder/Where). lyuwenyu/RT-DETR#307
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.