Skip to content
Open

Sft #458

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/art/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, AsyncIterator, Literal
from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal

import httpx
from tqdm import auto as tqdm
Expand All @@ -8,8 +8,8 @@
from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider

from . import dev
from .trajectories import TrajectoryGroup
from .types import TrainConfig
from .trajectories import Trajectory, TrajectoryGroup
from .types import SFTConfig, TrainConfig

if TYPE_CHECKING:
from .model import Model, TrainableModel
Expand Down Expand Up @@ -126,6 +126,21 @@ async def _train_model(
if pbar is not None:
pbar.close()

async def _train_sft(
self,
model: "TrainableModel",
trajectories: Iterable[Trajectory],
config: SFTConfig,
dev_config: dev.SFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
raise NotImplementedError(
"SFT training is not yet implemented. "
"This method will be available in a future release."
)
# This yield is unreachable but makes this an async generator
yield # type: ignore

# ------------------------------------------------------------------
# Experimental support for S3
# ------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/art/dev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from .openai_server import OpenAIServerConfig, ServerArgs, get_openai_server_config
from .torchtune import TorchtuneArgs
from .train import TrainConfig
from .train import SFTConfig, TrainConfig

__all__ = [
"EngineArgs",
Expand All @@ -18,6 +18,7 @@
"get_openai_server_config",
"OpenAIServerConfig",
"ServerArgs",
"SFTConfig",
"TorchtuneArgs",
"TrainConfig",
]
5 changes: 5 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ class TrainConfig(TypedDict, total=False):
scale_learning_rate_by_reward_std_dev: bool
scale_rewards: bool
truncated_importance_sampling: float | None


class SFTConfig(TypedDict, total=False):
"""Experimental SFT configuration options. Use at your own risk."""
pass
19 changes: 17 additions & 2 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
from datetime import datetime
from types import TracebackType
from typing import AsyncIterator, Literal, cast
from typing import AsyncIterator, Iterable, Literal, cast

import aiohttp
import numpy as np
Expand Down Expand Up @@ -54,7 +54,7 @@
)
from ..preprocessing.tokenize import tokenize_trajectory_groups
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import Message, TrainConfig
from ..types import Message, SFTConfig, TrainConfig
from ..utils import format_message, get_model_step
from .checkpoints import (
delete_checkpoints,
Expand Down Expand Up @@ -521,6 +521,21 @@ async def _train_model(
if verbose:
print("_train_model complete")

async def _train_sft(
self,
model: TrainableModel,
trajectories: Iterable[Trajectory],
config: SFTConfig,
dev_config: dev.SFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
raise NotImplementedError(
"SFT training is not yet implemented for LocalBackend. "
"Please use the Backend HTTP API or implement this method."
)
# This yield is unreachable but makes this an async generator
yield # type: ignore

def _get_reward_std_dev_learning_rate_multiplier(
self, model: TrainableModel
) -> float:
Expand Down
24 changes: 23 additions & 1 deletion src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import dev
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainConfig
from .types import SFTConfig, TrainConfig

if TYPE_CHECKING:
from art.backend import Backend
Expand Down Expand Up @@ -386,3 +386,25 @@ async def train(
self, list(trajectory_groups), config, _config or {}, verbose
):
pass

async def train_sft(
self,
trajectories: Iterable[Trajectory],
config: SFTConfig,
_config: dev.SFTConfig | None = None,
verbose: bool = False,
) -> None:
"""
Supervised fine-tune the model with an iterable of trajectories.

Args:
trajectories: An iterable of Trajectory objects.
config: SFT configuration including learning_rates and batch_size.
_config: Additional experimental configuration that is subject to change and
not yet part of the public API. Use at your own risk.
verbose: Whether to print verbose output.
"""
async for _ in self.backend()._train_sft(
self, trajectories, config, _config or {}, verbose
):
pass
168 changes: 168 additions & 0 deletions src/art/preprocessing/tokenize_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Tokenization utilities for Supervised Fine-Tuning (SFT)."""

import math
from dataclasses import dataclass
from typing import Generator

import torch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ..trajectories import Trajectory


@dataclass
class SFTBatch:
"""A batch of tokenized trajectories for supervised fine-tuning.

Attributes:
trajectory_tensors: List of tensor dictionaries, one per trajectory.
Each dict contains 'input_ids', 'attention_mask', and 'labels'.
learning_rate: Learning rate to use for this batch.
num_trajectories: Number of trajectories in this batch.
num_trainable_tokens: Total number of tokens being trained on (labels != -100).
"""
trajectory_tensors: list[dict[str, torch.Tensor]]
learning_rate: float
num_trajectories: int
num_trainable_tokens: int


def tokenize_sft_batches(
trajectories: list[Trajectory],
batch_size: int,
learning_rates: list[float],
tokenizer: PreTrainedTokenizerBase,
instruction_part: str,
response_part: str,
) -> Generator[SFTBatch, None, None]:
"""
Tokenize trajectories into batches for supervised fine-tuning.

Args:
trajectories: Flat list of trajectories
batch_size: Number of trajectories per batch
learning_rates: Learning rate for each batch
tokenizer: Tokenizer to use for encoding
instruction_part: Instruction template part (e.g., "User:")
response_part: Response template part (e.g., "Assistant:")

Yields:
SFTBatch object containing:
- trajectory_tensors: List of tensors for each trajectory
- learning_rate: Learning rate for this batch
- num_trajectories: Number of trajectories in this batch
- num_trainable_tokens: Total number of trainable tokens
"""
# Validate inputs
num_trajectories = len(trajectories)
num_learning_rates = len(learning_rates)
expected_num_batches = math.ceil(num_trajectories / batch_size)

if num_learning_rates != expected_num_batches:
raise ValueError(
f"Mismatch between trajectories and learning_rates: "
f"{num_trajectories} trajectories with batch_size={batch_size} "
f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates"
)

instruction_ids = tokenizer(instruction_part, add_special_tokens=False).input_ids
response_ids = tokenizer(response_part, add_special_tokens=False).input_ids
instruction_length = len(instruction_ids)
response_length = len(response_ids)
max_template_length = max(instruction_length, response_length)

def _train_on_responses_only(input_ids: list[int]) -> list[int]:
labels = [-100] * len(input_ids)
m = len(input_ids) - max_template_length
first_response = response_ids[0]
first_instruction = instruction_ids[0]
j = 0

while j < m:
if input_ids[j] == first_response:
if input_ids[j : j + response_length] == response_ids:
j = j + response_length
start = j
while j < m:
if input_ids[j] == first_instruction and input_ids[j : j + instruction_length] == instruction_ids:
j = j + instruction_length
labels[start : j] = input_ids[start : j]
break
elif j == (m - 1):
j = m
labels[start:] = input_ids[start:]
break
j += 1
j += 1

return labels

# Batch trajectories
for batch_idx, lr in enumerate(learning_rates):
start_idx = batch_idx * batch_size
end_idx = start_idx + batch_size
trajectory_batch = trajectories[start_idx:end_idx]

# First pass: tokenize all trajectories
tokenized_trajectories = []
for trajectory in trajectory_batch:
messages = trajectory.messages_and_choices
tools = trajectory.tools

# Single-step tokenization: apply_chat_template with tokenize=True
input_ids = tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=True,
add_generation_prompt=False
)

# Create attention mask (all 1s - no padding yet)
attention_mask = [1] * len(input_ids)

labels = _train_on_responses_only(input_ids)

tokenized_trajectories.append({
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
})

# Find max length in this batch for padding
max_seq_length = max(len(t['input_ids']) for t in tokenized_trajectories)

# Second pass: pad all trajectories to max_seq_length
trajectory_tensors = []
for tokenized in tokenized_trajectories:
input_ids = tokenized['input_ids']
attention_mask = tokenized['attention_mask']
labels = tokenized['labels']

# Pad to max_seq_length
padding_length = max_seq_length - len(input_ids)
if padding_length > 0:
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
attention_mask = attention_mask + [0] * padding_length
labels = labels + [-100] * padding_length

trajectory_tensor = {
'input_ids': torch.tensor([input_ids], dtype=torch.long),
'attention_mask': torch.tensor([attention_mask], dtype=torch.long),
'labels': torch.tensor([labels], dtype=torch.long),
}

trajectory_tensors.append(trajectory_tensor)

# Calculate total trainable tokens (labels != -100)
num_trainable_tokens = sum(
(tensor_dict['labels'] != -100).sum().item()
for tensor_dict in trajectory_tensors
)

yield SFTBatch(
trajectory_tensors=trajectory_tensors,
learning_rate=lr,
num_trajectories=len(trajectory_tensors),
num_trainable_tokens=num_trainable_tokens,
)

21 changes: 18 additions & 3 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, AsyncIterator, Literal
from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal

from openai._types import NOT_GIVEN
from tqdm import auto as tqdm
Expand All @@ -9,8 +9,8 @@

from .. import dev
from ..backend import Backend
from ..trajectories import TrajectoryGroup
from ..types import TrainConfig
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import SFTConfig, TrainConfig

if TYPE_CHECKING:
from ..model import Model, TrainableModel
Expand Down Expand Up @@ -159,6 +159,21 @@ async def _train_model(
raise RuntimeError(f"Training job failed: {error_message}")
after = event.id

async def _train_sft(
self,
model: "TrainableModel",
trajectories: Iterable[Trajectory],
config: SFTConfig,
dev_config: dev.SFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
raise NotImplementedError(
"SFT training is not yet implemented for ServerlessBackend. "
"Please use the Backend HTTP API or implement this method."
)
# This yield is unreachable but makes this an async generator
yield # type: ignore

# ------------------------------------------------------------------
# Experimental support for S3
# ------------------------------------------------------------------
Expand Down
8 changes: 7 additions & 1 deletion src/art/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Iterable, Literal

import pydantic
from openai.types.chat.chat_completion import Choice
Expand All @@ -17,4 +17,10 @@ class TrainConfig(pydantic.BaseModel):
beta: float = 0.0


class SFTConfig(pydantic.BaseModel):
learning_rate: float = 5e-5
batch_size: int | Literal["auto"] = "auto"
custom_lr_schedule: list[float] = []


Verbosity = Literal[0, 1, 2]
Loading