Skip to content

Flux#1079

Open
sriraksharao wants to merge 64 commits intohao-ai-lab:mainfrom
sriraksharao:flux-test
Open

Flux#1079
sriraksharao wants to merge 64 commits intohao-ai-lab:mainfrom
sriraksharao:flux-test

Conversation

@sriraksharao
Copy link

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sriraksharao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the fastvideo framework by integrating the Flux family of diffusion models. It provides all necessary configurations and the core model implementation, alongside critical enhancements to the pipeline's data flow and performance monitoring. These changes ensure that Flux models can be efficiently run within the framework, benefiting from advanced features like dynamic resource management and detailed performance insights.

Highlights

  • Flux Model Integration: Introduces comprehensive support for the Flux family of diffusion models (Flux, Flux2, Flux2Klein), including their specific architectural, VAE, pipeline, and sampling configurations.
  • Core Flux Transformer Implementation: Adds the FluxTransformer2DModel with custom attention mechanisms, transformer blocks, and positional embeddings tailored for Flux models.
  • Enhanced Pipeline Stages: Implements new hooks in pipeline stages for customizable latent shape preparation, latent packing/unpacking, and VAE decoding, crucial for Flux models' unique processing.
  • Dynamic Timestep Preparation: Enhances timestep generation with dynamic shifting logic and a Flux-specific scaling stage, allowing for more precise control over the diffusion process.
  • Advanced Performance Logging & Offloading: Integrates a new performance logging and profiling system, along with significant enhancements to layerwise CPU offloading capabilities for efficient resource management.
  • New Example Script: Adds a basic inference example (basic_flux.py) demonstrating video generation from text prompts using the FLUX.1-dev model with configurable resource offloading.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/inference/basic/basic_flux.py
    • Adds a basic inference example for the FLUX.1-dev model, demonstrating video generation from text prompts with configurable resource offloading.
  • fastvideo/configs/flux_t2i.json
    • Adds a JSON configuration file (flux_t2i.json) defining default parameters for Flux T2I models, including image dimensions, inference steps, and scaling factors.
  • fastvideo/configs/models/dits/init.py
    • Imports and exposes FluxConfig within the fastvideo.configs.models.dits package.
  • fastvideo/configs/models/dits/flux.py
    • Adds FluxArchConfig and FluxConfig dataclasses, detailing the architecture and default parameters for Flux Diffusion Transformers.
  • fastvideo/configs/models/vaes/init.py
    • Imports and exposes FluxVAEConfig and Flux2VAEConfig within the fastvideo.configs.models.vaes package.
  • fastvideo/configs/models/vaes/fluxvae.py
    • Adds FluxVAEConfig and Flux2VAEConfig dataclasses, specifying VAE architectures and parameters tailored for Flux models.
  • fastvideo/configs/pipelines/base.py
    • Introduces DataType enum for IMAGE and VIDEO.
    • Adds ModelTaskType enum to classify diffusion model tasks (e.g., T2I, T2V, I2I), including helper methods for task type properties.
    • Adds shard_rotary_emb_for_sp as a placeholder for sequence-parallel sharding of rotary embeddings.
    • Introduces ImagePipelineConfig as an alias for image-focused pipelines.
  • fastvideo/configs/pipelines/flux.py
    • Adds FluxPipelineConfig for the base Flux T2I pipeline, defining text encoders (T5, CLIP), VAE, and custom timestep shifting logic.
    • Introduces Flux2PipelineConfig for Flux2 (TI2I), featuring Mistral text encoder, specific latent packing/unpacking, and dynamic VAE decoding based on BatchNorm presence.
    • Adds Flux2KleinPipelineConfig for the distilled Flux2Klein model, using Qwen3 text encoder and disabling guidance.
    • Includes utility functions like _patchify_latents, _unpatchify_latents, _prepare_latent_ids, _prepare_text_ids, and _prepare_image_ids for latent and ID manipulation.
  • fastvideo/configs/pipelines/registry.py
    • Registers FluxT2IConfig for the model ID "black-forest-labs/FLUX.1-dev".
    • Adds a detector for "flux" in model IDs to automatically select the Flux pipeline configuration.
    • Sets FluxT2IConfig as the fallback pipeline configuration for the "flux" architecture.
  • fastvideo/configs/sample/flux.py
    • Adds FluxSamplingParam dataclass, defining default sampling parameters for Flux T2I models, including image dimensions, inference steps, and guidance scale.
  • fastvideo/configs/sample/registry.py
    • Registers FluxSamplingParam for the model ID "black-forest-labs/FLUX.1-dev".
    • Adds a detector for "flux" in model IDs to automatically select the Flux sampling parameters.
    • Sets FluxSamplingParam as the fallback sampling parameter configuration for the "flux" architecture.
  • fastvideo/envs.py
    • Adds environment variables FASTVIDEO_PERF_LOG_DIR, FASTVIDEO_DIFFUSION_STAGE_LOGGING, and FASTVIDEO_DIFFUSION_SYNC_STAGE_PROFILING to control performance logging and CUDA synchronization during profiling.
  • fastvideo/hooks/layerwise_offload.py
    • Adds a backward-compatible re-export of OffloadableDiTMixin from fastvideo.models.layerwise_offload.
  • fastvideo/layers/activation.py
    • Adds "gelu_tanh" activation function (GELU with approximate="tanh") to the activation registries.
  • fastvideo/layers/rotary_embedding.py
    • Modifies get_1d_rotary_pos_embed to accept a device argument, allowing torch.arange calls to create tensors directly on the specified device.
  • fastvideo/models/dits/flux.py
    • Adds FluxTransformer2DModel, the core architecture for Flux, including FluxAttention, FluxSingleTransformerBlock, and FluxTransformerBlock.
    • Introduces FluxPosEmbed for handling rotary positional embeddings.
    • Integrates ColumnParallelLinear, RMSNorm, and various AdaLayerNorm types.
    • Defines the forward pass logic for processing image latents, text embeddings, timesteps, and pooled projections.
  • fastvideo/models/layerwise_offload.py
    • Extends LayerwiseOffloadManager with methods for preparing for next requests, loading/unloading all layers, and enabling/disabling offloading.
    • Introduces _MultiLayerwiseOffloadManager to coordinate multiple LayerwiseOffloadManager instances.
    • Adds OffloadableDiTMixin to provide a standardized way for DiT models to integrate layerwise offloading, including configuration and lifecycle management.
  • fastvideo/models/loader/component_loader.py
    • Introduces _filter_config_for_arch to filter HuggingFace model configurations, ensuring only relevant fields are passed to FastVideo's architecture configs.
    • Applies this filtering to ImageEncoderLoader, VAELoader, and DiTLoader when updating model architectures.
    • Adds explicit support for loading AutoencoderKL and AutoencoderTiny VAEs directly from Diffusers.
  • fastvideo/models/loader/fsdp_load.py
    • Extends ALLOWED_NEW_PARAM_PATTERNS to include "adaLN_modulation", preventing warnings/errors when loading models with these parameters under FSDP.
  • fastvideo/models/registry.py
    • Registers FluxTransformer2DModel as a recognized DiT model, mapping it to its module and class name.
  • fastvideo/perf_logger.py
    • Adds RequestTimings dataclass to track performance metrics for individual requests, including stage and step durations.
    • Implements StageProfiler as a context manager for precise timing of pipeline stages and denoising steps.
    • Provides PerformanceLogger for generating structured benchmark reports and logging request summaries to a file.
    • Includes utility functions for determining log directories, retrieving Git commit hashes, and checking the main process.
  • fastvideo/pipelines/basic/flux/flux_pipeline.py
    • Adds FluxPipeline, a new composed pipeline for Flux text-to-image diffusion.
    • Introduces FluxPromptEncodingStage to handle specific CLIP embedding processing for positive and negative prompts.
    • Configures the pipeline with standard stages: input validation, prompt encoding, conditioning, timestep preparation (using FluxTimestepPreparationStage), latent preparation, denoising, and decoding.
  • fastvideo/pipelines/pipeline_registry.py
    • Registers FluxPipeline in the pipeline class to module name mapping, enabling its discovery and loading.
  • fastvideo/pipelines/schedule_batch.py
    • Adds Req (ForwardBatch) dataclass to encapsulate the entire state of a diffusion request, including sampling parameters, embeddings, latents, and other pipeline-specific data.
    • Implements __getattr__ and __setattr__ for Req to delegate access to sampling_params fields.
    • Introduces OutputBatch dataclass to hold the final output (image/audio), trajectory data, and performance timings.
  • fastvideo/pipelines/stages/init.py
    • Imports and exposes FluxTimestepPreparationStage within the fastvideo.pipelines.stages package.
  • fastvideo/pipelines/stages/decoding.py
    • Refines _denormalize_latents to support latents_mean/latents_std from pipeline configuration and more flexible handling of VAE scaling/shifting factors.
    • Introduces preprocess_decoding and postprocess_decoding hooks, enabling custom logic for VAE input/output manipulation defined in pipeline configurations.
    • Adds an optional debug feature to save the first decoded frame as a PNG image.
  • fastvideo/pipelines/stages/denoising.py
    • Minor formatting change (added newline).
  • fastvideo/pipelines/stages/latent_preparation.py
    • Adds a prepare_latent_shape hook, allowing pipeline configurations to define custom logic for calculating the initial latent shape.
    • Introduces a pack_latents_for_denoising hook, enabling pipeline configurations to apply custom packing logic to latents before the denoising process.
  • fastvideo/pipelines/stages/timestep_preparation.py
    • Introduces helper functions _to_scalar, _compute_image_seq_len, and _compute_mu for dynamic timestep calculation.
    • Modifies TimestepPreparationStage to incorporate force_dynamic_shifting and dynamically calculate a mu parameter for the scheduler's set_timesteps method, based on image sequence length.
    • Adds FluxTimestepPreparationStage to apply a timestep_input_scale to the generated timesteps, specific to Flux models.
Activity
  • The pull request was created by sriraksharao.
  • No explicit comments or reviews have been made yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Flux model, which is a significant addition to the codebase. The changes are extensive, covering configurations, model definitions, pipeline stages, and various utilities. The effort to make pipeline stages more flexible with hooks is a commendable design choice. However, there are several areas that require attention. The new example file and the Flux model definition include a substantial amount of commented-out code that should be removed to improve clarity. A critical bug was identified in fastvideo/models/dits/flux.py where a class member is assigned twice, which needs to be addressed. Additionally, there are minor issues such as a duplicated import and an incorrect type hint that should be corrected. Overall, this is a valuable contribution that will be even better with some cleanup and a critical bug fix.

Comment on lines +814 to +826
self.ff = MLP(
input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu"
)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.ff_context = MLP(
input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu"
)

self.ff_context = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The class members self.ff and self.ff_context are each assigned twice. The second assignment overwrites the first one, which is likely a bug. You are defining them first as MLP and then immediately overwriting them with FeedForward. Please resolve this to use the intended implementation.

Comment on lines +1 to +26
# import os

# from fastvideo import VideoGenerator

# PROMPT = (
# "A cinematic portrait of a fox, 35mm film, soft light, gentle grain."
# )


# def main() -> None:
# generator = VideoGenerator.from_pretrained(
# os.environ.get("FLUX_MODEL_PATH", "black-forest-labs/FLUX.1-dev"),
# num_gpus=1,
# )

# output_path = "outputs_image/flux_basic/output_flux_t2i.mp4"
# generator.generate_video(
# prompt=PROMPT,
# output_path=output_path,
# save_video=True,
# )
# generator.shutdown()


# if __name__ == "__main__":
# main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out code appears to be an older version of the example. To improve the clarity and maintainability of this example file, it's best to remove this unused code.

"HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig",
"StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig",
"LongCatVideoConfig", "LTX2VideoConfig"
"LongCatVideoConfig","FluxConfig", "LTX2VideoConfig"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and adherence to common Python style guides (like PEP 8), a space should be added after the comma.

    "LongCatVideoConfig", "FluxConfig", "LTX2VideoConfig"


arch_config: DiTArchConfig = field(default_factory=FluxArchConfig)

prefix: str = "Flux" No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline at the end. It's a common convention to end files with a newline character.

Suggested change
prefix: str = "Flux"
prefix: str = "Flux"


def _unpack_latents_with_ids(
x: torch.Tensor, x_ids: torch.Tensor
) -> list[torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for _unpack_latents_with_ids is list[torch.Tensor], but the function returns a torch.Tensor due to torch.stack. The type hint should be corrected to torch.Tensor to match the actual return type.

Suggested change
) -> list[torch.Tensor]:
) -> torch.Tensor:

from fastvideo.configs.pipelines.cosmos2_5 import Cosmos25Config

from fastvideo.configs.pipelines.flux import FluxT2IConfig
from fastvideo.configs.pipelines.flux import FluxT2IConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import for FluxT2IConfig is duplicated. The redundant import should be removed to keep the code clean.

Comment on lines +1 to +487
# # SPDX-License-Identifier: Apache-2.0
# from typing import Any

# import torch
# import torch.nn as nn
# from einops import rearrange

# from fastvideo.attention import LocalAttention
# from fastvideo.configs.models.dits import FluxConfig
# from fastvideo.layers.activation import get_act_fn
# from fastvideo.layers.layernorm import FP32LayerNorm, RMSNorm
# from fastvideo.layers.linear import ReplicatedLinear
# from fastvideo.layers.mlp import MLP
# from fastvideo.layers.rotary_embedding import get_1d_rotary_pos_embed
# from fastvideo.models.dits.base import BaseDiT
# from fastvideo.platforms import AttentionBackendEnum


# def timestep_embedding(t: torch.Tensor,
# dim: int,
# max_period: int = 10000,
# time_factor: float = 1000.0) -> torch.Tensor:
# t = time_factor * t
# half = dim // 2
# freqs = torch.exp(-torch.log(torch.tensor(max_period, dtype=torch.float32)) *
# torch.arange(start=0, end=half, dtype=torch.float32) /
# half).to(t.device)
# args = t[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat(
# [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# if torch.is_floating_point(t):
# embedding = embedding.to(t)
# return embedding


# class MLPEmbedder(nn.Module):

# def __init__(self, in_dim: int, hidden_dim: int, dtype: torch.dtype | None):
# super().__init__()
# self.in_layer = ReplicatedLinear(in_dim,
# hidden_dim,
# bias=True,
# params_dtype=dtype)
# self.act = get_act_fn("silu")
# self.out_layer = ReplicatedLinear(hidden_dim,
# hidden_dim,
# bias=True,
# params_dtype=dtype)

# def forward(self, x: torch.Tensor) -> torch.Tensor:
# x, _ = self.in_layer(x)
# x = self.act(x)
# x, _ = self.out_layer(x)
# return x


# class QKNorm(nn.Module):

# def __init__(self, dim: int, dtype: torch.dtype | None):
# super().__init__()
# self.query_norm = RMSNorm(dim, eps=1e-6, dtype=dtype)
# self.key_norm = RMSNorm(dim, eps=1e-6, dtype=dtype)

# def forward(self, q: torch.Tensor, k: torch.Tensor,
# v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# q = self.query_norm(q)
# k = self.key_norm(k)
# return q.to(v.dtype), k.to(v.dtype)


# class SelfAttention(nn.Module):

# def __init__(self,
# dim: int,
# num_heads: int,
# qkv_bias: bool,
# dtype: torch.dtype | None,
# supported_attention_backends: tuple[AttentionBackendEnum,
# ...]):
# super().__init__()
# self.num_heads = num_heads
# head_dim = dim // num_heads

# self.qkv = ReplicatedLinear(dim,
# dim * 3,
# bias=qkv_bias,
# params_dtype=dtype)
# self.norm = QKNorm(head_dim, dtype=dtype)
# self.proj = ReplicatedLinear(dim,
# dim,
# bias=True,
# params_dtype=dtype)

# self.attn = LocalAttention(
# num_heads=num_heads,
# head_size=head_dim,
# supported_attention_backends=supported_attention_backends,
# )

# def forward(self, x: torch.Tensor,
# freqs_cis: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
# qkv, _ = self.qkv(x)
# q, k, v = rearrange(qkv,
# "b l (k h d) -> k b l h d",
# k=3,
# h=self.num_heads)
# q, k = self.norm(q, k, v)
# attn = self.attn(q, k, v, freqs_cis=freqs_cis)
# attn = attn.reshape(x.shape[0], x.shape[1], -1)
# out, _ = self.proj(attn)
# return out


# class Modulation(nn.Module):

# def __init__(self, dim: int, double: bool, dtype: torch.dtype | None):
# super().__init__()
# self.is_double = double
# self.multiplier = 6 if double else 3
# self.lin = ReplicatedLinear(dim,
# self.multiplier * dim,
# bias=True,
# params_dtype=dtype)

# def forward(self, vec: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor,
# torch.Tensor, torch.Tensor,
# torch.Tensor, torch.Tensor]:
# out, _ = self.lin(torch.nn.functional.silu(vec))
# chunks = out[:, None, :].chunk(self.multiplier, dim=-1)
# return chunks # shift/scale/gate tuples


# class DoubleStreamBlock(nn.Module):

# def __init__(self,
# hidden_size: int,
# num_heads: int,
# mlp_ratio: float,
# qkv_bias: bool,
# dtype: torch.dtype | None,
# supported_attention_backends: tuple[AttentionBackendEnum,
# ...],
# prefix: str = ""):
# super().__init__()
# mlp_hidden_dim = int(hidden_size * mlp_ratio)
# self.num_heads = num_heads
# self.hidden_size = hidden_size

# self.img_mod = Modulation(hidden_size, double=True, dtype=dtype)
# self.img_norm1 = FP32LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.img_attn = SelfAttention(hidden_size,
# num_heads,
# qkv_bias=qkv_bias,
# dtype=dtype,
# supported_attention_backends=
# supported_attention_backends)
# self.img_norm2 = FP32LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.img_mlp = MLP(hidden_size,
# mlp_hidden_dim,
# bias=True,
# act_type="gelu_tanh",
# dtype=dtype,
# prefix=f"{prefix}.img_mlp")

# self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype)
# self.txt_norm1 = FP32LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.txt_attn = SelfAttention(hidden_size,
# num_heads,
# qkv_bias=qkv_bias,
# dtype=dtype,
# supported_attention_backends=
# supported_attention_backends)
# self.txt_norm2 = FP32LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.txt_mlp = MLP(hidden_size,
# mlp_hidden_dim,
# bias=True,
# act_type="gelu_tanh",
# dtype=dtype,
# prefix=f"{prefix}.txt_mlp")

# def forward(
# self,
# img: torch.Tensor,
# txt: torch.Tensor,
# vec: torch.Tensor,
# freqs_cis: tuple[torch.Tensor, torch.Tensor],
# ) -> tuple[torch.Tensor, torch.Tensor]:
# img_shift1, img_scale1, img_gate1, img_shift2, img_scale2, img_gate2 = self.img_mod(
# vec)
# txt_shift1, txt_scale1, txt_gate1, txt_shift2, txt_scale2, txt_gate2 = self.txt_mod(
# vec)

# img_mod = self.img_norm1(img)
# img_mod = (1 + img_scale1) * img_mod + img_shift1
# txt_mod = self.txt_norm1(txt)
# txt_mod = (1 + txt_scale1) * txt_mod + txt_shift1

# qkv_img, _ = self.img_attn.qkv(img_mod)
# img_q, img_k, img_v = rearrange(qkv_img,
# "b l (k h d) -> k b l h d",
# k=3,
# h=self.num_heads)
# img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

# qkv_txt, _ = self.txt_attn.qkv(txt_mod)
# txt_q, txt_k, txt_v = rearrange(qkv_txt,
# "b l (k h d) -> k b l h d",
# k=3,
# h=self.num_heads)
# txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# q = torch.cat((txt_q, img_q), dim=1)
# k = torch.cat((txt_k, img_k), dim=1)
# v = torch.cat((txt_v, img_v), dim=1)

# attn = self.img_attn.attn(q, k, v, freqs_cis=freqs_cis)
# txt_attn = attn[:, :txt.shape[1]]
# img_attn = attn[:, txt.shape[1]:]

# img = img + img_gate1 * self.img_attn.proj(img_attn.reshape(
# img.shape[0], img.shape[1], -1))[0]
# img = img + img_gate2 * self.img_mlp((1 + img_scale2) *
# self.img_norm2(img) +
# img_shift2)

# txt = txt + txt_gate1 * self.txt_attn.proj(txt_attn.reshape(
# txt.shape[0], txt.shape[1], -1))[0]
# txt = txt + txt_gate2 * self.txt_mlp((1 + txt_scale2) *
# self.txt_norm2(txt) +
# txt_shift2)
# return img, txt


# class SingleStreamBlock(nn.Module):

# def __init__(self,
# hidden_size: int,
# num_heads: int,
# mlp_ratio: float,
# dtype: torch.dtype | None,
# supported_attention_backends: tuple[AttentionBackendEnum,
# ...],
# prefix: str = ""):
# super().__init__()
# self.hidden_size = hidden_size
# self.num_heads = num_heads
# self.mlp_hidden_dim = int(hidden_size * mlp_ratio)

# self.linear1 = ReplicatedLinear(hidden_size,
# hidden_size * 3 + self.mlp_hidden_dim,
# bias=True,
# params_dtype=dtype)
# self.linear2 = ReplicatedLinear(hidden_size + self.mlp_hidden_dim,
# hidden_size,
# bias=True,
# params_dtype=dtype)
# self.norm = QKNorm(hidden_size // num_heads, dtype=dtype)
# self.pre_norm = FP32LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.mlp_act = get_act_fn("gelu_tanh")
# self.modulation = Modulation(hidden_size, double=False, dtype=dtype)

# self.attn = LocalAttention(
# num_heads=num_heads,
# head_size=hidden_size // num_heads,
# supported_attention_backends=supported_attention_backends,
# )

# def forward(self, x: torch.Tensor, vec: torch.Tensor,
# freqs_cis: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
# mod_shift, mod_scale, mod_gate = self.modulation(vec)[:3]
# x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
# linear1_out, _ = self.linear1(x_mod)
# qkv, mlp = torch.split(
# linear1_out, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
# q, k, v = rearrange(qkv,
# "b l (k h d) -> k b l h d",
# k=3,
# h=self.num_heads)
# q, k = self.norm(q, k, v)
# attn = self.attn(q, k, v, freqs_cis=freqs_cis)
# attn = attn.reshape(x.shape[0], x.shape[1], -1)
# out, _ = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=-1))
# return x + mod_gate * out


# class LastLayer(nn.Module):

# def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
# super().__init__()
# self.norm_final = nn.LayerNorm(hidden_size,
# elementwise_affine=False,
# eps=1e-6)
# self.linear = nn.Linear(hidden_size,
# patch_size * patch_size * out_channels,
# bias=True)
# self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
# nn.Linear(hidden_size, 2 * hidden_size, bias=True),
# )

# def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
# shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
# x = self.linear(x)
# return x


# def _build_freqs_from_ids(ids: torch.Tensor, axes_dim: list[int],
# theta: float) -> tuple[torch.Tensor, torch.Tensor]:
# ids_0 = ids[0] # [S, A]
# cos_list = []
# sin_list = []
# for i, dim in enumerate(axes_dim):
# cos, sin = get_1d_rotary_pos_embed(dim, ids_0[:, i], theta=theta)
# cos_list.append(cos)
# sin_list.append(sin)
# cos = torch.cat(cos_list, dim=1)
# sin = torch.cat(sin_list, dim=1)
# return cos, sin


# def _build_ids_from_grid(height: int, width: int, n_axes: int,
# device: torch.device) -> torch.Tensor:
# grid_y = torch.arange(height, device=device)
# grid_x = torch.arange(width, device=device)
# yy, xx = torch.meshgrid(grid_y, grid_x, indexing="ij")
# yy = yy.reshape(-1)
# xx = xx.reshape(-1)
# if n_axes == 1:
# ids = torch.arange(height * width, device=device)[:, None]
# elif n_axes >= 2:
# extra = []
# if n_axes > 2:
# extra = [torch.zeros_like(yy) for _ in range(n_axes - 2)]
# ids = torch.stack([yy, xx, *extra], dim=-1)
# return ids


# class FluxTransformer2DModel(BaseDiT):
# _fsdp_shard_conditions = FluxConfig()._fsdp_shard_conditions
# _compile_conditions = FluxConfig()._compile_conditions
# _supported_attention_backends = FluxConfig()._supported_attention_backends
# param_names_mapping = FluxConfig().param_names_mapping
# reverse_param_names_mapping = FluxConfig().reverse_param_names_mapping
# lora_param_names_mapping = FluxConfig().lora_param_names_mapping

# def __init__(self, config: FluxConfig, hf_config: dict[str, Any]):
# super().__init__(config=config, hf_config=hf_config)
# dtype = getattr(config, "dtype", None)

# self.in_channels = config.in_channels
# self.out_channels = config.out_channels
# self.hidden_size = config.hidden_size
# self.num_attention_heads = config.num_attention_heads
# self.num_channels_latents = config.num_channels_latents

# self.vec_in_dim = config.pooled_projection_dim
# self.context_in_dim = config.joint_attention_dim
# self.axes_dim = list(config.rope_axes_dim)
# self.theta = config.rope_theta
# self.guidance_embed = config.guidance_embeds
# self.mlp_ratio = config.mlp_ratio
# self.qkv_bias = getattr(config, "qkv_bias", False)

# self.img_in = ReplicatedLinear(self.in_channels,
# self.hidden_size,
# bias=True,
# params_dtype=dtype)
# self.time_in = MLPEmbedder(in_dim=256,
# hidden_dim=self.hidden_size,
# dtype=dtype)
# self.vector_in = MLPEmbedder(in_dim=self.vec_in_dim,
# hidden_dim=self.hidden_size,
# dtype=dtype)
# self.guidance_in = (MLPEmbedder(
# in_dim=256, hidden_dim=self.hidden_size, dtype=dtype)
# if self.guidance_embed else nn.Identity())
# self.txt_in = ReplicatedLinear(self.context_in_dim,
# self.hidden_size,
# bias=True,
# params_dtype=dtype)

# self.double_blocks = nn.ModuleList([
# DoubleStreamBlock(
# self.hidden_size,
# self.num_attention_heads,
# mlp_ratio=self.mlp_ratio,
# qkv_bias=self.qkv_bias,
# dtype=dtype,
# supported_attention_backends=self._supported_attention_backends,
# prefix=f"{config.prefix}.double_blocks.{i}",
# ) for i in range(config.num_layers)
# ])

# self.single_blocks = nn.ModuleList([
# SingleStreamBlock(
# self.hidden_size,
# self.num_attention_heads,
# mlp_ratio=self.mlp_ratio,
# dtype=dtype,
# supported_attention_backends=self._supported_attention_backends,
# prefix=f"{config.prefix}.single_blocks.{i}",
# ) for i in range(config.num_single_layers)
# ])

# self.final_layer = LastLayer(self.hidden_size,
# patch_size=1,
# out_channels=self.out_channels)

# self.__post_init__()

# def forward(
# self,
# hidden_states: torch.Tensor,
# encoder_hidden_states: torch.Tensor | list[torch.Tensor],
# timestep: torch.LongTensor,
# encoder_hidden_states_2: torch.Tensor | None = None,
# img_ids: torch.Tensor | None = None,
# txt_ids: torch.Tensor | None = None,
# guidance: torch.Tensor | None = None,
# **kwargs,
# ) -> torch.Tensor:
# if hidden_states.ndim != 5:
# raise ValueError(
# "FluxTransformer2DModel expects hidden_states with shape [B, C, T, H, W]"
# )

# img = rearrange(hidden_states, "b c t h w -> b (t h w) c")
# txt = encoder_hidden_states
# if isinstance(txt, list):
# txt = txt[0]

# y = encoder_hidden_states_2
# if y is None:
# y = torch.zeros(txt.shape[0],
# self.vec_in_dim,
# device=txt.device,
# dtype=txt.dtype)

# img, _ = self.img_in(img)
# vec = self.time_in(timestep_embedding(timestep, 256))
# if self.guidance_embed:
# if guidance is None:
# raise ValueError(
# "Guidance value is required for guidance-distilled Flux.")
# vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
# vec = vec + self.vector_in(y)
# txt, _ = self.txt_in(txt)

# bsz, txt_len, _ = txt.shape
# _, img_len, _ = img.shape
# if txt_ids is None:
# txt_ids = torch.zeros(bsz,
# txt_len,
# len(self.axes_dim),
# device=txt.device)
# if img_ids is None:
# _, _, _, h, w = hidden_states.shape
# ids = _build_ids_from_grid(h, w, len(self.axes_dim), txt.device)
# img_ids = ids.unsqueeze(0).expand(bsz, -1, -1)

# ids = torch.cat((txt_ids, img_ids), dim=1)
# freqs_cis = _build_freqs_from_ids(ids, self.axes_dim, self.theta)

# for block in self.double_blocks:
# img, txt = block(img=img, txt=txt, vec=vec, freqs_cis=freqs_cis)

# img = torch.cat((txt, img), 1)
# for block in self.single_blocks:
# img = block(img, vec=vec, freqs_cis=freqs_cis)
# img = img[:, txt.shape[1]:, ...]

# img = self.final_layer(img, vec)
# img = rearrange(img, "b (t h w) c -> b c t h w", t=1, h=h, w=w)
# return img
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file contains a very large block of commented-out code. This code should be removed to improve readability and reduce clutter in the file.

Removed commented-out code and retained the core functionality for video generation.
@sriraksharao sriraksharao marked this pull request as draft February 7, 2026 03:40
@sriraksharao sriraksharao marked this pull request as ready for review February 7, 2026 03:44
OUTPUT_PATH = "video_samples"


def _print_frame_matrix(frames, label: str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove if it's for debug

dit_cpu_offload=False,
vae_cpu_offload=False,
text_encoder_cpu_offload=True,
pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set workload_type arg to save result as a image

NONE = None


class DataType(str, Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this, use FastVideoArgs.workload_type

self.__post_init__()


@dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the empty class.

flux_max_shift: float = 1.15
flux_shift: bool = True

task_type: ModelTaskType = ModelTaskType.T2I
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Use WorkloadType


import torch
import torch.nn as nn
from diffusers.models.attention import AttentionModuleMixin, FeedForward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not use any module in diffusers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please refactor this file refer to #1075


decode_latents = latents
decode_ctx = None
preprocess = getattr(fastvideo_args.pipeline_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a separate stage instead of a callable in config

image = image.sample
postprocess = getattr(fastvideo_args.pipeline_config,
"postprocess_decoding", None)
if callable(postprocess):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants