Skip to content

Conversation

DavidBert
Copy link

@DavidBert DavidBert commented Oct 9, 2025

This commit adds support for the Photon image generation model:

  • PhotonTransformer2DModel: Core transformer architecture
  • PhotonPipeline: Text-to-image generation pipeline
  • Attention processor updates for Photon-specific attention mechanism
  • Conversion script for loading Photon checkpoints
  • Documentation and tests

Some exemples below with the 512 model fine-tuned on the Alchemist dataset and distilled with PAG

image_10 image_4 image_0 image_1

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests
print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

print(f"✓ Saved VAE to {vae_path}")


def download_and_save_text_encoder(output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the Text Encoder.

print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.

Also, it would be great to see some samples of Photon!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left a couple more comments. Let's also add the pipeline-level tests.

<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>

Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @stevhliu for a review on the docs.

return xq_out.reshape(*xq.shape).type_as(xq)


class PhotonAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we write it in a fashion similar to

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this suggestion - in particular, I think it would be more in line with other diffusers models implementations to reuse the layers defined in Attention, such as to_q/to_k/to_v, etc. instead of defining them in PhotonBlock (e.g. PhotonBlock.img_qkv_proj), and to keep the entire attention implementation in the PhotonAttnProcessor2_0 class.

Attention supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention, you could potentially add it in there or create a new Attention-style class like Flux does:

class FluxAttention(torch.nn.Module, AttentionModuleMixin):

def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support passing prompt embeddings too in case users want to supply them precomputed:

prompt_embeds: Optional[torch.FloatTensor] = None,

Comment on lines 484 to 486
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
height = height or default_sample_size
width = width or default_sample_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer this pattern:

height = height or self.default_sample_size * self.vae_scale_factor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.

Copy link
Member

@sayakpaul sayakpaul Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good point! I think we could make a small utility function in the pipeline class that determines the default resolution given the VAE that's loaded into it? WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, way cleaner! I did it.

@DavidBert DavidBert requested a review from sayakpaul October 15, 2025 13:40
@sayakpaul sayakpaul requested review from dg845 and stevhliu and removed request for sayakpaul October 15, 2025 15:05
@DavidBert DavidBert requested a review from sayakpaul October 15, 2025 15:18
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the docs, remember to add it to the toctree as well!

# See the License for the specific language governing permissions and
# limitations under the License. -->

# PhotonPipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# PhotonPipeline
# Photon

Comment on lines +18 to +26
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.

Key features:

- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks
- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling
- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels)
- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support
- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
Key features:
- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks
- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling
- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels)
- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support
- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.

Comment on lines +28 to +33
## Available models:
We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions.
Both **fine-tuned** and **non-fine-tuned** versions are available:

- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions.
- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Available models:
We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions.
Both **fine-tuned** and **non-fine-tuned** versions are available:
- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions.
- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**.
## Available models
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.


Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.

## Loading the Pipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Loading the Pipeline
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].


### Manual Component Loading

You can also load components individually:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to demonstrate why you're loading components individually, for example, it could be for quantization

Suggested change
You can also load components individually:
Load components individually to ...

pipe.to("cuda")
```

## VAE Variants
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAE section can be removed since its already mentioned in the first paragraph.


The VAE type is automatically determined from the checkpoint's `model_index.json` configuration.

## Generation Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section can also be removed since its safe to assume a user has some background knowledge of using a text-to-image pipeline

return self.out_layer(self.silu(self.in_layer(x)))


class QKNorm(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reusing the QK norm implementation in Attention; I believe setting qk_norm == "rms_norm" should be equivalent:

elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)


# img qkv
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the layers from the Attention instance (self.attention) rather than defining them here would be more idiomatic in diffusers. See also https://github.com/huggingface/diffusers/pull/12456/files#r2434379626.


def forward(
self,
img: Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more clear (in the diffusers context) to adopt the usual naming scheme in FluxTransformerBlock, WanTransformerBlock, etc.:

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

so something like img --> hidden_states, txt --> encoder_hidden_states, vec --> temb, pe --> image_rotary_emb, etc.

attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp

# Inline attention forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest putting all of the attention implementation logic in PhotonAttnProcessor2_0 (see https://github.com/huggingface/diffusers/pull/12456/files#r2434379626)

self._guidance_scale = guidance_scale

# 2. Encode input prompt
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it would be a little cleaner if text_embeddings was named prompt_embeds and uncond_text_embeddings was named negative_prompt_embeds here.

Comment on lines +265 to +266
self.text_encoder = text_encoder
self.tokenizer = tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting these attributes explicitly shouldn't be necessary since the register_modules call below should handle that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants