diff --git a/app/core/config.py b/app/core/config.py index 405f5f8..0a8c0bf 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -87,6 +87,11 @@ def similarity_threshold(self) -> float: """声纹相似度阈值""" return self.voiceprint.get("similarity_threshold", 0.2) + @property + def model_id(self)-> str: + """模型ID""" + return self.voiceprint.get("model_id","iic/speech_campplus_sv_zh-cn_3dspeaker_16k") + @property def target_sample_rate(self) -> int: """目标音频采样率""" diff --git a/app/services/voiceprint_service.py b/app/services/voiceprint_service.py index afdcfb8..d071939 100644 --- a/app/services/voiceprint_service.py +++ b/app/services/voiceprint_service.py @@ -19,6 +19,7 @@ class VoiceprintService: def __init__(self): self._pipeline = None self.similarity_threshold = settings.similarity_threshold + self.model_id = settings.model_id self._pipeline_lock = threading.Lock() # 添加线程锁 self._init_pipeline() self._warmup_model() # 添加模型预热 @@ -37,10 +38,10 @@ def _init_pipeline(self) -> None: device = "cpu" logger.info("使用CPU设备") - logger.info("开始加载模型: iic/speech_campplus_sv_zh-cn_3dspeaker_16k") + logger.info(f"开始加载模型: {self.model_id}") self._pipeline = pipeline( task=Tasks.speaker_verification, - model="iic/speech_campplus_sv_zh-cn_3dspeaker_16k", + model=self.model_id, device=device, ) diff --git a/docker-compose.yml b/docker-compose.yml index 28a5e76..0f1e92d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,5 +18,12 @@ services: volumes: # 配置文件目录 - ./data:/app/data + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] networks: default: diff --git a/requirements.txt b/requirements.txt index 82cbf26..a38b6cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,21 @@ -modelscope==1.13.0 -datasets==2.14.5 -numpy==1.23.5 -packaging==21.3 +modelscope==1.26.0 +datasets==3.2.0 +numpy==1.26.4 +packaging==24.1 addict==2.4.0 transformers==4.52.4 -torch==2.2.2 +torch==2.6.0 sentencepiece==0.2.0 soundfile==0.13.1 -torchaudio==2.2.2 +torchaudio==2.6.0 pyyaml==6.0.1 fastapi==0.110.2 uvicorn==0.29.0 PyMySQL==1.1.0 python-multipart==0.0.9 librosa==0.10.1 -loguru==0.7.2 \ No newline at end of file +loguru==0.7.2 +pyarrow==20.0.0 +Pillow==10.4.0 +simplejson==3.20.1 +sortedcontainers==2.4.0