Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ We present LFM2-Audio-1.5B, [Liquid AI](https://www.liquid.ai/)'s first end-to-e

LFM2-Audio supports two generation modes, interleaved and sequential, to maximize performance and quality across different tasks. Interleaved generation outputs text and audio tokens in a fixed interleaved pattern. This approach minimizes time to first audio output and number of tokens generated, making it ideal for naturally flowing real-time speech-to-speech interactions on resource constrained devices. Sequential generation mode, where the model decides when to switch modalities via special tokens, is suitable for non-conversational tasks, such as speech-to-text (ASR) or text-to-speech (TTS).

### Updates
- [LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) is released! This model is based on the stronger LFM2.5-1.2B base, and comes with a lightning fast LFM2 based audio detokenizer, stronger ASR, and better TTS voices. To use the new detokenizer, simply use `processor.decode`, see the examples below for more details. For the improved TTS voices, see the [TTS](#tts) section.

## Installation
The package can be installed via `pip`
```bash
Expand Down Expand Up @@ -61,7 +64,7 @@ import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
Expand Down Expand Up @@ -97,9 +100,8 @@ for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperatur

# Detokenize audio, removing the last "end-of-audio" codes
# Mimi returns audio at 24kHz
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
audio_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
waveform = processor.decode(audio_codes)
torchaudio.save("answer1.wav", waveform.cpu(), 24_000)

# Append newly generated tokens to chat history
Expand Down Expand Up @@ -128,9 +130,8 @@ for t in model.generate_interleaved(**chat, max_new_tokens=512, audio_temperatur
# output: Sure thing! How about “Comfortable Chairs, Crafted with Care” or “Elegant Seats, Handcrafted for You”? Let me know if you’d like a few more options.

# Detokenize second turn audio, removing the last "end-of-audio" codes
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
audio_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
waveform = processor.decode(audio_codes)
torchaudio.save("answer2.wav", waveform.cpu(), 24_000)
```

Expand All @@ -154,7 +155,7 @@ import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
Expand Down Expand Up @@ -182,19 +183,25 @@ for t in model.generate_sequential(**chat, max_new_tokens=512):
```

### TTS
For TTS, we also use sequential generation, with the fixed system prompt `Perform TTS.`. In addition, we can prompt the voice and a style using a natural language description.
For TTS, we also use sequential generation. We support four pre-defined voices, which can be selected by choosing one of the four system prompts below
```
Perform TTS. Use the US male voice.
Perform TTS. Use the US female voice.
Perform TTS. Use the UK male voice.
Perform TTS. Use the UK female voice.
```

<details>

<summary>TTS Sample</summary>

**Voice description**: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.
**System prompt**: Perform TTS. Use the UK male voice.

**Input sentence**: What is this obsession people have with books? They put them in their houses—like they're trophies. What do you need it for after you read it?

**Output audio**

https://github.com/user-attachments/assets/2fa953cf-d8a8-477a-b841-c4f18d9266e6
https://github.com/user-attachments/assets/8d57c184-b92e-4e1a-983b-d1f9d16d0d92

</details>

Expand All @@ -204,7 +211,7 @@ import torchaudio
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState, LFMModality

# Load models
HF_REPO = "LiquidAI/LFM2-Audio-1.5B"
HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"

processor = LFM2AudioProcessor.from_pretrained(HF_REPO).eval()
model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
Expand All @@ -213,7 +220,7 @@ model = LFM2AudioModel.from_pretrained(HF_REPO).eval()
chat = ChatState(processor)

chat.new_turn("system")
chat.add_text("Perform TTS.\nUse the following voice: A male speaker delivers his lines with a low-pitched voice and an animated tone. The recording is of excellent quality with almost no noise and a very close-sounding atmosphere.")
chat.add_text("Perform TTS. Use the UK male voice.")
chat.end_turn()

chat.new_turn("user")
Expand All @@ -229,9 +236,8 @@ for t in model.generate_sequential(**chat, max_new_tokens=512, audio_temperature
audio_out.append(t)

# Detokenize audio
mimi_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
with torch.no_grad():
waveform = processor.mimi.decode(mimi_codes)[0]
audio_codes = torch.stack(audio_out[:-1], 1).unsqueeze(0)
waveform = processor.decode(audio_codes)
torchaudio.save("tts.wav", waveform.cpu(), 24_000)
```

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "liquid-audio"
version = "1.0.0"
version = "1.1.0"
description = "Liquid Audio - Speech-to-Speech audio models"
readme = "README.md"
authors = [
Expand All @@ -16,6 +16,7 @@ dependencies = [
"sentencepiece>=0.2.1",
"torch>=2.8.0",
"torchaudio>=2.8.0",
"torchcodec>=0.9.1",
"transformers>=4.55.4",
]
keywords = ["Liquid AI", "LFM", "LFM2", "Audio", "Speech-to-Speech"]
Expand Down
3 changes: 2 additions & 1 deletion src/liquid_audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from liquid_audio.detokenizer import LFM2AudioDetokenizer
from liquid_audio.model.lfm2_audio import LFM2AudioModel
from liquid_audio.processor import ChatState, LFM2AudioProcessor
from liquid_audio.utils import LFMModality

__all__ = ["ChatState", "LFM2AudioModel", "LFM2AudioProcessor", "LFMModality"]
__all__ = ["ChatState", "LFM2AudioDetokenizer", "LFM2AudioModel", "LFM2AudioProcessor", "LFMModality"]
2 changes: 1 addition & 1 deletion src/liquid_audio/demo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__all__ = ["lfm2_audio", "mimi", "proc"]

HF_DIR = "LiquidAI/LFM2-Audio-1.5B"
HF_DIR = "LiquidAI/LFM2.5-Audio-1.5B"

logging.info("Loading processor")
proc = LFM2AudioProcessor.from_pretrained(HF_DIR).eval()
Expand Down
136 changes: 136 additions & 0 deletions src/liquid_audio/detokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from torch import nn
from transformers import Lfm2Config, Lfm2Model


class FusedEmbedding(nn.Module):
"""Turn codes into embeddings"""

def __init__(
self,
dim: int,
codeboooks: int = 8,
vocab_size: int = 2048,
):
super().__init__()
self.emb = nn.Embedding(codeboooks * vocab_size, dim)

self.codeboooks = codeboooks
self.vocab_size = vocab_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
offsets = torch.arange(self.codeboooks, device=x.device) * self.vocab_size # TODO: buffer?
offset_x = offsets[:, None] + x
return self.emb(offset_x).mean(1) # B L D


class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.

Adapted from Vocos: https://github.com/gemelo-ai/vocos/blob/c859e3b7b534f3776a357983029d34170ddd6fc3/vocos/spectral_ops.py#L7
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""

def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)

def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window, # type: ignore[arg-type]
center=True,
)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")

assert spec.dim() == 3, "Expected a 3D tensor as input"
_B, _N, T = spec.shape

# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None] # type: ignore[index]

# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]

# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) # type: ignore[operator]
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
).squeeze()[pad:-pad]

# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope

return y


class LFM2AudioDetokenizer(nn.Module):
def __init__(self, backbone_config: Lfm2Config):
super().__init__()
self.emb = FusedEmbedding(512)
self.lfm = Lfm2Model(backbone_config)
self.lin = nn.Linear(512, 1282) # half are log-magnitude, half are angle

self.istft = ISTFT(1280, 320, 1280, padding="same")
self.sliding_window_size = getattr(backbone_config, "sliding_window", 30)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.emb(x)
upsample_size = 6 * x.shape[1]
x = nn.functional.interpolate(x.mT, upsample_size, mode="nearest-exact").mT

# Set attn mask
idx = torch.arange(x.shape[1], device=x.device)
d_idx = idx - idx[:, None]
mask = torch.logical_and(d_idx <= 0, d_idx > -self.sliding_window_size)[None, None, ...]

x = self.lfm(inputs_embeds=x, attention_mask=mask, use_cache=False).last_hidden_state
x = self.lin(x)

log_abs, angle = torch.chunk(x.mT.contiguous(), 2, 1)
y = torch.polar(log_abs.exp(), angle)

return self.istft(y)
Loading
Loading