diff --git a/comet/encoders/bert.py b/comet/encoders/bert.py index 753dc70..d6b7992 100644 --- a/comet/encoders/bert.py +++ b/comet/encoders/bert.py @@ -168,15 +168,24 @@ def forward( Dict[str, torch.Tensor]: dictionary with 'sentemb', 'wordemb', 'all_layers' and 'attention_mask'. """ - last_hidden_states, pooler_output, all_layers = self.model( + output = self.model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=False, ) + # Handle both transformers 4.x (3 values) and 5.x (2 values when add_pooling_layer=False) + if len(output) == 2: + last_hidden_states, all_layers = output + # Use CLS token as sentence embedding when no pooler output + sentemb = last_hidden_states[:, 0, :] + else: + last_hidden_states, pooler_output, all_layers = output + # Use pooler output if available, otherwise use CLS token + sentemb = pooler_output if pooler_output is not None else last_hidden_states[:, 0, :] return { - "sentemb": pooler_output, + "sentemb": sentemb, "wordemb": last_hidden_states, "all_layers": all_layers, "attention_mask": attention_mask, diff --git a/comet/encoders/xlmr.py b/comet/encoders/xlmr.py index 13cc15a..5ab76c2 100644 --- a/comet/encoders/xlmr.py +++ b/comet/encoders/xlmr.py @@ -92,12 +92,17 @@ def from_pretrained( def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: - last_hidden_states, _, all_layers = self.model( + output = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=False, ) + # Handle both transformers 4.x (3 values) and 5.x (2 values when add_pooling_layer=False) + if len(output) == 2: + last_hidden_states, all_layers = output + else: + last_hidden_states, _, all_layers = output return { "sentemb": last_hidden_states[:, 0, :], "wordemb": last_hidden_states, diff --git a/pyproject.toml b/pyproject.toml index 214a48c..d98e997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ comet-mbr = 'comet.cli.mbr:mbr_command' python = "^3.8.0" sentencepiece = "^0.2.0" pandas = ">=1.4.1" -transformers = "^4.17" +transformers = ">=4.17" pytorch-lightning = "^2.0.0" jsonargparse = "3.13.1" torch = ">=1.6.0" @@ -48,7 +48,7 @@ torchmetrics = "^0.10.2" sacrebleu = "^2.0.0" scipy = "^1.5.4" entmax = "^1.1" -huggingface-hub = ">=0.19.3,<1.0" +huggingface-hub = ">=0.19.3" protobuf = "^4.24.4" [tool.poetry.dev-dependencies]