diff --git a/.env.example b/.env.example index c7e1031..38aca7b 100644 --- a/.env.example +++ b/.env.example @@ -14,3 +14,8 @@ OPENAI_API_KEY= OPENAI_MODEL=whisper-1 OPENAI_API_ENDPOINT=https://api.openai.com/v1/audio/transcriptions OPENAI_MODEL_NAME_CHAT=gpt-4o + +CLOUDFLARE_ACCOUNT_ID= +CLOUDFLARE_API_TOKEN= +CLOUDFLARE_WHISPER_MODEL=@cf/openai/whisper +CLOUDFLARE_MAX_FILE_SIZE_MB=25 diff --git a/README.md b/README.md index e5b7168..391c8ec 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,22 @@ # Video Transcription Tool -This tool automates the process of transcribing video files using multiple transcription services: Azure OpenAI, Groq, and OpenAI APIs. It converts video files to MP3 format, transcribes the audio, and saves the transcription as a text file. +This tool automates the process of transcribing video files using multiple transcription services: Azure OpenAI, Cloudflare Workers AI, Groq, and OpenAI APIs. It converts video files to MP3 format, transcribes the audio, and saves the transcription as a text file. ## Features - Converts video files to MP3 format using ffmpeg -- Supports transcription using Azure OpenAI, Groq, and OpenAI APIs +- Supports transcription using Azure OpenAI, Cloudflare Workers AI, Groq, and OpenAI APIs - Supports processing of individual video files or entire directories - Cleans up temporary MP3 files after transcription - Provides flexibility in selecting the transcription service via configuration ## Prerequisites -- Python 3.6+ +- Python 3.8+ - ffmpeg installed and available in the system PATH - API access for one or more supported services: - Azure OpenAI API + - Cloudflare Workers AI API - Groq Cloud API - OpenAI API @@ -53,6 +54,12 @@ This tool automates the process of transcribing video files using multiple trans OPENAI_MODEL=whisper-1 OPENAI_API_ENDPOINT=https://api.openai.com/v1/audio/transcriptions OPENAI_MODEL_NAME_CHAT=gpt-4o + + # Cloudflare Workers AI + CLOUDFLARE_ACCOUNT_ID=your_cloudflare_account_id_here + CLOUDFLARE_API_TOKEN=your_cloudflare_api_token_here + CLOUDFLARE_WHISPER_MODEL=@cf/openai/whisper + CLOUDFLARE_MAX_FILE_SIZE_MB=25 ``` ## Building and Installing the Package @@ -113,6 +120,7 @@ sapat [--language ] [--prompt ] [--t - `--quality`: Quality of the MP3 audio: 'L' for low, 'M' for medium, and 'H' for high (default: 'M'). - `--api`: Specify the API to use for transcription. - `--api azure` for Azure OpenAI API + - `--api cloudflare` for Cloudflare Workers AI - `--api groq` for Groq Cloud API - `--api openai` for OpenAI API @@ -122,6 +130,12 @@ Example: sapat my_video.mp4 --quality H --language es --prompt "This is a test prompt" --temperature 0.5 --api groq ``` +Cloudflare Workers AI example: + +``` +sapat my_video.mp4 --quality M --api cloudflare +``` + - If a file is provided, it will process that single file. - If a directory is provided, it will process all `.mp4` files in that directory. @@ -129,7 +143,7 @@ The script will create a `.txt` file with the same name as the input video file, ## Note -This tool is designed for use with multiple APIs (Azure OpenAI, Groq, and OpenAI). Ensure you have valid API credentials configured in the `.env` file and the necessary permissions and credits for the API service you plan to use. +This tool is designed for use with multiple APIs (Azure OpenAI, Cloudflare Workers AI, Groq, and OpenAI). Ensure you have valid API credentials configured in the `.env` file and the necessary permissions and credits for the API service you plan to use. ## License diff --git a/pyproject.toml b/pyproject.toml index 1db16a2..f8d1625 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "sapat" version = "0.1.2" description = "Video Transcription Tool using different APIs" -requires-python = ">=3.6" +requires-python = ">=3.8" dependencies = [ "click>=8.1.8", "requests>=2.32.3", diff --git a/src/sapat/script.py b/src/sapat/script.py index ed22032..c631532 100644 --- a/src/sapat/script.py +++ b/src/sapat/script.py @@ -2,6 +2,7 @@ from pathlib import Path from .transcription.groq import GroqCloudTranscription from .transcription.azure import AzureTranscription +from .transcription.cloudflare import CloudflareTranscription from .transcription.openai import OpenAITranscription @click.command() @@ -11,7 +12,7 @@ @click.option("--temperature", "-t", type=float, default=0.3, help="Sampling temperature (default: 0.3)") @click.option("--quality", "-q", type=click.Choice(['L', 'M', 'H'], case_sensitive=False), default='M', help="Quality of the MP3 audio: 'L' for low, 'M' for medium, and 'H' for high (default: 'M')") @click.option("--correct", is_flag=True, help="Use LLM to correct the transcript") -@click.option("--api", "-a", type=click.Choice(['openai', 'groq', 'azure'], case_sensitive=True), required=True, help="API to use for the transcription ('openai', 'groq' or 'azure')") +@click.option("--api", "-a", type=click.Choice(['openai', 'groq', 'azure', 'cloudflare'], case_sensitive=True), required=True, help="API to use for the transcription ('openai', 'groq', 'azure' or 'cloudflare')") def main(input_path, language, prompt, temperature, quality, correct, api): """ Transcribe video files using different APIs. @@ -26,6 +27,8 @@ def main(input_path, language, prompt, temperature, quality, correct, api): transcriber = GroqCloudTranscription(temperature=temperature) elif api.lower() == "azure": transcriber = AzureTranscription(temperature=temperature) + elif api.lower() == "cloudflare": + transcriber = CloudflareTranscription(temperature=temperature) elif api.lower() == "openai": transcriber = OpenAITranscription(temperature=temperature) else: diff --git a/src/sapat/transcription/cloudflare.py b/src/sapat/transcription/cloudflare.py new file mode 100644 index 0000000..fad9eb1 --- /dev/null +++ b/src/sapat/transcription/cloudflare.py @@ -0,0 +1,106 @@ +import os +from pathlib import Path + +import requests +from dotenv import load_dotenv + +from .base import TranscriptionBase + +# Load environment variables +load_dotenv(".env") + + +class CloudflareTranscription(TranscriptionBase): + """ + Cloudflare Workers AI implementation for transcription. + """ + + def __init__(self, temperature: float, response_format: str = "json"): + """ + Initializes the CloudflareTranscription class. + + Parameters: + - temperature (float): Default temperature value for transcription. + - response_format (str): Default response format for transcription. + """ + self.api_token = os.getenv("CLOUDFLARE_API_TOKEN") + self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID") + self.model = os.getenv("CLOUDFLARE_WHISPER_MODEL", "@cf/openai/whisper") + self.endpoint = os.getenv("CLOUDFLARE_API_ENDPOINT") + self.temperature = temperature + self.response_format = response_format + self.max_file_size_mb = int(os.getenv("CLOUDFLARE_MAX_FILE_SIZE_MB", "25")) + + def transcribe_audio(self, audio_file: str, **kwargs): + """ + Transcribes an audio file using Cloudflare Workers AI Whisper. + + Parameters: + - audio_file (str): Path to the audio file. + - kwargs: Additional parameters accepted for CLI compatibility. + + Returns: + - dict or str: The transcription result. + """ + self._validate_audio_file(audio_file) + self._validate_configuration() + + headers = { + "Authorization": f"Bearer {self.api_token}", + "Content-Type": self._content_type(audio_file), + } + + with open(audio_file, "rb") as f: + response = requests.post(self._transcription_endpoint(), headers=headers, data=f.read()) + + if response.status_code != 200: + raise Exception(f"Transcription failed: {response.text}") + + payload = response.json() + if isinstance(payload, dict) and "result" in payload: + return payload["result"] + return payload + + def _transcription_endpoint(self): + if self.endpoint: + return self.endpoint + + return f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{self.model}" + + def _validate_configuration(self): + if not self.api_token: + raise ValueError("CLOUDFLARE_API_TOKEN is required for Cloudflare transcription.") + + if not self.account_id and not self.endpoint: + raise ValueError("CLOUDFLARE_ACCOUNT_ID is required unless CLOUDFLARE_API_ENDPOINT is set.") + + def _validate_audio_file(self, audio_file: str): + """ + Validates the audio file for size and format. + + Parameters: + - audio_file (str): Path to the audio file. + + Raises: + - Exception: If the file is invalid. + """ + if not os.path.exists(audio_file): + raise ValueError(f"File {audio_file} does not exist.") + + file_size_mb = os.path.getsize(audio_file) / (1024 * 1024) + if file_size_mb > self.max_file_size_mb: + raise Exception(f"File size exceeds the maximum limit of {self.max_file_size_mb} MB.") + + valid_extensions = [".mp3", ".wav", ".flac"] + if not any(str(audio_file).endswith(ext) for ext in valid_extensions): + raise ValueError(f"Unsupported audio file format: {audio_file}. Supported formats are {valid_extensions}.") + + @staticmethod + def _content_type(audio_file: str): + extension = Path(audio_file).suffix.lower() + content_types = { + ".flac": "audio/flac", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + } + return content_types.get(extension, "application/octet-stream") diff --git a/tests/test_cloudflare_transcription.py b/tests/test_cloudflare_transcription.py new file mode 100644 index 0000000..3bd3ba3 --- /dev/null +++ b/tests/test_cloudflare_transcription.py @@ -0,0 +1,71 @@ +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +from sapat.transcription.cloudflare import CloudflareTranscription + + +class CloudflareTranscriptionTest(unittest.TestCase): + def setUp(self): + self.previous_env = { + "CLOUDFLARE_ACCOUNT_ID": os.environ.get("CLOUDFLARE_ACCOUNT_ID"), + "CLOUDFLARE_API_TOKEN": os.environ.get("CLOUDFLARE_API_TOKEN"), + "CLOUDFLARE_API_ENDPOINT": os.environ.get("CLOUDFLARE_API_ENDPOINT"), + "CLOUDFLARE_WHISPER_MODEL": os.environ.get("CLOUDFLARE_WHISPER_MODEL"), + } + os.environ["CLOUDFLARE_ACCOUNT_ID"] = "account-123" + os.environ["CLOUDFLARE_API_TOKEN"] = "token-abc" + os.environ.pop("CLOUDFLARE_API_ENDPOINT", None) + os.environ.pop("CLOUDFLARE_WHISPER_MODEL", None) + + def tearDown(self): + for key, value in self.previous_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + @patch("sapat.transcription.cloudflare.requests.post") + def test_transcribe_audio_posts_binary_audio_to_workers_ai(self, post): + post.return_value = Mock( + status_code=200, + text="ok", + json=Mock(return_value={"success": True, "result": {"text": "hello"}}), + ) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as audio: + audio.write(b"audio-bytes") + path = audio.name + + try: + result = CloudflareTranscription(temperature=0.3).transcribe_audio(path) + finally: + os.unlink(path) + + self.assertEqual({"text": "hello"}, result) + post.assert_called_once_with( + "https://api.cloudflare.com/client/v4/accounts/account-123/ai/run/@cf/openai/whisper", + headers={ + "Authorization": "Bearer token-abc", + "Content-Type": "audio/mpeg", + }, + data=b"audio-bytes", + ) + + def test_missing_token_raises_actionable_error(self): + os.environ.pop("CLOUDFLARE_API_TOKEN", None) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as audio: + audio.write(b"audio-bytes") + path = audio.name + + try: + with self.assertRaisesRegex(ValueError, "CLOUDFLARE_API_TOKEN"): + CloudflareTranscription(temperature=0.3).transcribe_audio(path) + finally: + os.unlink(path) + + +if __name__ == "__main__": + unittest.main()