diff --git a/pyproject.toml b/pyproject.toml index 77f2855..21f4738 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "isort==5.13.2", "tomli==2.2.1", "claude-agent-sdk>=0.1.0", + "tiktoken==0.12.0", + "genai-prices==0.0.51", ] classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/python_gpt_po/main.py b/python_gpt_po/main.py index b265c04..2eb4f45 100644 --- a/python_gpt_po/main.py +++ b/python_gpt_po/main.py @@ -11,10 +11,9 @@ import traceback from argparse import Namespace from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from .models.config import TranslationConfig, TranslationFlags -from .models.enums import ModelProvider from .models.provider_clients import ProviderClients from .services.language_detector import LanguageDetector from .services.model_manager import ModelManager @@ -22,6 +21,7 @@ from .utils.cli import (auto_select_provider, create_language_mapping, get_provider_from_args, parse_args, show_help_and_exit, validate_provider_key) from .utils.config_loader import ConfigLoader +from .utils.cost_estimator import CostEstimator def setup_logging(verbose: int = 0, quiet: bool = False): @@ -53,20 +53,11 @@ def setup_logging(verbose: int = 0, quiet: bool = False): logging.getLogger().setLevel(level) -def initialize_provider(args: Namespace) -> tuple[ProviderClients, ModelProvider, str]: +def get_offline_provider_info(args: Namespace) -> Tuple[Any, Any, str]: """ - Initialize the provider client and determine the appropriate model. - - Args: - args: Command line arguments from argparse - - Returns: - tuple: (provider_clients, provider, model) - - Raises: - SystemExit: If no valid provider can be found or initialized + Get provider and model information without making network calls. """ - # Initialize provider clients + # Initialize provider clients (reads environment variables and args) provider_clients = ProviderClients() api_keys = provider_clients.initialize_clients(args) @@ -82,40 +73,43 @@ def initialize_provider(args: Namespace) -> tuple[ProviderClients, ModelProvider if not validate_provider_key(provider, api_keys): sys.exit(1) + # Determine model - use CLI arg or default + model = args.model + if not model: + model = ModelManager.get_default_model(provider) + + return provider_clients, provider, model + + +def initialize_provider(args: Namespace, provider_clients: Any, provider: Any, model: str) -> Tuple[Any, Any, str]: + """ + Finalize provider initialization with network validation if needed. + """ # Create model manager for model operations model_manager = ModelManager() - # List models if requested and exit + # List models if requested and exit (this makes network calls) if args.list_models: models = model_manager.get_available_models(provider_clients, provider) print(f"Available models for {provider.value}:") - for model in models: - print(f" - {model}") + for m in models: + print(f" - {m}") sys.exit(0) - # Determine appropriate model - model = get_appropriate_model(provider, provider_clients, model_manager, args.model) + # Validate model (this makes network calls) + final_model = get_appropriate_model(provider, provider_clients, model_manager, model) - return provider_clients, provider, model + return provider_clients, provider, final_model def get_appropriate_model( - provider: ModelProvider, - provider_clients: ProviderClients, - model_manager: ModelManager, + provider: Any, + provider_clients: Any, + model_manager: Any, requested_model: Optional[str] ) -> str: """ Get the appropriate model for the provider. - - Args: - provider (ModelProvider): The selected provider - provider_clients (ProviderClients): The initialized provider clients - model_manager (ModelManager): The model manager instance - requested_model (Optional[str]): Model requested by the user - - Returns: - str: The appropriate model ID """ # If a specific model was requested, validate it if requested_model: @@ -143,7 +137,7 @@ def get_appropriate_model( @dataclass class TranslationTask: """Parameters for translation processing.""" - config: TranslationConfig + config: Any folder: str languages: List[str] detail_languages: Dict[str, str] @@ -154,9 +148,6 @@ class TranslationTask: def process_translations(task: TranslationTask): """ Process translations for the given task parameters. - - Args: - task: TranslationTask containing all processing parameters """ # Initialize translation service translation_service = TranslationService(task.config, task.batch_size) @@ -192,12 +183,9 @@ def main(): setup_logging(verbose=args.verbose, quiet=args.quiet) try: - # Initialize provider - provider_clients, provider, model = initialize_provider(args) - - # Get languages - either from args or auto-detect from PO files + # 1. Get languages (Pure logic) try: - respect_gitignore = not args.no_gitignore # Invert the flag + respect_gitignore = not args.no_gitignore languages = LanguageDetector.validate_or_detect_languages( folder=args.folder, lang_arg=args.lang, @@ -208,7 +196,63 @@ def main(): logging.error(str(e)) sys.exit(1) - # Create mapping between language codes and detailed names + # 2. Extract model name for offline estimation (Purely offline) + # Defaults to gpt-4o-mini if not specified. Avoids ModelManager to prevent early side-effects. + estimated_model = args.model or "gpt-4o-mini" + + # 3. Estimate cost if requested (Strictly Offline Terminal Flow) + if args.estimate_cost: + estimation = CostEstimator.estimate_cost( + args.folder, + languages, + estimated_model, + fix_fuzzy=args.fix_fuzzy, + respect_gitignore=respect_gitignore + ) + + print(f"\n{'=' * 40}") + print(" OFFLINE TOKEN ESTIMATION REPORT") + print(f"{'=' * 40}") + print(f"Model: {estimation['model']}") + print(f"Rate: {estimation['rate_info']}") + print(f"Unique msgids: {estimation['unique_texts']:,}") + print(f"Total Tokens: {estimation['total_tokens']:,} (estimated expansion included)") + + if estimation['estimated_cost'] is not None: + print(f"Estimated Cost: ${estimation['estimated_cost']:.4f}") + + print("\nPer-language Breakdown:") + for lang, data in estimation['breakdown'].items(): + cost_str = f"${data['cost']:.4f}" if data['cost'] is not None else "unavailable" + print(f" - {lang:5}: {data['tokens']:8,} tokens | {cost_str}") + + print("\nNote: Cost estimates are approximate and may not reflect current provider pricing.") + print(f"{'=' * 40}\n") + + if estimation['total_tokens'] == 0: + logging.info("No entries require translation.") + return + + if not args.yes: + confirm = input("Run actual translation with these settings? (y/n): ").lower() + if confirm != 'y': + logging.info("Cancelled by user.") + return + + # Issue #57: Hard exit after estimation to ensure zero side effects. + # Estimation is a terminal dry-run. This prevents "Registered provider" logs + # or connection attempts from leaking into the audit output. + print( + "\n[Audit Successful] To proceed with actual translation, " + "run the command again WITHOUT --estimate-cost." + ) + return + + # 4. Initialize providers (Online Execution Path Starts Here) + provider_clients, provider, final_model_id = get_offline_provider_info(args) + provider_clients, provider, model = initialize_provider(args, provider_clients, provider, final_model_id) + + # 5. Create mapping between language codes and detailed names try: detail_languages = create_language_mapping(languages, args.detail_lang) except ValueError as e: diff --git a/python_gpt_po/services/po_file_handler.py b/python_gpt_po/services/po_file_handler.py index f7b0bf0..bd11512 100644 --- a/python_gpt_po/services/po_file_handler.py +++ b/python_gpt_po/services/po_file_handler.py @@ -125,15 +125,20 @@ def get_file_language(po_file_path, po_file, languages, folder_language): if folder_language: for part in po_file_path.split(os.sep): - # Try variants of the folder part - variant_match = POFileHandler._try_language_variants(part, languages) + # Clean part (strip .po if it's the filename) + clean_part = part + if part.endswith('.po'): + clean_part = part[:-3] + + # Try variants of the folder/file part + variant_match = POFileHandler._try_language_variants(clean_part, languages) if variant_match: logging.info("Inferred language for .po file: %s as %s", po_file_path, variant_match) return variant_match # Try base language fallback - if not POFileHandler._should_skip_fallback(part): - norm_part = POFileHandler.normalize_language_code(part) + if not POFileHandler._should_skip_fallback(clean_part): + norm_part = POFileHandler.normalize_language_code(clean_part) if norm_part and norm_part in languages: logging.info("Inferred language for .po file: %s as %s (base of %s)", po_file_path, norm_part, part) diff --git a/python_gpt_po/tests/unit/test_cost_estimator.py b/python_gpt_po/tests/unit/test_cost_estimator.py new file mode 100644 index 0000000..4259fbc --- /dev/null +++ b/python_gpt_po/tests/unit/test_cost_estimator.py @@ -0,0 +1,66 @@ +import os +import shutil +import unittest + +import polib + +from python_gpt_po.utils.cost_estimator import CostEstimator + + +class TestCostEstimatorMinimal(unittest.TestCase): + def setUp(self): + self.test_dir = os.path.abspath("test_cost_est_minimal") + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + os.makedirs(self.test_dir) + + def tearDown(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_minimal_token_math(self): + """Verify tokenize once and multiply by languages.""" + po_path = os.path.join(self.test_dir, "test.po") + po = polib.POFile() + # "Hello" is approx 1-2 tokens. + po.append(polib.POEntry(msgid="Hello", msgstr="")) + po.save(po_path) + + # 1 language + est1 = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini") + t1 = est1['total_tokens'] + + # 3 languages + est3 = CostEstimator.estimate_cost(self.test_dir, ["fr", "es", "de"], "gpt-4o-mini") + t3 = est3['total_tokens'] + + self.assertEqual(t3, t1 * 3) + + def test_pricing_lookup(self): + """Verify dynamic pricing lookup via genai-prices.""" + po_path = os.path.join(self.test_dir, "test.po") + po = polib.POFile() + po.append(polib.POEntry(msgid="Test", msgstr="")) + po.save(po_path) + + # Known model + est_known = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini") + self.assertIsNotNone(est_known['estimated_cost']) + + # Unknown model + est_unknown = CostEstimator.estimate_cost(self.test_dir, ["fr"], "unknown-model") + self.assertIsNone(est_unknown['estimated_cost']) + + def test_zero_work(self): + """Verify zero tokens when everything is translated.""" + po_path = os.path.join(self.test_dir, "test.po") + po = polib.POFile() + po.append(polib.POEntry(msgid="Hello", msgstr="Bonjour")) + po.save(po_path) + + est = CostEstimator.estimate_cost(self.test_dir, ["fr"], "gpt-4o-mini") + self.assertEqual(est['total_tokens'], 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/python_gpt_po/utils/cli.py b/python_gpt_po/utils/cli.py index 17791ba..d3a3c3d 100644 --- a/python_gpt_po/utils/cli.py +++ b/python_gpt_po/utils/cli.py @@ -186,6 +186,16 @@ def parse_args() -> Namespace: metavar="SIZE", help="Number of strings to translate in each batch (default: 50)" ) + advanced_group.add_argument( + "--estimate-cost", + action="store_true", + help="Estimate token usage and cost before translating" + ) + advanced_group.add_argument( + "-y", "--yes", + action="store_true", + help="Skip confirmation prompt when using --estimate-cost" + ) fuzzy_group.add_argument( "--fuzzy", action="store_true", diff --git a/python_gpt_po/utils/cost_estimator.py b/python_gpt_po/utils/cost_estimator.py new file mode 100644 index 0000000..b18ed17 --- /dev/null +++ b/python_gpt_po/utils/cost_estimator.py @@ -0,0 +1,128 @@ +import logging +import os +from typing import Dict, List, Optional, Tuple + +import genai_prices +import polib +from genai_prices.types import Usage + +try: + import tiktoken +except ImportError: + tiktoken = None + + +from .gitignore import create_gitignore_parser +from .po_entry_helpers import is_entry_untranslated + + +class CostEstimator: + # Conservative estimate for translation expansion + OUTPUT_MULTIPLIER = 1.3 + + @classmethod + def estimate_cost( + cls, + folder: str, + languages: List[str], + model: str, + fix_fuzzy: bool = False, + respect_gitignore: bool = True + ) -> Dict: + """ + Estimate token usage and cost for Issue #57. + Algorithm: tokenize(unique msgids once) * count(target languages) * pricing + """ + + # 1. Collect all untranslated msgids once (Offline scan) + unique_msgids = set() + gitignore_parser = create_gitignore_parser(folder, respect_gitignore) + + for root, dirs, files in os.walk(folder): + dirs[:], files = gitignore_parser.filter_walk_results(root, dirs, files) + for file in files: + if file.endswith('.po'): + file_path = os.path.join(root, file) + try: + po = polib.pofile(file_path) + for entry in po: + if is_entry_untranslated(entry) or (fix_fuzzy and 'fuzzy' in entry.flags): + if entry.msgid: + unique_msgids.add(entry.msgid) + except Exception as e: + logging.warning("Error reading %s for estimation: %s", file_path, e) + + # 2. Tokenize the entire source content once + combined_text = "".join(unique_msgids) + source_tokens = cls._get_token_count(combined_text, model) + + # 3. Calculate total tokens (including expansion) + total_input_tokens = source_tokens * len(languages) + total_output_tokens = int(total_input_tokens * cls.OUTPUT_MULTIPLIER) + total_tokens = total_input_tokens + total_output_tokens + + # 4. Lookup price for model + pricing_data = cls._get_pricing(model.lower()) + estimated_cost = None + rate_info = "unavailable" + + if pricing_data: + in_p, out_p = pricing_data + cost_in = (total_input_tokens / 1000) * in_p + cost_out = (total_output_tokens / 1000) * out_p + estimated_cost = cost_in + cost_out + rate_info = f"${in_p:.5f} (in) / ${out_p:.5f} (out) per 1K tokens" + + # 5. Calculate breakdown + breakdown = {} + for lang in languages: + lang_in = source_tokens + lang_out = int(source_tokens * cls.OUTPUT_MULTIPLIER) + lang_total = lang_in + lang_out + lang_cost = None + if pricing_data: + lang_cost = ((lang_in / 1000) * in_p) + ((lang_out / 1000) * out_p) + breakdown[lang] = { + "tokens": lang_total, + "cost": lang_cost + } + + return { + "total_tokens": total_tokens, + "estimated_cost": estimated_cost, + "rate_info": rate_info, + "model": model, + "num_languages": len(languages), + "unique_texts": len(unique_msgids), + "breakdown": breakdown + } + + @staticmethod + def _get_token_count(text: str, model: str) -> int: + """Approximate token count using tiktoken or heuristic.""" + if not text: + return 0 + if tiktoken: + try: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(text)) + except Exception: + pass + # Fallback heuristic + return max(1, len(text) // 4) + + @classmethod + def _get_pricing(cls, model: str) -> Optional[Tuple[float, float]]: + """Lookup pricing for the active model name using genai-prices.""" + try: + + # Exact match only. genai-prices raises LookupError if not found. + # We use a unit usage of 1000 tokens to get the price per 1k tokens. + usage = Usage(input_tokens=1000, output_tokens=1000) + price_detail = genai_prices.calc_price(usage, model) + return float(price_detail.input_price), float(price_detail.output_price) + except (ImportError, LookupError, Exception): + return None diff --git a/requirements.txt b/requirements.txt index e1f8066..98f58b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ requests==2.32.3 responses==0.25.6 isort==6.0.1 tomli==2.2.1 +tiktoken==0.12.0 +genai-prices==0.0.51