diff --git a/src/modalities/models/huggingface_adapters/hf_adapter.py b/src/modalities/models/huggingface_adapters/hf_adapter.py index 09c751836..e8e5de6de 100644 --- a/src/modalities/models/huggingface_adapters/hf_adapter.py +++ b/src/modalities/models/huggingface_adapters/hf_adapter.py @@ -17,7 +17,7 @@ class HFModelAdapterConfig(PretrainedConfig): model_type = "modalities" - def __init__(self, **kwargs): + def __init__(self, config={}, **kwargs): """ Initializes an HFModelAdapterConfig object. @@ -28,6 +28,7 @@ def __init__(self, **kwargs): ConfigError: If the config is not passed in HFModelAdapterConfig. """ super().__init__(**kwargs) + self.config = config # self.config is added by the super class via kwargs if self.config is None: raise ConfigError("Config is not passed in HFModelAdapterConfig.") @@ -115,7 +116,7 @@ def forward( raise NotImplementedError model_input = {"input_ids": input_ids, "attention_mask": attention_mask} model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) - if return_dict: + if not return_dict: return ModalitiesModelOutput(**model_forward_output) else: return model_forward_output[self.prediction_key]