Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python/tvm/relax/frontend/torch/position_id_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import torch
import torch.nn as nn
from torch.export import export as torch_export
from transformers import AutoModel

from tvm.relax.frontend.torch import from_exported_program


class StateDictWrapper(dict):
"""Wrap exported state_dict and inject extra keys (non-persistent buffers)."""

def __init__(self, base_dict, extra):
super().__init__(base_dict)
self.extra = extra

def __getitem__(self, key):
if key in self.extra:
return self.extra[key]
return super().__getitem__(key)

def get(self, key, default=None):
if key in self.extra:
return self.extra[key]
return super().get(key, default)


class M(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
self.cls = nn.Linear(self.bert.config.hidden_size, 2)

def forward(self, x, mask=None):
out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
return self.cls(out)


def main():
torch.manual_seed(0)
m = M().eval()

x = torch.randint(0, 30522, (2, 16))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 30522 is a magic number, which appears to be the vocabulary size for 'bert-base-multilingual-uncased'. It's better to fetch this value from the model's configuration to improve readability and maintainability. This makes the code more robust if the model changes.

Suggested change
x = torch.randint(0, 30522, (2, 16))
x = torch.randint(0, m.bert.config.vocab_size, (2, 16))

mask = torch.ones_like(x)

ep = torch_export(m, (x, mask))
print("\n torch.export completed successfully\n")

# --- Build extra buffers dict ---
extra = {}
for buf_name in m.bert.embeddings._non_persistent_buffers_set:
tensor = m.bert.embeddings._buffers.get(buf_name)
if tensor is not None:
extra[f"bert.embeddings.{buf_name}"] = tensor
print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape {tensor.shape}")

# Wrap exported state_dict
sd_wrapped = StateDictWrapper(ep.state_dict, extra)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The custom StateDictWrapper class can be replaced by collections.ChainMap for a more concise and idiomatic implementation. ChainMap is designed for linking multiple dictionaries.

After this change, you can remove the StateDictWrapper class definition (lines 8-22) and add import collections to the top of the file.

Suggested change
sd_wrapped = StateDictWrapper(ep.state_dict, extra)
sd_wrapped = collections.ChainMap(extra, ep.state_dict)


# EP wrapper to override state_dict access
class EPWrapper:
def __init__(self, ep, sd_wrapped):
self.__dict__["_ep"] = ep
self.__dict__["_sd"] = sd_wrapped

def __getattr__(self, name):
if name == "state_dict":
return self._sd
return getattr(self._ep, name)

Comment on lines +77 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This proxy implementation for ExportedProgram is minimal. A more robust and idiomatic way to create this wrapper is by using a property for state_dict. This avoids overriding __getattr__ in a way that could be brittle and makes the intent clearer. The suggested implementation is cleaner and less prone to subtle bugs if the from_exported_program API has more complex interactions with the object.

Suggested change
class EPWrapper:
def __init__(self, ep, sd_wrapped):
self.__dict__["_ep"] = ep
self.__dict__["_sd"] = sd_wrapped
def __getattr__(self, name):
if name == "state_dict":
return self._sd
return getattr(self._ep, name)
class EPWrapper:
def __init__(self, ep, sd_wrapped):
self._ep = ep
self._sd = sd_wrapped
@property
def state_dict(self):
return self._sd
def __getattr__(self, name):
return getattr(self._ep, name)

ep_wrapped = EPWrapper(ep, sd_wrapped)

# Import to TVM
try:
mod = from_exported_program(ep_wrapped)
print("\n TVM import succeeded — all non-persistent buffers injected!\n")
except Exception as e:
print("\n TVM import failed with exception:")
import traceback

traceback.print_exc()


if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions python/tvm/relax/frontend/torch/position_id_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import torch
import torch.nn as nn
from torch.export import export as torch_export
from transformers import AutoModel

from tvm.relax.frontend.torch import from_exported_program


class StateDictWrapper(dict):
"""Wrap exported state_dict and inject extra keys (non-persistent buffers)."""

def __init__(self, base_dict, extra):
super().__init__(base_dict)
self.extra = extra

def __getitem__(self, key):
if key in self.extra:
return self.extra[key]
return super().__getitem__(key)

def get(self, key, default=None):
if key in self.extra:
return self.extra[key]
return super().get(key, default)


class M(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
self.cls = nn.Linear(self.bert.config.hidden_size, 2)

def forward(self, x, mask=None):
out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
return self.cls(out)


def main():
torch.manual_seed(0)
m = M().eval()

x = torch.randint(0, 30522, (2, 16))
mask = torch.ones_like(x)

ep = torch_export(m, (x, mask))
print("\n torch.export completed successfully\n")

# --- Build extra buffers dict ---
extra = {}
for buf_name in m.bert.embeddings._non_persistent_buffers_set:
tensor = m.bert.embeddings._buffers.get(buf_name)
if tensor is not None:
extra[f"bert.embeddings.{buf_name}"] = tensor
print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape {tensor.shape}")

# Wrap exported state_dict
sd_wrapped = StateDictWrapper(ep.state_dict, extra)

# EP wrapper to override state_dict access
class EPWrapper:
def __init__(self, ep, sd_wrapped):
self.__dict__["_ep"] = ep
self.__dict__["_sd"] = sd_wrapped

def __getattr__(self, name):
if name == "state_dict":
return self._sd
return getattr(self._ep, name)

ep_wrapped = EPWrapper(ep, sd_wrapped)

# Import to TVM
try:
mod = from_exported_program(ep_wrapped)
print("\n TVM import succeeded — all non-persistent buffers injected!\n")
except Exception as e:
print("\n TVM import failed with exception:")
import traceback

traceback.print_exc()


if __name__ == "__main__":
main()
Loading