diff --git a/src/ai/stt/eval_analyze_outputs.py b/src/ai/stt/eval_analyze_outputs.py new file mode 100644 index 0000000..c471043 --- /dev/null +++ b/src/ai/stt/eval_analyze_outputs.py @@ -0,0 +1,313 @@ +from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModel +from datasets import Audio, load_dataset +import os +import torchaudio.functional as F +import torchaudio +import torch +import tempfile +import json +import requests +from tqdm import tqdm +from jiwer import wer, cer +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +sample_dataset_size = 10 + +#classes for different model + +class ASRModel: + def __init__(self): + raise NotImplementedError + + def infer(self, mp3_path): + raise NotImplementedError + + +class WhsiperASRModel(ASRModel): + def __init__(self, model_path): + self.processor = WhisperProcessor.from_pretrained(model_path) + self.model = WhisperForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch.float16 + ).to(device) + self.target_sample_rate = 16000 + + def infer(self, mp3_path): + wav, sr = torchaudio.load(mp3_path) + if sr != self.target_sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sr, new_freq=self.target_sample_rate + ) + wav = resampler(wav) + + input_features = self.processor( + wav.squeeze().numpy(), + sampling_rate=self.target_sample_rate, + return_tensors="pt", + ).input_features + input_features = input_features.to(device, dtype=torch.float16) + # generate token ids + predicted_ids = self.model.generate( + input_features, task="transcribe", language="kn" + ) + # decode token ids to text + transcription_1 = self.processor.batch_decode( + predicted_ids, skip_special_tokens=True + )[0] + return transcription_1 + + +class IndicConformer(ASRModel): + def __init__(self, model_path, decoder="rnnt"): + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to( + device + ) + self.decoder = decoder + + def infer(self, mp3_path): + wav, sr = torchaudio.load(mp3_path) + wav = torch.mean(wav, dim=0, keepdim=True) + + target_sample_rate = 16000 + if sr != target_sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sr, new_freq=target_sample_rate + ) + wav = resampler(wav) + + transcription = self.model(wav, "kn", self.decoder) + return transcription + + +class SpeachesASRModel(ASRModel): + def __init__(self, base_url, model_name): + self.base_url = base_url + self.model_name = model_name + + def infer(self, mp3_path): + files = {"file": open(mp3_path, "rb")} + resp = requests.post( + f"{self.base_url}/v1/audio/transcriptions", + files=files, + json={"model": self.model_name, "language": "kn"}, + ) + output = resp.json() + return output["text"] + +#classes for different Dataset +class ASRDataset: + def __init__(self): + pass + + def save_samples(self, dir_to_save): + raise NotImplementedError + + +class SyntheticYoutubeASRDataset(ASRDataset): + def __init__(self, dataset_path): + super().__init__() + self.hf_dataset = load_dataset(dataset_path, split="test") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + def __repr__(self): + return "SyntheticYoutubeASRDataset" + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["prompt"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w",encoding="utf-8") as fp: + json.dump(json_data, fp,ensure_ascii=False) + + +class KathbathASRDataset(ASRDataset): + def __init__(self, dataset_path): + self.hf_dataset = load_dataset(dataset_path, "kannada", split="valid") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + + def __repr__(self): + return "KathbathASRDataset" + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio_filepath"]["array"]]) + filename = os.path.join(dir_to_save, row["audio_filepath"]["path"]) + gt = row["text"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w",encoding="utf-8") as fp: + json.dump(json_data, fp,ensure_ascii=False) + + +class GoogleTTSASRDataset(ASRDataset): + def __init__(self, dataset_path): + self.hf_dataset = load_dataset(dataset_path, split="test") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + def __repr__(self): + return "GoogleTTSASRDataset" + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["prompt"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w",encoding="utf-8") as fp: + json.dump(json_data, fp, ensure_ascii=False) + + +class ARTPARK_IISC_VANIDataset(ASRDataset): + dialect_for_vani = [ + "Karnataka_Bangalore", + "Karnataka_Belgaum", + "Karnataka_Bellary", + "Karnataka_Bidar", + "Karnataka_Bijapur", + "Karnataka_Chamrajnagar", + "Karnataka_DakshinKannada", + "Karnataka_Dharwad", + "Karnataka_Gulbarga", + "Karnataka_Koppal", + "Karnataka_Mysore", + "Karnataka_Raichur", + "Karnataka_Shimoga", + ] + + def __init__(self, dataset_path, subset): + super().__init__() + self.hf_dataset = load_dataset(dataset_path, subset, split="train") + self.hf_dataset = self.hf_dataset.filter( + lambda row: row["isTranscriptionAvailable"] == "Yes" + ) + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select(range(0, 3)) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + def __repr__(self): + return "ARTPARK_IISC_VANIDataset" + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["transcript"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w", encoding="utf-8") as fp: + json.dump(json_data, fp, ensure_ascii=False) + + +def evaluate_model(asr_model, asr_dataset): + dir_to_save = "/home/phoenix/tmp_whatever/stt/inference" + asr_dataset.save_samples(dir_to_save) + + with open(os.path.join(dir_to_save, "metadata.json")) as f: + json_data = json.load(f) + errors = [] + character_errors = [] + json_out=[] + for row in tqdm(json_data, desc="running model"): + filename = row["filename"] + gt = row["ground_truth"] + pred = asr_model.infer(filename) + row["prediction"] = pred + json_out.append(row) + error = wer(gt, pred) + character_error = cer(gt, pred) + # print("gt:", gt) + # print("pred:", pred) + # print("wer:", error) + # print("cer:", character_error ) + # print('-------') + errors.append(error) + character_errors.append(character_error) + with open(os.path.join(dir_to_save, f"{asr_dataset}_prediction.json"), "w",encoding='utf-8') as out: + json.dump(json_data, out, ensure_ascii=False) + + return np.mean(errors), np.mean(character_errors) + + +if __name__ == "__main__": + asr_models = { + # "finetuned_whispher": WhsiperASRModel( + # model_path="adithyal1998Bhat/whisper-kn" + # ), + # "indic_conformer": IndicConformer(model_path="ai4bharat/indic-conformer-600m-multilingual"), + # # "whisper-large-v3": SpeachesASRModel( + # # base_url="http://100.64.0.7:49827", + # # model_name="Systran/faster-whisper-large-v3", + # # ), + # "whisper-medium-vaani-kannada": WhsiperASRModel( + # model_path="ARTPARK-IISc/whisper-medium-vaani-kannada" + # ), + # "whisper-small-vaani-kannada": WhsiperASRModel( + # model_path="ARTPARK-IISc/whisper-small-vaani-kannada" + # ), + # "base_whisper_50_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/model_output_base_model" + # ), + # "base_youtube_50_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/model_output1" + # ), + # "base_youtube_5_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_5_epochs" + # ), + # "base_whisper_5_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_whisper_5_epochs" + # ), + # "openai/whisper-medium": WhsiperASRModel( + # model_path="openai/whisper-medium" + # ) + # "base_youtube_20_epochs_lr_1e_7": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_20_epochs_lr_1e_7" + # ), + "base_youtube_30_epochs_lr_1e_7": WhsiperASRModel( + model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_30_epochs_lr_1e_7" + ), + } + + asr_datasets = { + "google_tts": GoogleTTSASRDataset("adithyal1998Bhat/tts_synthetic_kn"), + "youtube_synthetic": SyntheticYoutubeASRDataset( + "adithyal1998Bhat/stt_synthetic_kn-IN_kannada" + ), + "kathbath": KathbathASRDataset("ai4bharat/Kathbath"), + } + + #asr_datasets['youtube_synthetic'].save_samples('recycle/youtube') + for dataset_name, asr_dataset in asr_datasets.items(): + for model_name, asr_model in asr_models.items(): + print("model", model_name, "dataset", dataset_name) + print('========\n'*3) + error = evaluate_model(asr_model, asr_dataset) + print(model_name, dataset_name," wer :", error[0]," cer :" ,error[1]) diff --git a/src/ai/stt/eval_wer_cer..py b/src/ai/stt/eval_wer_cer..py new file mode 100644 index 0000000..953afeb --- /dev/null +++ b/src/ai/stt/eval_wer_cer..py @@ -0,0 +1,303 @@ +from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModel +from datasets import Audio, load_dataset +import os +import torchaudio.functional as F +import torchaudio +import torch +import tempfile +import json +import requests +from tqdm import tqdm +from jiwer import wer, cer +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +sample_dataset_size = 100 + +#classes for different model + +class ASRModel: + def __init__(self): + raise NotImplementedError + + def infer(self, mp3_path): + raise NotImplementedError + + +class WhsiperASRModel(ASRModel): + def __init__(self, model_path): + self.processor = WhisperProcessor.from_pretrained(model_path) + self.model = WhisperForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch.float16 + ).to(device) + self.target_sample_rate = 16000 + + def infer(self, mp3_path): + wav, sr = torchaudio.load(mp3_path) + if sr != self.target_sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sr, new_freq=self.target_sample_rate + ) + wav = resampler(wav) + + input_features = self.processor( + wav.squeeze().numpy(), + sampling_rate=self.target_sample_rate, + return_tensors="pt", + ).input_features + input_features = input_features.to(device, dtype=torch.float16) + # generate token ids + predicted_ids = self.model.generate( + input_features, task="transcribe", language="kn" + ) + # decode token ids to text + transcription_1 = self.processor.batch_decode( + predicted_ids, skip_special_tokens=True + )[0] + return transcription_1 + + +class IndicConformer(ASRModel): + def __init__(self, model_path, decoder="rnnt"): + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to( + device + ) + self.decoder = decoder + + def infer(self, mp3_path): + wav, sr = torchaudio.load(mp3_path) + wav = torch.mean(wav, dim=0, keepdim=True) + + target_sample_rate = 16000 + if sr != target_sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sr, new_freq=target_sample_rate + ) + wav = resampler(wav) + + transcription = self.model(wav, "kn", self.decoder) + return transcription + + +class SpeachesASRModel(ASRModel): + def __init__(self, base_url, model_name): + self.base_url = base_url + self.model_name = model_name + + def infer(self, mp3_path): + files = {"file": open(mp3_path, "rb")} + resp = requests.post( + f"{self.base_url}/v1/audio/transcriptions", + files=files, + json={"model": self.model_name, "language": "kn"}, + ) + output = resp.json() + return output["text"] + +#classes for different Dataset +class ASRDataset: + def __init__(self): + pass + + def save_samples(self, dir_to_save): + raise NotImplementedError + + +class SyntheticYoutubeASRDataset(ASRDataset): + def __init__(self, dataset_path): + super().__init__() + self.hf_dataset = load_dataset(dataset_path, split="test") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["prompt"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w") as fp: + json.dump(json_data, fp) + + +class KathbathASRDataset(ASRDataset): + def __init__(self, dataset_path): + self.hf_dataset = load_dataset(dataset_path, "kannada", split="valid") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio_filepath"]["array"]]) + filename = os.path.join(dir_to_save, row["audio_filepath"]["path"]) + gt = row["text"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w") as fp: + json.dump(json_data, fp) + + +class GoogleTTSASRDataset(ASRDataset): + def __init__(self, dataset_path): + self.hf_dataset = load_dataset(dataset_path, split="test") + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select( + range(0, sample_dataset_size) + ) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["prompt"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w") as fp: + json.dump(json_data, fp) + + +class ARTPARK_IISC_VANIDataset(ASRDataset): + dialect_for_vani = [ + "Karnataka_Bangalore", + "Karnataka_Belgaum", + "Karnataka_Bellary", + "Karnataka_Bidar", + "Karnataka_Bijapur", + "Karnataka_Chamrajnagar", + "Karnataka_DakshinKannada", + "Karnataka_Dharwad", + "Karnataka_Gulbarga", + "Karnataka_Koppal", + "Karnataka_Mysore", + "Karnataka_Raichur", + "Karnataka_Shimoga", + ] + + def __init__(self, dataset_path, subset): + super().__init__() + self.hf_dataset = load_dataset(dataset_path, subset, split="train") + self.hf_dataset = self.hf_dataset.filter( + lambda row: row["isTranscriptionAvailable"] == "Yes" + ) + self.sample_rate = 16000 + self.hf_dataset = self.hf_dataset.shuffle(seed=42).select(range(0, 3)) + self.hf_dataset = self.hf_dataset.cast_column( + "audio", Audio(sampling_rate=self.sample_rate) + ) + + def save_samples(self, dir_to_save): + json_data = [] + for row in tqdm(self.hf_dataset, desc="processing data"): + wav = torch.tensor([row["audio"]["array"]]) + filename = os.path.join(dir_to_save, row["audio"]["path"]) + gt = row["transcript"] + torchaudio.save(filename, wav, self.sample_rate) + json_data.append({"filename": filename, "ground_truth": gt}) + + with open(os.path.join(dir_to_save, "metadata.json"), "w") as fp: + json.dump(json_data, fp) + + +def evaluate_model(asr_model, asr_dataset): + with tempfile.TemporaryDirectory() as dir_to_save: + asr_dataset.save_samples(dir_to_save) + with open(os.path.join(dir_to_save, "metadata.json")) as f: + json_data = json.load(f) + + errors = [] + character_errors = [] + for row in tqdm(json_data, desc="running model"): + filename = row["filename"] + gt = row["ground_truth"] + pred = asr_model.infer(filename) + error = wer(gt, pred) + character_error = cer(gt, pred) + # print("gt:", gt) + # print("pred:", pred) + # print("wer:", error) + # print("cer:", character_error ) + # print('-------') + errors.append(error) + character_errors.append(character_error) + + return np.mean(errors), np.mean(character_errors) + + +if __name__ == "__main__": + asr_models = { + # "finetuned_whispher": WhsiperASRModel( + # model_path="adithyal1998Bhat/whisper-kn" + # ), + # "indic_conformer": IndicConformer(model_path="ai4bharat/indic-conformer-600m-multilingual"), + # # "whisper-large-v3": SpeachesASRModel( + # # base_url="http://100.64.0.7:49827", + # # model_name="Systran/faster-whisper-large-v3", + # # ), + # "whisper-medium-vaani-kannada": WhsiperASRModel( + # model_path="ARTPARK-IISc/whisper-medium-vaani-kannada" + # ), + # "whisper-small-vaani-kannada": WhsiperASRModel( + # model_path="ARTPARK-IISc/whisper-small-vaani-kannada" + # ), + # "base_whisper_50_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/model_output_base_model" + # ), + # "base_youtube_50_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/model_output1" + # ), + # "base_youtube_5_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_5_epochs" + # ), + # "base_whisper_5_epochs": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_whisper_5_epochs" + # ), + # "openai/whisper-medium": WhsiperASRModel( + # model_path="openai/whisper-medium" + # ) + # "base_youtube_20_epochs_lr_1e_7": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_20_epochs_lr_1e_7" + # ), + # "base_youtube_30_epochs_lr_1e_7": WhsiperASRModel( + # model_path="/home/phoenix/tmp_whatever/stt/training/base_youtube_30_epochs_lr_1e_7" + # ), + "base_whisper_50_epochs_lr_1e_7": WhsiperASRModel( + model_path="/home/phoenix/tmp_whatever/stt/training/base_whisper_50_epochs_lr_1e_7" + ), + + } + + asr_datasets = { + "google_tts": GoogleTTSASRDataset("adithyal1998Bhat/tts_synthetic_kn"), + "youtube_synthetic": SyntheticYoutubeASRDataset( + "adithyal1998Bhat/stt_synthetic_kn-IN_kannada" + ), + "kathbath": KathbathASRDataset("ai4bharat/Kathbath"), + } + + #asr_datasets['youtube_synthetic'].save_samples('recycle/youtube') + for dataset_name, asr_dataset in asr_datasets.items(): + for model_name, asr_model in asr_models.items(): + print("model", model_name, "dataset", dataset_name) + print('========\n'*3) + error = evaluate_model(asr_model, asr_dataset) + print(model_name, dataset_name," wer :", error[0]," cer :" ,error[1]) diff --git a/src/ai/stt/training.py b/src/ai/stt/training.py new file mode 100644 index 0000000..e62e6b9 --- /dev/null +++ b/src/ai/stt/training.py @@ -0,0 +1,198 @@ +import pkg_resources +import os +import sys +pkg_resources.require("transformers==4.48.0") +os.system("pip install -U 'protobuf>=3.4.0'") +import transformers +from huggingface_hub import notebook_login +from datasets import load_dataset, DatasetDict +from transformers import WhisperFeatureExtractor +from transformers import WhisperTokenizer +from transformers import WhisperForConditionalGeneration +import torch +from dataclasses import dataclass +from typing import Any, Dict, List, Union +import evaluate + + +from transformers import ( + Seq2SeqTrainingArguments, + Seq2SeqTrainer, + WhisperTokenizer, + WhisperProcessor, +) +from datasets import Audio +from transformers import pipeline +import gradio as gr + +def make_dirs(path): + if not os.path.exists(path): + os.makedirs(path) + print(path, "created") + + +def sample(ds, percentage=1): + if percentage >= 1: + return ds + + return ds.select(range(int(percentage * len(ds)))) + + +# os.system("huggingface-cli ....." mind that you should be logged into hugging face either through the script or +HF_dataset_name = "adithyal1998Bhat/tts_synthetic_kn_single_sentences" +common_voice = DatasetDict() +common_voice["train"] = sample(load_dataset(HF_dataset_name, split="train")) +common_voice["test"] = sample(load_dataset(HF_dataset_name, split="test"), 0.05) + + +base_model_name = "openai/whisper-medium" +feature_extractor = WhisperFeatureExtractor.from_pretrained(base_model_name) + +# what does the tokenize do here +tokenizer = WhisperTokenizer.from_pretrained( + base_model_name, language="Kannada", task="transcribe" +) +processor = WhisperProcessor.from_pretrained( + base_model_name, language="Kannada", task="transcribe" +) +common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000)) + + +def prepare_dataset(batch): + audio = batch["audio"] + # compute log-mel input features from input audio array. Why? + batch["input_features"] = feature_extractor( + audio["array"], sampling_rate=audio["sampling_rate"] + ).input_features[0] + + # encode target text to label ids + batch["labels"] = tokenizer(batch["prompt"]).input_ids + return batch + + +common_voice = common_voice.map(prepare_dataset, num_proc=2, writer_batch_size=20) + + +model = WhisperForConditionalGeneration.from_pretrained(base_model_name) + + +def drop_large_sentences(ds): + # drop samples with tokens supported by the model + max_size = model.config.max_target_positions + idxs_small = [idx for idx, d in enumerate(ds["labels"]) if len(d) < max_size] + print("dropping {} samples with large sentences".format(len(ds) - len(idxs_small))) + return ds.select(idxs_small) + + +common_voice["train"] = drop_large_sentences(common_voice["train"]) +common_voice["test"] = drop_large_sentences(common_voice["test"]) +print(common_voice) + + +model.generation_config.language = "Kannada" +model.generation_config.task = "transcribe" +model.generation_config.forced_decoder_ids = None + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + processor: Any + decoder_start_token_id: int + + def __call__( + self, features: List[Dict[str, Union[List[int], torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: + input_features = [ + {"input_features": feature["input_features"]} for feature in features + ] + batch = self.processor.feature_extractor.pad( + input_features, return_tensors="pt" + ) + + label_features = [{"input_ids": feature["labels"]} for feature in features] + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + labels = labels_batch["input_ids"].masked_fill( + labels_batch.attention_mask.ne(1), -100 + ) + + # what is bos? + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + return batch + + +data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + decoder_start_token_id=model.config.decoder_start_token_id, +) +metric = evaluate.load("wer") + + +def compute_metrics(pred): + pred_ids = pred.predictions + label_ids = pred.label_ids + + label_ids[label_ids == -100] = tokenizer.pad_token_id + + # no grouping token when we are computing metrics + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) + + wer = 100 * metric.compute(predictions=pred_str, references=label_str) + return {"wer": wer} + + +data_dir = sys.argv[1] + +make_dirs(data_dir) +training_args = Seq2SeqTrainingArguments( + output_dir=data_dir, + per_device_train_batch_size=1, + gradient_accumulation_steps=16, + learning_rate=1.00e-7, + warmup_steps=500, + max_steps=10000, + gradient_checkpointing=True, + fp16=True, + eval_strategy="steps", + per_device_eval_batch_size=1, + predict_with_generate=True, + generation_max_length=400, + save_steps=10000, + eval_steps=1000, + logging_steps=25, + report_to=["wandb"], + load_best_model_at_end=True, + metric_for_best_model="wer", + greater_is_better=False, + push_to_hub=False, +) +trainer = Seq2SeqTrainer( + args=training_args, + model=model, + train_dataset=common_voice["train"], + eval_dataset=common_voice["test"], + data_collator=data_collator, + compute_metrics=compute_metrics, + tokenizer=processor.feature_extractor, +) + + +processor.save_pretrained(training_args.output_dir) +#trainer.train(resume_from_checkpoint=True) +trainer.train() + +# kwargs = { +# "dataset_tags": HF_dataset_name, +# "dataset": "kannada voices", # a 'pretty' name for the training dataset +# "dataset_args": "config: kn, split: test", +# "language": "kn", +# "model_name": "Whisper Small kn - Saraswathi", # a 'pretty' name for our model +# "finetuned_from": "ope100whisper-small", +# "tasks": "automatic-speech-recognition", +# } + +# trainer.push_to_hub(**kwargs) \ No newline at end of file diff --git a/src/ai/tts/README.md b/src/ai/tts/README.md index fb36ec7..b4fc293 100644 --- a/src/ai/tts/README.md +++ b/src/ai/tts/README.md @@ -199,8 +199,23 @@ bash download_wiki.sh https://dumps.wikimedia.org/knwiki/20250620/knwiki-2025062 Here /tmp/tts_data is the directory where the text wikipedia data is downloaded and extracted to. -Prepare the data +Prepare the data ``` python3 /synthetic_data/text_cleaning/clean.py /tmp/tts_data/ -``` \ No newline at end of file +``` +Following which we shall have a clean and final dataset called data_set_list_kannada_wiki_final_dataset.json which shall be used as an input to synthesize audio. + + +On the bash shell. Input the path of the .json credential file for accessing google tts API as shown below + +``` +export GOOGLE_APPLICATION_CREDENTIALS="/home/to/wherever/tts_gcp.json" +``` + +Now we synthesize audio and push to hugging face hub(guide to hugging face login https://huggingface.co/docs/huggingface_hub/en/guides/cli) + +``` +:~/llama.lisp/src/ai/tts$ python synthetic_data/google_tts/generate.py /tmp/tts_data +``` +