Skip to content

Commit a1ef501

Browse files
committed
[VLM] subclass HFTokenizer to add more special tokens
1 parent 96149f6 commit a1ef501

File tree

11 files changed

+80
-121
lines changed

11 files changed

+80
-121
lines changed

tests/assets/tokenizer/tokenizer.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,11 +2029,7 @@
20292029
"land": 1994,
20302030
"?\n": 1995,
20312031
" respect": 1996,
2032-
"ances": 1997,
2033-
"<|image|>": 1998,
2034-
"<|begin_of_image|>": 1999,
2035-
"<|end_of_image|>": 2000,
2036-
"<|pad|>": 2001
2032+
"ances": 1997
20372033
},
20382034
"merges": [
20392035
]

tests/assets/tokenizer/tokenizer_config.json

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,11 @@
1515
"rstrip": false,
1616
"single_word": false,
1717
"special": true
18-
},
19-
"1998": {
20-
"content": "<|image|>",
21-
"lstrip": false,
22-
"normalized": false,
23-
"rstrip": false,
24-
"single_word": false,
25-
"special": true
26-
},
27-
"1999": {
28-
"content": "<|begin_of_image|>",
29-
"lstrip": false,
30-
"normalized": false,
31-
"rstrip": false,
32-
"single_word": false,
33-
"special": true
34-
},
35-
"2000": {
36-
"content": "<|end_of_image|>",
37-
"lstrip": false,
38-
"normalized": false,
39-
"rstrip": false,
40-
"single_word": false,
41-
"special": true
42-
},
43-
"2001": {
44-
"content": "<|pad|>",
45-
"lstrip": false,
46-
"normalized": false,
47-
"rstrip": false,
48-
"single_word": false,
49-
"special": true
5018
}
5119
},
5220
"bos_token": "<|begin_of_text|>",
5321
"clean_up_tokenization_spaces": true,
5422
"eos_token": "<|end_of_text|>",
55-
"img_token": "<|image|>",
56-
"boi_token": "<|begin_of_image|>",
57-
"eoi_token": "<|end_of_image|>",
5823
"pad_token": "<|pad|>",
5924
"model_input_names": [
6025
"input_ids",

torchtitan/components/tokenizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[s
190190
return token
191191

192192
def _process_special_token(
193-
self, token_str: str, token_config: dict, token_id: Optional[int] = None
193+
self,
194+
token_str: str,
195+
token_config: dict | None = None,
196+
token_id: int | None = None,
194197
) -> AddedToken:
195198
"""
196199
Process a special token and update BOS/EOS attributes if applicable.

torchtitan/experiments/vlm/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import asdict, replace
7+
from dataclasses import asdict
88

99
from torchtitan.components.loss import build_cross_entropy_loss
1010
from torchtitan.components.lr_scheduler import build_lr_schedulers
1111
from torchtitan.components.optimizer import build_optimizers
12-
from torchtitan.components.tokenizer import build_hf_tokenizer
1312
from torchtitan.components.validate import build_validator
13+
from torchtitan.experiments.vlm.tokenizer import build_vlm_tokenizer
1414
from torchtitan.models.llama3 import llama3_configs
1515
from torchtitan.protocols.train_spec import TrainSpec
1616

@@ -29,7 +29,7 @@
2929

3030
llama3_siglip2_configs = {
3131
"debugmodel": Llama3Siglip2ModelArgs(
32-
**asdict(replace(llama3_configs["debugmodel"], vocab_size=2048)),
32+
**asdict(llama3_configs["debugmodel"]),
3333
encoder=Siglip2ModelArgs(
3434
dim=128,
3535
ffn_dim=256,
@@ -50,7 +50,7 @@ def get_train_spec() -> TrainSpec:
5050
build_optimizers_fn=build_optimizers,
5151
build_lr_schedulers_fn=build_lr_schedulers,
5252
build_dataloader_fn=build_mm_dataloader,
53-
build_tokenizer_fn=build_hf_tokenizer,
53+
build_tokenizer_fn=build_vlm_tokenizer,
5454
build_loss_fn=build_cross_entropy_loss,
5555
build_validator_fn=build_validator,
5656
)

torchtitan/experiments/vlm/datasets/mm_collator_nld.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212

1313
from torchtitan.tools.logging import logger
1414

15-
from ..model.args import SpecialTokens
16-
15+
from ..tokenizer import VLMTokenizer
1716
from .utils.image import (
1817
convert_to_patches,
1918
pad_empty_images_to_target_batch_size,
2019
pad_patches,
2120
)
2221
from .utils.text import pad_input_ids_and_labels_to_target_batch_size, pad_text_batch
2322

23+
IGNORE_INDEX = -100
24+
2425

2526
@dataclass
2627
class MultiModalCollatorNLD:
@@ -85,7 +86,7 @@ class MultiModalCollatorNLD:
8586
max_images_per_batch: int # Vision Encoder's batch size
8687
max_patches_per_image: int # Vision Encoder's sequence length
8788

88-
special_tokens: SpecialTokens
89+
tokenizer: VLMTokenizer
8990

9091
def collate_images(
9192
self, all_images: list[torch.Tensor]
@@ -145,28 +146,28 @@ def collate_text(
145146
input_ids = pad_sequence(
146147
[s["input_ids"] for s in batch],
147148
batch_first=True,
148-
padding_value=self.special_tokens.pad_id,
149+
padding_value=self.tokenizer.pad_id,
149150
)
150151
labels = pad_sequence(
151152
[s["labels"] for s in batch],
152153
batch_first=True,
153-
padding_value=self.special_tokens.pad_id,
154+
padding_value=self.tokenizer.pad_id,
154155
)
155156

156157
# Handle sequence length
157158
input_ids, labels = pad_text_batch(
158159
input_ids,
159160
labels,
160161
self.seq_len + 1, # Extra token for label shifting
161-
padding_idx=self.special_tokens.pad_id,
162-
ignore_idx=self.special_tokens.ignore_id,
162+
padding_idx=self.tokenizer.pad_id,
163+
ignore_idx=IGNORE_INDEX,
163164
)
164165
input_ids, labels = pad_input_ids_and_labels_to_target_batch_size(
165166
input_ids,
166167
labels,
167168
self.batch_size,
168-
padding_idx=self.special_tokens.pad_id,
169-
ignore_idx=self.special_tokens.ignore_id,
169+
padding_idx=self.tokenizer.pad_id,
170+
ignore_idx=IGNORE_INDEX,
170171
)
171172

172173
return input_ids[:, :-1], labels[:, 1:] # Shift for next token prediction
@@ -221,7 +222,7 @@ def __call__(
221222
"input": input_ids,
222223
"pixel_values": patches,
223224
"grid_thw": grids,
224-
"special_tokens": self.special_tokens,
225+
"img_id": self.tokenizer.img_id,
225226
}
226227

227228
return input_dict, labels

torchtitan/experiments/vlm/datasets/mm_datasets.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,27 @@
2020
from torch.utils.data import IterableDataset
2121

2222
from torchtitan.components.dataloader import ParallelAwareDataloader
23-
from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer
2423
from torchtitan.config import JobConfig
2524
from torchtitan.datasets import DatasetConfig
2625
from torchtitan.tools.logging import logger
2726

28-
from ..model.args import SpecialTokens
27+
from ..tokenizer import VLMTokenizer as Tokenizer
2928
from .mm_collator_nld import MultiModalCollatorNLD
3029
from .utils.image import calculate_image_tokens, process_image
3130
from .utils.packing import SamplePacker
3231
from .utils.text import process_text_with_images
3332

3433

34+
IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy
35+
36+
3537
def _process_mm_sample(
3638
texts: list[str] | str,
3739
images: list[bytes] | bytes,
38-
tokenizer: BaseTokenizer,
40+
tokenizer: Tokenizer,
3941
patch_size: int,
4042
max_patch_per_image: int,
4143
spatial_merge_size: int,
42-
special_tokens: SpecialTokens,
4344
) -> dict[str, Any] | None:
4445
"""Common processing logic for multimodal samples.
4546
@@ -98,7 +99,7 @@ def _process_mm_sample(
9899
processed_images.append(processed_img)
99100
image_dimensions.append((num_tokens, width, height))
100101
# Replace None with image token
101-
texts_list[idx] = special_tokens.img_token
102+
texts_list[idx] = tokenizer.img_token
102103
else:
103104
# Replace None with empty string if processing failed
104105
texts_list[idx] = ""
@@ -109,7 +110,7 @@ def _process_mm_sample(
109110

110111
# Process all image tokens at once
111112
processed_text = process_text_with_images(
112-
texts_list, image_dimensions, tokenizer, special_tokens, add_eos=True
113+
texts_list, image_dimensions, tokenizer, add_eos=True
113114
)
114115

115116
tokens = tokenizer.encode(processed_text)
@@ -120,10 +121,10 @@ def _process_mm_sample(
120121

121122
# Mask special tokens in labels
122123
special_token_ids = torch.tensor(
123-
[special_tokens.boi_id, special_tokens.eoi_id, special_tokens.img_id]
124+
[tokenizer.boi_id, tokenizer.eoi_id, tokenizer.img_id]
124125
)
125126
labels = torch.where(
126-
torch.isin(labels, special_token_ids), special_tokens.ignore_id, labels
127+
torch.isin(labels, special_token_ids), IGNORE_INDEX, labels
127128
)
128129

129130
return {
@@ -139,11 +140,10 @@ def _process_mm_sample(
139140

140141
def _process_obelics_sample(
141142
sample: dict[str, Any],
142-
tokenizer: HuggingFaceTokenizer,
143+
tokenizer: Tokenizer,
143144
patch_size: int,
144145
spatial_merge_size: int,
145146
max_patch_per_image: int,
146-
special_tokens: SpecialTokens,
147147
) -> dict[str, Any] | None:
148148
"""Process a sample from the OBELICS dataset."""
149149
return _process_mm_sample(
@@ -153,17 +153,15 @@ def _process_obelics_sample(
153153
patch_size=patch_size,
154154
spatial_merge_size=spatial_merge_size,
155155
max_patch_per_image=max_patch_per_image,
156-
special_tokens=special_tokens,
157156
)
158157

159158

160159
def _process_cc12_wd_sample(
161160
sample: dict[str, Any],
162-
tokenizer: BaseTokenizer,
161+
tokenizer: Tokenizer,
163162
patch_size: int,
164163
spatial_merge_size: int,
165164
max_patch_per_image: int,
166-
special_tokens: SpecialTokens,
167165
) -> dict[str, Any] | None:
168166
"""Process a sample from the CC12-WD dataset.
169167
Transforms CC12-WD format to match Interleaved format:
@@ -184,7 +182,6 @@ def _process_cc12_wd_sample(
184182
patch_size=patch_size,
185183
spatial_merge_size=spatial_merge_size,
186184
max_patch_per_image=max_patch_per_image,
187-
special_tokens=special_tokens,
188185
)
189186

190187

@@ -225,15 +222,14 @@ def __init__(
225222
self,
226223
dataset_name: str,
227224
dataset_path: str | None,
228-
tokenizer: BaseTokenizer,
225+
tokenizer: Tokenizer,
229226
batch_size: int,
230227
seq_len: int,
231228
patch_size: int,
232229
spatial_merge_size: int,
233230
max_patches_per_image: int,
234231
max_images_per_batch: int,
235232
packing_buffer_size: int,
236-
special_tokens: SpecialTokens,
237233
dp_rank: int = 0,
238234
dp_world_size: int = 1,
239235
infinite: bool = False,
@@ -254,7 +250,6 @@ def __init__(
254250
self.spatial_merge_size = spatial_merge_size
255251
self.max_patches_per_image = max_patches_per_image
256252
self.max_images_per_batch = max_images_per_batch
257-
self.special_tokens = special_tokens
258253
self.enable_packing = packing_buffer_size > 0
259254
if self.enable_packing:
260255
self.packer = SamplePacker(
@@ -277,7 +272,6 @@ def __iter__(self):
277272
patch_size=self.patch_size,
278273
spatial_merge_size=self.spatial_merge_size,
279274
max_patch_per_image=self.max_patches_per_image,
280-
special_tokens=self.special_tokens,
281275
)
282276
if processed is None:
283277
continue
@@ -366,7 +360,7 @@ def state_dict(self):
366360
def build_mm_dataloader(
367361
dp_world_size: int,
368362
dp_rank: int,
369-
tokenizer: HuggingFaceTokenizer,
363+
tokenizer: Tokenizer,
370364
job_config: JobConfig,
371365
infinite: bool = True,
372366
) -> ParallelAwareDataloader:
@@ -393,7 +387,6 @@ def build_mm_dataloader(
393387
patch_size = job_config.data.patch_size
394388
spatial_merge_size = job_config.data.spatial_merge_size
395389
packing_buffer_size = job_config.data.packing_buffer_size
396-
special_tokens = SpecialTokens.from_tokenizer(tokenizer)
397390

398391
dataset = MultiModalDataset(
399392
dataset_name=job_config.training.dataset,
@@ -406,7 +399,6 @@ def build_mm_dataloader(
406399
max_patches_per_image=max_patches_per_image,
407400
max_images_per_batch=max_images_per_batch,
408401
packing_buffer_size=packing_buffer_size,
409-
special_tokens=special_tokens,
410402
dp_rank=dp_rank,
411403
dp_world_size=dp_world_size,
412404
infinite=infinite,
@@ -418,7 +410,7 @@ def build_mm_dataloader(
418410
patch_size=patch_size,
419411
max_images_per_batch=max_images_per_batch,
420412
max_patches_per_image=max_patches_per_image,
421-
special_tokens=special_tokens,
413+
tokenizer=tokenizer,
422414
)
423415

424416
base_dataloader = ParallelAwareDataloader(

torchtitan/experiments/vlm/datasets/utils/text.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import torch
88

9+
from ...tokenizer import VLMTokenizer as Tokenizer
10+
911

1012
def pad_text_batch(
1113
input_ids: torch.Tensor,
@@ -97,8 +99,7 @@ def pad_input_ids_and_labels_to_target_batch_size(
9799
def process_text_with_images(
98100
text: list[str],
99101
image_tokens: list[tuple[int, int, int]], # [(total, width, height), ...]
100-
tokenizer,
101-
special_tokens,
102+
tokenizer: Tokenizer,
102103
add_eos: bool = True,
103104
) -> str:
104105
"""Process text by interleaving image tokens efficiently.
@@ -122,14 +123,14 @@ def process_text_with_images(
122123
image_idx = 0
123124

124125
for part in text:
125-
if part == special_tokens.img_token and image_idx < len(image_tokens):
126+
if part == tokenizer.img_token and image_idx < len(image_tokens):
126127
num_image_tokens, _, _ = image_tokens[image_idx]
127128

128129
parts.extend(
129130
[
130-
special_tokens.boi_token,
131-
*([special_tokens.img_token] * num_image_tokens),
132-
special_tokens.eoi_token,
131+
tokenizer.boi_token,
132+
*([tokenizer.img_token] * num_image_tokens),
133+
tokenizer.eoi_token,
133134
]
134135
)
135136
image_idx += 1

0 commit comments

Comments
 (0)