2020from torch .utils .data import IterableDataset
2121
2222from torchtitan .components .dataloader import ParallelAwareDataloader
23- from torchtitan .components .tokenizer import BaseTokenizer , HuggingFaceTokenizer
2423from torchtitan .config import JobConfig
2524from torchtitan .datasets import DatasetConfig
2625from torchtitan .tools .logging import logger
2726
28- from ..model . args import SpecialTokens
27+ from ..tokenizer import VLMTokenizer as Tokenizer
2928from .mm_collator_nld import MultiModalCollatorNLD
3029from .utils .image import calculate_image_tokens , process_image
3130from .utils .packing import SamplePacker
3231from .utils .text import process_text_with_images
3332
3433
34+ IGNORE_INDEX = - 100 # Pytorch's default for F.cross_entropy
35+
36+
3537def _process_mm_sample (
3638 texts : list [str ] | str ,
3739 images : list [bytes ] | bytes ,
38- tokenizer : BaseTokenizer ,
40+ tokenizer : Tokenizer ,
3941 patch_size : int ,
4042 max_patch_per_image : int ,
4143 spatial_merge_size : int ,
42- special_tokens : SpecialTokens ,
4344) -> dict [str , Any ] | None :
4445 """Common processing logic for multimodal samples.
4546
@@ -98,7 +99,7 @@ def _process_mm_sample(
9899 processed_images .append (processed_img )
99100 image_dimensions .append ((num_tokens , width , height ))
100101 # Replace None with image token
101- texts_list [idx ] = special_tokens .img_token
102+ texts_list [idx ] = tokenizer .img_token
102103 else :
103104 # Replace None with empty string if processing failed
104105 texts_list [idx ] = ""
@@ -109,7 +110,7 @@ def _process_mm_sample(
109110
110111 # Process all image tokens at once
111112 processed_text = process_text_with_images (
112- texts_list , image_dimensions , tokenizer , special_tokens , add_eos = True
113+ texts_list , image_dimensions , tokenizer , add_eos = True
113114 )
114115
115116 tokens = tokenizer .encode (processed_text )
@@ -120,10 +121,10 @@ def _process_mm_sample(
120121
121122 # Mask special tokens in labels
122123 special_token_ids = torch .tensor (
123- [special_tokens .boi_id , special_tokens .eoi_id , special_tokens .img_id ]
124+ [tokenizer .boi_id , tokenizer .eoi_id , tokenizer .img_id ]
124125 )
125126 labels = torch .where (
126- torch .isin (labels , special_token_ids ), special_tokens . ignore_id , labels
127+ torch .isin (labels , special_token_ids ), IGNORE_INDEX , labels
127128 )
128129
129130 return {
@@ -139,11 +140,10 @@ def _process_mm_sample(
139140
140141def _process_obelics_sample (
141142 sample : dict [str , Any ],
142- tokenizer : HuggingFaceTokenizer ,
143+ tokenizer : Tokenizer ,
143144 patch_size : int ,
144145 spatial_merge_size : int ,
145146 max_patch_per_image : int ,
146- special_tokens : SpecialTokens ,
147147) -> dict [str , Any ] | None :
148148 """Process a sample from the OBELICS dataset."""
149149 return _process_mm_sample (
@@ -153,17 +153,15 @@ def _process_obelics_sample(
153153 patch_size = patch_size ,
154154 spatial_merge_size = spatial_merge_size ,
155155 max_patch_per_image = max_patch_per_image ,
156- special_tokens = special_tokens ,
157156 )
158157
159158
160159def _process_cc12_wd_sample (
161160 sample : dict [str , Any ],
162- tokenizer : BaseTokenizer ,
161+ tokenizer : Tokenizer ,
163162 patch_size : int ,
164163 spatial_merge_size : int ,
165164 max_patch_per_image : int ,
166- special_tokens : SpecialTokens ,
167165) -> dict [str , Any ] | None :
168166 """Process a sample from the CC12-WD dataset.
169167 Transforms CC12-WD format to match Interleaved format:
@@ -184,7 +182,6 @@ def _process_cc12_wd_sample(
184182 patch_size = patch_size ,
185183 spatial_merge_size = spatial_merge_size ,
186184 max_patch_per_image = max_patch_per_image ,
187- special_tokens = special_tokens ,
188185 )
189186
190187
@@ -225,15 +222,14 @@ def __init__(
225222 self ,
226223 dataset_name : str ,
227224 dataset_path : str | None ,
228- tokenizer : BaseTokenizer ,
225+ tokenizer : Tokenizer ,
229226 batch_size : int ,
230227 seq_len : int ,
231228 patch_size : int ,
232229 spatial_merge_size : int ,
233230 max_patches_per_image : int ,
234231 max_images_per_batch : int ,
235232 packing_buffer_size : int ,
236- special_tokens : SpecialTokens ,
237233 dp_rank : int = 0 ,
238234 dp_world_size : int = 1 ,
239235 infinite : bool = False ,
@@ -254,7 +250,6 @@ def __init__(
254250 self .spatial_merge_size = spatial_merge_size
255251 self .max_patches_per_image = max_patches_per_image
256252 self .max_images_per_batch = max_images_per_batch
257- self .special_tokens = special_tokens
258253 self .enable_packing = packing_buffer_size > 0
259254 if self .enable_packing :
260255 self .packer = SamplePacker (
@@ -277,7 +272,6 @@ def __iter__(self):
277272 patch_size = self .patch_size ,
278273 spatial_merge_size = self .spatial_merge_size ,
279274 max_patch_per_image = self .max_patches_per_image ,
280- special_tokens = self .special_tokens ,
281275 )
282276 if processed is None :
283277 continue
@@ -366,7 +360,7 @@ def state_dict(self):
366360def build_mm_dataloader (
367361 dp_world_size : int ,
368362 dp_rank : int ,
369- tokenizer : HuggingFaceTokenizer ,
363+ tokenizer : Tokenizer ,
370364 job_config : JobConfig ,
371365 infinite : bool = True ,
372366) -> ParallelAwareDataloader :
@@ -393,7 +387,6 @@ def build_mm_dataloader(
393387 patch_size = job_config .data .patch_size
394388 spatial_merge_size = job_config .data .spatial_merge_size
395389 packing_buffer_size = job_config .data .packing_buffer_size
396- special_tokens = SpecialTokens .from_tokenizer (tokenizer )
397390
398391 dataset = MultiModalDataset (
399392 dataset_name = job_config .training .dataset ,
@@ -406,7 +399,6 @@ def build_mm_dataloader(
406399 max_patches_per_image = max_patches_per_image ,
407400 max_images_per_batch = max_images_per_batch ,
408401 packing_buffer_size = packing_buffer_size ,
409- special_tokens = special_tokens ,
410402 dp_rank = dp_rank ,
411403 dp_world_size = dp_world_size ,
412404 infinite = infinite ,
@@ -418,7 +410,7 @@ def build_mm_dataloader(
418410 patch_size = patch_size ,
419411 max_images_per_batch = max_images_per_batch ,
420412 max_patches_per_image = max_patches_per_image ,
421- special_tokens = special_tokens ,
413+ tokenizer = tokenizer ,
422414 )
423415
424416 base_dataloader = ParallelAwareDataloader (
0 commit comments