diff --git a/src/liquid_audio/data/mapper.py b/src/liquid_audio/data/mapper.py index ef639a4..c1804d3 100644 --- a/src/liquid_audio/data/mapper.py +++ b/src/liquid_audio/data/mapper.py @@ -13,9 +13,18 @@ class LFM2AudioChatMapper: """Map a chat into an LFM2 training sample.""" - def __init__(self, processor: LFM2AudioProcessor, *, codebooks: int = 8) -> None: + def __init__( + self, + processor: LFM2AudioProcessor, + *, + codebooks: int = 8, + interleaved_text_tokens: int = 6, + interleaved_audio_tokens: int = 12, + ) -> None: self.processor = processor self.codebooks = codebooks + self.interleaved_text_tokens = interleaved_text_tokens + self.interleaved_audio_tokens = interleaved_audio_tokens def __call__(self, messages: list[ChatMessage]) -> LFM2AudioTrainingSample: text_parts: list[torch.Tensor] = [] @@ -136,8 +145,8 @@ def _append_interleaved_out( audio_out = self._encode_audio_out(wav=wav, sampling_rate=sampling_rate) audio_out_parts.append(audio_out) - n_text = 6 - n_audio = 12 + n_text = self.interleaved_text_tokens + n_audio = self.interleaved_audio_tokens text_left = int(text_tokens.shape[0]) audio_left = int(audio_out.shape[1]) while text_left > 0 or audio_left > 0: