Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import importlib.util
import os
import re
import warnings
Expand Down Expand Up @@ -1073,6 +1074,22 @@ def get_connected_passed_kwargs(prefix):
return init_kwargs


def _is_deprecated_pipeline_module(module_candidate: str) -> bool:
"""Return whether ``module_candidate`` is a pipeline module relocated under
``diffusers.pipelines.deprecated``.

Deprecated pipelines (e.g. Wuerstchen) are no longer attributes of ``diffusers.pipelines``, so a
plain ``hasattr(diffusers.pipelines, module_candidate)`` check fails for them even though the
module still ships with diffusers. We resolve the spec without importing the module to avoid
triggering its (potentially heavy) import side effects.
"""
try:
return importlib.util.find_spec(f"diffusers.pipelines.deprecated.{module_candidate}") is not None
except (ImportError, ModuleNotFoundError, ValueError):
# ValueError covers malformed candidate names (e.g. containing path separators).
return False


def _get_custom_components_and_folders(
pretrained_model_name: str,
config_dict: dict[str, Any],
Expand Down Expand Up @@ -1101,7 +1118,14 @@ def _get_custom_components_and_folders(

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
elif (
module_candidate not in LOADABLE_CLASSES
and not hasattr(pipelines, module_candidate)
# Pipelines moved under `diffusers.pipelines.deprecated` are no longer attributes of
# `diffusers.pipelines`, so `hasattr` above misses them. Check the deprecated namespace
# too before treating the component as a missing custom module.
and not _is_deprecated_pipeline_module(module_candidate)
):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)
Expand Down
52 changes: 51 additions & 1 deletion tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.pipelines.pipeline_loading_utils import (
_get_custom_components_and_folders,
_is_deprecated_pipeline_module,
is_safetensors_compatible,
variant_compatible_siblings,
)

from ..testing_utils import require_torch_accelerator, torch_device

Expand Down Expand Up @@ -233,6 +238,51 @@ def test_is_compatible_variant_and_non_safetensors(self):
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))


class GetCustomComponentsAndFoldersTests(unittest.TestCase):
def test_deprecated_pipeline_module_is_recognized(self):
# Pipelines relocated under `diffusers.pipelines.deprecated` (e.g. Wuerstchen) are no longer
# attributes of `diffusers.pipelines`. They must still resolve instead of being mistaken for
# a missing custom module. Regression test for loading repos like `warp-ai/wuerstchen-prior`.
config_dict = {
"_class_name": "WuerstchenPriorPipeline",
"prior": ["wuerstchen", "WuerstchenPrior"],
}
custom_components, folder_names = _get_custom_components_and_folders(
"warp-ai/wuerstchen-prior", config_dict, filenames=[]
)
self.assertEqual(custom_components, {})
self.assertEqual(folder_names, ["prior"])

def test_missing_custom_module_still_raises(self):
# A component that is neither a loadable class, a known pipeline module, a deprecated module,
# nor an actual custom file on the Hub must still raise.
config_dict = {
"_class_name": "FooPipeline",
"foo": ["totally_made_up_module", "FooModel"],
}
with self.assertRaises(ValueError):
_get_custom_components_and_folders("some/repo", config_dict, filenames=[])

def test_custom_component_file_is_detected(self):
# When the custom module file is actually present on the Hub it is recorded as a custom component.
config_dict = {
"_class_name": "FooPipeline",
"foo": ["my_pipeline", "FooModel"],
}
custom_components, folder_names = _get_custom_components_and_folders(
"some/repo", config_dict, filenames=["foo/my_pipeline.py"]
)
self.assertEqual(custom_components, {"foo": "my_pipeline"})

def test_is_deprecated_pipeline_module(self):
self.assertTrue(_is_deprecated_pipeline_module("wuerstchen"))
self.assertFalse(_is_deprecated_pipeline_module("totally_made_up_module"))
# A non-deprecated (current) pipeline module is not under the deprecated namespace.
self.assertFalse(_is_deprecated_pipeline_module("stable_diffusion"))
# Malformed candidate names must not raise.
self.assertFalse(_is_deprecated_pipeline_module("weird/name"))


class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
ignore_patterns = ["*.bin"]
Expand Down
Loading