Skip to content

Regression (#13485) Broken TorchAO Compat #14097

Description

@Beinsezii

Describe the bug

Some changes introduced in #13485 / 236e5dd have regressed compat with compiled 8-bit models through TorchAO.

Reproduction

Simply ./regression.py. Tested on 1xH200. Switch to previous commit rev = "6d9331ea607183f84a21bdc37da6389a611fe7bd" to observe it functional.

#! /usr/bin/env -S uv run --script
# /// script
# requires-python = "==3.12.13"
# dependencies = [
#     "accelerate==1.14.0",
#     "diffusers",
#     "torch==2.12.1",
#     "torchao==0.17.0",
#     "transformers==5.5.4",
# ]
#
# [[tool.uv.index]]
# url = "https://download.pytorch.org/whl/cu132"
#
# [tool.uv.sources]
# diffusers = { git = "https://github.com/huggingface/diffusers.git", rev = "236e5dd9f38e21ae40c002539368b9be9a5e0fc8" }
# ///

import torch
from diffusers import AutoModel, Flux2KleinPipeline, TorchAoConfig
from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig

with torch.no_grad():
    device = "cuda"
    dtype = torch.bfloat16

    pipe = Flux2KleinPipeline.from_pretrained(
        "black-forest-labs/FLUX.2-klein-9B",
        transformer=torch.compile(
            AutoModel.from_pretrained(
                "black-forest-labs/FLUX.2-klein-9B",
                subfolder="transformer",
                torch_dtype=dtype,
                device_map=device,
                quantization_config=TorchAoConfig(
                    Int8DynamicActivationInt8WeightConfig(version=2),
                    modules_to_not_convert=[
                        "pos_embed",
                        "time_guidance_embed",
                        "double_stream_modulation_img",
                        "double_stream_modulation_txt",
                        "single_stream_modulation",
                        "x_embedder",
                        "context_embedder",
                        "norm_out",
                        "proj_out",
                    ],
                ),
            ),
            mode="max-autotune",
            fullgraph=True,
            dynamic=True,
        ),
        torch_dtype=dtype,
        device_map=device,
    )

    prompt = "A cat holding a sign that says hello world"
    image = pipe(
        prompt=prompt,
        height=512,
        width=512,
        guidance_scale=1.0,
        num_inference_steps=4,
        generator=torch.Generator(device=device).manual_seed(0),
    ).images[0]
    image.save("flux-klein.png")

Logs

File "/root/ml-api/./regression.py", line 59, in <module>
    image = pipe(
            ^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2_klein.py", line 849, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 473, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1062, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1069, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1049, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1836, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1597, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2613, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2619, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
                                                             ^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2555, in codegen
    self.scheduler.codegen()
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 7350, in codegen
    self._codegen_partitions()
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 7490, in _codegen_partitions
    self._codegen(partition)
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 7641, in _codegen
    self.get_backend(device).codegen_node(node)
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 152, in codegen_node
    return self._triton_scheduling.codegen_node(node)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 1864, in codegen_node
    return self._codegen_nodes(nodes, coalesce_analysis)  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 1837, in _codegen_nodes
    return self.codegen_node_schedule(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 1952, in codegen_node_schedule
    self.codegen_node_schedule_with_kernel(node_schedule, kernel)
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 2045, in codegen_node_schedule_with_kernel
    index_vars = kernel.split_and_set_ranges(node.get_ranges())
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 943, in split_and_set_ranges
    return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 970, in map_kernel_groups_to_node_sizes
    new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/uv/environments-v2/regression-42c8e9b09668863f/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py", line 847, in _split_iteration_ranges
    raise CantSplit(size, remaining[current_group])
torch._inductor.exc.InductorError: CantSplit: 16384*s50 + 16384*s78 not divisible by s50 + s78

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

System Info

- 🤗 Diffusers version: 0.39.0.dev0
- Platform: Linux-5.15.0-157-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.12.1+cu132 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.10.2
- Transformers version: 5.5.4
- Accelerate version: 1.12.0
- PEFT version: 0.18.1
- bitsandbytes version: 0.49.2
- optimum-quanto version: 0.2.7
- torchao version: 0.17.0
- Safetensors version: 0.8.0-rc.0
- xFormers version: not installed
- Accelerator: NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
NVIDIA H200, 143771 MiB
- Using GPU in script?: true
- Using distributed or parallel set-up in script?: false

Who can help?

@sayakpaul maintainer, @JingyaHuang original implementation

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions