Skip to content

Commit e114e9c

Browse files
committed
Format position_id_patch.py to pass lint
1 parent 0f3d043 commit e114e9c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

position_id_patch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from torch.export import export as torch_export
66
from tvm.relax.frontend.torch import from_exported_program
77

8+
89
class StateDictWrapper(dict):
910
"""Wrap exported state_dict and inject extra keys (non-persistent buffers)."""
11+
1012
def __init__(self, base_dict, extra):
1113
super().__init__(base_dict)
1214
self.extra = extra
@@ -21,6 +23,7 @@ def get(self, key, default=None):
2123
return self.extra[key]
2224
return super().get(key, default)
2325

26+
2427
class M(nn.Module):
2528
def __init__(self):
2629
super().__init__()
@@ -31,6 +34,7 @@ def forward(self, x, mask=None):
3134
out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
3235
return self.cls(out)
3336

37+
3438
def main():
3539
torch.manual_seed(0)
3640
m = M().eval()
@@ -72,7 +76,9 @@ def __getattr__(self, name):
7276
except Exception as e:
7377
print("\n TVM import failed with exception:")
7478
import traceback
79+
7580
traceback.print_exc()
7681

82+
7783
if __name__ == "__main__":
7884
main()

0 commit comments

Comments
 (0)