From b6d353f0519b86056dbdb97cbdb5531715a1bb97 Mon Sep 17 00:00:00 2001 From: Ishaan Gupta Date: Sun, 15 Feb 2026 01:39:28 +0530 Subject: [PATCH] setup pinecone with central logging,retry,exceptions --- src/config/__init__.py | 72 ++- src/config/logging.py | 434 ++++++++++++++++ src/config/settings.py | 17 + src/storage/__init__.py | 105 ++++ src/storage/base.py | 582 +++++++++++++++++++++ src/storage/pinecone.py | 1088 +++++++++++++++++++++++++++++++++++++++ src/utils/__init__.py | 55 ++ src/utils/exceptions.py | 408 +++++++++++++++ src/utils/retry.py | 416 +++++++++++++++ 9 files changed, 3175 insertions(+), 2 deletions(-) create mode 100644 src/utils/exceptions.py diff --git a/src/config/__init__.py b/src/config/__init__.py index 1675602..165a484 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -1,3 +1,71 @@ -from src.config.settings import Settings +""" +================================================================================ +CONFIG PACKAGE - Centralized Configuration +================================================================================ + +This package provides all configuration-related functionality: +- Settings: Environment variables and configuration values +- Logging: Logging setup and utilities + +USAGE: +------ + # Import settings singleton + from src.config import settings + + # Access configuration values + api_key = settings.pinecone_api_key + model = settings.embedding_model + + # Import logging utilities + from src.config import get_logger, setup_logging + + # Setup logging at app startup + setup_logging() + + # Get logger in any module + logger = get_logger(__name__) + logger.info("Hello from my module") + +================================================================================ +""" + +# Import Settings class and create singleton instance +from .settings import Settings + +# Create the settings singleton +# This is instantiated once when the config package is first imported +# All subsequent imports get the same instance settings = Settings() -__all__ = ["settings", "Settings"] \ No newline at end of file + +# Import logging utilities +from .logging import ( + # Setup function (call once at app startup) + setup_logging, + + # Get logger (call in each module) + get_logger, + + # Runtime log level control + set_log_level, + disable_logging, + enable_logging, + + # Configuration classes + LogConfig, + LogLevel, +) + +__all__ = [ + # Settings + "Settings", + "settings", + + # Logging + "setup_logging", + "get_logger", + "set_log_level", + "disable_logging", + "enable_logging", + "LogConfig", + "LogLevel", +] diff --git a/src/config/logging.py b/src/config/logging.py index e69de29..1e3f70c 100644 --- a/src/config/logging.py +++ b/src/config/logging.py @@ -0,0 +1,434 @@ +""" +================================================================================ +LOGGING CONFIGURATION - Centralized Logging Setup +================================================================================ + +LOG LEVELS (from least to most severe): +--------------------------------------- + DEBUG - Detailed information for debugging + INFO - General operational messages + WARNING - Something unexpected but not critical + ERROR - Something failed but app continues + CRITICAL - App may not be able to continue + +USAGE: +------ + # In any module: + from src.config import get_logger + + logger = get_logger(__name__) + + logger.debug("Detailed debug info") + logger.info("Operation completed successfully") + logger.warning("Something unusual happened") + logger.error("Operation failed", exc_info=True) + logger.critical("Application cannot continue") + +STRUCTURED LOGGING: +------------------- + # Include context in log messages + logger.info( + "Processing request", + extra={ + "user_id": user_id, + "request_id": request_id, + "operation": "add_memory" + } + ) + +================================================================================ +""" + +import logging + +from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler +import sys +import os +from pathlib import Path +from typing import Optional +from enum import Enum +from dataclasses import dataclass, field + + +class LogLevel(str, Enum): + """ + Log level enumeration for type-safe log level configuration. + + Inheriting from str makes these usable as strings: + LogLevel.INFO == "INFO" # True + """ + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +@dataclass +class LogConfig: + """ + Configuration for logging behavior. + + ATTRIBUTES: + ----------- + level : LogLevel + Minimum log level to output (default: INFO) + + format : str + Log message format string + Available fields: %(asctime)s, %(name)s, %(levelname)s, %(message)s, etc. + + date_format : str + Date/time format for %(asctime)s + + enable_console : bool + Output logs to console/stdout (default: True) + + enable_file : bool + Output logs to file (default: False) + + log_file : Optional[str] + Path to log file (default: logs/xmem.log) + + max_file_size : int + Maximum log file size in bytes before rotation (default: 10MB) + + backup_count : int + Number of backup files to keep (default: 5) + + enable_json : bool + Use JSON format for structured logging (default: False) + """ + + level: LogLevel = LogLevel.INFO + + # Standard format: timestamp - logger name - level - message + format: str = "%(asctime)s | %(name)-30s | %(levelname)-8s | %(message)s" + + # ISO 8601 date format + date_format: str = "%Y-%m-%d %H:%M:%S" + + # Console logging (stdout) + enable_console: bool = True + + # File logging + enable_file: bool = False + log_file: Optional[str] = None + max_file_size: int = 10 * 1024 * 1024 # 10 MB + backup_count: int = 5 + + # JSON structured logging (useful for log aggregation) + enable_json: bool = False + + +# Default configuration +DEFAULT_LOG_CONFIG = LogConfig() +class ColoredFormatter(logging.Formatter): + """ + Custom formatter that adds colors to console output. + + Colors make it easier to visually scan logs: + - DEBUG: Cyan + - INFO: Green + - WARNING: Yellow + - ERROR: Red + - CRITICAL: Red background + + Note: Colors only work in terminals that support ANSI codes. + """ + + # ANSI color codes + COLORS = { + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[41m", # Red background + } + RESET = "\033[0m" # Reset to default + + def __init__(self, fmt: str, datefmt: str): + """Initialize with format strings.""" + super().__init__(fmt, datefmt) + + def format(self, record: logging.LogRecord) -> str: + """ + Format the log record with colors. + + Args: + record: The log record to format + + Returns: + Formatted string with ANSI color codes + """ + # Get the base formatted message + message = super().format(record) + + # Add color based on log level + color = self.COLORS.get(record.levelname, "") + if color: + # Wrap the entire message in color codes + return f"{color}{message}{self.RESET}" + + return message + + +class JSONFormatter(logging.Formatter): + """ + Formatter that outputs logs as JSON objects. + + Useful for: + - Log aggregation systems (ELK, Splunk, CloudWatch) + - Structured log analysis + - Machine parsing + + Output format: + { + "timestamp": "2024-01-15T10:30:00", + "level": "INFO", + "logger": "src.storage.pinecone", + "message": "Added 100 vectors", + "extra": { ... } + } + """ + + def format(self, record: logging.LogRecord) -> str: + """ + Format the log record as JSON. + + Args: + record: The log record to format + + Returns: + JSON string representation + """ + import json + from datetime import datetime + + # Build log entry dict + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Add exception info if present + if record.exc_info: + log_entry["exception"] = self.formatException(record.exc_info) + + # Add any extra fields + # Extra fields are added via: logger.info("msg", extra={"key": "value"}) + for key, value in record.__dict__.items(): + if key not in ( + "name", "msg", "args", "created", "filename", "funcName", + "levelname", "levelno", "lineno", "module", "msecs", + "pathname", "process", "processName", "relativeCreated", + "stack_info", "exc_info", "exc_text", "thread", "threadName", + "message" + ): + log_entry[key] = value + + return json.dumps(log_entry) + +def setup_logging( + config: Optional[LogConfig] = None, + level: Optional[LogLevel] = None, + enable_console: Optional[bool] = None, + enable_file: Optional[bool] = None, + log_file: Optional[str] = None, + enable_json: Optional[bool] = None, +) -> logging.Logger: + """ + Set up logging for the entire application. + + Call this once at application startup to configure all loggers. + + Args: + config: LogConfig instance (overrides individual params) + level: Log level override + enable_console: Console output override + enable_file: File output override + log_file: Log file path override + enable_json: JSON format override + + Returns: + Root logger for the application + + USAGE: + ------ + # At application startup (e.g., main.py): + from src.config.logging import setup_logging, LogLevel + + # Basic setup + setup_logging() + + # Custom setup + setup_logging( + level=LogLevel.DEBUG, + enable_file=True, + log_file="logs/xmem.log" + ) + + # Using config object + from src.config.logging import LogConfig + config = LogConfig(level=LogLevel.DEBUG, enable_file=True) + setup_logging(config=config) + """ + + # Resolve configuration + if config is not None: + effective_config = config + else: + # Build config from params with defaults + effective_config = LogConfig( + level=level if level is not None else DEFAULT_LOG_CONFIG.level, + enable_console=enable_console if enable_console is not None else DEFAULT_LOG_CONFIG.enable_console, + enable_file=enable_file if enable_file is not None else DEFAULT_LOG_CONFIG.enable_file, + log_file=log_file if log_file is not None else DEFAULT_LOG_CONFIG.log_file, + enable_json=enable_json if enable_json is not None else DEFAULT_LOG_CONFIG.enable_json, + ) + + # Get the root logger for our application + # Using "src" as the root means all "src.*" loggers inherit this config + root_logger = logging.getLogger("src") + + # Set the log level + root_logger.setLevel(effective_config.level.value) + + # Remove existing handlers (prevent duplicate logs on re-configuration) + root_logger.handlers.clear() + + if effective_config.enable_console: + # StreamHandler outputs to stdout (console) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(effective_config.level.value) + + # Use appropriate formatter + if effective_config.enable_json: + formatter = JSONFormatter() + else: + # Use colored formatter for console + formatter = ColoredFormatter( + fmt=effective_config.format, + datefmt=effective_config.date_format + ) + + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + if effective_config.enable_file: + # Determine log file path + log_file_path = effective_config.log_file or "logs/xmem.log" + + # Create log directory if it doesn't exist + log_dir = Path(log_file_path).parent + log_dir.mkdir(parents=True, exist_ok=True) + + # RotatingFileHandler: Rotates log files when they reach max size + file_handler = RotatingFileHandler( + filename=log_file_path, + maxBytes=effective_config.max_file_size, + backupCount=effective_config.backup_count, + encoding="utf-8" + ) + file_handler.setLevel(effective_config.level.value) + + # Use appropriate formatter (JSON for files is common for log aggregation) + if effective_config.enable_json: + formatter = JSONFormatter() + else: + formatter = logging.Formatter( + fmt=effective_config.format, + datefmt=effective_config.date_format + ) + + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + + # Log that logging is configured + root_logger.debug( + f"Logging configured: level={effective_config.level.value}, " + f"console={effective_config.enable_console}, " + f"file={effective_config.enable_file}" + ) + + return root_logger + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger for the given module name. + + This is the primary way to get a logger in your modules. + + Args: + name: Logger name (typically __name__ for module name) + + Returns: + Logger instance + + USAGE: + ------ + # At the top of each module: + from src.config import get_logger + + logger = get_logger(__name__) + + # Then use throughout the module: + logger.info("Operation completed") + logger.error("Something failed", exc_info=True) + """ + return logging.getLogger(name) + +def set_log_level(level: LogLevel) -> None: + """ + Change the log level at runtime. + + Useful for debugging in production without restart. + + Args: + level: New log level + + Usage: + from src.config.logging import set_log_level, LogLevel + + # Enable debug logging temporarily + set_log_level(LogLevel.DEBUG) + + # Back to normal + set_log_level(LogLevel.INFO) + """ + root_logger = logging.getLogger("src") + root_logger.setLevel(level.value) + + # Update all handlers + for handler in root_logger.handlers: + handler.setLevel(level.value) + + root_logger.info(f"Log level changed to {level.value}") + + +def disable_logging() -> None: + """ + Disable all logging (useful for tests). + + Usage: + from src.config.logging import disable_logging + + disable_logging() # No more log output + """ + logging.getLogger("src").disabled = True + + +def enable_logging() -> None: + """ + Re-enable logging after disabling. + + Usage: + from src.config.logging import enable_logging + + enable_logging() # Logs work again + """ + logging.getLogger("src").disabled = False diff --git a/src/config/settings.py b/src/config/settings.py index f6f46f3..2fadd1a 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -62,6 +62,23 @@ class Settings(BaseSettings): default="default", description="Pinecone namespace for organizing vectors" ) + pinecone_dimension: int = Field( + default=384, + description="Pinecone dimension for embeddings" + ) + pinecone_metric: str = Field( + default="cosine", + description="Pinecone metric for embeddings" + ) + pinecone_cloud: str = Field( + default="aws", + description="Pinecone cloud for embeddings" + ) + pinecone_region: str = Field( + default="us-east-1", + description="Pinecone region for embeddings" + ) + embedding_model: str = Field( default="all-MiniLM-L6-v2", description="Sentence transformer model for embeddings" diff --git a/src/storage/__init__.py b/src/storage/__init__.py index e69de29..b9ba654 100644 --- a/src/storage/__init__.py +++ b/src/storage/__init__.py @@ -0,0 +1,105 @@ +""" +================================================================================ +STORAGE PACKAGE - Vector Store Implementations for XMem +================================================================================ + +This package provides vector store functionality using the Strategy Pattern. +The base class defines the interface; concrete classes implement specific backends. + +ARCHITECTURE: +------------- + BaseVectorStore (Abstract Interface) + │ + ├── PineconeVectorStore (Managed cloud service) + ├── QdrantVectorStore (Future: open-source) + ├── ChromaVectorStore (Future: embedded) + └── MockVectorStore (Future: for testing) + +USAGE: +------ + # Import the interface and a specific implementation + from src.storage import BaseVectorStore, PineconeVectorStore + + # Create a store instance (uses settings by default) + store = PineconeVectorStore() + + # Or with custom configuration + store = PineconeVectorStore( + index_name="my-index", + namespace="production" + ) + + # Use the store + ids = store.add(texts=["Hello"], embeddings=[[0.1, 0.2, ...]]) + results = store.search(query_embedding=[0.1, 0.2, ...], top_k=5) + +DEPENDENCY INJECTION: +--------------------- + # Your business logic depends on the INTERFACE, not concrete implementation + def process_memories(store: BaseVectorStore): + # This works with ANY vector store implementation + results = store.search(query_embedding, top_k=10) + return results + + # At app startup, inject the concrete implementation + store = PineconeVectorStore() # or QdrantVectorStore() in the future + process_memories(store) + +================================================================================ +""" + +# ============================================================================ +# IMPORTS FROM SUBMODULES +# ============================================================================ + +# Import base class and data structures +from .base import ( + # Abstract base class (interface) + BaseVectorStore, + + # Data classes for structured data + SearchResult, + VectorDocument, + IndexStats, + + # Enums + DistanceMetric, +) + +# Import concrete implementations +from .pinecone import PineconeVectorStore + +# Re-export exceptions for convenience +# (They're also available from src.utils.exceptions) +from ..utils.exceptions import ( + VectorStoreError, + VectorStoreConnectionError, + VectorStoreValidationError, + VectorNotFoundError, +) + +# ============================================================================ +# __all__ - Defines what's exported with: from src.storage import * +# ============================================================================ + +__all__ = [ + # Base class (interface) + "BaseVectorStore", + + # Data classes + "SearchResult", + "VectorDocument", + "IndexStats", + + # Enums + "DistanceMetric", + + # Concrete implementations + "PineconeVectorStore", + + # Exceptions (for convenience) + "VectorStoreError", + "VectorStoreConnectionError", + "VectorStoreValidationError", + "VectorNotFoundError", +] diff --git a/src/storage/base.py b/src/storage/base.py index e69de29..6eb5a8c 100644 --- a/src/storage/base.py +++ b/src/storage/base.py @@ -0,0 +1,582 @@ +""" +================================================================================ +BASE VECTOR STORE - ABSTRACT BASE CLASS (Interface Definition) +================================================================================ + +WHY THIS FILE EXISTS (Design Pattern Explanation): +-------------------------------------------------- +This file defines an INTERFACE (contract) that all vector store implementations +must follow. This is called the "Strategy Pattern" or "Dependency Inversion Principle". + +WHAT IS AN ABSTRACT BASE CLASS (ABC)? +------------------------------------- +- An ABC is a class that CANNOT be instantiated directly +- It defines METHOD SIGNATURES that child classes MUST implement +- If a child class doesn't implement all @abstractmethod methods, Python raises TypeError + +WHY DEFINE add() HERE AND ALSO IN pinecone.py? +---------------------------------------------- +- base.py: Defines WHAT methods exist and their signatures (the contract/interface) +- pinecone.py: Defines HOW those methods actually work with Pinecone specifically + +BENEFITS OF THIS PATTERN: +------------------------- +1. SWAPPABILITY: Switch from Pinecone to Qdrant/Chroma without changing app code +2. TESTABILITY: Mock the interface in unit tests without real database +3. TYPE SAFETY: IDE autocomplete works; type checkers catch errors +4. DOCUMENTATION: Interface documents expected behavior for all implementations +5. LOOSE COUPLING: App code depends on abstract interface, not concrete implementation + +EXAMPLE USAGE: +-------------- + # In your app code, depend on the INTERFACE, not the concrete class: + def process_memories(store: BaseVectorStore): # <- Takes ANY vector store + store.add(texts, embeddings) # Works with Pinecone, Qdrant, Chroma, etc. + + # At startup, inject the concrete implementation: + store = PineconeVectorStore() # or QdrantVectorStore() or ChromaVectorStore() + process_memories(store) + +================================================================================ +""" + +# ============================================================================ +# IMPORTS +# ============================================================================ + +# abc module: Provides Abstract Base Class functionality +# ABC: Base class that makes this class abstract (cannot be instantiated) +# abstractmethod: Decorator that marks methods that MUST be implemented by children +from abc import ABC, abstractmethod + +# typing module: Provides type hints for better code documentation and IDE support +# List: Type hint for list, e.g., List[str] means "list of strings" +# Dict: Type hint for dictionary, e.g., Dict[str, Any] means "dict with string keys" +# Any: Type hint meaning "any type is allowed" +# Optional: Type hint meaning "this value can be None", e.g., Optional[str] = str | None +# Tuple: Type hint for tuple, e.g., Tuple[str, int] means "(string, integer)" +from typing import List, Dict, Any, Optional, Tuple + +# dataclasses module: Provides @dataclass decorator for automatic __init__, __repr__, etc. +# dataclass: Decorator that auto-generates boilerplate code for data-holding classes +# field: Function to customize individual fields in a dataclass +from dataclasses import dataclass, field + +# enum module: Provides Enum class for creating enumerated constants +# Enum: Base class for creating enumeration types (fixed set of named values) +from enum import Enum +from ..config import get_logger +from ..utils.exceptions import ( + VectorStoreError, + VectorStoreConnectionError, + VectorStoreValidationError, + VectorNotFoundError, +) + +logger = get_logger(__name__) +class DistanceMetric(str, Enum): + COSINE = "cosine" # Cosine similarity: measures angle between vectors (most common) + EUCLIDEAN = "euclidean" # Euclidean distance: straight-line distance in vector space + DOT_PRODUCT = "dotproduct" # Dot product: sum of element-wise products (faster but requires normalized vectors) + + +@dataclass +class SearchResult: + """ + Represents a single search result from the vector store. + + WHAT IS @dataclass? + ------------------- + - A decorator that auto-generates __init__, __repr__, __eq__ methods + - You just define attributes with type hints; Python creates the boilerplate + + WITHOUT @dataclass (you'd have to write): + ------------------------------------------ + class SearchResult: + def __init__(self, id, content, score, metadata): + self.id = id + self.content = content + self.score = score + self.metadata = metadata + def __repr__(self): + return f"SearchResult(id={self.id}, ...)" + def __eq__(self, other): + return self.id == other.id and ... + + WITH @dataclass (all auto-generated): + ------------------------------------- + @dataclass + class SearchResult: + id: str + content: str + score: float + metadata: Dict[str, Any] + + USAGE: + ------ + result = SearchResult(id="123", content="hello", score=0.95, metadata={}) + print(result) # SearchResult(id='123', content='hello', score=0.95, metadata={}) + print(result.score) # 0.95 + """ + + # Unique identifier of the document/vector + id: str + + # The actual text content that was embedded + content: str + + # Similarity score (0.0 to 1.0 for cosine, higher = more similar) + score: float + + # Additional metadata stored with the vector (user_id, timestamp, tags, etc.) + # field(default_factory=dict) means: if not provided, create a new empty dict + # WHY default_factory? Because mutable defaults (like {}) are shared across instances! + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """ + Called automatically after __init__ completes. + + Use for validation or transformation of input values. + """ + # Validate score is in reasonable range + if not 0.0 <= self.score <= 1.0: + logger.warning( + f"SearchResult score {self.score} is outside normal range [0, 1]. " + "This may indicate a different similarity metric is being used." + ) + + +@dataclass +class VectorDocument: + """ + Represents a document with its embedding, used for batch operations. + + This provides a clean structure for adding multiple documents at once, + rather than passing parallel lists (texts, embeddings, ids, metadata). + """ + + # The text content to store + text: str + + # The embedding vector (list of floats representing the text in vector space) + embedding: List[float] + + # Optional: provide your own ID, or let the store generate one + id: Optional[str] = None + + # Optional: additional metadata to store with the vector + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class IndexStats: + """ + Statistics about a vector store index. + + Useful for monitoring, debugging, and capacity planning. + """ + + # Total number of vectors in the index + total_vector_count: int + + # Dimension of vectors in this index + dimension: int + + # Count of vectors per namespace (for stores that support namespaces) + namespaces: Dict[str, int] = field(default_factory=dict) + + # Index fullness percentage (for stores with capacity limits) + fullness_percentage: Optional[float] = None + +class BaseVectorStore(ABC): + """ + Abstract base class defining the interface for all vector store implementations. + + ================================================================================ + DESIGN PATTERN: Strategy Pattern / Dependency Inversion Principle + ================================================================================ + + This class defines WHAT operations a vector store must support, not HOW they work. + Concrete implementations (PineconeVectorStore, QdrantVectorStore, etc.) provide the HOW. + + WHY IS THIS GOOD PRACTICE? + -------------------------- + + 1. LOOSE COUPLING: + Your application code depends on BaseVectorStore (abstract), not PineconeVectorStore. + If you switch from Pinecone to Qdrant, only the instantiation changes. + + # BAD (tight coupling): + def save_memory(store: PineconeVectorStore): # Depends on CONCRETE class + store.add(...) # Can ONLY work with Pinecone + + # GOOD (loose coupling): + def save_memory(store: BaseVectorStore): # Depends on ABSTRACT interface + store.add(...) # Works with ANY implementation + + 2. TESTABILITY: + You can create a MockVectorStore for unit tests without needing real Pinecone. + + class MockVectorStore(BaseVectorStore): + def add(self, ...): return ["mock-id"] + def search(self, ...): return [SearchResult(...)] + + # In tests: + store = MockVectorStore() + result = my_function(store) # Tests without network calls! + + 3. DOCUMENTATION: + This class documents the expected behavior of ALL vector stores. + Anyone implementing a new backend knows exactly what methods to implement. + + 4. TYPE CHECKING: + IDEs and type checkers (mypy, pyright) can verify you're using valid methods. + If you call store.invalid_method(), the type checker will catch it. + + SUPPORTED IMPLEMENTATIONS: + -------------------------- + - PineconeVectorStore: Managed vector database, serverless, scalable + - QdrantVectorStore: Open-source, self-hosted or cloud + - ChromaVectorStore: Lightweight, embedded, great for development + - PGVectorStore: PostgreSQL extension, good if you already use Postgres + - WeaviateVectorStore: Open-source, GraphQL API + + ================================================================================ + """ + + @abstractmethod + def add( + self, + texts: List[str], + embeddings: List[List[float]], + ids: Optional[List[str]] = None, + metadata: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + Add texts with their corresponding embeddings to the vector store. + + WHAT IS @abstractmethod? + ------------------------ + - Decorator from abc module that marks this method as REQUIRED + - Any class inheriting from BaseVectorStore MUST implement this method + - If not implemented, Python raises TypeError when you try to instantiate + + WHY pass AT THE END? + -------------------- + - Abstract methods have no implementation (the base class doesn't know HOW) + - 'pass' is a no-op statement that satisfies Python's requirement for a body + - The docstring serves as documentation for implementers + + Args: + texts: List of text strings to store + Example: ["Hello world", "How are you?"] + + embeddings: List of embedding vectors, one per text + Each embedding is a list of floats (typically 384-1536 dimensions) + Example: [[0.1, 0.2, ...], [0.3, 0.4, ...]] + + ids: Optional list of unique IDs for each text + If not provided, the implementation should generate UUIDs + Example: ["doc-001", "doc-002"] + + metadata: Optional list of metadata dicts, one per text + Example: [{"user_id": "u1", "type": "note"}, {"user_id": "u1"}] + + Returns: + List of IDs for the added documents (generated if not provided) + + Raises: + VectorStoreValidationError: If inputs are invalid (dimension mismatch, empty lists) + VectorStoreConnectionError: If connection to store fails + VectorStoreError: For other storage-related errors + + Example: + ids = store.add( + texts=["Memory 1", "Memory 2"], + embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]], + metadata=[{"user": "alice"}, {"user": "alice"}] + ) + print(ids) # ["abc-123", "def-456"] + """ + pass # No implementation - concrete classes must provide this + + @abstractmethod + def search( + self, + query_embedding: List[float], + top_k: int = 5, + filters: Optional[Dict[str, Any]] = None + ) -> List[SearchResult]: + """ + Search for similar documents using vector similarity. + + This is the core retrieval operation - given a query vector, + find the most similar stored vectors. + + Args: + query_embedding: The embedding vector of the search query + Must have same dimension as stored embeddings + Example: [0.1, 0.2, 0.3, ...] + + top_k: Number of results to return (default: 5) + Higher values = more results but slower + Example: top_k=10 returns 10 most similar documents + + filters: Optional metadata filters to narrow search + Format varies by implementation, typically key-value pairs + Example: {"user_id": "alice", "type": "note"} + + Returns: + List of SearchResult objects, sorted by similarity (highest first) + Each result contains: id, content, score, metadata + + Raises: + VectorStoreValidationError: If query embedding dimension is wrong + VectorStoreConnectionError: If connection to store fails + VectorStoreError: For other search-related errors + + Example: + results = store.search( + query_embedding=[0.1, 0.2, ...], + top_k=3, + filters={"user_id": "alice"} + ) + for r in results: + print(f"{r.content} (score: {r.score})") + """ + pass + + @abstractmethod + def update( + self, + id: str, + text: Optional[str] = None, + embedding: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> bool: + """ + Update an existing document in the vector store. + + Allows partial updates - only provided fields are updated. + + Args: + id: The unique ID of the document to update + + text: New text content (optional) + If provided, should also provide new embedding + + embedding: New embedding vector (optional) + Required if text is changed + + metadata: New/updated metadata fields (optional) + Merged with existing metadata + + Returns: + True if update was successful, False if document not found + + Raises: + VectorNotFoundError: If the document ID doesn't exist (alternative to returning False) + VectorStoreValidationError: If embedding dimension is wrong + VectorStoreError: For other update-related errors + + Example: + success = store.update( + id="doc-123", + text="Updated content", + embedding=[0.5, 0.6, ...], + metadata={"edited": True, "edit_time": "2024-01-15"} + ) + """ + pass + + @abstractmethod + def delete(self, ids: List[str]) -> bool: + """ + Delete documents from the vector store by their IDs. + + Args: + ids: List of document IDs to delete + Example: ["doc-001", "doc-002"] + + Returns: + True if deletion was successful + (Note: returns True even if some IDs didn't exist - idempotent) + + Raises: + VectorStoreConnectionError: If connection to store fails + VectorStoreError: For other deletion-related errors + + Example: + store.delete(ids=["doc-001", "doc-002"]) + """ + pass + + @abstractmethod + def get(self, ids: List[str]) -> List[Dict[str, Any]]: + """ + Retrieve documents by their IDs (exact lookup, not similarity search). + + Args: + ids: List of document IDs to retrieve + + Returns: + List of document dictionaries, each containing: + - id: The document ID + - content: The text content + - metadata: The metadata dict + - embedding: The embedding vector (optional, depends on implementation) + + Raises: + VectorStoreConnectionError: If connection to store fails + VectorStoreError: For other retrieval-related errors + + Example: + docs = store.get(ids=["doc-001", "doc-002"]) + for doc in docs: + print(f"{doc['id']}: {doc['content']}") + """ + pass + + # ======================================================================== + # OPTIONAL ABSTRACT METHODS - Override if your backend supports them + # ======================================================================== + + @abstractmethod + def health_check(self) -> bool: + """ + Check if the vector store connection is healthy. + + Use for monitoring, readiness probes, and connection validation. + + Returns: + True if the store is accessible and working + False if there are connection issues + + Example: + if not store.health_check(): + logger.error("Vector store is unhealthy!") + notify_ops_team() + """ + pass + + @abstractmethod + def get_stats(self) -> IndexStats: + """ + Get statistics about the vector store index. + + Returns: + IndexStats object with vector count, dimension, etc. + + Example: + stats = store.get_stats() + print(f"Index has {stats.total_vector_count} vectors") + """ + pass + + # ======================================================================== + # CONCRETE HELPER METHODS - Shared implementation for all stores + # ======================================================================== + + def validate_embeddings( + self, + embeddings: List[List[float]], + expected_dimension: Optional[int] = None + ) -> Tuple[bool, Optional[str]]: + """ + Validate embedding format and dimensions. + + This is a CONCRETE method (not abstract) - it has a real implementation + that all child classes can use without overriding. + + Args: + embeddings: List of embedding vectors to validate + expected_dimension: Expected vector dimension (if known) + + Returns: + Tuple of (is_valid: bool, error_message: Optional[str]) + + Example: + is_valid, error = store.validate_embeddings(embeddings, expected_dimension=384) + if not is_valid: + raise VectorStoreValidationError(error) + """ + # Check if embeddings list is empty + if not embeddings: + return False, "Embeddings list cannot be empty" + + # Get dimension from first embedding + first_dim = len(embeddings[0]) + + # Check all embeddings have same dimension + for i, emb in enumerate(embeddings): + if len(emb) != first_dim: + return False, f"Embedding {i} has dimension {len(emb)}, expected {first_dim}" + + # Check against expected dimension if provided + if expected_dimension is not None and first_dim != expected_dimension: + return False, f"Embedding dimension {first_dim} doesn't match expected {expected_dimension}" + + return True, None + + def validate_inputs( + self, + texts: List[str], + embeddings: List[List[float]], + ids: Optional[List[str]] = None, + metadata: Optional[List[Dict[str, Any]]] = None + ) -> None: + """ + Validate inputs for add operation. + + Raises VectorStoreValidationError if validation fails. + + Args: + texts: List of text strings + embeddings: List of embedding vectors + ids: Optional list of IDs + metadata: Optional list of metadata dicts + + Raises: + VectorStoreValidationError: If any validation check fails + """ + # Check texts is not empty + if not texts: + raise VectorStoreValidationError( + "Texts list cannot be empty", + operation="add" + ) + + # Check embeddings is not empty + if not embeddings: + raise VectorStoreValidationError( + "Embeddings list cannot be empty", + operation="add" + ) + + # Check lengths match + if len(texts) != len(embeddings): + raise VectorStoreValidationError( + f"Length mismatch: {len(texts)} texts vs {len(embeddings)} embeddings", + operation="add", + details={"texts_count": len(texts), "embeddings_count": len(embeddings)} + ) + + # Check IDs length if provided + if ids is not None and len(ids) != len(texts): + raise VectorStoreValidationError( + f"Length mismatch: {len(texts)} texts vs {len(ids)} ids", + operation="add" + ) + + # Check metadata length if provided + if metadata is not None and len(metadata) != len(texts): + raise VectorStoreValidationError( + f"Length mismatch: {len(texts)} texts vs {len(metadata)} metadata entries", + operation="add" + ) + + # Validate embedding dimensions are consistent + is_valid, error = self.validate_embeddings(embeddings) + if not is_valid: + raise VectorStoreValidationError(error, operation="add") diff --git a/src/storage/pinecone.py b/src/storage/pinecone.py index e69de29..8d9c29c 100644 --- a/src/storage/pinecone.py +++ b/src/storage/pinecone.py @@ -0,0 +1,1088 @@ +""" +================================================================================ +PINECONE VECTOR STORE - CONCRETE IMPLEMENTATION +================================================================================ + +This file provides the CONCRETE implementation of BaseVectorStore for Pinecone. +While base.py defines WHAT methods exist, this file defines HOW they work with Pinecone. + + +RELATIONSHIP TO base.py: +------------------------ +- base.py: Defines the interface (abstract methods) +- pinecone.py: Implements the interface for Pinecone specifically + + BaseVectorStore (Abstract) + │ + ├── PineconeVectorStore (this file) + ├── QdrantVectorStore (future) + ├── ChromaVectorStore (future) + └── PGVectorStore (future) + +================================================================================ +""" +from typing import List, Dict, Any, Optional, Final +import uuid + +# ---------------------------------------------------------------------------- +# THIRD-PARTY IMPORTS (with graceful degradation) +# ---------------------------------------------------------------------------- + +# Try to import Pinecone SDK +# If not installed, set flag to False so we can raise helpful error later +try: + # Pinecone: Main client class for interacting with Pinecone service + # ServerlessSpec: Configuration for serverless index deployment + from pinecone import Pinecone, ServerlessSpec + + # Set flag indicating Pinecone is available + PINECONE_AVAILABLE: Final[bool] = True + +except ImportError: + # ImportError: Raised when an import statement fails to find the module + # This happens if pinecone-client is not installed + PINECONE_AVAILABLE: Final[bool] = False + +from .base import ( + BaseVectorStore, + SearchResult, + IndexStats, +) +from ..config import settings, get_logger +from ..utils.exceptions import ( + VectorStoreError, + VectorStoreConnectionError, + VectorStoreValidationError, + VectorNotFoundError, +) +from ..utils.retry import with_retry, RetryConfig +logger = get_logger(__name__) + +# Final[int] = type hint indicating this should never be reassigned +PINECONE_BATCH_SIZE: Final[int] = 100 # Recommended by Pinecone for upsert operations +PINECONE_RETRY_CONFIG: Final[RetryConfig] = RetryConfig( + max_retries=3, # Retry up to 3 times + delay=1.0, # Start with 1 second delay + backoff_multiplier=2.0, # Double the delay each retry: 1s, 2s, 4s + max_delay=30.0, # Cap at 30 seconds +) + +class PineconeVectorStore(BaseVectorStore): + """ + Pinecone implementation of the BaseVectorStore interface. + + This class provides all the operations defined in BaseVectorStore, + implemented using the Pinecone Python SDK. + + CLASS STRUCTURE: + ---------------- + class PineconeVectorStore(BaseVectorStore): + │ + ├── __init__() # Constructor - sets up Pinecone connection + │ + ├── CRUD Operations (from BaseVectorStore): + │ ├── add() # Add vectors to the index + │ ├── search() # Find similar vectors + │ ├── update() # Update existing vector + │ ├── delete() # Remove vectors + │ └── get() # Retrieve vectors by ID + │ + ├── Management Operations (from BaseVectorStore): + │ ├── health_check() # Check connection status + │ └── get_stats() # Get index statistics + │ + └── Pinecone-Specific Operations: + ├── count() # Count vectors in namespace + ├── clear() # Clear all vectors in namespace + ├── delete_index() # Delete entire index + └── _build_filter() # Helper for filter construction + + USAGE: + ------ + # Option 1: Use settings defaults + store = PineconeVectorStore() + + # Option 2: Override specific settings + store = PineconeVectorStore( + index_name="my-custom-index", + namespace="production" + ) + + # Add vectors + ids = store.add( + texts=["Hello world", "How are you?"], + embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]] + ) + + # Search + results = store.search( + query_embedding=[0.15, 0.25, ...], + top_k=5, + filters={"user_id": "alice"} + ) + """ + + # ======================================================================== + # CONSTRUCTOR + # ======================================================================== + + def __init__( + self, + api_key: Optional[str] = None, + index_name: Optional[str] = None, + dimension: Optional[int] = None, + metric: Optional[str] = None, + cloud: Optional[str] = None, + region: Optional[str] = None, + namespace: Optional[str] = None, + create_if_not_exists: bool = True + ) -> None: + """ + Initialize the Pinecone Vector Store. + + WHAT IS __init__? + ----------------- + - The constructor method, called when you create a new instance + - PineconeVectorStore() calls this method automatically + - self: Reference to the instance being created + + PARAMETER PATTERN: Optional[T] = None WITH SETTINGS FALLBACK + ------------------------------------------------------------- + Each parameter can be: + - Explicitly provided: PineconeVectorStore(api_key="xxx") + - Omitted (defaults to None): Falls back to settings + + This allows: + - Easy default usage: PineconeVectorStore() + - Override for testing: PineconeVectorStore(api_key="test-key") + - Partial override: PineconeVectorStore(namespace="test") + + Args: + api_key: Pinecone API key (falls back to settings.pinecone_api_key) + index_name: Name of the index (falls back to settings.pinecone_index_name) + dimension: Vector dimension (falls back to settings.pinecone_dimension) + metric: Distance metric (falls back to settings.pinecone_metric) + cloud: Cloud provider (falls back to settings.pinecone_cloud) + region: Cloud region (falls back to settings.pinecone_region) + namespace: Namespace for vector isolation (falls back to settings.pinecone_namespace) + create_if_not_exists: If True, create index if it doesn't exist + + Raises: + ImportError: If pinecone-client is not installed + VectorStoreConnectionError: If connection to Pinecone fails + + Example: + # Use all defaults from settings + store = PineconeVectorStore() + + # Override specific settings + store = PineconeVectorStore( + namespace="testing", + create_if_not_exists=False + ) + """ + + # -------------------------------------------------------------------- + # STEP 1: Check if Pinecone SDK is available + # -------------------------------------------------------------------- + + # If the import at the top failed, PINECONE_AVAILABLE will be False + if not PINECONE_AVAILABLE: + # Raise ImportError with helpful installation instructions + raise ImportError( + "Pinecone SDK is not installed. " + "Install it with: pip install pinecone-client" + ) + + # -------------------------------------------------------------------- + # STEP 2: Resolve configuration (explicit params > settings > defaults) + # -------------------------------------------------------------------- + + # Pattern: value = explicit_param if explicit_param is not None else settings.value + # This allows overriding settings for testing or special cases + + # 'or' operator: Uses right side if left side is falsy (None, "", 0, False) + # self._api_key: Leading underscore indicates "internal" attribute + self._api_key: str = api_key or settings.pinecone_api_key + self._index_name: str = index_name or settings.pinecone_index_name + self._dimension: int = dimension or settings.pinecone_dimension + self._metric: str = metric or settings.pinecone_metric + self._cloud: str = cloud or settings.pinecone_cloud + self._region: str = region or settings.pinecone_region + self._namespace: str = namespace or settings.pinecone_namespace + + if not self._api_key: + raise VectorStoreValidationError( + "Pinecone API key is required. " + "Set PINECONE_API_KEY environment variable or pass api_key parameter.", + operation="init" + ) + + logger.info( + f"Initializing PineconeVectorStore: " + f"index={self._index_name}, namespace={self._namespace}, dimension={self._dimension}" + ) + + try: + # Create Pinecone client instance + # self._pc: The Pinecone client object for API calls + self._pc: Pinecone = Pinecone(api_key=self._api_key) + + except Exception as e: + # Wrap any connection errors in our custom exception type + raise VectorStoreConnectionError( + f"Failed to initialize Pinecone client: {e}", + operation="init", + details={"error_type": type(e).__name__} + ) + + if create_if_not_exists: + # Get list of existing index names using list comprehension + # [idx.name for idx in ...] creates a list of just the names + existing_indexes: List[str] = [ + idx.name for idx in self._pc.list_indexes() + ] + + # Check if our index exists + if self._index_name not in existing_indexes: + # Log index creation + logger.info( + f"Creating new Pinecone index: {self._index_name} " + f"(dimension={self._dimension}, metric={self._metric})" + ) + + try: + # Create the index with serverless spec + self._pc.create_index( + name=self._index_name, + dimension=self._dimension, + metric=self._metric, + spec=ServerlessSpec( + cloud=self._cloud, + region=self._region + ) + ) + + logger.info(f"Successfully created index: {self._index_name}") + + except Exception as e: + raise VectorStoreConnectionError( + f"Failed to create Pinecone index: {e}", + operation="init", + details={ + "index_name": self._index_name, + "dimension": self._dimension + } + ) + + try: + # Get an Index object for operations (upsert, query, etc.) + # self._index: The Index object used for all vector operations + self._index = self._pc.Index(self._index_name) + + logger.info(f"Connected to Pinecone index: {self._index_name}") + + except Exception as e: + raise VectorStoreConnectionError( + f"Failed to connect to Pinecone index: {e}", + operation="init", + details={"index_name": self._index_name} + ) + + # ======================================================================== + # PROPERTIES - Controlled access to internal state + # ======================================================================== + + @property + def index_name(self) -> str: + """ + Get the name of the Pinecone index. + + WHAT IS @property? + ------------------ + - A decorator that makes a method accessible like an attribute + - store.index_name instead of store.index_name() + - Allows adding logic to attribute access (validation, logging, etc.) + + Returns: + The index name as a string + """ + return self._index_name + + @property + def namespace(self) -> str: + """Get the current namespace.""" + return self._namespace + + @property + def dimension(self) -> int: + """Get the embedding dimension for this index.""" + return self._dimension + + # ======================================================================== + # CRUD OPERATIONS - Core data manipulation methods + # ======================================================================== + + @with_retry(config=PINECONE_RETRY_CONFIG) + def add( + self, + texts: List[str], + embeddings: List[List[float]], + ids: Optional[List[str]] = None, + metadata: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + Add texts with their embeddings to the Pinecone index. + + This implements the abstract add() method from BaseVectorStore. + + PINECONE SPECIFICS: + ------------------- + - Pinecone stores vectors with metadata + - We store the text content in metadata under "content" key + - Vectors are batched in groups of 100 (Pinecone recommendation) + - upsert() will update if ID exists, insert if new + + Args: + texts: List of text strings to store + embeddings: Corresponding embedding vectors + ids: Optional custom IDs (generated if not provided) + metadata: Optional metadata for each text + + Returns: + List of IDs for the added/updated vectors + + Raises: + VectorStoreValidationError: If input validation fails + VectorStoreConnectionError: If Pinecone API call fails + + Example: + ids = store.add( + texts=["Memory one", "Memory two"], + embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]], + metadata=[{"user": "alice"}, {"user": "alice"}] + ) + """ + + # -------------------------------------------------------------------- + # STEP 1: Validate inputs using inherited method from BaseVectorStore + # -------------------------------------------------------------------- + + # This method is defined in base.py and checks: + # - Non-empty inputs + # - Matching lengths + # - Consistent embedding dimensions + self.validate_inputs(texts, embeddings, ids, metadata) + + # Additional validation: check embedding dimension matches index + is_valid, error = self.validate_embeddings( + embeddings, + expected_dimension=self._dimension + ) + if not is_valid: + raise VectorStoreValidationError(error, operation="add") + + # -------------------------------------------------------------------- + # STEP 2: Generate IDs if not provided + # -------------------------------------------------------------------- + + # List comprehension: [expression for variable in iterable] + # str(uuid.uuid4()) generates a random UUID string + # _ is used when we don't need the loop variable + if ids is None: + ids = [str(uuid.uuid4()) for _ in texts] + + # -------------------------------------------------------------------- + # STEP 3: Initialize metadata if not provided + # -------------------------------------------------------------------- + + if metadata is None: + # Create list of empty dicts, one per text + metadata = [{} for _ in texts] + + # -------------------------------------------------------------------- + # STEP 4: Prepare vectors for Pinecone upsert + # -------------------------------------------------------------------- + + # Build list of vector objects for Pinecone + vectors: List[Dict[str, Any]] = [] + + # zip(): Combines multiple iterables element-by-element + # zip([1,2], [a,b], [x,y]) -> [(1,a,x), (2,b,y)] + # enumerate(): Adds index to each element + # enumerate(zip(...)) -> [(0, (1,a,x)), (1, (2,b,y))] + for i, (text, embedding, vec_id, meta) in enumerate( + zip(texts, embeddings, ids, metadata) + ): + # Create a copy of metadata and add the text content + # {**meta, "content": text} = dict unpacking + new key + # This creates a new dict with all of meta's keys plus "content" + meta_with_content: Dict[str, Any] = { + **meta, # Unpack all existing metadata + "content": text # Add text as "content" field + } + + # Build the vector object in Pinecone's expected format + vectors.append({ + "id": vec_id, # Unique identifier + "values": embedding, # The embedding vector + "metadata": meta_with_content # Metadata including content + }) + + # -------------------------------------------------------------------- + # STEP 5: Upsert vectors in batches + # -------------------------------------------------------------------- + + # Log the operation + logger.info(f"Adding {len(vectors)} vectors to namespace '{self._namespace}'") + + # Batch upsert: Pinecone recommends batches of 100 for performance + # range(start, stop, step): 0, 100, 200, ... + for i in range(0, len(vectors), PINECONE_BATCH_SIZE): + # Slice the vectors list: [i:i+100] + batch: List[Dict[str, Any]] = vectors[i:i + PINECONE_BATCH_SIZE] + + # Calculate batch number for logging + batch_num = (i // PINECONE_BATCH_SIZE) + 1 + total_batches = (len(vectors) + PINECONE_BATCH_SIZE - 1) // PINECONE_BATCH_SIZE + + logger.debug(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} vectors)") + + # Call Pinecone upsert API + # upsert = update if exists, insert if new + self._index.upsert( + vectors=batch, + namespace=self._namespace + ) + + logger.info(f"Successfully added {len(vectors)} vectors") + + # Return the list of IDs + return ids + + @with_retry(config=PINECONE_RETRY_CONFIG) + def search( + self, + query_embedding: List[float], + top_k: int = 5, + filters: Optional[Dict[str, Any]] = None + ) -> List[SearchResult]: + """ + Search for similar vectors using cosine similarity. + + This implements the abstract search() method from BaseVectorStore. + + PINECONE SPECIFICS: + ------------------- + - Pinecone uses a specific filter syntax with operators like $eq, $and + - We convert simple dict filters to Pinecone format + - Results include similarity scores (0-1 for cosine) + + Args: + query_embedding: The embedding vector to search for + top_k: Number of results to return + filters: Optional metadata filters + Example: {"user_id": "alice", "type": "note"} + + Returns: + List of SearchResult objects sorted by similarity (highest first) + + Raises: + VectorStoreValidationError: If query embedding dimension is wrong + VectorStoreConnectionError: If Pinecone API call fails + + Example: + results = store.search( + query_embedding=[0.1, 0.2, ...], + top_k=5, + filters={"user_id": "alice"} + ) + for r in results: + print(f"Score: {r.score}, Content: {r.content}") + """ + + # -------------------------------------------------------------------- + # STEP 1: Validate query embedding dimension + # -------------------------------------------------------------------- + + if len(query_embedding) != self._dimension: + raise VectorStoreValidationError( + f"Query embedding dimension {len(query_embedding)} " + f"doesn't match index dimension {self._dimension}", + operation="search", + details={ + "query_dimension": len(query_embedding), + "index_dimension": self._dimension + } + ) + + # -------------------------------------------------------------------- + # STEP 2: Build Pinecone filter from dict + # -------------------------------------------------------------------- + + # Convert simple filters to Pinecone format + pinecone_filter: Optional[Dict[str, Any]] = self._build_filter(filters) + + # -------------------------------------------------------------------- + # STEP 3: Query Pinecone + # -------------------------------------------------------------------- + + logger.debug( + f"Searching namespace '{self._namespace}' " + f"(top_k={top_k}, filter={pinecone_filter})" + ) + + # Call Pinecone query API + results = self._index.query( + vector=query_embedding, # The query vector + top_k=top_k, # Number of results + include_metadata=True, # Include metadata in results + namespace=self._namespace, # Search within this namespace + filter=pinecone_filter # Optional filters + ) + + # -------------------------------------------------------------------- + # STEP 4: Convert Pinecone results to SearchResult objects + # -------------------------------------------------------------------- + + search_results: List[SearchResult] = [] + + # results.matches: List of match objects from Pinecone + for match in results.matches: + # Get metadata (or empty dict if None) + # 'or {}' provides default if match.metadata is None + metadata: Dict[str, Any] = match.metadata or {} + + # Extract and remove content from metadata + # pop(key, default): Removes key from dict and returns value + # If key doesn't exist, returns default instead of raising error + content: str = metadata.pop("content", "") + + # Create SearchResult object (defined in base.py) + search_results.append(SearchResult( + id=match.id, # Vector ID + content=content, # The original text + score=match.score, # Similarity score (0-1 for cosine) + metadata=metadata # Remaining metadata + )) + + logger.debug(f"Found {len(search_results)} results") + + return search_results + + @with_retry(config=PINECONE_RETRY_CONFIG) + def update( + self, + id: str, + text: Optional[str] = None, + embedding: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> bool: + """ + Update an existing vector in the Pinecone index. + + This implements the abstract update() method from BaseVectorStore. + + PINECONE SPECIFICS: + ------------------- + - Pinecone doesn't have a native "update" operation + - We implement update by: fetch existing -> merge changes -> upsert + - This is a common pattern called "read-modify-write" + + Args: + id: The ID of the vector to update + text: New text content (optional) + embedding: New embedding vector (optional) + metadata: Metadata to merge with existing (optional) + + Returns: + True if update succeeded, False if vector not found + + Raises: + VectorStoreValidationError: If new embedding dimension is wrong + VectorStoreConnectionError: If Pinecone API call fails + + Example: + success = store.update( + id="vec-123", + text="Updated content", + embedding=[0.5, 0.6, ...], + metadata={"edited": True} + ) + """ + + # -------------------------------------------------------------------- + # STEP 1: Validate new embedding dimension if provided + # -------------------------------------------------------------------- + + if embedding is not None and len(embedding) != self._dimension: + raise VectorStoreValidationError( + f"Embedding dimension {len(embedding)} " + f"doesn't match index dimension {self._dimension}", + operation="update" + ) + + # -------------------------------------------------------------------- + # STEP 2: Fetch existing vector + # -------------------------------------------------------------------- + + logger.debug(f"Updating vector {id} in namespace '{self._namespace}'") + + # fetch(): Retrieve vectors by ID + fetch_result = self._index.fetch( + ids=[id], # List of IDs to fetch + namespace=self._namespace # Namespace to search in + ) + + # Check if vector exists + # fetch_result.vectors: Dict[str, VectorData] + if id not in fetch_result.vectors: + logger.warning(f"Vector {id} not found for update") + return False + + # Get the existing vector data + existing = fetch_result.vectors[id] + + # -------------------------------------------------------------------- + # STEP 3: Prepare updated values + # -------------------------------------------------------------------- + + # Use new embedding if provided, otherwise keep existing + # Conditional expression: value_if_true if condition else value_if_false + new_values: List[float] = ( + embedding if embedding is not None + else existing.values + ) + + # Start with copy of existing metadata + # .copy() creates shallow copy to avoid mutating original + new_metadata: Dict[str, Any] = ( + existing.metadata.copy() if existing.metadata + else {} + ) + + # Update content if new text provided + if text is not None: + new_metadata["content"] = text + + # Merge new metadata if provided + # .update() adds/updates keys from the given dict + if metadata is not None: + new_metadata.update(metadata) + + # -------------------------------------------------------------------- + # STEP 4: Upsert the updated vector + # -------------------------------------------------------------------- + + self._index.upsert( + vectors=[{ + "id": id, + "values": new_values, + "metadata": new_metadata + }], + namespace=self._namespace + ) + + logger.info(f"Successfully updated vector {id}") + + return True + + @with_retry(config=PINECONE_RETRY_CONFIG) + def delete(self, ids: List[str]) -> bool: + """ + Delete vectors from the Pinecone index by their IDs. + + This implements the abstract delete() method from BaseVectorStore. + + Args: + ids: List of vector IDs to delete + + Returns: + True if deletion succeeded (even if some IDs didn't exist) + + Raises: + VectorStoreConnectionError: If Pinecone API call fails + + Example: + store.delete(ids=["vec-001", "vec-002"]) + """ + + # Validate input + if not ids: + logger.warning("delete() called with empty IDs list") + return True # Nothing to delete is technically a success + + logger.info(f"Deleting {len(ids)} vectors from namespace '{self._namespace}'") + + # Call Pinecone delete API + self._index.delete( + ids=ids, + namespace=self._namespace + ) + + logger.info(f"Successfully deleted {len(ids)} vectors") + + return True + + @with_retry(config=PINECONE_RETRY_CONFIG) + def get(self, ids: List[str]) -> List[Dict[str, Any]]: + """ + Retrieve vectors from the Pinecone index by their IDs. + + This implements the abstract get() method from BaseVectorStore. + + Args: + ids: List of vector IDs to retrieve + + Returns: + List of document dicts, each containing: + - id: The vector ID + - content: The text content + - metadata: Additional metadata + - embedding: The embedding vector + + Raises: + VectorStoreConnectionError: If Pinecone API call fails + + Example: + docs = store.get(ids=["vec-001", "vec-002"]) + for doc in docs: + print(f"{doc['id']}: {doc['content']}") + """ + + # Validate input + if not ids: + logger.warning("get() called with empty IDs list") + return [] + + logger.debug(f"Fetching {len(ids)} vectors from namespace '{self._namespace}'") + + # Call Pinecone fetch API + fetch_result = self._index.fetch( + ids=ids, + namespace=self._namespace + ) + + # Convert Pinecone results to our format + documents: List[Dict[str, Any]] = [] + + # .items() returns (key, value) pairs from dict + for vec_id, vector_data in fetch_result.vectors.items(): + # Get metadata or empty dict + metadata: Dict[str, Any] = vector_data.metadata or {} + + # Extract content from metadata + content: str = metadata.pop("content", "") + + # Build document dict + documents.append({ + "id": vec_id, + "content": content, + "metadata": metadata, + "embedding": vector_data.values + }) + + logger.debug(f"Retrieved {len(documents)} documents") + + return documents + + # ======================================================================== + # MANAGEMENT OPERATIONS + # ======================================================================== + + @with_retry(config=PINECONE_RETRY_CONFIG) + def health_check(self) -> bool: + """ + Check if the Pinecone connection is healthy. + + This implements the abstract health_check() method from BaseVectorStore. + + Use for: + - Kubernetes readiness probes + - Health monitoring dashboards + - Connection validation at startup + + Returns: + True if connection is healthy, False otherwise + + Example: + if not store.health_check(): + send_alert("Vector store is down!") + """ + + try: + # Call describe_index_stats as a lightweight health check + # This verifies both client connection and index accessibility + stats = self._index.describe_index_stats() + + # If we got stats without error, connection is healthy + logger.debug(f"Health check passed: {stats.total_vector_count} vectors") + return True + + except Exception as e: + # Log the failure + logger.warning(f"Health check failed: {e}") + return False + + @with_retry(config=PINECONE_RETRY_CONFIG) + def get_stats(self) -> IndexStats: + """ + Get statistics about the Pinecone index. + + This implements the abstract get_stats() method from BaseVectorStore. + + Returns: + IndexStats object with vector count, dimension, etc. + + Example: + stats = store.get_stats() + print(f"Total vectors: {stats.total_vector_count}") + print(f"Vectors in default namespace: {stats.namespaces.get('default', 0)}") + """ + + # Get index statistics from Pinecone + stats = self._index.describe_index_stats() + + # Build namespace counts dict + # stats.namespaces is Dict[str, NamespaceStats] + namespace_counts: Dict[str, int] = {} + + if stats.namespaces: + for ns_name, ns_stats in stats.namespaces.items(): + # ns_stats has vector_count attribute + namespace_counts[ns_name] = ns_stats.get("vector_count", 0) + + # Return our IndexStats dataclass + return IndexStats( + total_vector_count=stats.total_vector_count, + dimension=self._dimension, + namespaces=namespace_counts, + fullness_percentage=stats.index_fullness if hasattr(stats, 'index_fullness') else None + ) + + # ======================================================================== + # PINECONE-SPECIFIC METHODS (not in BaseVectorStore) + # ======================================================================== + + @with_retry(config=PINECONE_RETRY_CONFIG) + def count(self) -> int: + """ + Get the number of vectors in the current namespace. + + Note: This is Pinecone-specific (not in BaseVectorStore interface). + Other backends may have different counting mechanisms. + + Returns: + Approximate count of vectors in the namespace + + Example: + print(f"Namespace has {store.count()} vectors") + """ + + # Get index statistics + stats = self._index.describe_index_stats() + + # If using a namespace, get count for that namespace + if self._namespace: + # .get() with default handles missing namespace + namespace_stats = stats.namespaces.get(self._namespace, {}) + return namespace_stats.get("vector_count", 0) + + # Otherwise return total count + return stats.total_vector_count + + @with_retry(config=PINECONE_RETRY_CONFIG) + def clear(self) -> bool: + """ + Clear all vectors from the current namespace. + + WARNING: This is destructive and cannot be undone! + + Returns: + True if clear succeeded + + Example: + if confirm_deletion(): + store.clear() + """ + + logger.warning(f"Clearing all vectors from namespace '{self._namespace}'") + + # Pinecone requires delete_all=True for clearing namespace + self._index.delete( + delete_all=True, + namespace=self._namespace + ) + + logger.info(f"Successfully cleared namespace '{self._namespace}'") + + return True + + def delete_index(self) -> None: + """ + Delete the entire Pinecone index. + + WARNING: This is extremely destructive! + - Deletes ALL data in ALL namespaces + - Cannot be undone + - Requires re-creating index to use again + + Use only for cleanup/teardown scenarios. + + Example: + # In test cleanup: + store.delete_index() + """ + + logger.warning(f"DELETING ENTIRE INDEX: {self._index_name}") + + # Delete the index via the Pinecone client + self._pc.delete_index(self._index_name) + + logger.info(f"Successfully deleted index: {self._index_name}") + + # ======================================================================== + # PRIVATE HELPER METHODS + # ======================================================================== + + def _build_filter( + self, + filters: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """ + Convert simple filter dict to Pinecone filter format. + + WHY IS THIS METHOD NAMED WITH UNDERSCORE? + ------------------------------------------ + Leading underscore (_build_filter) is a Python convention meaning: + "This is an internal/private method, don't call it from outside the class" + + It's not enforced by Python (you CAN still call it), but it signals intent. + + PINECONE FILTER FORMAT: + ----------------------- + Simple filter: {"field": {"$eq": "value"}} + Multiple filters: {"$and": [{"f1": {"$eq": "v1"}}, {"f2": {"$eq": "v2"}}]} + + This method converts simple dict {"key": "value"} to Pinecone format. + + Args: + filters: Simple dict like {"user_id": "alice", "type": "note"} + + Returns: + Pinecone-formatted filter dict, or None if no filters + + Example: + self._build_filter({"user_id": "alice"}) + # Returns: {"user_id": {"$eq": "alice"}} + + self._build_filter({"user_id": "alice", "type": "note"}) + # Returns: {"$and": [ + # {"user_id": {"$eq": "alice"}}, + # {"type": {"$eq": "note"}} + # ]} + """ + + # Return None if no filters provided + if not filters: + return None + + # Single filter: convert directly + if len(filters) == 1: + # Get the single key-value pair + # list(filters.items())[0] gets first (only) tuple + key, value = list(filters.items())[0] + + # Return Pinecone format + return {key: {"$eq": value}} + + # Multiple filters: combine with $and + # List comprehension builds list of filter conditions + filter_conditions: List[Dict[str, Any]] = [ + {k: {"$eq": v}} for k, v in filters.items() + ] + + return {"$and": filter_conditions} + + # ======================================================================== + # CONTEXT MANAGER SUPPORT (Optional) + # ======================================================================== + + def __enter__(self) -> "PineconeVectorStore": + """ + Support for 'with' statement (context manager). + + WHAT IS A CONTEXT MANAGER? + -------------------------- + Allows using the class with 'with' statement for automatic cleanup: + + with PineconeVectorStore() as store: + store.add(...) + # Automatic cleanup when exiting 'with' block + + __enter__ is called when entering the 'with' block. + + Returns: + self (the store instance) + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Exit the context manager. + + Called when exiting the 'with' block (even if exception occurred). + + Args: + exc_type: Exception type if one occurred, else None + exc_val: Exception value if one occurred, else None + exc_tb: Exception traceback if one occurred, else None + + Note: + Pinecone client doesn't require explicit cleanup, + but this method is here for consistency and future use. + """ + # Log if there was an exception + if exc_type is not None: + logger.error(f"Error in context manager: {exc_type.__name__}: {exc_val}") + + # Pinecone doesn't need explicit cleanup + # If it did, we'd call cleanup methods here + pass + + # ======================================================================== + # STRING REPRESENTATION + # ======================================================================== + + def __repr__(self) -> str: + """ + Return string representation for debugging. + + WHAT IS __repr__? + ----------------- + Called when you do repr(object) or in debugger. + Should return a string that could recreate the object. + + Returns: + String like "PineconeVectorStore(index='my-index', namespace='default')" + """ + return ( + f"PineconeVectorStore(" + f"index='{self._index_name}', " + f"namespace='{self._namespace}', " + f"dimension={self._dimension})" + ) + + def __str__(self) -> str: + """ + Return user-friendly string representation. + + WHAT IS __str__? + ---------------- + Called when you do str(object) or print(object). + Should return human-readable description. + + Returns: + Friendly string like "Pinecone: my-index (default)" + """ + return f"Pinecone: {self._index_name} ({self._namespace})" diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e69de29..8e86fea 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -0,0 +1,55 @@ +from .exceptions import ( + # Base exception + XMemError, + + # Configuration errors + ConfigurationError, + + # Validation errors + ValidationError, + + # Storage/Vector store errors + VectorStoreError, + VectorStoreConnectionError, + VectorStoreValidationError, + VectorNotFoundError, + + # Database errors + DatabaseError, + DatabaseConnectionError, + + # LLM/API errors + LLMError, + LLMRateLimitError, + LLMContextLengthError, + + # Embedding errors + EmbeddingError, +) + +# Import retry utilities +from .retry import ( + with_retry, + RetryConfig, +) + +__all__ = [ + # Exceptions + "XMemError", + "ConfigurationError", + "ValidationError", + "VectorStoreError", + "VectorStoreConnectionError", + "VectorStoreValidationError", + "VectorNotFoundError", + "DatabaseError", + "DatabaseConnectionError", + "LLMError", + "LLMRateLimitError", + "LLMContextLengthError", + "EmbeddingError", + + # Retry utilities + "with_retry", + "RetryConfig", +] diff --git a/src/utils/exceptions.py b/src/utils/exceptions.py new file mode 100644 index 0000000..3f6f011 --- /dev/null +++ b/src/utils/exceptions.py @@ -0,0 +1,408 @@ +""" +================================================================================ +CUSTOM EXCEPTIONS - Centralized Error Handling for XMem +================================================================================ + +EXCEPTION HIERARCHY: +-------------------- + XMemError (base for all XMem errors) + │ + ├── ConfigurationError + │ └── Missing API keys, invalid settings + │ + ├── ValidationError + │ └── Invalid input data, schema violations + │ + ├── VectorStoreError + │ ├── VectorStoreConnectionError + │ ├── VectorStoreValidationError + │ └── VectorNotFoundError + │ + ├── DatabaseError + │ └── DatabaseConnectionError + │ + ├── LLMError + │ ├── LLMRateLimitError + │ └── LLMContextLengthError + │ + └── EmbeddingError + +""" + +from typing import Dict, Any, Optional + + +# ============================================================================ +# BASE EXCEPTION +# ============================================================================ + +class XMemError(Exception): + """ + Base exception class for all XMem errors. + + All custom exceptions in XMem should inherit from this class. + This allows catching all XMem-related errors with a single except clause. + + ATTRIBUTES: + ----------- + message : str + Human-readable error description + operation : Optional[str] + Name of the operation that failed (e.g., "add", "search", "connect") + details : Dict[str, Any] + Additional context for debugging (IDs, counts, parameters) + + USAGE: + ------ + try: + do_something() + except XMemError as e: + # Catches ALL XMem errors + logger.error(f"XMem error in {e.operation}: {e}") + logger.debug(f"Details: {e.details}") + + Example: + raise XMemError( + "Failed to process memory", + operation="process_memory", + details={"user_id": "123", "memory_count": 5} + ) + """ + + def __init__( + self, + message: str, + operation: Optional[str] = None, + details: Optional[Dict[str, Any]] = None + ) -> None: + """ + Initialize the exception with context. + + Args: + message: Human-readable description of what went wrong + operation: Name of the operation that failed (for logging/debugging) + details: Dictionary of additional context (IDs, counts, etc.) + """ + # Call parent Exception.__init__ with the message + # super().__init__(message) sets self.args = (message,) + super().__init__(message) + + # Store the original message for access + self.message: str = message + + # Store operation name (e.g., "add", "search", "delete") + self.operation: Optional[str] = operation + + # Store additional details (e.g., {"ids": ["id1", "id2"]}) + # 'or {}' ensures details is never None, always a dict + self.details: Dict[str, Any] = details or {} + + def __str__(self) -> str: + """ + Return string representation of the exception. + + Called when you do str(exception) or print(exception). + Includes operation name if available for better debugging. + + Returns: + Formatted error message string + """ + # Include operation context if available + if self.operation: + return f"[{self.operation}] {self.message}" + return self.message + + def __repr__(self) -> str: + """ + Return detailed representation for debugging. + + Called by repr(exception) and in debugger. + Shows all attributes for complete debugging info. + + Returns: + Detailed string representation + """ + return ( + f"{self.__class__.__name__}(" + f"message={self.message!r}, " + f"operation={self.operation!r}, " + f"details={self.details!r})" + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert exception to dictionary for JSON serialization. + + Useful for API error responses. + + Returns: + Dictionary representation of the error + + Example: + except XMemError as e: + return jsonify(e.to_dict()), 500 + """ + return { + "error": self.__class__.__name__, + "message": self.message, + "operation": self.operation, + "details": self.details, + } + + +# ============================================================================ +# CONFIGURATION ERRORS +# ============================================================================ + +class ConfigurationError(XMemError): + """ + Raised when there's a configuration or settings error. + + Examples: + - Missing required API key + - Invalid configuration value + - Incompatible settings combination + + Usage: + if not settings.pinecone_api_key: + raise ConfigurationError( + "Pinecone API key is required", + operation="init", + details={"setting": "PINECONE_API_KEY"} + ) + """ + pass # Inherits everything from XMemError + + +# ============================================================================ +# VALIDATION ERRORS +# ============================================================================ + +class ValidationError(XMemError): + """ + Raised when input validation fails. + + Examples: + - Invalid data format + - Schema validation failure + - Constraint violation + + Usage: + if len(texts) != len(embeddings): + raise ValidationError( + f"Length mismatch: {len(texts)} texts vs {len(embeddings)} embeddings", + operation="add", + details={"texts_count": len(texts), "embeddings_count": len(embeddings)} + ) + """ + pass + + +# ============================================================================ +# VECTOR STORE ERRORS +# ============================================================================ + +class VectorStoreError(XMemError): + """ + Base exception for all vector store related errors. + + Inherit from this for specific vector store errors. + Allows catching all storage errors with one except clause. + + Usage: + try: + store.add(texts, embeddings) + except VectorStoreError as e: + # Handles connection, validation, and not-found errors + logger.error(f"Storage operation failed: {e}") + """ + pass + + +class VectorStoreConnectionError(VectorStoreError): + """ + Raised when connection to the vector store fails. + + Examples: + - Network timeout + - Invalid credentials + - Service unavailable + - Rate limiting + + Usage: + try: + client = Pinecone(api_key=api_key) + except Exception as e: + raise VectorStoreConnectionError( + f"Failed to connect to Pinecone: {e}", + operation="init", + details={"original_error": str(e)} + ) + """ + pass + + +class VectorStoreValidationError(VectorStoreError): + """ + Raised when vector store input validation fails. + + Examples: + - Embedding dimension mismatch + - Empty input lists + - Invalid metadata format + + Usage: + if len(embedding) != expected_dim: + raise VectorStoreValidationError( + f"Dimension mismatch: got {len(embedding)}, expected {expected_dim}", + operation="search" + ) + """ + pass + + +class VectorNotFoundError(VectorStoreError): + """ + Raised when a requested vector ID doesn't exist. + + Examples: + - Trying to update non-existent vector + - Trying to get vector by invalid ID + + Usage: + if id not in fetch_result.vectors: + raise VectorNotFoundError( + f"Vector with ID '{id}' not found", + operation="update", + details={"id": id} + ) + """ + pass + + +# ============================================================================ +# DATABASE ERRORS +# ============================================================================ + +class DatabaseError(XMemError): + """ + Base exception for all database related errors. + + Use for MongoDB, Neo4j, or any other database operations. + + Usage: + try: + collection.insert_one(document) + except PyMongoError as e: + raise DatabaseError( + f"Failed to insert document: {e}", + operation="insert" + ) + """ + pass + + +class DatabaseConnectionError(DatabaseError): + """ + Raised when database connection fails. + + Examples: + - Cannot connect to MongoDB + - Neo4j authentication failed + - Connection pool exhausted + + Usage: + try: + client = MongoClient(uri) + except ConnectionFailure as e: + raise DatabaseConnectionError( + f"Failed to connect to MongoDB: {e}", + operation="connect", + details={"uri": uri} + ) + """ + pass + + +# ============================================================================ +# LLM ERRORS +# ============================================================================ + +class LLMError(XMemError): + """ + Base exception for all LLM (Language Model) related errors. + + Use for OpenAI, Anthropic, Google, or any LLM API operations. + + Usage: + try: + response = client.chat.completions.create(...) + except OpenAIError as e: + raise LLMError( + f"LLM call failed: {e}", + operation="generate" + ) + """ + pass + + +class LLMRateLimitError(LLMError): + """ + Raised when LLM API rate limit is exceeded. + + This is often retryable after waiting. + + Usage: + except RateLimitError as e: + raise LLMRateLimitError( + "Rate limit exceeded, please retry later", + operation="generate", + details={"retry_after": e.retry_after} + ) + """ + pass + + +class LLMContextLengthError(LLMError): + """ + Raised when input exceeds LLM's context length limit. + + Examples: + - Prompt too long for model + - Combined input + output exceeds limit + + Usage: + if token_count > model_limit: + raise LLMContextLengthError( + f"Input exceeds context limit: {token_count} > {model_limit}", + operation="generate", + details={"token_count": token_count, "limit": model_limit} + ) + """ + pass + + +# ============================================================================ +# EMBEDDING ERRORS +# ============================================================================ + +class EmbeddingError(XMemError): + """ + Raised when embedding generation fails. + + Examples: + - Model loading failed + - Text too long for embedding model + - Encoding error + + Usage: + try: + embedding = model.encode(text) + except Exception as e: + raise EmbeddingError( + f"Failed to generate embedding: {e}", + operation="embed", + details={"text_length": len(text)} + ) + """ + pass diff --git a/src/utils/retry.py b/src/utils/retry.py index e69de29..09e83e9 100644 --- a/src/utils/retry.py +++ b/src/utils/retry.py @@ -0,0 +1,416 @@ +""" +================================================================================ +RETRY UTILITIES - Exponential Backoff for Transient Failures +================================================================================ + +WHY RETRY LOGIC? +---------------- +Network calls can fail temporarily due to: +- Network timeouts +- Rate limiting (429 errors) +- Service temporarily unavailable (503 errors) +- Connection resets + +Instead of immediately failing, we retry with exponential backoff: +- 1st retry: wait 1 second +- 2nd retry: wait 2 seconds +- 3rd retry: wait 4 seconds + +This gives the service time to recover without overwhelming it. + +USAGE: +------ + from src.utils.retry import with_retry, RetryConfig + + # Basic usage with defaults + @with_retry() + def call_external_api(): + return requests.get("https://api.example.com") + + # Custom configuration + @with_retry(max_retries=5, delay=0.5, backoff=2.0) + def call_api_with_custom_retry(): + return api.call() + + # Using RetryConfig for reusable settings + api_retry_config = RetryConfig(max_retries=3, delay=1.0) + + @with_retry(config=api_retry_config) + def another_api_call(): + return api.call() + +================================================================================ +""" + +from typing import TypeVar, Callable, Optional, Type, Tuple, Any + +# functools: Higher-order functions +# wraps: Preserves function metadata when creating decorators +from functools import wraps +import time +import logging +from dataclasses import dataclass, field +from .exceptions import XMemError, ValidationError + +logger = logging.getLogger(__name__) +T = TypeVar("T") + +@dataclass +class RetryConfig: + """ + Configuration for retry behavior. + + Using a dataclass allows: + - Reusable configurations across multiple functions + - Clear documentation of all options + - Easy modification and testing + + ATTRIBUTES: + ----------- + max_retries : int + Maximum number of retry attempts (default: 3) + Total attempts = max_retries + 1 (initial + retries) + + delay : float + Initial delay between retries in seconds (default: 1.0) + Actual delay = delay * (backoff_multiplier ^ attempt) + + backoff_multiplier : float + Multiplier for exponential backoff (default: 2.0) + delay=1, backoff=2: waits 1s, 2s, 4s, 8s... + + max_delay : float + Maximum delay cap in seconds (default: 60.0) + Prevents extremely long waits + + retryable_exceptions : Tuple[Type[Exception], ...] + Exception types that should trigger retry + Default: (Exception,) - retry all exceptions + Example: (ConnectionError, TimeoutError) + + non_retryable_exceptions : Tuple[Type[Exception], ...] + Exception types that should NOT be retried + Default: (ValidationError,) - don't retry validation errors + + USAGE: + ------ + # Create reusable config + api_config = RetryConfig( + max_retries=5, + delay=0.5, + retryable_exceptions=(ConnectionError, TimeoutError) + ) + + @with_retry(config=api_config) + def call_api(): + ... + """ + + max_retries: int = 3 + delay: float = 1.0 + backoff_multiplier: float = 2.0 + max_delay: float = 60.0 + retryable_exceptions: Tuple[Type[Exception], ...] = field( + default_factory=lambda: (Exception,) + ) + non_retryable_exceptions: Tuple[Type[Exception], ...] = field( + default_factory=lambda: (ValidationError,) + ) + +# Default configuration instance +DEFAULT_RETRY_CONFIG = RetryConfig() + +def with_retry( + max_retries: Optional[int] = None, + delay: Optional[float] = None, + backoff_multiplier: Optional[float] = None, + max_delay: Optional[float] = None, + retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, + non_retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, + config: Optional[RetryConfig] = None, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Decorator that adds retry logic with exponential backoff. + + WHAT IS A DECORATOR? + -------------------- + A decorator wraps a function to add behavior. The @decorator syntax + is equivalent to: function = decorator(function) + + WHAT IS EXPONENTIAL BACKOFF? + ---------------------------- + Instead of retrying immediately, wait longer after each failure: + - 1st retry: wait 1s (delay * 2^0) + - 2nd retry: wait 2s (delay * 2^1) + - 3rd retry: wait 4s (delay * 2^2) + + This prevents overwhelming a struggling service. + + DECORATOR FACTORY PATTERN: + -------------------------- + This is a "decorator factory" - it returns a decorator. + + @with_retry(max_retries=3) # Factory call returns decorator + def my_func(): # Decorator wraps this function + pass + + Args: + max_retries: Maximum retry attempts (default: 3) + delay: Initial delay in seconds (default: 1.0) + backoff_multiplier: Backoff multiplier (default: 2.0) + max_delay: Maximum delay cap (default: 60.0) + retryable_exceptions: Exception types to retry + non_retryable_exceptions: Exception types to NOT retry + config: RetryConfig instance (overrides individual params) + + Returns: + Decorator function that wraps the target function + + USAGE: + ------ + # Basic usage + @with_retry() + def call_api(): + return requests.get(url) + + # Custom parameters + @with_retry(max_retries=5, delay=0.5) + def call_api_custom(): + return api.call() + + # With config object + config = RetryConfig(max_retries=3) + @with_retry(config=config) + def call_api_with_config(): + return api.call() + + Example with full flow: + @with_retry(max_retries=3, delay=1.0) + def fetch_data(): + return external_api.get() + + # If fetch_data() fails: + # - Attempt 1: fails → wait 1s + # - Attempt 2: fails → wait 2s + # - Attempt 3: fails → wait 4s + # - Attempt 4: fails → raise exception + """ + + # ======================================================================== + # STEP 1: Resolve configuration + # ======================================================================== + + # Use provided config or build from individual params + if config is not None: + # Use the provided config object + effective_config = config + else: + # Build config from individual params, falling back to defaults + effective_config = RetryConfig( + max_retries=max_retries if max_retries is not None else DEFAULT_RETRY_CONFIG.max_retries, + delay=delay if delay is not None else DEFAULT_RETRY_CONFIG.delay, + backoff_multiplier=backoff_multiplier if backoff_multiplier is not None else DEFAULT_RETRY_CONFIG.backoff_multiplier, + max_delay=max_delay if max_delay is not None else DEFAULT_RETRY_CONFIG.max_delay, + retryable_exceptions=retryable_exceptions if retryable_exceptions is not None else DEFAULT_RETRY_CONFIG.retryable_exceptions, + non_retryable_exceptions=non_retryable_exceptions if non_retryable_exceptions is not None else DEFAULT_RETRY_CONFIG.non_retryable_exceptions, + ) + + # ======================================================================== + # STEP 2: Create the actual decorator + # ======================================================================== + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + """ + The actual decorator that wraps the function. + + Args: + func: The function to wrap with retry logic + + Returns: + Wrapped function with retry behavior + """ + + # @wraps(func) preserves the original function's metadata + # Without this, __name__, __doc__, etc. would be lost + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + """ + Wrapper function that implements retry logic. + + *args: Positional arguments passed to the wrapped function + **kwargs: Keyword arguments passed to the wrapped function + + Returns: + Result from the wrapped function + + Raises: + The last exception if all retries are exhausted + """ + + # Track the last exception for re-raising if all retries fail + last_exception: Optional[Exception] = None + + # Total attempts = initial attempt + retries + total_attempts = effective_config.max_retries + 1 + + # ================================================================ + # RETRY LOOP + # ================================================================ + + for attempt in range(total_attempts): + try: + # Call the original function with all its arguments + # If successful, return immediately + return func(*args, **kwargs) + + except effective_config.non_retryable_exceptions as e: + # These exceptions should not be retried + # (e.g., validation errors won't succeed on retry) + logger.debug( + f"Non-retryable exception in {func.__name__}: " + f"{type(e).__name__}: {e}" + ) + raise + + except effective_config.retryable_exceptions as e: + # Store the exception in case we need to re-raise it + last_exception = e + + # Check if we have retries left + if attempt < effective_config.max_retries: + # Calculate delay with exponential backoff + # delay * (backoff ^ attempt) = 1, 2, 4, 8, ... + current_delay = effective_config.delay * ( + effective_config.backoff_multiplier ** attempt + ) + + # Cap the delay at max_delay + current_delay = min(current_delay, effective_config.max_delay) + + # Log the retry attempt + logger.warning( + f"Attempt {attempt + 1}/{total_attempts} failed for " + f"{func.__name__}: {type(e).__name__}: {e}. " + f"Retrying in {current_delay:.1f}s..." + ) + + # Wait before retrying + time.sleep(current_delay) + + else: + # No more retries - log the final failure + logger.error( + f"All {total_attempts} attempts failed for " + f"{func.__name__}: {type(e).__name__}: {e}" + ) + + # ================================================================ + # ALL RETRIES EXHAUSTED + # ================================================================ + + # If we get here, all retries failed + # Re-raise the last exception + if last_exception is not None: + raise last_exception + + # This should never happen, but satisfy type checker + raise RuntimeError(f"Unexpected state in retry logic for {func.__name__}") + + # Return the wrapper function + return wrapper + + # Return the decorator + return decorator + +def with_async_retry( + max_retries: Optional[int] = None, + delay: Optional[float] = None, + backoff_multiplier: Optional[float] = None, + max_delay: Optional[float] = None, + retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, + non_retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, + config: Optional[RetryConfig] = None, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Async version of with_retry for async/await functions. + + Uses asyncio.sleep instead of time.sleep to avoid blocking. + + USAGE: + ------ + @with_async_retry(max_retries=3) + async def fetch_data(): + async with aiohttp.ClientSession() as session: + return await session.get(url) + """ + + # Import asyncio here to avoid import if not using async + import asyncio + + # Resolve configuration (same as sync version) + if config is not None: + effective_config = config + else: + effective_config = RetryConfig( + max_retries=max_retries if max_retries is not None else DEFAULT_RETRY_CONFIG.max_retries, + delay=delay if delay is not None else DEFAULT_RETRY_CONFIG.delay, + backoff_multiplier=backoff_multiplier if backoff_multiplier is not None else DEFAULT_RETRY_CONFIG.backoff_multiplier, + max_delay=max_delay if max_delay is not None else DEFAULT_RETRY_CONFIG.max_delay, + retryable_exceptions=retryable_exceptions if retryable_exceptions is not None else DEFAULT_RETRY_CONFIG.retryable_exceptions, + non_retryable_exceptions=non_retryable_exceptions if non_retryable_exceptions is not None else DEFAULT_RETRY_CONFIG.non_retryable_exceptions, + ) + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + """Async decorator wrapper.""" + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + """Async wrapper with retry logic.""" + + last_exception: Optional[Exception] = None + total_attempts = effective_config.max_retries + 1 + + for attempt in range(total_attempts): + try: + # Await the async function + return await func(*args, **kwargs) + + except effective_config.non_retryable_exceptions as e: + logger.debug( + f"Non-retryable exception in {func.__name__}: " + f"{type(e).__name__}: {e}" + ) + raise + + except effective_config.retryable_exceptions as e: + last_exception = e + + if attempt < effective_config.max_retries: + current_delay = effective_config.delay * ( + effective_config.backoff_multiplier ** attempt + ) + current_delay = min(current_delay, effective_config.max_delay) + + logger.warning( + f"Attempt {attempt + 1}/{total_attempts} failed for " + f"{func.__name__}: {type(e).__name__}: {e}. " + f"Retrying in {current_delay:.1f}s..." + ) + + # Use asyncio.sleep for non-blocking wait + await asyncio.sleep(current_delay) + + else: + logger.error( + f"All {total_attempts} attempts failed for " + f"{func.__name__}: {type(e).__name__}: {e}" + ) + + if last_exception is not None: + raise last_exception + + raise RuntimeError(f"Unexpected state in async retry logic for {func.__name__}") + + return wrapper + + return decorator