diff --git a/src/art/backend.py b/src/art/backend.py index 9fa95c0e..473681a0 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal import httpx from tqdm import auto as tqdm @@ -8,8 +8,8 @@ from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider from . import dev -from .trajectories import TrajectoryGroup -from .types import TrainConfig +from .trajectories import Trajectory, TrajectoryGroup +from .types import SFTConfig, TrainConfig if TYPE_CHECKING: from .model import Model, TrainableModel @@ -126,6 +126,21 @@ async def _train_model( if pbar is not None: pbar.close() + async def _train_sft( + self, + model: "TrainableModel", + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented. " + "This method will be available in a future release." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + # ------------------------------------------------------------------ # Experimental support for S3 # ------------------------------------------------------------------ diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py index b60525d9..6257135f 100644 --- a/src/art/dev/__init__.py +++ b/src/art/dev/__init__.py @@ -7,7 +7,7 @@ ) from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config from .torchtune import TorchtuneArgs -from .train import TrainConfig +from .train import SFTConfig, TrainConfig __all__ = [ "EngineArgs", @@ -18,6 +18,7 @@ "get_openai_server_config", "OpenAIServerConfig", "ServerArgs", + "SFTConfig", "TorchtuneArgs", "TrainConfig", ] diff --git a/src/art/dev/train.py b/src/art/dev/train.py index f6491b15..6a540a9f 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -22,3 +22,8 @@ class TrainConfig(TypedDict, total=False): scale_learning_rate_by_reward_std_dev: bool scale_rewards: bool truncated_importance_sampling: float | None + + +class SFTConfig(TypedDict, total=False): + """Experimental SFT configuration options. Use at your own risk.""" + pass diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 13a906b4..ef1e2e3a 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -5,7 +5,7 @@ import subprocess from datetime import datetime from types import TracebackType -from typing import AsyncIterator, Literal, cast +from typing import AsyncIterator, Iterable, Literal, cast import aiohttp import numpy as np @@ -54,7 +54,7 @@ ) from ..preprocessing.tokenize import tokenize_trajectory_groups from ..trajectories import Trajectory, TrajectoryGroup -from ..types import Message, TrainConfig +from ..types import Message, SFTConfig, TrainConfig from ..utils import format_message, get_model_step from .checkpoints import ( delete_checkpoints, @@ -521,6 +521,21 @@ async def _train_model( if verbose: print("_train_model complete") + async def _train_sft( + self, + model: TrainableModel, + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented for LocalBackend. " + "Please use the Backend HTTP API or implement this method." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + def _get_reward_std_dev_learning_rate_multiplier( self, model: TrainableModel ) -> float: diff --git a/src/art/model.py b/src/art/model.py index 43c519b2..ba88601d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -7,7 +7,7 @@ from . import dev from .trajectories import Trajectory, TrajectoryGroup -from .types import TrainConfig +from .types import SFTConfig, TrainConfig if TYPE_CHECKING: from art.backend import Backend @@ -386,3 +386,25 @@ async def train( self, list(trajectory_groups), config, _config or {}, verbose ): pass + + async def train_sft( + self, + trajectories: Iterable[Trajectory], + config: SFTConfig, + _config: dev.SFTConfig | None = None, + verbose: bool = False, + ) -> None: + """ + Supervised fine-tune the model with an iterable of trajectories. + + Args: + trajectories: An iterable of Trajectory objects. + config: SFT configuration including learning_rates and batch_size. + _config: Additional experimental configuration that is subject to change and + not yet part of the public API. Use at your own risk. + verbose: Whether to print verbose output. + """ + async for _ in self.backend()._train_sft( + self, trajectories, config, _config or {}, verbose + ): + pass diff --git a/src/art/preprocessing/tokenize_sft.py b/src/art/preprocessing/tokenize_sft.py new file mode 100644 index 00000000..f7194219 --- /dev/null +++ b/src/art/preprocessing/tokenize_sft.py @@ -0,0 +1,168 @@ +"""Tokenization utilities for Supervised Fine-Tuning (SFT).""" + +import math +from dataclasses import dataclass +from typing import Generator + +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from ..trajectories import Trajectory + + +@dataclass +class SFTBatch: + """A batch of tokenized trajectories for supervised fine-tuning. + + Attributes: + trajectory_tensors: List of tensor dictionaries, one per trajectory. + Each dict contains 'input_ids', 'attention_mask', and 'labels'. + learning_rate: Learning rate to use for this batch. + num_trajectories: Number of trajectories in this batch. + num_trainable_tokens: Total number of tokens being trained on (labels != -100). + """ + trajectory_tensors: list[dict[str, torch.Tensor]] + learning_rate: float + num_trajectories: int + num_trainable_tokens: int + + +def tokenize_sft_batches( + trajectories: list[Trajectory], + batch_size: int, + learning_rates: list[float], + tokenizer: PreTrainedTokenizerBase, + instruction_part: str, + response_part: str, +) -> Generator[SFTBatch, None, None]: + """ + Tokenize trajectories into batches for supervised fine-tuning. + + Args: + trajectories: Flat list of trajectories + batch_size: Number of trajectories per batch + learning_rates: Learning rate for each batch + tokenizer: Tokenizer to use for encoding + instruction_part: Instruction template part (e.g., "User:") + response_part: Response template part (e.g., "Assistant:") + + Yields: + SFTBatch object containing: + - trajectory_tensors: List of tensors for each trajectory + - learning_rate: Learning rate for this batch + - num_trajectories: Number of trajectories in this batch + - num_trainable_tokens: Total number of trainable tokens + """ + # Validate inputs + num_trajectories = len(trajectories) + num_learning_rates = len(learning_rates) + expected_num_batches = math.ceil(num_trajectories / batch_size) + + if num_learning_rates != expected_num_batches: + raise ValueError( + f"Mismatch between trajectories and learning_rates: " + f"{num_trajectories} trajectories with batch_size={batch_size} " + f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" + ) + + instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids + response_ids = tokenizer(response_part, add_special_tokens=False).input_ids + instruction_length = len(instruction_ids) + response_length = len(response_ids) + max_template_length = max(instruction_length, response_length) + + def _train_on_responses_only(input_ids: list[int]) -> list[int]: + labels = [-100] * len(input_ids) + m = len(input_ids) - max_template_length + first_response = response_ids[0] + first_instruction = instruction_ids[0] + j = 0 + + while j < m: + if input_ids[j] == first_response: + if input_ids[j : j + response_length] == response_ids: + j = j + response_length + start = j + while j < m: + if input_ids[j] == first_instruction and input_ids[j : j + instruction_length] == instruction_ids: + j = j + instruction_length + labels[start : j] = input_ids[start : j] + break + elif j == (m - 1): + j = m + labels[start:] = input_ids[start:] + break + j += 1 + j += 1 + + return labels + + # Batch trajectories + for batch_idx, lr in enumerate(learning_rates): + start_idx = batch_idx * batch_size + end_idx = start_idx + batch_size + trajectory_batch = trajectories[start_idx:end_idx] + + # First pass: tokenize all trajectories + tokenized_trajectories = [] + for trajectory in trajectory_batch: + messages = trajectory.messages_and_choices + tools = trajectory.tools + + # Single-step tokenization: apply_chat_template with tokenize=True + input_ids = tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=True, + add_generation_prompt=False + ) + + # Create attention mask (all 1s - no padding yet) + attention_mask = [1] * len(input_ids) + + labels = _train_on_responses_only(input_ids) + + tokenized_trajectories.append({ + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': labels, + }) + + # Find max length in this batch for padding + max_seq_length = max(len(t['input_ids']) for t in tokenized_trajectories) + + # Second pass: pad all trajectories to max_seq_length + trajectory_tensors = [] + for tokenized in tokenized_trajectories: + input_ids = tokenized['input_ids'] + attention_mask = tokenized['attention_mask'] + labels = tokenized['labels'] + + # Pad to max_seq_length + padding_length = max_seq_length - len(input_ids) + if padding_length > 0: + input_ids = input_ids + [tokenizer.pad_token_id] * padding_length + attention_mask = attention_mask + [0] * padding_length + labels = labels + [-100] * padding_length + + trajectory_tensor = { + 'input_ids': torch.tensor([input_ids], dtype=torch.long), + 'attention_mask': torch.tensor([attention_mask], dtype=torch.long), + 'labels': torch.tensor([labels], dtype=torch.long), + } + + trajectory_tensors.append(trajectory_tensor) + + # Calculate total trainable tokens (labels != -100) + num_trainable_tokens = sum( + (tensor_dict['labels'] != -100).sum().item() + for tensor_dict in trajectory_tensors + ) + + yield SFTBatch( + trajectory_tensors=trajectory_tensors, + learning_rate=lr, + num_trajectories=len(trajectory_tensors), + num_trainable_tokens=num_trainable_tokens, + ) + diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 604faea5..a07ae789 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal from openai._types import NOT_GIVEN from tqdm import auto as tqdm @@ -9,8 +9,8 @@ from .. import dev from ..backend import Backend -from ..trajectories import TrajectoryGroup -from ..types import TrainConfig +from ..trajectories import Trajectory, TrajectoryGroup +from ..types import SFTConfig, TrainConfig if TYPE_CHECKING: from ..model import Model, TrainableModel @@ -159,6 +159,21 @@ async def _train_model( raise RuntimeError(f"Training job failed: {error_message}") after = event.id + async def _train_sft( + self, + model: "TrainableModel", + trajectories: Iterable[Trajectory], + config: SFTConfig, + dev_config: dev.SFTConfig, + verbose: bool = False, + ) -> AsyncIterator[dict[str, float]]: + raise NotImplementedError( + "SFT training is not yet implemented for ServerlessBackend. " + "Please use the Backend HTTP API or implement this method." + ) + # This yield is unreachable but makes this an async generator + yield # type: ignore + # ------------------------------------------------------------------ # Experimental support for S3 # ------------------------------------------------------------------ diff --git a/src/art/types.py b/src/art/types.py index fd1bb272..6dbb9b24 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Iterable, Literal import pydantic from openai.types.chat.chat_completion import Choice @@ -17,4 +17,10 @@ class TrainConfig(pydantic.BaseModel): beta: float = 0.0 +class SFTConfig(pydantic.BaseModel): + learning_rate: float = 5e-5 + batch_size: int | Literal["auto"] = "auto" + custom_lr_schedule: list[float] = [] + + Verbosity = Literal[0, 1, 2] diff --git a/src/art/unsloth/train_sft.py b/src/art/unsloth/train_sft.py new file mode 100644 index 00000000..6c5b175c --- /dev/null +++ b/src/art/unsloth/train_sft.py @@ -0,0 +1,141 @@ +"""Training utilities for Supervised Fine-Tuning (SFT).""" + +import asyncio +from collections import defaultdict +from typing import TYPE_CHECKING, Callable, Iterator + +import nest_asyncio +import torch +from trl import SFTTrainer + +if TYPE_CHECKING: + from ..preprocessing.tokenize_sft import SFTBatch + +nest_asyncio.apply() + + +async def train_sft( + trainer: SFTTrainer, + input_queue: asyncio.Queue["SFTBatch"], + results_queue: asyncio.Queue[dict[str, float]], +) -> None: + """ + Train an SFT model using batches from a queue. + + Args: + trainer: TRL SFTTrainer instance + input_queue: Queue containing SFTBatch objects + results_queue: Queue for training metrics/results + """ + _get_batch_samples = trainer.get_batch_samples + _log = trainer.log + + trainer.get_batch_samples = get_batch_samples_fn(trainer, input_queue) + trainer.log = get_log_fn(trainer, results_queue) + + # Ensure we have a metrics container in the expected format + try: + is_dict = isinstance(getattr(trainer, "_metrics", None), dict) + is_train_dict = is_dict and isinstance(trainer._metrics.get("train"), dict) + except Exception: + is_train_dict = False + if not is_train_dict: + trainer._metrics = {"train": defaultdict(list)} + + try: + trainer.train() + finally: + trainer.get_batch_samples = _get_batch_samples + trainer.log = _log + + +def get_batch_samples_fn( + trainer: SFTTrainer, + input_queue: asyncio.Queue["SFTBatch"], +) -> Callable[..., tuple[list[dict[str, torch.Tensor]], torch.Tensor]]: + """ + Create a get_batch_samples function that: + 1. Reads SFTBatch from queue + 2. Sets learning rate from batch + 3. Sets gradient accumulation steps + 4. Returns batch samples and num_items_in_batch as tensor + """ + + def get_batch_samples( + epoch_iterator: Iterator, + num_batches: int, + device: torch.device | str | None = None, + ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor]: + """ + Override get_batch_samples to read from queue instead of epoch_iterator. + + Returns: + tuple of (batch_samples, num_items_in_batch as tensor int) + """ + # Read SFTBatch from queue asynchronously + async def get_sft_batch() -> "SFTBatch": + return await input_queue.get() + + # Get the batch from queue + sft_batch: "SFTBatch" = asyncio.run(get_sft_batch()) + + # Set learning rate for this batch + if optimizer := trainer.optimizer: + optimizer = getattr(optimizer, "optimizer", optimizer) + if param_groups := getattr(optimizer, "param_groups"): + for param_group in param_groups: + param_group["lr"] = sft_batch.learning_rate + + # Set gradient accumulation steps to number of trajectories + # We're doing micro-batch size 1, so accumulate across all trajectories + if hasattr(trainer.args, "gradient_accumulation_steps"): + trainer.args.gradient_accumulation_steps = sft_batch.num_trajectories + + # Convert each trajectory to a separate sample for micro-batching + # Trainer will process each sample individually and accumulate gradients + batch_samples = [] + for trajectory_tensor in sft_batch.trajectory_tensors: + # Move each trajectory's tensors to device + sample = { + key: tensor.to(device) + for key, tensor in trajectory_tensor.items() + } + batch_samples.append(sample) + + # Return batch samples and num_items_in_batch as tensor (on device) + num_items_in_batch = torch.tensor( + sft_batch.num_trajectories, + dtype=torch.long, + device=device + ) + + return batch_samples, num_items_in_batch + + return get_batch_samples + + +def get_log_fn( + trainer: SFTTrainer, + results_queue: asyncio.Queue[dict[str, float]], +) -> Callable[..., None]: + """ + Create a logging function that sends metrics to the results queue. + Same pattern as GRPO trainer. + """ + def log(logs: dict[str, float], start_time: float | None = None) -> None: + """Log metrics and send to results queue.""" + metrics = { + key: sum(val) / len(val) for key, val in trainer._metrics["train"].items() + } # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + logs.pop("learning_rate", None) + results_queue.put_nowait(logs) + trainer._metrics["train"].clear() + + return log \ No newline at end of file diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py new file mode 100644 index 00000000..74e5e971 --- /dev/null +++ b/src/art/utils/sft.py @@ -0,0 +1,430 @@ +"""Utilities for supervised fine-tuning (SFT).""" + +import json +import math +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generator, List, Literal + +from tqdm.auto import tqdm + +if TYPE_CHECKING: + from art.model import TrainableModel + from art.trajectories import Trajectory + from art.types import SFTConfig + + +@dataclass +class SFTDatasetChunk: + """Container for SFT dataset chunk with trajectories, config, and step information.""" + + trajectories: List["Trajectory"] + config: "SFTConfig" + step: int + epoch: int + epoch_step: int + +def get_file_row_count(file_path: str) -> int: + """ + Count the number of non-empty rows in a JSONL file. + + Args: + file_path: Path to JSONL file + + Returns: + Number of non-empty lines in the file + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + count = get_file_row_count("data.jsonl") + print(f"Dataset has {count} items") + """ + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + count = 0 + with open(file_path, "r") as f: + for line in f: + if line.strip(): + count += 1 + return count + + +def create_lr_schedule( + total_steps: int, + peak_lr: float, + method: Literal["cosine", "linear", "constant"] = "linear", + warmup_steps: int = 0, + min_lr: float = 0.0, +) -> List[float]: + """ + Create learning rate schedule for training with optional warmup. + + Args: + total_steps: Total number of training steps + peak_lr: Peak learning rate + method: Learning rate schedule method. Options: + - "cosine": Cosine annealing from peak_lr to min_lr + - "linear": Linear decay from peak_lr to min_lr + - "constant": Constant learning rate (peak_lr for all steps) + warmup_steps: Number of warmup steps (linear warmup from 0 to peak_lr) + min_lr: Minimum learning rate (floor for decay schedules) + + Returns: + List of learning rates for each step + + Example: + # Cosine schedule with warmup + lrs = create_lr_schedule(100, 1e-4, method="cosine", warmup_steps=10) + + # Use with training loop + for step, chunk in enumerate(chunk_trajectories(...)): + train_sft(chunk, learning_rate=lrs[step]) + """ + learning_rates = [] + + for step in range(total_steps): + # Warmup phase: linear warmup from 0 to peak_lr + if step < warmup_steps: + lr = peak_lr * (step / warmup_steps) + else: + # Main schedule phase + # Adjust step to be relative to post-warmup period + adjusted_step = step - warmup_steps + adjusted_total = total_steps - warmup_steps + + if method == "cosine": + # Cosine annealing: lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + cos(pi * t)) + lr = min_lr + (peak_lr - min_lr) * 0.5 * ( + 1 + math.cos(math.pi * adjusted_step / adjusted_total) + ) + elif method == "linear": + # Linear decay: lr = peak_lr - (peak_lr - min_lr) * (t / total) + lr = peak_lr - (peak_lr - min_lr) * (adjusted_step / adjusted_total) + elif method == "constant": + # Constant learning rate + lr = peak_lr + else: + raise ValueError( + f"Unknown method: {method}. Choose from: cosine, linear, constant" + ) + + learning_rates.append(lr) + + return learning_rates + + +def create_sft_dataset_iterator( + trajectories: List["Trajectory"], + epochs: int = 1, + batch_size: int = 1, + chunk_size: int = 50, + peak_lr: float = 2e-4, + schedule_type: Literal["cosine", "linear", "constant"] = "linear", + warmup_ratio: float = 0.1, + initial_step: int = 0, + use_tqdm: bool = True, +) -> Generator[SFTDatasetChunk, None, None]: + """ + Create an iterator that yields SFT dataset chunks with trajectories, config, and step info. + + Combines trajectory batching with learning rate scheduling. Yields SFTDatasetChunk objects + containing flattened trajectories, SFTConfig with learning rates, and step tracking info. + + Args: + trajectories: List of Trajectory objects to train on + epochs: Number of times to iterate over the trajectories. Default: 1 + batch_size: Number of trajectories per batch. Default: 1 + chunk_size: Number of batches per chunk. Default: 50 + peak_lr: Peak learning rate. Default: 5e-5 + schedule_type: Learning rate schedule type ("cosine", "linear", "constant"). Default: "linear" + warmup_ratio: Ratio of total steps to use for warmup (0.0 to 1.0). Default: 0.1 + initial_step: The global chunk step to start from. Default: 0. + Useful for resuming training. + use_tqdm: Whether to display a progress bar. Default: True + + Yields: + SFTDatasetChunk containing: + - trajectories: Flattened list of trajectories (chunk_size * batch_size trajectories) + - config: SFTConfig with custom_lr_schedule containing learning rates for each batch + - step: Global step number across all epochs + - epoch: Current epoch number (0-indexed) + - epoch_step: Step number within current epoch (0-indexed) + + Example: + trajectories = [traj1, traj2, ..., traj100] + + # Create SFT dataset iterator with linear schedule + for chunk in create_sft_dataset_iterator( + trajectories=trajectories, + epochs=3, + batch_size=4, + chunk_size=10, + peak_lr=1e-4, + schedule_type="linear", + warmup_ratio=0.1, + ): + # chunk.trajectories is a flat list of 40 trajectories (10 batches * 4 per batch) + # chunk.config.custom_lr_schedule is a list of 10 learning rates (one per batch) + # chunk.config.batch_size is 4 + # chunk.step is global step number + # chunk.epoch is current epoch + # chunk.epoch_step is step within epoch + train_sft(chunk.trajectories, chunk.config) + + # Resume from chunk step 5 + for chunk in create_sft_dataset_iterator( + trajectories=trajectories, + epochs=3, + batch_size=4, + chunk_size=10, + initial_step=5, + ): + # Starts from chunk step 5 + pass + """ + from art.types import SFTConfig + + dataset_size = len(trajectories) + if dataset_size == 0: + return + + # Calculate total batch steps (one step per batch) + batches_per_epoch = math.ceil(dataset_size / batch_size) + total_batch_steps = batches_per_epoch * epochs + + # Calculate warmup steps + warmup_steps = int(total_batch_steps * warmup_ratio) + + # Create learning rate schedule (one LR per batch) + custom_lr_schedule = create_lr_schedule( + total_steps=total_batch_steps, + peak_lr=peak_lr, + method=schedule_type, + warmup_steps=warmup_steps, + min_lr=0.0, + ) + + # Calculate chunk iteration parameters + items_per_chunk = batch_size * chunk_size + chunks_per_epoch = math.ceil(dataset_size / items_per_chunk) + total_steps = chunks_per_epoch * epochs + + progress_bar = None + if use_tqdm: + progress_bar = tqdm( + initial=initial_step, + total=total_steps, + desc="Training SFT", + unit="chunk", + ) + + for epoch in range(epochs): + # Create indices and shuffle deterministically based on epoch + indices = list(range(dataset_size)) + random.seed(epoch) + random.shuffle(indices) + + for chunk_idx in range(chunks_per_epoch): + # Calculate step numbers + epoch_step = chunk_idx + global_step = epoch * chunks_per_epoch + chunk_idx + + # Skip if before initial_step + if global_step < initial_step: + continue + + # Get indices for this chunk + chunk_start = chunk_idx * items_per_chunk + chunk_end = min(chunk_start + items_per_chunk, dataset_size) + step_indices = indices[chunk_start:chunk_end] + + # Flatten trajectories for this chunk + chunk_trajectories: List["Trajectory"] = [ + trajectories[idx] for idx in step_indices + ] + + # Calculate learning rates for each batch in this chunk + chunk_lrs: List[float] = [] + num_batches_in_chunk = math.ceil(len(step_indices) / batch_size) + + for batch_idx in range(num_batches_in_chunk): + # Calculate global batch step + global_batch_step = epoch * batches_per_epoch + (chunk_start // batch_size) + batch_idx + chunk_lrs.append(custom_lr_schedule[global_batch_step]) + + # Create SFTConfig with custom learning rate schedule + config = SFTConfig( + batch_size=batch_size, + custom_lr_schedule=chunk_lrs, + ) + + yield SFTDatasetChunk( + trajectories=chunk_trajectories, + config=config, + step=global_step, + epoch=epoch, + epoch_step=epoch_step, + ) + + # Update progress bar after yielding + if progress_bar: + progress_bar.update(1) + + if progress_bar: + progress_bar.close() + +def iterate_file( + file_path: str, + epochs: int, + shuffle: bool = True, + shuffle_buffer_size: int = 10000, + seed: int | None = 42, +) -> Generator["Trajectory", None, None]: + """ + Read JSONL file for each epoch, yielding individual Trajectory objects. + + Completes reading the entire file for one epoch before starting the next epoch. + This ensures all trajectories from epoch N are yielded before any from epoch N+1. + + Each line should contain a dict with: + - messages: List of chat messages + - tools: Optional list of tools + - reward: Optional reward (defaults to 0.0) + - split: Optional split name (stored in metadata) + - Any other fields will be stored in metadata + + Args: + file_path: Path to JSONL file (one JSON object per line) + epochs: Number of times to read through the file + shuffle: Whether to shuffle trajectories. Defaults to True. + shuffle_buffer_size: Size of shuffle buffer for streaming shuffle. Default: 10000. + Only used if shuffle=True. + seed: Random seed for deterministic shuffling. Default: 42. + Only used if shuffle=True. + + Yields: + Individual Trajectory objects + + Raises: + ValueError: If file_path does not end with .jsonl + + Example: + # With shuffle + for trajectory in iterate_file("data.jsonl", epochs=3, shuffle=True): + # trajectory is a single Trajectory object + process(trajectory) + + # No shuffle + for trajectory in iterate_file("data.jsonl", epochs=3, shuffle=False): + process(trajectory) + """ + from art.trajectories import Trajectory + + if not file_path.endswith(".jsonl"): + raise ValueError(f"Only JSONL files are supported. Got: {file_path}") + + for epoch in range(epochs): + if shuffle and seed is not None: + random.seed(seed + epoch) + + if shuffle: + # Streaming shuffle with buffer + shuffle_buffer: List["Trajectory"] = [] + + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + traj = Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + + shuffle_buffer.append(traj) + + # Once buffer is full, start yielding randomly + if len(shuffle_buffer) >= shuffle_buffer_size: + idx = random.randint(0, len(shuffle_buffer) - 1) + yield shuffle_buffer.pop(idx) + + # Flush remaining items in shuffle buffer at end of epoch + random.shuffle(shuffle_buffer) + for traj in shuffle_buffer: + yield traj + else: + # No shuffle - sequential reading + with open(file_path, "r") as f: + for line in f: + if not line.strip(): + continue + + data = json.loads(line) + messages = data.get("messages", []) + tools = data.get("tools", None) + + yield Trajectory( + messages_and_choices=messages, + tools=tools if tools else None, + reward=0.0 + ) + + +async def train_sft_from_file( + model: "TrainableModel", + file_path: str, + epochs: int, + learning_rate: float, + batch_size: int = 8, +) -> None: + """ + Convenience function to train a model with SFT from a JSONL file. + + Args: + model: TrainableModel to train + file_path: Path to JSONL file containing trajectories + epochs: Number of epochs to train + learning_rate: Peak learning rate (uses cosine schedule) + batch_size: Number of trajectories per batch/step. Defaults to 8. + + Example: + await train_sft_from_file( + model=model, + file_path="data.jsonl", + epochs=3, + learning_rate=1e-5, + ) + """ + from art.types import SFTConfig + + # Calculate total steps - batches carry over across epochs + num_trajectories = get_file_row_count(file_path) + total_steps = math.ceil((num_trajectories * epochs) / batch_size) + + # Set warmup steps: 10% of total steps, capped at 1000 + warmup_steps = min(total_steps // 10, 1000) + + # Create cosine learning rate schedule with warmup + custom_lr_schedule = create_lr_schedule( + total_steps=total_steps, + peak_lr=learning_rate, + method="linear", + warmup_steps=warmup_steps, + ) + + # Create SFT config with shuffling enabled + config = SFTConfig(custom_lr_schedule=custom_lr_schedule, batch_size=batch_size) + + # Train the model + await model.train_sft( + trajectories=iterate_file(file_path, epochs=epochs), + config=config + ) diff --git a/tests/unit/test_sft.py b/tests/unit/test_sft.py new file mode 100644 index 00000000..43e0c66c --- /dev/null +++ b/tests/unit/test_sft.py @@ -0,0 +1,182 @@ +"""Unit tests for SFT utilities.""" + +import json +import math +import tempfile +from pathlib import Path +from typing import Iterable, List + +import pytest + +from art.trajectories import Trajectory +from art.types import SFTConfig +from art.utils.iterate_dataset import iterate_file, iterate_trajectories +from art.utils.sft import create_lr_schedule + + +# Helper to create dummy trajectories +def create_dummy_trajectory(idx: int) -> Trajectory: + """Create a dummy trajectory with a unique identifier.""" + return Trajectory( + messages_and_choices=[ + {"role": "user", "content": f"Message {idx}"}, + {"role": "assistant", "content": f"Response {idx}"}, + ], + reward=float(idx), + ) + + +# Helper to create a temporary JSONL file +def create_temp_jsonl(num_trajectories: int) -> Path: + """Create a temporary JSONL file with dummy trajectories.""" + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + for i in range(num_trajectories): + data = { + "messages": [ + {"role": "user", "content": f"Message {i}"}, + {"role": "assistant", "content": f"Response {i}"}, + ], + } + temp_file.write(json.dumps(data) + "\n") + temp_file.close() + return Path(temp_file.name) + + +# Dummy train_sft for integration testing +def dummy_train_sft( + trajectories: Iterable[List[Trajectory]], + config: SFTConfig, +) -> dict: + """ + Dummy train_sft function that collects batches and learning rates. + + Args: + trajectories: Iterable of trajectory batches + config: SFT configuration with learning rates + + Returns: + dict with: + - num_batches: number of batches processed + - total_trajectories: total number of trajectories seen + - learning_rates_used: list of learning rates used + """ + num_batches = 0 + total_trajectories = 0 + + for batch in trajectories: + num_batches += 1 + total_trajectories += len(batch) + + return { + "num_batches": num_batches, + "total_trajectories": total_trajectories + } + + +# ============================================================================ +# Integration tests +# ============================================================================ + +def test_integration_iterate_trajectories_with_train_sft(): + """Test using iterate_trajectories chunks with train_sft.""" + trajectories = [create_dummy_trajectory(i) for i in range(20)] + + # batch_size=8, chunk_size=2 means each chunk has up to 2 batches of 8 trajectories + # With 20 trajectories per epoch: + # - Items per chunk: 8 * 2 = 16 + # - Chunks per epoch: ceil(20/16) = 2 (one with 16 trajs, one with 4 trajs) + # With 3 epochs: 2 * 3 = 6 chunks total + + # Create LR schedule for up to 2 batches per chunk + lrs_per_chunk = create_lr_schedule(2, peak_lr=1e-4, method="linear") + + # Manually iterate over chunks and train on each + results = [] + for chunk in iterate_trajectories( + trajectories, + epochs=3, + batch_size=8, # 8 trajectories per batch + chunk_size=2, # 2 batches per chunk + ): + print(f"Chunk: {chunk}") + # chunk is List[List[Trajectory]] which is an Iterable[List[Trajectory]] + result = dummy_train_sft( + trajectories=chunk, + config=SFTConfig(learning_rate=lrs_per_chunk), + ) + results.append(result) + + # Should have 6 chunks total (2 per epoch * 3 epochs) + assert len(results) == 6 + # Pattern repeats for each epoch: full chunk (2 batches), partial chunk (1 batch) + assert results[0]["num_batches"] == 2 # Epoch 1, chunk 1 + assert results[0]["total_trajectories"] == 16 + assert results[1]["num_batches"] == 1 # Epoch 1, chunk 2 (partial) + assert results[1]["total_trajectories"] == 4 + assert results[2]["num_batches"] == 2 # Epoch 2, chunk 1 + assert results[2]["total_trajectories"] == 16 + assert results[3]["num_batches"] == 1 # Epoch 2, chunk 2 (partial) + assert results[3]["total_trajectories"] == 4 + assert results[4]["num_batches"] == 2 # Epoch 3, chunk 1 + assert results[4]["total_trajectories"] == 16 + assert results[5]["num_batches"] == 1 # Epoch 3, chunk 2 (partial) + assert results[5]["total_trajectories"] == 4 + +def test_integration_iterate_file_with_train_sft(): + """Test using iterate_file directly with train_sft.""" + jsonl_file = create_temp_jsonl(100) + + try: + # Create learning rate schedule + total_steps = math.ceil((100 * 2) / 3) # 10 trajectories, 2 epochs, batch_size=3 + lrs = create_lr_schedule(total_steps, peak_lr=1e-4, method="constant") + + config = SFTConfig(learning_rate=lrs) + + # Pass iterate_file directly to train_sft + result = dummy_train_sft( + trajectories=iterate_file( + str(jsonl_file), + epochs=2, + batch_size=3, + shuffle=True, + ), + config=config, + ) + + # Should process 7 batches: [3, 3, 3, 3, 3, 3, 2] + assert result["num_batches"] == 67 + assert result["total_trajectories"] == 200 + finally: + jsonl_file.unlink() + +# def test_total_steps_calculation(): +# """Test that total steps calculation matches actual batches.""" +# num_trajectories = 105 +# epochs = 3 +# batch_size = 8 + +# # This is how train_sft_from_file calculates total_steps +# expected_total_steps = math.ceil((num_trajectories * epochs) / batch_size) + +# # Create file and count actual batches +# jsonl_file = create_temp_jsonl(num_trajectories) + +# try: +# batches = list(iterate_file( +# str(jsonl_file), +# epochs=epochs, +# batch_size=batch_size, +# shuffle=False, +# )) + +# actual_batches = len(batches) + +# # Should match +# assert actual_batches == expected_total_steps +# finally: +# jsonl_file.unlink() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])