diff --git a/src/liquid_audio/demo/chat.py b/src/liquid_audio/demo/chat.py index b186529..2b3db20 100644 --- a/src/liquid_audio/demo/chat.py +++ b/src/liquid_audio/demo/chat.py @@ -8,7 +8,7 @@ from liquid_audio import ChatState, LFMModality -from .model import lfm2_audio, mimi, proc +from .model import device, lfm2_audio, mimi, proc def chat_producer( @@ -91,7 +91,7 @@ def chat_response(audio: tuple[int, np.ndarray], _id: str, chat: ChatState, temp chat.append( text=torch.stack(out_text, 1), audio_out=torch.stack(out_audio, 1), - modality_flag=torch.tensor(out_modality, device="cuda"), + modality_flag=torch.tensor(out_modality, device=device), ) chat.end_turn() @@ -122,7 +122,7 @@ def clear(): webrtc.stream( ReplyOnPause( chat_response, # type: ignore[arg-type] - input_sample_rate=24_000, + input_sample_rate=48_000, output_sample_rate=24_000, can_interrupt=False, ), diff --git a/src/liquid_audio/demo/model.py b/src/liquid_audio/demo/model.py index 8b27506..739b4d4 100644 --- a/src/liquid_audio/demo/model.py +++ b/src/liquid_audio/demo/model.py @@ -8,19 +8,26 @@ logger = logging.getLogger(__name__) -__all__ = ["lfm2_audio", "mimi", "proc"] +__all__ = ["lfm2_audio", "mimi", "proc", "device"] HF_DIR = "LiquidAI/LFM2.5-Audio-1.5B" +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + logging.info("Loading processor") -proc = LFM2AudioProcessor.from_pretrained(HF_DIR).eval() +proc = LFM2AudioProcessor.from_pretrained(HF_DIR, device=device).eval() logging.info("Loading model") -lfm2_audio = LFM2AudioModel.from_pretrained(HF_DIR).eval() +lfm2_audio = LFM2AudioModel.from_pretrained(HF_DIR, device=device).eval() logging.info("Loading tokenizer") mimi = proc.mimi.eval() logging.info("Warmup tokenizer") with mimi.streaming(1), torch.no_grad(): for _ in range(5): - x = torch.randint(2048, (1, 8, 1), device="cuda") + x = torch.randint(2048, (1, 8, 1), device=device) mimi.decode(x) diff --git a/src/liquid_audio/processor.py b/src/liquid_audio/processor.py index 7c60fed..c602eb2 100644 --- a/src/liquid_audio/processor.py +++ b/src/liquid_audio/processor.py @@ -111,7 +111,9 @@ def mimi(self) -> MimiModel: from safetensors.torch import load_file mimi_model = moshi.models.loaders.get_mimi(None, device=self.device) - mimi_weights = load_file(self.mimi_weights_path, device=str(self.device)) + # safetensors only supports cpu/cuda as load targets + load_device = str(self.device) if self.device.type in ("cpu", "cuda") else "cpu" + mimi_weights = load_file(self.mimi_weights_path, device=load_device) mimi_model.load_state_dict(mimi_weights, strict=True) self._mimi = mimi_model @@ -148,7 +150,7 @@ def rename_layer( assert isinstance(detok_config.layer_types, list) detok_config.layer_types = [rename_layer(layer) for layer in detok_config.layer_types] # type: ignore[arg-type] - detok = LFM2AudioDetokenizer(detok_config).eval().cuda() + detok = LFM2AudioDetokenizer(detok_config).eval().to(self.device) detok_weights_path = Path(self.detokenizer_path) / "model.safetensors" from safetensors.torch import load_file