Skip to content

Commit dd901de

Browse files
committed
[feat] Allow specifying output_key for mlp head in MMFT
This will allow doing multitasking on single batch easily through different heads with different output_key per task
1 parent 08f062e commit dd901de

File tree

1 file changed

+3
-1
lines changed
  • mmf/models/transformers/heads

1 file changed

+3
-1
lines changed

mmf/models/transformers/heads/mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Config(BaseTransformerHead.Config):
2020
hidden_dropout_prob: float = 0.1
2121
layer_norm_eps: float = 1e-6
2222
hidden_act: str = "gelu"
23+
output_key: str = "scores"
2324

2425
def __init__(self, config: Config, *args, **kwargs):
2526
super().__init__(config, *args, **kwargs)
@@ -33,6 +34,7 @@ def __init__(self, config: Config, *args, **kwargs):
3334
)
3435
self.num_labels = self.config.num_labels
3536
self.hidden_size = self.config.hidden_size
37+
self.output_key = self.config.get("output_key", "scores")
3638

3739
def forward(
3840
self,
@@ -46,5 +48,5 @@ def forward(
4648
output_dict = {}
4749
pooled_output = self.pooler(sequence_output)
4850
prediction = self.classifier(pooled_output)
49-
output_dict["scores"] = prediction.view(-1, self.num_labels)
51+
output_dict[self.output_key] = prediction.view(-1, self.num_labels)
5052
return output_dict

0 commit comments

Comments
 (0)