Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export support for GIT #2132

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

marcindulak
Copy link

@marcindulak marcindulak commented Dec 19, 2024

What does this PR do?

Relates to #874

This an unsuccessful attempt to add ONNX export of https://huggingface.co/microsoft/git-base. I've followed https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/contribute and the recent "Add ONNX export support" PRs by @xenova.

I'm not sure how feasible it is to add GIT exporter, or how much more work is required. I'll need some guidance.

Encountered problems:

I'm using 34b3d8b as the main base of the feature branch

1. hf-internal-testing/tiny-random-GitModel: ValueError: Input image size (64*64) doesn't match model (32*32).
docker exec -it hf bash -ci "optimum-cli export onnx --model hf-internal-testing/tiny-random-GitModel /tmp/tiny-random-GitModel"
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:685: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
Traceback (most recent call last):
  File "/root/venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 373, in main_export
    onnx_export_from_model(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 1197, in onnx_export_from_model
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 783, in export_models
    export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 888, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 584, in export_pytorch
    onnx_export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1300, in forward
    visual_features = self.image_encoder(
                      ^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1100, in forward
    return self.vision_model(
           ^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1025, in forward
    hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 686, in forward
    raise ValueError(
ValueError: Input image size (64*64) doesn't match model (32*32).
2. hf-internal-testing/tiny-random-GitForCausalLM: KeyError: "Unknown task: image-text-to-text.
docker exec -it hf bash -ci "optimum-cli export onnx --model hf-internal-testing/tiny-random-GitForCausalLM /tmp/tiny-random-GitForCausalLM"
Traceback (most recent call last):
  File "/root/venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 303, in main_export
    model = TasksManager.get_model_from_task(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/tasks.py", line 2150, in get_model_from_task
    model_class = TasksManager.get_model_class_for_task(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/tasks.py", line 1470, in get_model_class_for_task
    raise KeyError(
KeyError: "Unknown task: image-text-to-text. Possible values are: `audio-classification` for AutoModelForAudioClassification, `audio-frame-classification` for AutoModelForAudioFrameClassification, `audio-xvector` for AutoModelForAudioXVector, `automatic-speech-recognition` for ('AutoModelForSpeechSeq2Seq', 'AutoModelForCTC'), `depth-estimation` for AutoModelForDepthEstimation, `feature-extraction` for AutoModel, `fill-mask` for AutoModelForMaskedLM, `image-classification` for AutoModelForImageClassification, `image-segmentation` for ('AutoModelForImageSegmentation', 'AutoModelForSemanticSegmentation'), `image-to-image` for AutoModelForImageToImage, `image-to-text` for ('AutoModelForVision2Seq', 'AutoModel'), `mask-generation` for AutoModel, `masked-im` for AutoModelForMaskedImageModeling, `multiple-choice` for AutoModelForMultipleChoice, `object-detection` for AutoModelForObjectDetection, `question-answering` for AutoModelForQuestionAnswering, `reinforcement-learning` for AutoModel, `semantic-segmentation` for AutoModelForSemanticSegmentation, `text-to-audio` for ('AutoModelForTextToSpectrogram', 'AutoModelForTextToWaveform'), `text-generation` for AutoModelForCausalLM, `text2text-generation` for AutoModelForSeq2SeqLM, `text-classification` for AutoModelForSequenceClassification, `token-classification` for AutoModelForTokenClassification, `zero-shot-image-classification` for AutoModelForZeroShotImageClassification, `zero-shot-object-detection` for AutoModelForZeroShotObjectDetection"
3. microsoft/git-base: ValueError: Input image size (64*64) doesn't match model (224*224).
docker exec -it hf bash -ci "optimum-cli export onnx --model microsoft/git-base /tmp/git-base"
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:685: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
Traceback (most recent call last):
  File "/root/venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 373, in main_export
    onnx_export_from_model(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 1197, in onnx_export_from_model
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 783, in export_models
    export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 888, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 584, in export_pytorch
    onnx_export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1570, in forward
    outputs = self.git(
              ^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1300, in forward
    visual_features = self.image_encoder(
                      ^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1100, in forward
    return self.vision_model(
           ^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1025, in forward
    hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 686, in forward
    raise ValueError(
ValueError: Input image size (64*64) doesn't match model (224*224).

To reproduce

Setup optimum in docker
  1. Use the branch
git clone https://github.com/marcindulak/optimum
cd optimum
git checkout add-git
cd -
  1. Install optimum dependencies in a docker container
docker run --detach --volume $PWD:/opt/hf --env HF_HOME=/opt/hf --env OPENBLAS_NUM_THREADS=1 --name hf debian:stable bash -c "sleep infinity"
docker exec -it hf bash -c "apt-get update && apt-get install --yes git python-is-python3 python3-dev python3-pip python3-venv"
docker exec -it hf bash -c "cd ~/ && python -m venv venv && . venv/bin/activate"
docker exec -it hf bash -c "echo 'if [ -f ~/venv/bin/activate  ]; then . ~/venv/bin/activate; fi' >> ~/.bashrc"
docker exec -it hf bash -ci "python -m pip install torch==2.* transformers==4.* onnx"
  1. Use the optimum branch inside of the container
docker exec -it hf bash -ci "python -m pip install /opt/hf/optimum"
docker exec -it hf bash -ci "python -m pip freeze"

Output

certifi==2024.12.14
charset-normalizer==3.4.0
filelock==3.16.1
fsspec==2024.12.0
huggingface-hub==0.27.0
idna==3.10
Jinja2==3.1.4
MarkupSafe==3.0.2
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.0
onnx==1.17.0
optimum @ file:///opt/hf/optimum
packaging==24.2
protobuf==5.29.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.4.5
sympy==1.13.1
tokenizers==0.21.0
torch==2.5.1
tqdm==4.67.1
transformers==4.47.1
typing_extensions==4.12.2
urllib3==2.2.3
  1. Test
docker exec -it hf bash -ci "optimum-cli export onnx --model hf-internal-testing/tiny-random-GitModel /tmp/tiny-random-GitModel"
docker exec -it hf bash -ci "optimum-cli export onnx --model hf-internal-testing/tiny-random-GitForCausalLM /tmp/tiny-random-GitForCausalLM"
docker exec -it hf bash -ci "optimum-cli export onnx --model microsoft/git-base /tmp/git-base"

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@echarlaix

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your contribution @marcindulak!

Comment on lines 701 to 705
"git-vision-model": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="GITVisionModelOnnxConfig",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this model type doesn't exist so can be removed

Suggested change
"git-vision-model": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="GITVisionModelOnnxConfig",
),

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in marcindulak@ea2321c

I imagined that git-vision-model is expected since there is a separate clip-vision-model.
The docs show CLIPVisionModel
https://huggingface.co/docs/transformers/main/model_doc/clip#transformers.CLIPVisionModel
and GitVisionModel
https://huggingface.co/docs/transformers/main/model_doc/git#transformers.GitVisionModel
so I thought the setup will be similar.

I see git_vision_model in https://huggingface.co/microsoft/git-large/blob/main/config.json, but it's nested under vision_config. Is this the reason why there is no separate OnnxConfig needed?

Comment on lines 2638 to 2643
class GITVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this case should be included in GITOnnxConfig directly depending on self.task no ? if not then in which case should the model be exported with input_ids as input?

Suggested change
class GITVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempted to use self.task in marcindulak@ea2321c

Problems:

1. ValueError: You have to specify either input_ids or inputs_embeds
optimum-cli export onnx --model microsoft/git-base /tmp/git-base
image-to-text <class 'optimum.utils.input_generators.DummyVisionInputGenerator'>
Traceback (most recent call last):
  File "/root/venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 373, in main_export
    onnx_export_from_model(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 1176, in onnx_export_from_model
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 762, in export_models
    export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 867, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 563, in export_pytorch
    onnx_export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1570, in forward
    outputs = self.git(
              ^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py", line 1276, in forward
    raise ValueError("You have to specify either input_ids or inputs_embeds")
ValueError: You have to specify either input_ids or inputs_embeds

It looks like all three types of tasks "feature-extraction", "image-text-to-text", "image-to-text" want "input_ids" as input. Could it be due to the use of TextAndVisionOnnxConfig as the base class?

class GITOnnxConfig(TextAndVisionOnnxConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args(vision_config="vision_config")
2. We don't have an op for aten::full but it isn't a special case. Argument types: int[], bool, NoneType, NoneType, Device, bool
optimum-cli export onnx --model hf-internal-testing/tiny-random-GitForCausalLM /tmp/tiny-random-GitForCausalLM
image-text-to-text <class 'optimum.utils.input_generators.DummyTextInputGenerator'>
image-text-to-text <class 'optimum.utils.input_generators.DummyVisionInputGenerator'>
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:685: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:695: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if interpolate_pos_encoding:
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:768: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/root/venv/lib/python3.11/site-packages/transformers/models/git/modeling_git.py:808: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
Traceback (most recent call last):
  File "/root/venv/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/root/venv/lib/python3.11/site-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 373, in main_export
    onnx_export_from_model(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 1176, in onnx_export_from_model
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 762, in export_models
    export(
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 867, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 563, in export_pytorch
    onnx_export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 0 INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/alias_analysis.cpp":617, please report a bug to PyTorch. We don't have an op for aten::full but it isn't a special case.  Argument types: int[], bool, NoneType, NoneType, Device, bool, 

Candidates:
	aten::full.names(int[] size, Scalar fill_value, *, str[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
	aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
	aten::full.names_out(int[] size, Scalar fill_value, *, str[]? names, Tensor(a!) out) -> Tensor(a!)
	aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)

Is this the case of pytorch/pytorch#137202
pytorch/pytorch#130229, or some misconfiguration?



class GITOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue you're reporting ValueError: Input image size (64*64) doesn't match model (32*32). should be fixed if you replace the config with :

Suggested change
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
NormalizedTextAndVisionConfig.with_args(vision_config="vision_config")

as it seems that for GitConfig the image_size attribute needs to be taken from the vision_config directly https://github.com/huggingface/transformers/blob/504c4d36929b6bb8a8c2ecfad0f2625f4075f22a/src/transformers/models/git/configuration_git.py#L98.

What is currently happening is that before export this value is not found in the config and is default to 64

if normalized_config.has_attribute("image_size"):
when it should be set to 32 in your case https://huggingface.co/hf-internal-testing/tiny-random-GitModel/blob/main/config.json#L52

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -692,6 +692,17 @@ class TasksManager:
"text-classification",
onnx="GemmaOnnxConfig",
),
"git": supported_tasks_mapping(
"feature-extraction",
"image-text-to-text",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue
KeyError: "Unknown task: image-text-to-text. Possible values are: `audio-classification` for AutoModelForAudioClassification, `audio-frame-classification` for AutoModelForAudioFrameClassification, `audio-xvector` for AutoModelForAudioXVector, `automatic-speech-recognition` for ('AutoModelForSpeechSeq2Seq', 'AutoModelForCTC'), `depth-estimation` for AutoModelForDepthEstimation, `feature-extraction` for AutoModel, `fill-mask` for AutoModelForMaskedLM, `image-classification` for AutoModelForImageClassification, `image-segmentation` for ('AutoModelForImageSegmentation', 'AutoModelForSemanticSegmentation'), `image-to-image` for AutoModelForImageToImage, `image-to-text` for ('AutoModelForVision2Seq', 'AutoModel'), `mask-generation` for AutoModel, `masked-im` for AutoModelForMaskedImageModeling, `multiple-choice` for AutoModelForMultipleChoice, `object-detection` for AutoModelForObjectDetection, `question-answering` for AutoModelForQuestionAnswering, `reinforcement-learning` for AutoModel, `semantic-segmentation` for AutoModelForSemanticSegmentation, `text-to-audio` for ('AutoModelForTextToSpectrogram', 'AutoModelForTextToWaveform'), `text-generation` for AutoModelForCausalLM, `text2text-generation` for AutoModelForSeq2SeqLM, `text-classification` for AutoModelForSequenceClassification, `token-classification` for AutoModelForTokenClassification, `zero-shot-image-classification` for AutoModelForZeroShotImageClassification, `zero-shot-object-detection` for AutoModelForZeroShotObjectDetection"

comes from the fact that we don't yet support the "image-text-to-text" task but can be added here

_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcindulak marcindulak marked this pull request as draft December 20, 2024 19:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants