Skip to content

Triton/CUDA Illegal Memory Access in triton_rmsnorm_forward with Wan2.1-T2V-1.3B-Diffusers #55

@1717Li

Description

@1717Li

作者你好,我在使用Wan2.1-T2V-1.3B-Diffusers并使用SAP或者SVG的时候遇到了一些CUDA error。
我只是将wan_t2v_720p_svg.sh 或 wan_t2v_720p_sap.sh 中的model_id改成了1.3B,其他参数没有改动,在运行的过程中出现报错,完整报错信息如下:

Traceback (most recent call last):
  File "/workspace/Sparse-VideoGen/wan_t2v_inference.py", line 159, in <module>
    output = pipe(
             ^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan.py", line 536, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/Sparse-VideoGen/svg/models/wan/custom_models.py", line 172, in forward
    hidden_states = block(
                    ^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/Sparse-VideoGen/svg/models/wan/custom_models.py", line 57, in forward
    attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, timestep=timestep)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 605, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/workspace/Sparse-VideoGen/svg/models/wan/attention.py", line 168, in __call__
    query, key = self.get_qk_norm(attn, query, key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx.conda/envs/svg/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/Sparse-VideoGen/svg/models/wan/attention.py", line 117, in get_qk_norm
    key = triton_rmsnorm_forward(key, attn.norm_k.weight, attn.norm_k.eps)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/Sparse-VideoGen/svg/kernels/triton/rmsnorm.py", line 85, in triton_rmsnorm_forward
    _rms_norm_fwd_fused[(triton.cdiv(M, BLOCK_M),)](  #
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/triton/runtime/jit.py", line 617, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/xxx/.conda/envs/svg/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 708, in __call__
    self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

使用Wan2.1-T2V-14B-Diffusers能够顺利运行。
请问要如何解决

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions