Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
91de4b3
Add RewardModel and PrefGRPORewardModel classes for reward computation
LouisRouss Sep 10, 2025
db14e8a
Refactor RewardModel and PrefGRPORewardModel to enhance image handlin…
LouisRouss Sep 11, 2025
02e1ca8
Add return_latents option to Diffuser's denoise method for latent rep…
LouisRouss Sep 13, 2025
2ff6348
Add attribute delegation and enhanced dir() support to Diffuser class
LouisRouss Sep 13, 2025
b54d0c7
Fix dtype argument in model initialization
LouisRouss Sep 13, 2025
7e60ff7
Add one_step_denoise_grpo method for GRPO training in Flow class
LouisRouss Sep 13, 2025
b0240f0
Refactor training classes to use a common trainer and reorganize impo…
LouisRouss Sep 16, 2025
f60fd65
Add GRPO support to Diffuser and Flow classes with new methods and ut…
LouisRouss Sep 16, 2025
2a3ffc8
Enhance RewardModel and PrefGRPORewardModel with n_image_per_prompt s…
LouisRouss Sep 16, 2025
43f6cd3
Add GRPO support with new BatchData structures and update training cl…
LouisRouss Sep 19, 2025
6b014f7
fix typing
LouisRouss Sep 19, 2025
fd4252a
fix loss calculation grpo flow
LouisRouss Sep 20, 2025
ba8f074
Refactor loss computation in Flow class to use a list for step-wise l…
LouisRouss Sep 20, 2025
10a823d
Refactor trainer imports and implement validation step in GRPOTrainer
LouisRouss Sep 20, 2025
7519c8c
Finish GRPO training loop and fix epoch level scheduler logic
LouisRouss Sep 20, 2025
2f1940e
Add clip in reward model
LouisRouss Sep 21, 2025
d442082
Implement StepResult and Sampler classes for diffusion process; add E…
LouisRouss Sep 23, 2025
693e403
adapt to abstraction sampler and clean GRPO logic
LouisRouss Sep 23, 2025
342982a
Refactor ContextEmbedder to implement properties for n_output and out…
LouisRouss Sep 23, 2025
f7c4200
Refactor PrefGRPORewardModel to standardize clip model ID usage and i…
LouisRouss Sep 23, 2025
dc033ab
Refactor sampler classes to standardize set_steps method for improved…
LouisRouss Sep 25, 2025
8f17901
Add DDIM and DDPM sampler implementations with step and parameter set…
LouisRouss Sep 25, 2025
2a9847a
Refactor Flow and EulerMaruyama classes for improved parameter handli…
LouisRouss Sep 25, 2025
c64db67
improve tensor handling and device compatibility in flow and euler me…
LouisRouss Sep 25, 2025
2d476d4
- Refactor model input handling in Diffuser, Flow, and GRPOTrainer cl…
LouisRouss Sep 27, 2025
d2425fb
Add a generic abstract sampler class over modelization specific sampl…
LouisRouss Sep 27, 2025
0b9e159
Refactor diffusion model classes to standardize sampler initializatio…
LouisRouss Sep 27, 2025
3c81ce9
update docstring
LouisRouss Sep 27, 2025
854d82a
Refactor denoise method signatures in Diffuser, Flow, and GaussianDif…
LouisRouss Sep 27, 2025
89d8c90
Allow MMDiT to use a context embedder without pooled embedding
LouisRouss Sep 28, 2025
d079027
- Add loguru dependency
LouisRouss Sep 28, 2025
4ae13d2
Refactor preprocess method in DinoV2
LouisRouss Sep 28, 2025
8838fba
Enhance input validation in encode method of DCAE class to support ad…
LouisRouss Sep 28, 2025
ee3d906
Implement DDT architecture and refactor modulation classes for enhanc…
LouisRouss Sep 29, 2025
562f1f3
add dinoV3 and precompute functions
LouisRouss Oct 1, 2025
2b8a642
Refactor SD3TextEmbedder to improve type casting and add attention ma…
LouisRouss Oct 5, 2025
5c24b79
Update step method docstring in Euler and EulerMaruyama classes to re…
LouisRouss Oct 22, 2025
20945ec
add dependencies
LouisRouss Oct 22, 2025
9960d67
improve attn unet
LouisRouss Oct 22, 2025
94c9c44
Refactor ContextEmbedder to use ContextEmbedderOutput for forward method
LouisRouss Oct 22, 2025
d333dc5
Enhance MMDiTAttention and MMDiTBlock to support attention masks and …
LouisRouss Oct 26, 2025
da43da7
- rename mask to attn mask in context embedder output
LouisRouss Oct 26, 2025
ee8f7d0
Update DDT to utilize ContextEmbedderOutput for improved context hand…
LouisRouss Oct 26, 2025
5aed411
use transformers instead of open clip
LouisRouss Oct 26, 2025
ba5028a
Add attention mask to U-Net and use torch scaled do product attn
LouisRouss Oct 26, 2025
59b1d1d
finish forward method of PrefGRPORewardModel
LouisRouss Oct 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ share/python-wheels/
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

Expand Down Expand Up @@ -83,36 +81,12 @@ notebooks/
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
# PEP 582
__pypackages__/

# Celery stuff
Expand Down Expand Up @@ -155,13 +129,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# running logs
examples/wandb
outputs/
Expand Down
4 changes: 2 additions & 2 deletions examples/train_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer


@hydra.main(version_base=None, config_path="../configs", config_name="train_mnist_flow_matching")
Expand Down Expand Up @@ -53,7 +53,7 @@ def count_parameters(model: torch.nn.Module) -> int:
)

# TODO: add a run name for wandb
trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand Down
4 changes: 2 additions & 2 deletions examples/train_repa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer
from diffulab.training.losses.repa import RepaLoss


Expand Down Expand Up @@ -77,7 +77,7 @@ def count_parameters(model: torch.nn.Module) -> int:
+ list(repa_loss.resampler.parameters() if repa_loss.resampler else []),
)

trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand Down
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
[project]
name = "diffulab"
version = "0.1.0"
description = "Add your description here"
description = "DiffuLab is designed to provide a simple and flexible way to train diffusion models while allowing full customization of its core components"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate>=1.4.0",
"blobfile>=3.1.0",
"diffusers>=0.33.1",
"einops>=0.8.1",
"ema-pytorch>=0.7.7",
"hydra-core>=1.3.2",
"jaxtyping>=0.3.0",
"loguru>=0.7.3",
"mosaicml-streaming>=0.12.0",
"omegaconf>=2.3.0",
"open-clip-torch>=2.30.0",
"pyopenssl==23.2.0",
"sentencepiece>=0.2.1",
"tiktoken>=0.11.0",
"torch>=2.6.0",
"transformers>=4.49.0",
"wandb>=0.19.6",
Expand All @@ -31,6 +34,9 @@ dev = [
repa = [
"timm>=1.0.15",
]
prefgrpo = [
"qwen-vl-utils>=0.0.11",
]

[tool.uv.sources]
diffulab = {workspace = true}
Expand Down Expand Up @@ -58,7 +64,7 @@ ignore = [
combine-as-imports = true

[tool.pyright]
include = ["src/diffulab", "tests/", "examples/"]
include = ["src/diffulab", "tests/", "examples/", "local/"]
strict = ["*"]
exclude = ["**/__pycache__"]
reportMissingTypeStubs = "warning"
4 changes: 3 additions & 1 deletion src/diffulab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .datasets import BaseDataset, CIFAR10Dataset, ImageNetLatentREPA, MNISTDataset
from .diffuse import Diffuser, Flow, GaussianDiffusion
from .networks import DCAE, REPA, Denoiser, DinoV2, MMDiT, PerceiverResampler, SD3TextEmbedder, UNetModel, VisionTower
from .training import LossFunction, RepaLoss, Trainer
from .training import BaseTrainer, GRPOTrainer, LossFunction, RepaLoss, Trainer

__all__ = [
"BaseDataset",
Expand All @@ -22,5 +22,7 @@
"VisionTower",
"LossFunction",
"RepaLoss",
"BaseTrainer",
"GRPOTrainer",
"Trainer",
]
9 changes: 7 additions & 2 deletions src/diffulab/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from torch import Tensor
from torch.utils.data import Dataset

from diffulab.networks.denoisers.common import ModelInput
from diffulab.networks.denoisers.common import ExtraInputGRPO, ModelInput, ModelInputGRPO


class BatchData(TypedDict, total=False):
model_inputs: Required[ModelInput]
extra: NotRequired[dict[str, Tensor | None]]
extra: NotRequired[dict[str, Tensor | list[str] | None]]


class BatchDataGRPO(TypedDict, total=False):
model_inputs: Required[ModelInputGRPO]
extra: Required[ExtraInputGRPO]


class BaseDataset(Dataset[BatchData], ABC):
Expand Down
64 changes: 40 additions & 24 deletions src/diffulab/diffuse/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from diffulab.diffuse.modelizations.diffusion import Diffusion
from diffulab.diffuse.modelizations.flow import Flow
from diffulab.diffuse.modelizations.gaussian_diffusion import GaussianDiffusion
from diffulab.diffuse.utils import SamplingOutput
from diffulab.networks.denoisers.common import Denoiser, ModelInput
from diffulab.networks.vision_towers.common import VisionTower
from diffulab.training.losses import LossFunction
Expand Down Expand Up @@ -96,71 +97,83 @@ def draw_timesteps(self, batch_size: int) -> Tensor:
def compute_loss(
self,
model_inputs: ModelInput,
timesteps: Tensor,
timesteps: Tensor | None = None,
noise: Tensor | None = None,
extra_args: dict[str, Any] = {},
grpo: bool = False,
grpo_args: dict[str, Any] = {},
) -> dict[str, Tensor]:
"""
Compute the loss for the diffusion model using the denoiser and diffusion process.
This method serves as a bridge between the Diffuser class and the underlying
diffusion implementation by forwarding the loss computation to the diffusion model.
Args:
model_inputs (ModelInput): A dictionary containing the model inputs,
including the data tensor keyed as 'x' and any conditional information.
including the data tensor keyed as 'x' and any conditional information. For GRPO,
x is not needed in the model inputs.
timesteps (Tensor): A tensor of timesteps for the batch.
noise (Tensor | None, optional): Pre-defined noise to add to the input.
If None, random noise will be generated. Defaults to None.
extra_args (dict[str, Any], optional): Additional arguments for the additional losses computation.
Returns:
dict[str, Tensor]: A dictionary containing the loss value and any additional losses
"""
return self.diffusion.compute_loss(self.denoiser, model_inputs, timesteps, noise, self.extra_losses, extra_args)
if grpo:
assert isinstance(self.diffusion, Flow), "GRPO loss computation is only available for Flow-based models"
return self.diffusion.compute_loss_grpo(
self.denoiser,
model_inputs, # type: ignore[reportArgumentType]
**grpo_args,
)
assert timesteps is not None, "timesteps must be provided for loss computation"
return self.diffusion.compute_loss(self.denoiser, model_inputs, timesteps, noise, self.extra_losses, extra_args) # type: ignore[reportArgumentType]

def set_steps(self, n_steps: int, **extra_args: dict[str, Any]) -> None:
def set_steps(self, n_steps: int, schedule: str = "linear") -> None:
"""
Update the number of diffusion steps and related parameters.
This method allows changing the number of steps used in the diffusion process
after the Diffuser has been initialized. It delegates to the underlying diffusion
model's set_steps method.
Args:
n_steps (int): The new number of diffusion steps to use.
**extra_args (dict[str, Any]): Additional arguments to pass to the diffusion
model's set_steps method. These may include parameters like 'schedule'
or 'section_counts' depending on the diffusion model implementation.
schedule (str, optional): The schedule to use for the timesteps. Defaults to "linear".
Example:
```
diffuser = Diffuser(denoiser, sampling_method="ddpm", n_steps=1000)
# Later, change to use fewer steps for faster sampling
diffuser.set_steps(100, schedule="ddim")
```
"""
self.diffusion.set_steps(n_steps, **extra_args) # type: ignore
self.diffusion.set_steps(n_steps, schedule=schedule)

def generate(
self,
data_shape: tuple[int, ...],
model_inputs: ModelInput,
data_shape: tuple[int, ...] | None = None,
use_tqdm: bool = True,
clamp_x: bool = False,
guidance_scale: float = 0,
**kwargs: dict[str, Any],
) -> Tensor:
sampler_args: dict[str, Any] = {},
return_intermediates: bool = False,
return_latents: bool = False,
) -> SamplingOutput:
"""
Generates a new sample using the diffusion model.
This method delegates to the underlying diffusion model's denoise method to
generate a sample from the diffusion process. It can handle conditional generation
with guidance when appropriate settings are provided.
Args:
data_shape (tuple[int, ...]): Shape of the data to generate, typically (batch_size, channels, height, width).
If a vision tower is used, this should be the shape of the latent space.
model_inputs (ModelInput): A dictionary containing inputs for the model, such as initial noise,
conditional information, or labels. If 'x' is not provided, random noise will be generated.
data_shape (tuple[int, ...], optional): The shape of the data to generate. If a vision tower is used,
this shape should correspond to the latent space shape.
Required if 'x' is not in model_inputs. Defaults to None.
use_tqdm (bool, optional): Whether to display a progress bar during generation. Defaults to True.
clamp_x (bool, optional): Whether to clamp the generated values to [-1, 1] range. Defaults to False.
guidance_scale (float, optional): Scale for classifier or classifier-free guidance.
Values greater than 0 enable guidance. Defaults to 0.
**kwargs (dict[str, Any]): Additional arguments to pass to the diffusion model's denoise method.
These may include parameters like 'classifier', 'classifier_free', etc.
return_latents (bool, optional): Whether to return the latent representation when using a
vision tower instead of decoded data. Defaults to False.
Returns:
Tensor: The generated data tensor.
Example:
Expand All @@ -177,24 +190,27 @@ def generate(
```
"""
if self.vision_tower:
z = self.diffusion.denoise(
sampling_output = self.diffusion.denoise(
self.denoiser,
data_shape,
model_inputs,
model_inputs=model_inputs,
data_shape=data_shape,
use_tqdm=use_tqdm,
clamp_x=clamp_x,
guidance_scale=guidance_scale,
**kwargs,
sampler_args=sampler_args,
return_intermediates=return_intermediates,
)
z = z / self.latent_scale
return self.vision_tower.decode(z)
if not return_latents:
sampling_output["x"] = self.vision_tower.decode(sampling_output["x"] / self.latent_scale)
return sampling_output

return self.diffusion.denoise(
self.denoiser,
data_shape,
model_inputs,
model_inputs=model_inputs,
data_shape=data_shape,
use_tqdm=use_tqdm,
clamp_x=clamp_x,
guidance_scale=guidance_scale,
**kwargs,
sampler_args=sampler_args,
return_intermediates=return_intermediates,
)
Loading
Loading