Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class PercetualNetworkType(StrEnum):
"""Types of neural networks that are supported by perceptua loss."""

alex = "alex"
vgg = "vgg"
squeeze = "squeeze"
Expand Down Expand Up @@ -108,9 +110,11 @@ def __init__(

self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module

# If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(
net=network_type, verbose=False, channel_wise=channel_wise
net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
Expand All @@ -122,7 +126,9 @@ def __init__(
pretrained_state_dict_key=pretrained_state_dict_key,
)
else:
# VGG, AlexNet and SqueezeNet are independently handled by LPIPS.
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)

self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio
self.channel_wise = channel_wise
Expand Down Expand Up @@ -194,7 +200,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
"Warvito/MedicalNet-models".
"Project-MONAI/perceptual-models".

Args:
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Expand All @@ -205,11 +211,17 @@ class MedicalNetPerceptualSimilarity(nn.Module):
"""

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to check that huggingface_hub is installed as well, the torch.hub.load call will fail if so.

self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
self,
net: str = "medicalnet_resnet_10_23datasets",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we should do to improve security is to move "medicalnet_resnet10_23datasets" and any other valid name into constants then check net is one of them.

verbose: bool = False,
channel_wise: bool = False,
cache_dir: str | None = None,
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
self.model = torch.hub.load(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this should have trust_repo=True as well since we do trust this repo and it's name is hardcoded here.

"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir
)
self.eval()

self.channel_wise = channel_wise
Expand Down Expand Up @@ -287,17 +299,17 @@ class RadImageNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
uses torch Hub to download the networks from "Warvito/radimagenet-models".
uses torch Hub to download the networks from "Project-MONAI/perceptual-models".

Args:
net: {``"radimagenet_resnet50"``}
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir)
self.eval()

for param in self.parameters():
Expand Down
Loading