-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relax][PyTorch] Add extra_buffers support for exported program #18503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @abeeha123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to add support for extra non-persistent buffers in the PyTorch frontend. While the intention is good, the current implementation has a couple of issues. First, the extra_buffers are hardcoded, which limits the feature's reusability. It would be better to pass them as a parameter. Second, the logic to handle these extra buffers is incorrect and will lead to runtime errors due to an undefined variable and incorrect attribute access. I've provided suggestions to address these points.
| #Buffers | ||
| info=None | ||
| if spec.target in merged_state: | ||
| info=merged_state[spec.target] | ||
| elif spec.target.split(".")[-1] in merged_state: | ||
| info = merged_state[spec.target.split(".")[-1]] | ||
| if info is None: | ||
| raise KeyError(f"Missing target in state_dict or extra buffers: {spec.target}") | ||
|
|
||
| # Handle both original and extra buffer | ||
| if hasattr(info,"shape") and hasattr(info,"dtype"): | ||
| torch_shape=info.shape | ||
| torch_dtype=info.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has a couple of critical issues:
- The
merged_statevariable is used but not defined, which will cause aNameError. You should define it before the loop by mergingexported_program.state_dictandextra_buffers. For example:merged_state = {**exported_program.state_dict, **extra_buffers}. - The logic to get shape and dtype using
hasattr(info, "shape")will not work for the dictionaries inextra_buffers, as dictionaries do not haveshapeanddtypeattributes. This will lead totorch_shapeandtorch_dtypenot being defined, causing anotherNameError.
Here is a suggested implementation that fixes these issues. Please note that you still need to define merged_state before this block.
| #Buffers | |
| info=None | |
| if spec.target in merged_state: | |
| info=merged_state[spec.target] | |
| elif spec.target.split(".")[-1] in merged_state: | |
| info = merged_state[spec.target.split(".")[-1]] | |
| if info is None: | |
| raise KeyError(f"Missing target in state_dict or extra buffers: {spec.target}") | |
| # Handle both original and extra buffer | |
| if hasattr(info,"shape") and hasattr(info,"dtype"): | |
| torch_shape=info.shape | |
| torch_dtype=info.dtype | |
| info = merged_state.get(spec.target) | |
| if info is None: | |
| short_name = spec.target.split(".")[-1] | |
| info = merged_state.get(short_name) | |
| if info is None: | |
| raise KeyError(f"Missing target in state_dict or extra buffers: {spec.target}") | |
| if isinstance(info, torch.Tensor): | |
| torch_shape = info.shape | |
| torch_dtype = info.dtype | |
| elif isinstance(info, dict): | |
| torch_shape = info["shape"] | |
| torch_dtype = info["dtype"] | |
| else: | |
| raise TypeError(f"Unsupported type for buffer/parameter info: {type(info)}") |
| extra_buffers={ | ||
| "position_ids": {"shape":(1,128), "dtype":torch.int64}, | ||
| "token_type_ids": {"shape":(1,128), "dtype":torch.int64}, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra_buffers dictionary is hardcoded, which limits the utility of this feature to a specific model configuration (BERT with sequence length 128). To make this more general, extra_buffers should be an optional parameter to from_exported_program and passed down to this function. This would allow users to provide their own non-persistent buffers for different models.
55e8c37 to
280736e
Compare
280736e to
93eb4e0
Compare
|
Make sense to me. Can you also add testcase for that in https://github.com/apache/tvm/blob/main/tests/python/relax/test_frontend_from_exported_program.py ? |
Fixes #18357 — KeyError in the Relax PyTorch frontend when loading an HuggingFace BERT model exported with torch. export.
The issue occurs because non-persistent buffers, such as position_ids and token_type_ids, are not included in the state_dict, leading to missing-buffer errors.
This PR adds support for extra non-persistent buffers by introducing an extra_buffers lookup and extending buffer-resolution logic to handle shapes and dtypes correctly.
With this fix, HuggingFace BERT models exported via torch. export no longer fails with missing-buffer KeyErrors.