diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 2ae03bc8dc..d31788c633 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -17,13 +17,24 @@ import torch.nn as nn from monai.utils import optional_import + from monai.utils.enums import StrEnum +# Valid model name to download from the repository +HF_MONAI_MODELS = ( + "medicalnet_resnet10_23datasets", + "medicalnet_resnet50_23datasets", + "radimagenet_resnet50", +) + LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") + class PercetualNetworkType(StrEnum): + """Types of neural networks that are supported by perceptual loss.""" + alex = "alex" vgg = "vgg" squeeze = "squeeze" @@ -84,13 +95,18 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: - raise ValueError( - "MedicalNet networks are only compatible with ``spatial_dims=3``." - "Argument is_fake_3d must be set to False." - ) - if channel_wise and "medicalnet_" not in network_type: + # Strict validation for MedicalNet + if "medicalnet_" in network_type: + if spatial_dims == 2 or is_fake_3d: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." + ) + if not channel_wise: + warnings.warn("MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.") + + # Channel-wise only for MedicalNet + elif channel_wise: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") if network_type.lower() not in list(PercetualNetworkType): @@ -108,9 +124,11 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module + + # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) @@ -122,7 +140,9 @@ def __init__( pretrained_state_dict_key=pretrained_state_dict_key, ) else: + # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio self.channel_wise = channel_wise @@ -194,7 +214,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from - "Warvito/MedicalNet-models". + "Project-MONAI/perceptual-models". Args: net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} @@ -205,11 +225,23 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ def __init__( - self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + self, + net: str = "medicalnet_resnet10_23datasets", + verbose: bool = False, + channel_wise: bool = False, + cache_dir: str | None = None, ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError( + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." + ) + + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, + trust_repo=True, + ) self.eval() self.channel_wise = channel_wise @@ -287,7 +319,7 @@ class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class - uses torch Hub to download the networks from "Warvito/radimagenet-models". + uses torch Hub to download the networks from "Project-MONAI/perceptual-models". Args: net: {``"radimagenet_resnet50"``} @@ -295,9 +327,14 @@ class RadImageNetPerceptualSimilarity(nn.Module): verbose: if false, mute messages from torch Hub load function. """ - def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) + if net not in HF_MONAI_MODELS: + raise ValueError( + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." + ) + self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, + trust_repo=True) self.eval() for param in self.parameters():