-
Notifications
You must be signed in to change notification settings - Fork 1.4k
8627 perceptual loss errors out after hitting the maximum number of downloads #8652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
b1e4a50
0aeb4d9
fa0639b
685aee2
915de5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ | |
|
|
||
|
|
||
| class PercetualNetworkType(StrEnum): | ||
| """Types of neural networks that are supported by perceptua loss.""" | ||
|
|
||
| alex = "alex" | ||
| vgg = "vgg" | ||
| squeeze = "squeeze" | ||
|
|
@@ -108,9 +110,11 @@ def __init__( | |
|
|
||
| self.spatial_dims = spatial_dims | ||
| self.perceptual_function: nn.Module | ||
|
|
||
| # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. | ||
| if spatial_dims == 3 and is_fake_3d is False: | ||
| self.perceptual_function = MedicalNetPerceptualSimilarity( | ||
| net=network_type, verbose=False, channel_wise=channel_wise | ||
| net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir | ||
| ) | ||
| elif "radimagenet_" in network_type: | ||
| self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) | ||
|
|
@@ -122,7 +126,9 @@ def __init__( | |
| pretrained_state_dict_key=pretrained_state_dict_key, | ||
| ) | ||
| else: | ||
| # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. | ||
| self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) | ||
|
|
||
| self.is_fake_3d = is_fake_3d | ||
| self.fake_3d_ratio = fake_3d_ratio | ||
| self.channel_wise = channel_wise | ||
|
|
@@ -194,7 +200,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer | ||
| Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from | ||
| "Warvito/MedicalNet-models". | ||
| "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} | ||
|
|
@@ -205,11 +211,17 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
|
|
||
| def __init__( | ||
| self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False | ||
| self, | ||
| net: str = "medicalnet_resnet_10_23datasets", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What we should do to improve security is to move "medicalnet_resnet10_23datasets" and any other valid name into constants then check |
||
| verbose: bool = False, | ||
| channel_wise: bool = False, | ||
| cache_dir: str | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True | ||
| self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) | ||
| self.model = torch.hub.load( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this should have |
||
| "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir | ||
| ) | ||
| self.eval() | ||
|
|
||
| self.channel_wise = channel_wise | ||
|
|
@@ -287,17 +299,17 @@ class RadImageNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et | ||
| al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class | ||
| uses torch Hub to download the networks from "Warvito/radimagenet-models". | ||
| uses torch Hub to download the networks from "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"radimagenet_resnet50"``} | ||
| Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. | ||
| verbose: if false, mute messages from torch Hub load function. | ||
| """ | ||
|
|
||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: | ||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: | ||
| super().__init__() | ||
| self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) | ||
| self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) | ||
| self.eval() | ||
|
|
||
| for param in self.parameters(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to check that
huggingface_hubis installed as well, thetorch.hub.loadcall will fail if so.