diff --git a/docs/getting-started.md b/docs/getting-started.md index 13c8fda..409d1d3 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -26,15 +26,40 @@ pip install -e .[dev] ## Configure Environment Variables -Create an `.env` file (or export in your shell) with credentials for the LLM providers you use. A minimal configuration: +DELM requires API keys for the LLM providers you use. You are responsible for loading these environment variables in whatever way works best for your workflow. -```env -OPENAI_API_KEY=sk-... -ANTHROPIC_API_KEY=... -TOGETHER_API_KEY=... +### Required Environment Variables by Provider + +- **OpenAI**: `OPENAI_API_KEY` +- **Anthropic**: `ANTHROPIC_API_KEY` +- **Google**: `GOOGLE_API_KEY` +- **Groq**: `GROQ_API_KEY` +- **Together AI**: `TOGETHER_API_KEY` +- **Fireworks AI**: `FIREWORKS_API_KEY` + +### Option 1: Export in Your Shell + +```bash +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="..." +``` + +### Option 2: Use python-dotenv (Optional) + +If you prefer using `.env` files, install and use `python-dotenv`: + +```bash +pip install python-dotenv +``` + +Then in your script: + +```python +from dotenv import load_dotenv +load_dotenv() # Load from .env file in current directory ``` -Replace the values with your credentials. DELM only loads providers that have available keys. +**Note**: You only need to set the API key for the provider you're using. DELM accesses environment variables directly via the LLM client libraries (OpenAI, Anthropic, etc.). ## Create Your Pipeline Configuration diff --git a/example.env b/example.env index 3118527..47c59e6 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,47 @@ +# Environment Variable Reference for DELM +# ======================================== +# +# DELM requires certain environment variables to be set depending on which LLM provider you use. +# These variables should be loaded into your environment before running DELM. +# +# How to set environment variables: +# --------------------------------- +# +# Option 1: Export in your shell +# export OPENAI_API_KEY="your-openai-key" +# +# Option 2: Use a .env file with python-dotenv (user's choice) +# pip install python-dotenv +# Then in your script: from dotenv import load_dotenv; load_dotenv() +# +# Option 3: Set in Docker/container environment +# docker run -e OPENAI_API_KEY="your-key" ... +# +# Option 4: Use cloud secrets manager (AWS Secrets Manager, GCP Secret Manager, etc.) +# +# Option 5: Set in your IDE/development environment +# +# Required Environment Variables by Provider: +# ------------------------------------------- + +# OpenAI (for gpt-4, gpt-3.5-turbo, etc.) OPENAI_API_KEY="your-openai-key" + +# Anthropic (for claude-3-*, claude-2, etc.) ANTHROPIC_API_KEY="your-anthropic-key" + +# Google (for gemini-*, palm-*, etc.) GOOGLE_API_KEY="your-google-key" + +# Groq (for llama-*, mixtral-*, etc.) GROQ_API_KEY="your-groq-key" + +# Together AI TOGETHER_API_KEY="your-together-key" -FIREWORKS_API_KEY="your-fireworks-key" \ No newline at end of file + +# Fireworks AI +FIREWORKS_API_KEY="your-fireworks-key" + +# Note: You only need to set the API key for the provider you're using. +# DELM no longer automatically loads .env files - you are responsible for +# ensuring the appropriate environment variables are set before running DELM. \ No newline at end of file diff --git a/examples/prompt_optimization/prompt_optimization.py b/examples/prompt_optimization/prompt_optimization.py index 959b043..db01531 100644 --- a/examples/prompt_optimization/prompt_optimization.py +++ b/examples/prompt_optimization/prompt_optimization.py @@ -1,5 +1,4 @@ -"""Implements LLM-In-the-Loop PRompt Optimization (LILPRO) using DELM. -""" +"""Implements LLM-In-the-Loop PRompt Optimization (LILPRO) using DELM.""" from __future__ import annotations @@ -54,6 +53,7 @@ # helpers # ---------------------------------------------------------------------------- + def build_expected_df(record_labeled_df: pd.DataFrame) -> pd.DataFrame: """Create nested expected JSON per id, aggregating duplicates. @@ -86,7 +86,9 @@ def build_expected_df(record_labeled_df: pd.DataFrame) -> pd.DataFrame: .reset_index(name="items") ) - grouped["expected_json"] = grouped["items"].apply(lambda items: {CONTAINER_NAME: items}) + grouped["expected_json"] = grouped["items"].apply( + lambda items: {CONTAINER_NAME: items} + ) return grouped[["id", "expected_json"]] @@ -94,8 +96,16 @@ def _count_price_expectation(items: List[Dict[str, Any]] | None) -> Tuple[int, i """Return counts of True/False for price_expectation across items.""" if not items: return 0, 0 - true_count = sum(1 for it in items if isinstance(it, dict) and it.get("price_expectation") is True) - false_count = sum(1 for it in items if isinstance(it, dict) and it.get("price_expectation") is False) + true_count = sum( + 1 + for it in items + if isinstance(it, dict) and it.get("price_expectation") is True + ) + false_count = sum( + 1 + for it in items + if isinstance(it, dict) and it.get("price_expectation") is False + ) return true_count, false_count @@ -103,7 +113,9 @@ def _extract_items(d: Dict[str, Any] | None) -> List[Dict[str, Any]]: if not isinstance(d, dict): return [] items = d.get(CONTAINER_NAME) - return [it for it in items if isinstance(it, dict)] if isinstance(items, list) else [] + return ( + [it for it in items if isinstance(it, dict)] if isinstance(items, list) else [] + ) def _normalize_good(value: Any) -> str: @@ -212,14 +224,24 @@ def annotate_price_expectation_counts(record_pairs_df: pd.DataFrame) -> pd.DataF ), axis=1, ) - counts_df = pd.DataFrame(list(counts), columns=["expected_counts", "predicted_counts"], index=df.index) + counts_df = pd.DataFrame( + list(counts), columns=["expected_counts", "predicted_counts"], index=df.index + ) out = pd.DataFrame( { "id": df["id"].tolist(), - "exp_true": counts_df["expected_counts"].apply(lambda x: int(x[0]) if isinstance(x, tuple) else 0), - "exp_false": counts_df["expected_counts"].apply(lambda x: int(x[1]) if isinstance(x, tuple) else 0), - "pred_true": counts_df["predicted_counts"].apply(lambda x: int(x[0]) if isinstance(x, tuple) else 0), - "pred_false": counts_df["predicted_counts"].apply(lambda x: int(x[1]) if isinstance(x, tuple) else 0), + "exp_true": counts_df["expected_counts"].apply( + lambda x: int(x[0]) if isinstance(x, tuple) else 0 + ), + "exp_false": counts_df["expected_counts"].apply( + lambda x: int(x[1]) if isinstance(x, tuple) else 0 + ), + "pred_true": counts_df["predicted_counts"].apply( + lambda x: int(x[0]) if isinstance(x, tuple) else 0 + ), + "pred_false": counts_df["predicted_counts"].apply( + lambda x: int(x[1]) if isinstance(x, tuple) else 0 + ), } ) return out @@ -259,7 +281,12 @@ def compute_batch_stats( # Total extractions (total predicted items across all records) n_extractions = int( - sum(len(_extract_items(d)) for d in record_pairs_df.get("extracted_dict", pd.Series([{}] * len(record_pairs_df))) ) + sum( + len(_extract_items(d)) + for d in record_pairs_df.get( + "extracted_dict", pd.Series([{}] * len(record_pairs_df)) + ) + ) ) # Wrong price_expectation among matched (id+good) pairs (boolean inequality) @@ -291,7 +318,9 @@ def append_metrics_row(csv_path: Path, row: Dict[str, Any]) -> None: df.to_csv(csv_path, mode="a", header=header, index=False) -def save_precision_plot(csv_path: Path, out_path: Path, series: str = "presence") -> None: +def save_precision_plot( + csv_path: Path, out_path: Path, series: str = "presence" +) -> None: """Render precision-vs-batch plot from CSV with dynamic y-limits. series: "presence" to plot estimator precision; "matched" to plot matched_precision @@ -305,18 +334,20 @@ def save_precision_plot(csv_path: Path, out_path: Path, series: str = "presence" return # ICLR-friendly style similar to cost_vs_coverage sns.set_theme(style="whitegrid", font_scale=1.2) - plt.rcParams.update({ - "figure.figsize": (3.0, 2.0), - "font.size": 8, - "axes.labelsize": 8, - "axes.titlesize": 9, - "legend.fontsize": 7, - "xtick.labelsize": 7, - "ytick.labelsize": 7, - "savefig.bbox": "tight", - "savefig.pad_inches": 0.02, - "pdf.fonttype": 42, - }) + plt.rcParams.update( + { + "figure.figsize": (3.0, 2.0), + "font.size": 8, + "axes.labelsize": 8, + "axes.titlesize": 9, + "legend.fontsize": 7, + "xtick.labelsize": 7, + "ytick.labelsize": 7, + "savefig.bbox": "tight", + "savefig.pad_inches": 0.02, + "pdf.fonttype": 42, + } + ) plt.figure() if series == "presence": y = df["precision"] @@ -395,10 +426,14 @@ def set_price_expectation_description(schema_path: Path, new_description: str) - changed = True break if changed: - schema_path.write_text(yaml.safe_dump(spec, sort_keys=False, allow_unicode=True)) + schema_path.write_text( + yaml.safe_dump(spec, sort_keys=False, allow_unicode=True) + ) -def run_optimizer_and_get_guidance(current_definition: str, examples_text: str) -> Dict[str, Any]: +def run_optimizer_and_get_guidance( + current_definition: str, examples_text: str +) -> Dict[str, Any]: """Run optimizer to produce a refined definition from wrong examples.""" cfg = DELMConfig.from_yaml(OPTIMIZER_CONFIG_PATH) cfg.schema.spec_path = OPTIMIZER_SCHEMA_PATH @@ -441,6 +476,7 @@ def run_optimizer_and_get_guidance(current_definition: str, examples_text: str) # main flow # ---------------------------------------------------------------------------- + def main() -> None: """Run iterative optimization and plot precision across batches.""" random.seed(RANDOM_SEED) @@ -465,7 +501,9 @@ def main() -> None: metrics_csv_path = EXPERIMENT_ROOT_DIR / "precision_by_batch.csv" # Determine 10% evaluation sample size (at least 1 record) - eval_record_sample_size = max(1, int(np.ceil(EVAL_SAMPLE_RATIO * len(record_expected_df)))) + eval_record_sample_size = max( + 1, int(np.ceil(EVAL_SAMPLE_RATIO * len(record_expected_df))) + ) for batch_idx in tqdm(range(NUM_BATCHES + 1), desc="batches", leave=True): cfg = DELMConfig.from_dict(base_cfg.to_serialized_config_dict()) @@ -501,12 +539,24 @@ def main() -> None: ) # Append to in-memory list for reference - batch_records.append({"batch": batch_idx, "precision": precision, "matched_precision": matched_precision, **stats}) + batch_records.append( + { + "batch": batch_idx, + "precision": precision, + "matched_precision": matched_precision, + **stats, + } + ) # Persist/update the metrics CSV after each batch append_metrics_row( metrics_csv_path, - {"batch": batch_idx, "precision": precision, "matched_precision": matched_precision, **stats}, + { + "batch": batch_idx, + "precision": precision, + "matched_precision": matched_precision, + **stats, + }, ) # Save the per-record trace for price_expectation counts @@ -518,13 +568,31 @@ def main() -> None: json.dump(metrics_dict, fh, ensure_ascii=False, indent=2) record_pairs_out_df = record_pairs_df.copy() - record_pairs_out_df.to_json(exp_dir / "record_pairs.json", orient="records", force_ascii=False, indent=2) + record_pairs_out_df.to_json( + exp_dir / "record_pairs.json", orient="records", force_ascii=False, indent=2 + ) # Save or update the precision plots incrementally (PNG + PDF) - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.png", series="presence") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.pdf", series="presence") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.png", series="matched") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.pdf", series="matched") + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.png", + series="presence", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.pdf", + series="presence", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.png", + series="matched", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.pdf", + series="matched", + ) if batch_idx < NUM_BATCHES: wrong_df = find_wrong_price_expectation_records(record_pairs_df) @@ -551,10 +619,26 @@ def main() -> None: set_price_expectation_description(BASE_SCHEMA_PATH, new_def) # Final plot refresh from accumulated CSV (PNG + PDF) - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.png", series="presence") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.pdf", series="presence") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.png", series="matched") - save_precision_plot(metrics_csv_path, EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.pdf", series="matched") + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.png", + series="presence", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_presence.pdf", + series="presence", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.png", + series="matched", + ) + save_precision_plot( + metrics_csv_path, + EXPERIMENT_ROOT_DIR / "precision_vs_batch_matched.pdf", + series="matched", + ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 2f8f768..431a9cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "instructor>=0.4.0", "pydantic>=2.0.0", "pyyaml>=6.0", - "python-dotenv>=1.0.0", "tqdm>=4.64.0", "rapidfuzz>=3.0.0", "beautifulsoup4>=4.11.0", diff --git a/src/delm/__init__.py b/src/delm/__init__.py index 4016d03..1c18328 100644 --- a/src/delm/__init__.py +++ b/src/delm/__init__.py @@ -7,16 +7,19 @@ # Library-local logger log = logging.getLogger(__name__) -log.addHandler(logging.NullHandler()) # avoids spurious warnings +log.addHandler(logging.NullHandler()) # avoids spurious warnings from delm.delm import DELM from delm.logging import configure as configure_logging -from delm.config import DELMConfig, LLMExtractionConfig, DataPreprocessingConfig, SchemaConfig, SplittingConfig, ScoringConfig -from delm.exceptions import ( - DELMError, - ExperimentManagementError, - InstructorError +from delm.config import ( + DELMConfig, + LLMExtractionConfig, + DataPreprocessingConfig, + SchemaConfig, + SplittingConfig, + ScoringConfig, ) +from delm.exceptions import DELMError, ExperimentManagementError, InstructorError from .constants import ( # LLM/API Configuration DEFAULT_PROVIDER, @@ -28,28 +31,22 @@ DEFAULT_BASE_DELAY, DEFAULT_TRACK_COST, DEFAULT_MAX_BUDGET, - DEFAULT_DOTENV_PATH, - # Data Processing DEFAULT_DROP_TARGET_COLUMN, DEFAULT_PANDAS_SCORE_FILTER, - # Schema Configuration DEFAULT_SCHEMA_PATH, DEFAULT_PROMPT_TEMPLATE, DEFAULT_SYSTEM_PROMPT, - # Experiment Management DEFAULT_EXPERIMENT_DIR, DEFAULT_OVERWRITE_EXPERIMENT, DEFAULT_AUTO_CHECKPOINT_AND_RESUME, - # Semantic Cache DEFAULT_SEMANTIC_CACHE_BACKEND, DEFAULT_SEMANTIC_CACHE_PATH, DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB, DEFAULT_SEMANTIC_CACHE_SYNCHRONOUS, - # System Constants SYSTEM_FILE_NAME_COLUMN, SYSTEM_RAW_DATA_COLUMN, @@ -61,10 +58,8 @@ SYSTEM_ERRORS_COLUMN, SYSTEM_EXTRACTED_DATA_JSON_COLUMN, SYSTEM_RANDOM_SEED, - # File and Directory Constants DATA_DIR_NAME, - CACHE_DIR_NAME, PROCESSING_CACHE_DIR_NAME, BATCH_FILE_PREFIX, BATCH_FILE_SUFFIX, @@ -76,7 +71,6 @@ PREPROCESSED_DATA_SUFFIX, META_DATA_PREFIX, META_DATA_SUFFIX, - # Utility Constants IGNORE_FILES, ) @@ -93,12 +87,10 @@ "SchemaConfig", "SplittingConfig", "ScoringConfig", - # Exceptions "DELMError", "ExperimentManagementError", "InstructorError", - # LLM/API Configuration "DEFAULT_PROVIDER", "DEFAULT_MODEL_NAME", @@ -109,28 +101,22 @@ "DEFAULT_BASE_DELAY", "DEFAULT_TRACK_COST", "DEFAULT_MAX_BUDGET", - "DEFAULT_DOTENV_PATH", - # Data Processing "DEFAULT_DROP_TARGET_COLUMN", "DEFAULT_PANDAS_SCORE_FILTER", - # Schema Configuration "DEFAULT_SCHEMA_PATH", "DEFAULT_PROMPT_TEMPLATE", "DEFAULT_SYSTEM_PROMPT", - # Experiment Management "DEFAULT_EXPERIMENT_DIR", "DEFAULT_OVERWRITE_EXPERIMENT", "DEFAULT_AUTO_CHECKPOINT_AND_RESUME", - # Semantic Cache "DEFAULT_SEMANTIC_CACHE_BACKEND", "DEFAULT_SEMANTIC_CACHE_PATH", "DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB", "DEFAULT_SEMANTIC_CACHE_SYNCHRONOUS", - # System Constants "SYSTEM_FILE_NAME_COLUMN", "SYSTEM_RAW_DATA_COLUMN", @@ -142,10 +128,8 @@ "SYSTEM_ERRORS_COLUMN", "SYSTEM_EXTRACTED_DATA_JSON_COLUMN", "SYSTEM_RANDOM_SEED", - # File and Directory Constants "DATA_DIR_NAME", - "CACHE_DIR_NAME", "PROCESSING_CACHE_DIR_NAME", "BATCH_FILE_PREFIX", "BATCH_FILE_SUFFIX", @@ -157,10 +141,8 @@ "PREPROCESSED_DATA_SUFFIX", "META_DATA_PREFIX", "META_DATA_SUFFIX", - # Utility Constants "IGNORE_FILES", - # Logging "configure_logging", -] \ No newline at end of file +] diff --git a/src/delm/config.py b/src/delm/config.py index ae07054..c6e4f47 100644 --- a/src/delm/config.py +++ b/src/delm/config.py @@ -12,43 +12,37 @@ from typing import Any, Dict, Optional, Union, TypeVar import yaml -T = TypeVar('T', bound='BaseConfig') +T = TypeVar("T", bound="BaseConfig") from delm.strategies import RelevanceScorer, KeywordScorer, FuzzyScorer from delm.strategies import SplitStrategy, ParagraphSplit, FixedWindowSplit, RegexSplit from delm.constants import ( # LLM/API Configuration - DEFAULT_PROVIDER, - DEFAULT_MODEL_NAME, + DEFAULT_PROVIDER, + DEFAULT_MODEL_NAME, DEFAULT_TEMPERATURE, - DEFAULT_MAX_RETRIES, - DEFAULT_BASE_DELAY, - DEFAULT_BATCH_SIZE, + DEFAULT_MAX_RETRIES, + DEFAULT_BASE_DELAY, + DEFAULT_BATCH_SIZE, DEFAULT_MAX_WORKERS, - DEFAULT_TRACK_COST, - DEFAULT_MAX_BUDGET, - DEFAULT_DOTENV_PATH, - + DEFAULT_TRACK_COST, + DEFAULT_MAX_BUDGET, # Data Processing # Splitting DEFAULT_FIXED_WINDOW_SIZE, DEFAULT_FIXED_WINDOW_STRIDE, DEFAULT_REGEX_PATTERN, - - DEFAULT_DROP_TARGET_COLUMN, - DEFAULT_PANDAS_SCORE_FILTER, - + DEFAULT_DROP_TARGET_COLUMN, + DEFAULT_PANDAS_SCORE_FILTER, # Schema Configuration - DEFAULT_SCHEMA_PATH, - DEFAULT_PROMPT_TEMPLATE, + DEFAULT_SCHEMA_PATH, + DEFAULT_PROMPT_TEMPLATE, DEFAULT_SYSTEM_PROMPT, - # Semantic Cache - DEFAULT_SEMANTIC_CACHE_BACKEND, + DEFAULT_SEMANTIC_CACHE_BACKEND, DEFAULT_SEMANTIC_CACHE_PATH, - DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB, + DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB, DEFAULT_SEMANTIC_CACHE_SYNCHRONOUS, - # System Constants SYSTEM_RAW_DATA_COLUMN, DEFAULT_FIXED_WINDOW_SIZE, @@ -63,18 +57,18 @@ class BaseConfig: Subclasses should implement ``validate`` and ``to_dict`` to provide strict validation and stable serialization. """ - + def validate(self): """Validate configuration. Subclasses should raise ``ValueError`` when fields are invalid. """ pass - + def to_dict(self) -> dict: """Convert configuration to a serializable dictionary.""" return {} - + @classmethod def from_dict(cls: type[T], data: Dict[str, Any]) -> T: """Create configuration instance from a dictionary.""" @@ -84,6 +78,7 @@ def from_dict(cls: type[T], data: Dict[str, Any]) -> T: @dataclass class LLMExtractionConfig(BaseConfig): """Configuration for the LLM extraction process.""" + provider: str = DEFAULT_PROVIDER name: str = DEFAULT_MODEL_NAME temperature: float = DEFAULT_TEMPERATURE @@ -91,7 +86,6 @@ class LLMExtractionConfig(BaseConfig): batch_size: int = DEFAULT_BATCH_SIZE max_workers: int = DEFAULT_MAX_WORKERS base_delay: float = DEFAULT_BASE_DELAY - dotenv_path: Optional[Union[str, Path]] = DEFAULT_DOTENV_PATH track_cost: bool = DEFAULT_TRACK_COST max_budget: Optional[float] = DEFAULT_MAX_BUDGET model_input_cost_per_1M_tokens: Optional[float] = None @@ -139,10 +133,6 @@ def validate(self): raise ValueError( f"base_delay must be non-negative. base_delay: {self.base_delay}, Suggestion: Use a non-negative float" ) - if self.dotenv_path is not None and not Path(self.dotenv_path).exists(): - raise ValueError( - f"dotenv_path does not exist: {self.dotenv_path}, Suggestion: Check the file path or create the .env file" - ) if not isinstance(self.track_cost, bool): raise ValueError( f"track_cost must be a boolean. track_cost: {self.track_cost}, Suggestion: Use True or False" @@ -166,7 +156,6 @@ def to_dict(self) -> dict: "batch_size": self.batch_size, "max_workers": self.max_workers, "base_delay": self.base_delay, - "dotenv_path": str(self.dotenv_path) if self.dotenv_path else None, "track_cost": self.track_cost, "max_budget": self.max_budget, "model_input_cost_per_1M_tokens": self.model_input_cost_per_1M_tokens, @@ -177,6 +166,7 @@ def to_dict(self) -> dict: @dataclass class SplittingConfig(BaseConfig): """Configuration for text splitting strategy.""" + strategy: Optional[SplitStrategy] = field(default=None) def validate(self): @@ -226,12 +216,15 @@ def _create_strategy(cfg: Dict[str, Any]) -> Optional[SplitStrategy]: """ if cfg == {} or cfg is None: return None - + split_type = cfg.get("type", None) if split_type == "ParagraphSplit": return ParagraphSplit() elif split_type == "FixedWindowSplit": - return FixedWindowSplit(cfg.get("window", DEFAULT_FIXED_WINDOW_SIZE), cfg.get("stride", DEFAULT_FIXED_WINDOW_STRIDE)) + return FixedWindowSplit( + cfg.get("window", DEFAULT_FIXED_WINDOW_SIZE), + cfg.get("stride", DEFAULT_FIXED_WINDOW_STRIDE), + ) elif split_type == "RegexSplit": return RegexSplit(cfg.get("pattern", DEFAULT_REGEX_PATTERN)) elif split_type in ("None", None): @@ -239,13 +232,17 @@ def _create_strategy(cfg: Dict[str, Any]) -> Optional[SplitStrategy]: else: raise ValueError( f"Unknown split strategy: {split_type}", - {"split_type": split_type, "suggestion": "Use 'ParagraphSplit', 'FixedWindowSplit', 'RegexSplit', or 'None'"} + { + "split_type": split_type, + "suggestion": "Use 'ParagraphSplit', 'FixedWindowSplit', 'RegexSplit', or 'None'", + }, ) @dataclass class ScoringConfig(BaseConfig): """Configuration for relevance scoring strategy.""" + scorer: Optional[RelevanceScorer] = field(default=None) def validate(self): @@ -291,7 +288,7 @@ def _create_scorer(cfg: Dict[str, Any]) -> Optional[RelevanceScorer]: """ if cfg == {} or cfg is None: return None - + scorer_type = cfg.get("type", None) if scorer_type == "KeywordScorer": keywords = cfg.get("keywords", []) @@ -318,10 +315,15 @@ def _create_scorer(cfg: Dict[str, Any]) -> Optional[RelevanceScorer]: @dataclass class DataPreprocessingConfig(BaseConfig): """Configuration for the data preprocessing pipeline.""" + target_column: str = SYSTEM_RAW_DATA_COLUMN drop_target_column: bool = DEFAULT_DROP_TARGET_COLUMN - splitting: SplittingConfig = field(default_factory=SplittingConfig) # use default factory because these types are mutable - scoring: ScoringConfig = field(default_factory=ScoringConfig) # use default factory because these types are mutable + splitting: SplittingConfig = field( + default_factory=SplittingConfig + ) # use default factory because these types are mutable + scoring: ScoringConfig = field( + default_factory=ScoringConfig + ) # use default factory because these types are mutable pandas_score_filter: Optional[str] = DEFAULT_PANDAS_SCORE_FILTER preprocessed_data_path: Optional[str] = None _explicitly_set_fields: set = field(default_factory=set, init=False) @@ -337,7 +339,7 @@ def validate(self): self._validate_preprocessed_data_path() self._validate_no_conflicts_with_preprocessed_data() return - + self._validate_basic_fields() self.splitting.validate() self.scoring.validate() @@ -350,18 +352,22 @@ def _validate_preprocessed_data_path(self): """ if self.preprocessed_data_path is None: return - + if not self.preprocessed_data_path.endswith(".feather"): raise ValueError( f"preprocessed_data_path must be a feather file. preprocessed_data_path: {self.preprocessed_data_path}, Suggestion: Provide a valid feather file path" ) - + # Verify file has correct columns import pandas as pd from .constants import SYSTEM_CHUNK_COLUMN, SYSTEM_CHUNK_ID_COLUMN + try: df = pd.read_feather(self.preprocessed_data_path) - if not all(col in df.columns for col in [SYSTEM_CHUNK_COLUMN, SYSTEM_CHUNK_ID_COLUMN]): + if not all( + col in df.columns + for col in [SYSTEM_CHUNK_COLUMN, SYSTEM_CHUNK_ID_COLUMN] + ): raise ValueError( f"preprocessed_data_path must have the correct columns. preprocessed_data_path: {self.preprocessed_data_path}, Suggestion: Provide a valid feather file path with the correct columns" ) @@ -387,7 +393,7 @@ def _validate_no_conflicts_with_preprocessed_data(self): conflicting.append("splitting") if self.scoring.scorer is not None: conflicting.append("scoring") - + if conflicting: raise ValueError( f"Cannot specify {', '.join(conflicting)} when preprocessed_data_path is set. preprocessed_data_path: {self.preprocessed_data_path}, Suggestion: Remove other data fields when using preprocessed_data_path." @@ -415,6 +421,7 @@ def _validate_basic_fields(self): # Validate pandas query syntax import pandas as pd from .constants import SYSTEM_SCORE_COLUMN + try: pd.DataFrame({SYSTEM_SCORE_COLUMN: [1]}).query(self.pandas_score_filter) except Exception as e: @@ -430,7 +437,7 @@ def to_dict(self) -> dict: """ if self.preprocessed_data_path: return {"preprocessed_data_path": self.preprocessed_data_path} - + return { "target_column": self.target_column, "drop_target_column": self.drop_target_column, @@ -454,13 +461,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "DataPreprocessingConfig": """ # Track explicitly set fields explicitly_set_fields = set(data.keys()) - + instance = cls( target_column=data.get("target_column", SYSTEM_RAW_DATA_COLUMN), - drop_target_column=data.get("drop_target_column", DEFAULT_DROP_TARGET_COLUMN), + drop_target_column=data.get( + "drop_target_column", DEFAULT_DROP_TARGET_COLUMN + ), splitting=SplittingConfig.from_dict(data.get("splitting", {})), scoring=ScoringConfig.from_dict(data.get("scoring", {})), - pandas_score_filter=data.get("pandas_score_filter", DEFAULT_PANDAS_SCORE_FILTER), + pandas_score_filter=data.get( + "pandas_score_filter", DEFAULT_PANDAS_SCORE_FILTER + ), preprocessed_data_path=data.get("preprocessed_data_path", None), ) instance._explicitly_set_fields = explicitly_set_fields @@ -478,6 +489,7 @@ class SchemaConfig(BaseConfig): The actual schema definition (including container_name) is stored in the separate schema_spec.yaml file. """ + spec_path: Optional[Union[str, Path]] = DEFAULT_SCHEMA_PATH prompt_template: str = DEFAULT_PROMPT_TEMPLATE system_prompt: str = DEFAULT_SYSTEM_PROMPT @@ -522,21 +534,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "SchemaConfig": """Construct a ``SchemaConfig`` from a mapping.""" if data is None: data = {} - + spec_path = data.get("spec_path", "") if isinstance(spec_path, str): spec_path = Path(spec_path) - + return cls( spec_path=spec_path, prompt_template=data.get("prompt_template", DEFAULT_PROMPT_TEMPLATE), - system_prompt=data.get("system_prompt", DEFAULT_SYSTEM_PROMPT) + system_prompt=data.get("system_prompt", DEFAULT_SYSTEM_PROMPT), ) @dataclass class SemanticCacheConfig(BaseConfig): """Persistent semantic‑cache settings.""" + backend: str = DEFAULT_SEMANTIC_CACHE_BACKEND path: Union[str, Path] = DEFAULT_SEMANTIC_CACHE_PATH max_size_mb: int = DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB @@ -579,7 +592,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SemanticCacheConfig": """Construct a ``SemanticCacheConfig`` from a mapping.""" if data is None: data = {} - + return cls( backend=data.get("backend", DEFAULT_SEMANTIC_CACHE_BACKEND), path=data.get("path", DEFAULT_SEMANTIC_CACHE_PATH), @@ -600,6 +613,7 @@ class DELMConfig(BaseConfig): - A single pipeline config file (config.yaml) that references a schema file - Separate pipeline config and schema spec files """ + llm_extraction: LLMExtractionConfig data_preprocessing: DataPreprocessingConfig schema: SchemaConfig @@ -625,46 +639,50 @@ def to_serialized_schema_spec_dict(self) -> dict: """Load and return the schema spec as a dictionary (schema_spec.yaml).""" import yaml import json - + path = self.schema.spec_path if path is None: raise ValueError("Schema spec path is None") - + if isinstance(path, str): path = Path(path) - + if not path.exists(): raise FileNotFoundError(f"Schema spec file does not exist: {path}") - + if path.suffix.lower() in {".yml", ".yaml"}: return yaml.safe_load(path.read_text()) or {} elif path.suffix.lower() == ".json": return json.loads(path.read_text()) else: raise ValueError(f"Unsupported schema file format: {path.suffix}") - # Backward compatibility aliases def to_dict(self) -> dict: """Alias for ``to_serialized_config_dict`` for backward compatibility.""" return self.to_serialized_config_dict() - @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DELMConfig": """Create ``DELMConfig`` from a mapping.""" if data is None: data = {} - + return cls( - llm_extraction=LLMExtractionConfig.from_dict(data.get("llm_extraction", {})), - data_preprocessing=DataPreprocessingConfig.from_dict(data.get("data_preprocessing", {})), + llm_extraction=LLMExtractionConfig.from_dict( + data.get("llm_extraction", {}) + ), + data_preprocessing=DataPreprocessingConfig.from_dict( + data.get("data_preprocessing", {}) + ), schema=SchemaConfig.from_dict(data.get("schema", {})), - semantic_cache=SemanticCacheConfig.from_dict(data.get("semantic_cache", {})), + semantic_cache=SemanticCacheConfig.from_dict( + data.get("semantic_cache", {}) + ), ) @classmethod - def from_yaml(cls, path: Path) -> "DELMConfig": + def from_yaml(cls, path: Union[str, Path]) -> "DELMConfig": """Create ``DELMConfig`` from a pipeline config YAML file. Args: @@ -676,14 +694,14 @@ def from_yaml(cls, path: Path) -> "DELMConfig": Raises: FileNotFoundError: If the file does not exist. """ + if isinstance(path, str): + path = Path(path) if not path.exists(): - raise FileNotFoundError( - f"YAML config file does not exist: {path}" - ) - - with open(path, "r") as f: + raise FileNotFoundError(f"YAML config file does not exist: {path}") + + with path.open("r") as f: data = yaml.safe_load(f) - + return cls.from_dict(data) @staticmethod @@ -708,4 +726,6 @@ def from_any( elif isinstance(config_like, dict): return DELMConfig.from_dict(config_like) else: - raise ValueError(f"config must be a DELMConfig, dict, or path to YAML. config_type: {type(config_like).__name__}") \ No newline at end of file + raise ValueError( + f"config must be a DELMConfig, dict, or path to YAML. config_type: {type(config_like).__name__}" + ) diff --git a/src/delm/constants.py b/src/delm/constants.py index 1418a0a..a923175 100644 --- a/src/delm/constants.py +++ b/src/delm/constants.py @@ -13,24 +13,21 @@ # ============================================================================= # Provider and Model Settings -DEFAULT_PROVIDER = "openai" # LLM provider (openai, anthropic, google, etc.) -DEFAULT_MODEL_NAME = "gpt-4o-mini" # LLM model name -DEFAULT_TEMPERATURE = 0.0 # Temperature for LLM responses (0.0 = deterministic) +DEFAULT_PROVIDER = "openai" # LLM provider (openai, anthropic, google, etc.) +DEFAULT_MODEL_NAME = "gpt-4o-mini" # LLM model name +DEFAULT_TEMPERATURE = 0.0 # Temperature for LLM responses (0.0 = deterministic) # API Request Settings -DEFAULT_MAX_RETRIES = 3 # Maximum retry attempts for failed API calls -DEFAULT_BASE_DELAY = 1.0 # Base delay between retries (seconds) +DEFAULT_MAX_RETRIES = 3 # Maximum retry attempts for failed API calls +DEFAULT_BASE_DELAY = 1.0 # Base delay between retries (seconds) # Processing Settings -DEFAULT_BATCH_SIZE = 10 # Number of records to process in each batch -DEFAULT_MAX_WORKERS = 1 # Number of concurrent worker processes +DEFAULT_BATCH_SIZE = 10 # Number of records to process in each batch +DEFAULT_MAX_WORKERS = 1 # Number of concurrent worker processes # Cost and Budget Settings -DEFAULT_TRACK_COST = True # Whether to track API call costs -DEFAULT_MAX_BUDGET = None # Maximum budget limit (None = no limit) - -# Environment Settings -DEFAULT_DOTENV_PATH = None # Path to .env file +DEFAULT_TRACK_COST = True # Whether to track API call costs +DEFAULT_MAX_BUDGET = None # Maximum budget limit (None = no limit) # ============================================================================= # DATA PROCESSING DEFAULTS @@ -38,14 +35,16 @@ ## Splitting Defaults # FixedWindowSplit -DEFAULT_FIXED_WINDOW_SIZE = 5 # Number of sentences per chunk -DEFAULT_FIXED_WINDOW_STRIDE = 5 # Number of sentences to overlap +DEFAULT_FIXED_WINDOW_SIZE = 5 # Number of sentences per chunk +DEFAULT_FIXED_WINDOW_STRIDE = 5 # Number of sentences to overlap # RegexSplit -DEFAULT_REGEX_PATTERN = "\n\n" # Regex pattern to split on +DEFAULT_REGEX_PATTERN = "\n\n" # Regex pattern to split on # Column and Data Settings -DEFAULT_DROP_TARGET_COLUMN = False # Whether to drop the target column after processing -DEFAULT_PANDAS_SCORE_FILTER = None # Pandas query string for filtering by score (None = no filter) +DEFAULT_DROP_TARGET_COLUMN = False # Whether to drop the target column after processing +DEFAULT_PANDAS_SCORE_FILTER = ( + None # Pandas query string for filtering by score (None = no filter) +) # Extraction Settings DEFAULT_EXPLODE_JSON_RESULTS = False # Whether to convert extracted JSON to DataFrame @@ -55,7 +54,7 @@ # ============================================================================= # Schema File Settings -DEFAULT_SCHEMA_PATH = None # Default path to schema specification file +DEFAULT_SCHEMA_PATH = None # Default path to schema specification file # Prompt Settings DEFAULT_PROMPT_TEMPLATE = """Extract the following information from the text: @@ -73,19 +72,25 @@ # EXPERIMENT MANAGEMENT DEFAULTS # ============================================================================= -DEFAULT_EXPERIMENT_DIR = Path("delm_experiments") # Default directory for experiment outputs -DEFAULT_OVERWRITE_EXPERIMENT = False # Whether to overwrite existing experiments -DEFAULT_AUTO_CHECKPOINT_AND_RESUME = True # Whether to automatically checkpoint and resume +DEFAULT_EXPERIMENT_DIR = Path( + "delm_experiments" +) # Default directory for experiment outputs +DEFAULT_OVERWRITE_EXPERIMENT = False # Whether to overwrite existing experiments +DEFAULT_AUTO_CHECKPOINT_AND_RESUME = ( + True # Whether to automatically checkpoint and resume +) # ============================================================================= # SEMANTIC CACHE DEFAULTS # ============================================================================= # Cache Backend Settings -DEFAULT_SEMANTIC_CACHE_BACKEND = "sqlite" # Cache backend: "sqlite" | "lmdb" | "filesystem" -DEFAULT_SEMANTIC_CACHE_PATH = ".delm_cache" # Cache directory path -DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB = 512 # Maximum cache size before pruning -DEFAULT_SEMANTIC_CACHE_SYNCHRONOUS = "normal" # SQLite sync mode: "normal" | "full" +DEFAULT_SEMANTIC_CACHE_BACKEND = ( + "sqlite" # Cache backend: "sqlite" | "lmdb" | "filesystem" +) +DEFAULT_SEMANTIC_CACHE_PATH = ".delm_cache" # Cache directory path +DEFAULT_SEMANTIC_CACHE_MAX_SIZE_MB = 512 # Maximum cache size before pruning +DEFAULT_SEMANTIC_CACHE_SYNCHRONOUS = "normal" # SQLite sync mode: "normal" | "full" # ============================================================================= # SYSTEM CONSTANTS (Internal Use Only) @@ -94,58 +99,61 @@ # They should NOT be used in user data or configuration. # System Column Names -SYSTEM_FILE_NAME_COLUMN = "delm_file_name" # Column for source file names -SYSTEM_RAW_DATA_COLUMN = "delm_raw_data" # Column for original text data -SYSTEM_RECORD_ID_COLUMN = "delm_record_id" # Column for internal unique record IDs -SYSTEM_CHUNK_COLUMN = "delm_text_chunk" # Column for text chunks -SYSTEM_CHUNK_ID_COLUMN = "delm_chunk_id" # Column for internal chunk IDs -SYSTEM_SCORE_COLUMN = "delm_score" # Column for relevance scores -SYSTEM_BATCH_ID_COLUMN = "delm_batch_id" # Column for batch IDs -SYSTEM_ERRORS_COLUMN = "delm_errors" # Column for error messages +SYSTEM_FILE_NAME_COLUMN = "delm_file_name" # Column for source file names +SYSTEM_RAW_DATA_COLUMN = "delm_raw_data" # Column for original text data +SYSTEM_RECORD_ID_COLUMN = "delm_record_id" # Column for internal unique record IDs +SYSTEM_CHUNK_COLUMN = "delm_text_chunk" # Column for text chunks +SYSTEM_CHUNK_ID_COLUMN = "delm_chunk_id" # Column for internal chunk IDs +SYSTEM_SCORE_COLUMN = "delm_score" # Column for relevance scores +SYSTEM_BATCH_ID_COLUMN = "delm_batch_id" # Column for batch IDs +SYSTEM_ERRORS_COLUMN = "delm_errors" # Column for error messages # Data Storage Columns -SYSTEM_EXTRACTED_DATA_JSON_COLUMN = "delm_extracted_data_json" # Column for extracted JSON data +SYSTEM_EXTRACTED_DATA_JSON_COLUMN = ( + "delm_extracted_data_json" # Column for extracted JSON data +) # System Behavior Constants -SYSTEM_RANDOM_SEED = 42 # Random seed for reproducibility +SYSTEM_RANDOM_SEED = 42 # Random seed for reproducibility # ============================================================================= # FILE AND DIRECTORY CONSTANTS # ============================================================================= # Directory Names -DATA_DIR_NAME = "delm_data" # Name of data directory -CACHE_DIR_NAME = ".delm_cache" # Name of cache directory -PROCESSING_CACHE_DIR_NAME = "llm_processing" # Name of processing cache subdirectory +DATA_DIR_NAME = "delm_data" # Name of data directory +PROCESSING_CACHE_DIR_NAME = ( + "delm_llm_processing" # Name of processing cache subdirectory +) # File Naming Patterns -BATCH_FILE_PREFIX = "batch_" # Prefix for batch files -BATCH_FILE_SUFFIX = ".feather" # Suffix for batch files -BATCH_FILE_DIGITS = 6 # Number of digits in batch file names +BATCH_FILE_PREFIX = "batch_" # Prefix for batch files +BATCH_FILE_SUFFIX = ".feather" # Suffix for batch files +BATCH_FILE_DIGITS = 6 # Number of digits in batch file names # State and Result Files -STATE_FILE_NAME = "state.json" # Name of state file -CONSOLIDATED_RESULT_PREFIX = "extraction_result_" # Prefix for consolidated results -CONSOLIDATED_RESULT_SUFFIX = ".feather" # Suffix for consolidated results +STATE_FILE_NAME = "state.json" # Name of state file +CONSOLIDATED_RESULT_PREFIX = "extraction_result_" # Prefix for consolidated results +CONSOLIDATED_RESULT_SUFFIX = ".feather" # Suffix for consolidated results # Preprocessed Data Files -PREPROCESSED_DATA_PREFIX = "preprocessed_" # Prefix for preprocessed data files -PREPROCESSED_DATA_SUFFIX = ".feather" # Suffix for preprocessed data files +PREPROCESSED_DATA_PREFIX = "preprocessed_" # Prefix for preprocessed data files +PREPROCESSED_DATA_SUFFIX = ".feather" # Suffix for preprocessed data files # Metadata Files -META_DATA_PREFIX = "meta_data_" # Prefix for metadata files -META_DATA_SUFFIX = ".feather" # Suffix for metadata files +META_DATA_PREFIX = "meta_data_" # Prefix for metadata files +META_DATA_SUFFIX = ".feather" # Suffix for metadata files # ============================================================================= # LOGGING CONSTANTS # ============================================================================= # Logging Settings -DEFAULT_LOG_DIR = "delm_logs" # Default directory for log files -SYSTEM_LOG_FILE_PREFIX = "delm_" # Default prefix for log files -SYSTEM_LOG_FILE_SUFFIX = ".log" # Default suffix for log files -DEFAULT_CONSOLE_LOG_LEVEL = "INFO" # Default console log level -DEFAULT_FILE_LOG_LEVEL = "DEBUG" # Default file log level +DEFAULT_LOG_DIR = "delm_logs" # Default directory for log files +SYSTEM_LOG_FILE_PREFIX = "delm_" # Default prefix for log files +SYSTEM_LOG_FILE_SUFFIX = ".log" # Default suffix for log files +DEFAULT_CONSOLE_LOG_LEVEL = "INFO" # Default console log level +DEFAULT_FILE_LOG_LEVEL = "DEBUG" # Default file log level # ============================================================================= # UTILITY CONSTANTS @@ -153,7 +161,7 @@ # Files to Ignore IGNORE_FILES = [ - ".DS_Store", # macOS system files + ".DS_Store", # macOS system files ] LLM_NULL_WORDS_LOWERCASE = [ @@ -162,4 +170,4 @@ "unknown", "n/a", "", -] \ No newline at end of file +] diff --git a/src/delm/core/experiment_manager.py b/src/delm/core/experiment_manager.py index 1adcb02..8332004 100644 --- a/src/delm/core/experiment_manager.py +++ b/src/delm/core/experiment_manager.py @@ -21,7 +21,6 @@ from delm.config import DELMConfig from delm.constants import ( DATA_DIR_NAME, - CACHE_DIR_NAME, PROCESSING_CACHE_DIR_NAME, BATCH_FILE_PREFIX, BATCH_FILE_SUFFIX, @@ -242,7 +241,7 @@ def data_dir(self) -> Path: @property def cache_dir(self) -> Path: - d = self.experiment_dir / CACHE_DIR_NAME / PROCESSING_CACHE_DIR_NAME + d = self.experiment_dir / PROCESSING_CACHE_DIR_NAME d.mkdir(parents=True, exist_ok=True) return d diff --git a/src/delm/delm.py b/src/delm/delm.py index c0fced3..cd1ddf9 100644 --- a/src/delm/delm.py +++ b/src/delm/delm.py @@ -6,7 +6,6 @@ import logging import time from pathlib import Path -import dotenv import pandas as pd # Module-level logger @@ -319,19 +318,30 @@ def get_cost_summary(self) -> dict[str, Any]: log.debug("Cost summary retrieved: %s", cost_summary) return cost_summary + def preview_prompt( + self, + text: Optional[str] = None, + ) -> str: + """Preview the compiled prompt for the extraction schema. + + Returns: + A string containing the compiled prompt. + """ + target_column_name = self.config.data_preprocessing.target_column + if text is None: + text = f"<{target_column_name}>" + prompt = self.schema_manager.extraction_schema.create_prompt( + text=text, + prompt_template=self.schema_manager.prompt_template, + ) + return prompt + ## ------------------------------ Private API ------------------------------- ## def _initialize_components(self) -> None: """Initialize all components using composition.""" log.debug("Initializing DELM components") - # Environment & secrets -------------------------------------------- # - if self.config.llm_extraction.dotenv_path: - log.debug( - "Loading environment from %s", self.config.llm_extraction.dotenv_path - ) - dotenv.load_dotenv(self.config.llm_extraction.dotenv_path) - # Initialize components log.debug("Initializing data processor") self.data_processor = DataProcessor(self.config.data_preprocessing) diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py index 5c0c4d9..02470b8 100644 --- a/tests/unit/config/test_config.py +++ b/tests/unit/config/test_config.py @@ -38,7 +38,6 @@ DEFAULT_BASE_DELAY, DEFAULT_TRACK_COST, DEFAULT_MAX_BUDGET, - DEFAULT_DOTENV_PATH, DEFAULT_FIXED_WINDOW_SIZE, DEFAULT_FIXED_WINDOW_STRIDE, DEFAULT_REGEX_PATTERN, @@ -58,9 +57,6 @@ ) - - - class TestLLMExtractionConfig: """Test LLM extraction configuration.""" @@ -74,7 +70,6 @@ def test_initialization_defaults(self): assert config.batch_size == DEFAULT_BATCH_SIZE assert config.max_workers == DEFAULT_MAX_WORKERS assert config.base_delay == DEFAULT_BASE_DELAY - assert config.dotenv_path == DEFAULT_DOTENV_PATH assert config.track_cost == DEFAULT_TRACK_COST assert config.max_budget == DEFAULT_MAX_BUDGET @@ -183,7 +178,9 @@ def test_validate_invalid_track_cost(self): def test_validate_max_budget_without_track_cost(self): """Test validation when max_budget is set but track_cost is False.""" config = LLMExtractionConfig(track_cost=False, max_budget=100.0) - with pytest.raises(ValueError, match="track_cost must be True if max_budget is specified"): + with pytest.raises( + ValueError, match="track_cost must be True if max_budget is specified" + ): config.validate() def test_validate_invalid_max_budget(self): @@ -216,7 +213,6 @@ def test_to_dict(self): "batch_size": 5, "max_workers": 2, "base_delay": 0.5, - "dotenv_path": None, "track_cost": True, "max_budget": 50.0, "model_input_cost_per_1M_tokens": 10.0, @@ -274,7 +270,9 @@ def test_validate_valid_config(self): def test_validate_invalid_strategy(self): """Test validation with invalid strategy.""" config = SplittingConfig(strategy="invalid") - with pytest.raises(ValueError, match="strategy must be a SplitStrategy instance"): + with pytest.raises( + ValueError, match="strategy must be a SplitStrategy instance" + ): config.validate() def test_to_dict_no_strategy(self): @@ -379,7 +377,9 @@ def test_validate_valid_config(self): def test_validate_invalid_scorer(self): """Test validation with invalid scorer.""" config = ScoringConfig(scorer="invalid") - with pytest.raises(ValueError, match="scorer must be a RelevanceScorer instance"): + with pytest.raises( + ValueError, match="scorer must be a RelevanceScorer instance" + ): config.validate() def test_to_dict_no_scorer(self): @@ -423,7 +423,9 @@ def test_from_dict_keyword_scorer_empty_keywords(self): "type": "KeywordScorer", "keywords": [], } - with pytest.raises(ValueError, match="KeywordScorer requires a non-empty keywords list"): + with pytest.raises( + ValueError, match="KeywordScorer requires a non-empty keywords list" + ): ScoringConfig.from_dict(data) def test_from_dict_fuzzy_scorer(self): @@ -442,7 +444,9 @@ def test_from_dict_fuzzy_scorer_empty_keywords(self): "type": "FuzzyScorer", "keywords": [], } - with pytest.raises(ValueError, match="FuzzyScorer requires a non-empty keywords list"): + with pytest.raises( + ValueError, match="FuzzyScorer requires a non-empty keywords list" + ): ScoringConfig.from_dict(data) def test_from_dict_unknown_scorer(self): @@ -486,11 +490,15 @@ def test_validate_valid_config(self): def test_validate_invalid_target_column(self): """Test validation with invalid target_column.""" config = DataPreprocessingConfig(target_column="") - with pytest.raises(ValueError, match="target_column must be a non-empty string"): + with pytest.raises( + ValueError, match="target_column must be a non-empty string" + ): config.validate() config = DataPreprocessingConfig(target_column=123) - with pytest.raises(ValueError, match="target_column must be a non-empty string"): + with pytest.raises( + ValueError, match="target_column must be a non-empty string" + ): config.validate() def test_validate_invalid_drop_target_column(self): @@ -502,73 +510,90 @@ def test_validate_invalid_drop_target_column(self): def test_validate_invalid_pandas_score_filter(self): """Test validation with invalid pandas_score_filter.""" config = DataPreprocessingConfig(pandas_score_filter=123) - with pytest.raises(ValueError, match="pandas_score_filter must be a string or None"): + with pytest.raises( + ValueError, match="pandas_score_filter must be a string or None" + ): config.validate() def test_validate_invalid_pandas_query(self): """Test validation with invalid pandas query.""" config = DataPreprocessingConfig(pandas_score_filter="invalid query") - with pytest.raises(ValueError, match="pandas_score_filter is not a valid pandas query"): + with pytest.raises( + ValueError, match="pandas_score_filter is not a valid pandas query" + ): config.validate() def test_validate_valid_pandas_query(self): """Test validation with valid pandas query.""" - config = DataPreprocessingConfig(pandas_score_filter=f"{SYSTEM_SCORE_COLUMN} > 0.5") + config = DataPreprocessingConfig( + pandas_score_filter=f"{SYSTEM_SCORE_COLUMN} > 0.5" + ) config.validate() # Should not raise def test_validate_preprocessed_data_path_not_feather(self): """Test validation with non-feather preprocessed data path.""" config = DataPreprocessingConfig(preprocessed_data_path="data.csv") - with pytest.raises(ValueError, match="preprocessed_data_path must be a feather file"): + with pytest.raises( + ValueError, match="preprocessed_data_path must be a feather file" + ): config.validate() def test_validate_preprocessed_data_path_missing_columns(self): """Test validation with preprocessed data missing required columns.""" # Create a temporary feather file with wrong columns import pandas as pd - + with tempfile.NamedTemporaryFile(suffix=".feather", delete=False) as f: temp_path = f.name - + try: # Create a DataFrame with wrong columns and save as feather df = pd.DataFrame({"wrong_column": [1]}) df.to_feather(temp_path) - + config = DataPreprocessingConfig(preprocessed_data_path=temp_path) - with pytest.raises(ValueError, match="Failed to read preprocessed data file"): + with pytest.raises( + ValueError, match="Failed to read preprocessed data file" + ): config.validate() finally: Path(temp_path).unlink() - @patch('pandas.read_feather') + @patch("pandas.read_feather") def test_validate_preprocessed_data_path_valid(self, mock_read_feather): """Test validation with valid preprocessed data path.""" - mock_df = pd.DataFrame({ - SYSTEM_CHUNK_COLUMN: ["chunk1"], - SYSTEM_CHUNK_ID_COLUMN: [1], - }) + mock_df = pd.DataFrame( + { + SYSTEM_CHUNK_COLUMN: ["chunk1"], + SYSTEM_CHUNK_ID_COLUMN: [1], + } + ) mock_read_feather.return_value = mock_df - + config = DataPreprocessingConfig(preprocessed_data_path="data.feather") config.validate() # Should not raise - @patch('pandas.read_feather') + @patch("pandas.read_feather") def test_validate_preprocessed_data_conflicts(self, mock_read_feather): """Test validation when preprocessed data conflicts with other settings.""" - mock_df = pd.DataFrame({ - SYSTEM_CHUNK_COLUMN: ["chunk1"], - SYSTEM_CHUNK_ID_COLUMN: [1], - }) + mock_df = pd.DataFrame( + { + SYSTEM_CHUNK_COLUMN: ["chunk1"], + SYSTEM_CHUNK_ID_COLUMN: [1], + } + ) mock_read_feather.return_value = mock_df - + config = DataPreprocessingConfig( preprocessed_data_path="data.feather", target_column="custom_column", ) config._explicitly_set_fields = {"target_column"} - - with pytest.raises(ValueError, match="Cannot specify target_column when preprocessed_data_path is set"): + + with pytest.raises( + ValueError, + match="Cannot specify target_column when preprocessed_data_path is set", + ): config.validate() def test_to_dict_with_preprocessed_data(self): @@ -638,7 +663,7 @@ def test_validate_valid_config(self): with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: f.write(b"test: data") temp_path = f.name - + try: config = SchemaConfig(spec_path=temp_path) config.validate() # Should not raise @@ -648,11 +673,15 @@ def test_validate_valid_config(self): def test_validate_invalid_spec_path(self): """Test validation with invalid spec_path.""" config = SchemaConfig(spec_path="") - with pytest.raises(ValueError, match="spec_path must be a valid Path or string"): + with pytest.raises( + ValueError, match="spec_path must be a valid Path or string" + ): config.validate() config = SchemaConfig(spec_path=123) - with pytest.raises(ValueError, match="spec_path must be a valid Path or string"): + with pytest.raises( + ValueError, match="spec_path must be a valid Path or string" + ): config.validate() def test_validate_nonexistent_file(self): @@ -666,7 +695,7 @@ def test_validate_invalid_prompt_template(self): with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: f.write(b"test: data") temp_path = f.name - + try: config = SchemaConfig(spec_path=temp_path, prompt_template=123) with pytest.raises(ValueError, match="prompt_template must be a string"): @@ -679,7 +708,7 @@ def test_validate_invalid_system_prompt(self): with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: f.write(b"test: data") temp_path = f.name - + try: config = SchemaConfig(spec_path=temp_path, system_prompt=123) with pytest.raises(ValueError, match="system_prompt must be a string"): @@ -761,27 +790,37 @@ def test_validate_valid_config(self): def test_validate_invalid_backend(self): """Test validation with invalid backend.""" config = SemanticCacheConfig(backend="invalid") - with pytest.raises(ValueError, match="cache.backend must be 'sqlite', 'lmdb', or 'filesystem'"): + with pytest.raises( + ValueError, match="cache.backend must be 'sqlite', 'lmdb', or 'filesystem'" + ): config.validate() def test_validate_invalid_max_size_mb(self): """Test validation with invalid max_size_mb.""" config = SemanticCacheConfig(max_size_mb=0) - with pytest.raises(ValueError, match="cache.max_size_mb must be a positive integer"): + with pytest.raises( + ValueError, match="cache.max_size_mb must be a positive integer" + ): config.validate() config = SemanticCacheConfig(max_size_mb=-1) - with pytest.raises(ValueError, match="cache.max_size_mb must be a positive integer"): + with pytest.raises( + ValueError, match="cache.max_size_mb must be a positive integer" + ): config.validate() config = SemanticCacheConfig(max_size_mb="100") - with pytest.raises(ValueError, match="cache.max_size_mb must be a positive integer"): + with pytest.raises( + ValueError, match="cache.max_size_mb must be a positive integer" + ): config.validate() def test_validate_invalid_synchronous_sqlite(self): """Test validation with invalid synchronous for SQLite.""" config = SemanticCacheConfig(backend="sqlite", synchronous="invalid") - with pytest.raises(ValueError, match="cache.synchronous must be 'normal' or 'full' for SQLite"): + with pytest.raises( + ValueError, match="cache.synchronous must be 'normal' or 'full' for SQLite" + ): config.validate() def test_validate_valid_synchronous_sqlite(self): @@ -841,14 +880,14 @@ def test_initialization(self): data_config = DataPreprocessingConfig() schema_config = SchemaConfig() cache_config = SemanticCacheConfig() - + config = DELMConfig( llm_extraction=llm_config, data_preprocessing=data_config, schema=schema_config, semantic_cache=cache_config, ) - + assert config.llm_extraction == llm_config assert config.data_preprocessing == data_config assert config.schema == schema_config @@ -860,17 +899,17 @@ def test_validate(self): data_config = DataPreprocessingConfig() schema_config = SchemaConfig() cache_config = SemanticCacheConfig() - + config = DELMConfig( llm_extraction=llm_config, data_preprocessing=data_config, schema=schema_config, semantic_cache=cache_config, ) - + # Should not raise if all sub-configs are valid # Note: This will fail if schema.spec_path doesn't exist, so we'll mock it - with patch.object(schema_config, 'validate'): + with patch.object(schema_config, "validate"): config.validate() def test_to_serialized_config_dict(self): @@ -879,14 +918,14 @@ def test_to_serialized_config_dict(self): data_config = DataPreprocessingConfig(target_column="custom_column") schema_config = SchemaConfig(spec_path="test.yaml") cache_config = SemanticCacheConfig(backend="sqlite") - + config = DELMConfig( llm_extraction=llm_config, data_preprocessing=data_config, schema=schema_config, semantic_cache=cache_config, ) - + result = config.to_serialized_config_dict() expected = { "llm_extraction": llm_config.to_dict(), @@ -899,11 +938,11 @@ def test_to_serialized_config_dict(self): def test_to_serialized_schema_spec_dict_yaml(self): """Test to_serialized_schema_spec_dict with YAML file.""" schema_data = {"test": "data", "nested": {"key": "value"}} - - with tempfile.NamedTemporaryFile(suffix=".yaml", mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: yaml.dump(schema_data, f) temp_path = f.name - + try: schema_config = SchemaConfig(spec_path=temp_path) config = DELMConfig( @@ -912,7 +951,7 @@ def test_to_serialized_schema_spec_dict_yaml(self): schema=schema_config, semantic_cache=SemanticCacheConfig(), ) - + result = config.to_serialized_schema_spec_dict() assert result == schema_data finally: @@ -921,12 +960,13 @@ def test_to_serialized_schema_spec_dict_yaml(self): def test_to_serialized_schema_spec_dict_json(self): """Test to_serialized_schema_spec_dict with JSON file.""" import json + schema_data = {"test": "data", "nested": {"key": "value"}} - - with tempfile.NamedTemporaryFile(suffix=".json", mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: json.dump(schema_data, f) temp_path = f.name - + try: schema_config = SchemaConfig(spec_path=temp_path) config = DELMConfig( @@ -935,7 +975,7 @@ def test_to_serialized_schema_spec_dict_json(self): schema=schema_config, semantic_cache=SemanticCacheConfig(), ) - + result = config.to_serialized_schema_spec_dict() assert result == schema_data finally: @@ -950,7 +990,7 @@ def test_to_serialized_schema_spec_dict_none_path(self): schema=schema_config, semantic_cache=SemanticCacheConfig(), ) - + with pytest.raises(ValueError, match="Schema spec path is None"): config.to_serialized_schema_spec_dict() @@ -963,7 +1003,7 @@ def test_to_serialized_schema_spec_dict_nonexistent_file(self): schema=schema_config, semantic_cache=SemanticCacheConfig(), ) - + with pytest.raises(FileNotFoundError, match="Schema spec file does not exist"): config.to_serialized_schema_spec_dict() @@ -972,7 +1012,7 @@ def test_to_serialized_schema_spec_dict_unsupported_format(self): with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: f.write(b"test data") temp_path = f.name - + try: schema_config = SchemaConfig(spec_path=temp_path) config = DELMConfig( @@ -981,7 +1021,7 @@ def test_to_serialized_schema_spec_dict_unsupported_format(self): schema=schema_config, semantic_cache=SemanticCacheConfig(), ) - + with pytest.raises(ValueError, match="Unsupported schema file format"): config.to_serialized_schema_spec_dict() finally: @@ -995,7 +1035,7 @@ def test_to_dict_alias(self): schema=SchemaConfig(), semantic_cache=SemanticCacheConfig(), ) - + result = config.to_dict() expected = config.to_serialized_config_dict() assert result == expected @@ -1017,7 +1057,7 @@ def test_from_dict(self): "backend": "sqlite", }, } - + config = DELMConfig.from_dict(data) assert config.llm_extraction.provider == "openai" assert config.llm_extraction.name == "gpt-4" @@ -1050,11 +1090,11 @@ def test_from_yaml(self): "backend": "lmdb", }, } - - with tempfile.NamedTemporaryFile(suffix=".yaml", mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: yaml.dump(config_data, f) temp_path = f.name - + try: config = DELMConfig.from_yaml(Path(temp_path)) assert config.llm_extraction.provider == "anthropic" @@ -1078,7 +1118,7 @@ def test_from_any_delm_config(self): schema=SchemaConfig(), semantic_cache=SemanticCacheConfig(), ) - + result = DELMConfig.from_any(original_config) assert result is original_config @@ -1090,7 +1130,7 @@ def test_from_any_dict(self): "schema": {"spec_path": "schema.yaml"}, "semantic_cache": {"backend": "sqlite"}, } - + config = DELMConfig.from_any(data) assert isinstance(config, DELMConfig) assert config.llm_extraction.provider == "openai" @@ -1103,11 +1143,11 @@ def test_from_any_yaml_path(self): "schema": {"spec_path": "schema.yaml"}, "semantic_cache": {"backend": "lmdb"}, } - - with tempfile.NamedTemporaryFile(suffix=".yaml", mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: yaml.dump(config_data, f) temp_path = f.name - + try: config = DELMConfig.from_any(temp_path) assert isinstance(config, DELMConfig) @@ -1117,5 +1157,7 @@ def test_from_any_yaml_path(self): def test_from_any_invalid_type(self): """Test from_any with invalid type.""" - with pytest.raises(ValueError, match="config must be a DELMConfig, dict, or path to YAML"): - DELMConfig.from_any(123) \ No newline at end of file + with pytest.raises( + ValueError, match="config must be a DELMConfig, dict, or path to YAML" + ): + DELMConfig.from_any(123) diff --git a/tests/unit/delm_class/__init__.py b/tests/unit/delm_class/__init__.py new file mode 100644 index 0000000..1cdd8ab --- /dev/null +++ b/tests/unit/delm_class/__init__.py @@ -0,0 +1,2 @@ +"""Unit tests for DELM main class.""" + diff --git a/tests/unit/delm_class/test_delm.py b/tests/unit/delm_class/test_delm.py new file mode 100644 index 0000000..34d8b1c --- /dev/null +++ b/tests/unit/delm_class/test_delm.py @@ -0,0 +1,340 @@ +""" +Unit tests for DELM main class. +""" + +import pytest +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock + +from delm import DELM +from delm.config import ( + DELMConfig, + LLMExtractionConfig, + DataPreprocessingConfig, + SchemaConfig, + SemanticCacheConfig, +) + + +class TestDELMPreviewPrompt: + """Test the DELM.preview_prompt method.""" + + @pytest.fixture + def mock_config(self): + """Create a mock DELMConfig.""" + config = Mock(spec=DELMConfig) + + # Mock data_preprocessing config + data_preprocessing = Mock(spec=DataPreprocessingConfig) + data_preprocessing.target_column = "text_column" + config.data_preprocessing = data_preprocessing + + # Mock llm_extraction config + llm_extraction = Mock(spec=LLMExtractionConfig) + llm_extraction.provider = "openai" + llm_extraction.name = "gpt-4" + llm_extraction.track_cost = False + llm_extraction.batch_size = 32 + config.llm_extraction = llm_extraction + + # Mock schema config + schema = Mock(spec=SchemaConfig) + schema.spec_path = "tests/unit/schemas/test_data/simple_schema.yaml" + schema.prompt_template = "Extract the following from {text}: {fields}" + schema.system_prompt = "You are a helpful assistant." + config.schema = schema + + # Mock semantic_cache config + semantic_cache = Mock(spec=SemanticCacheConfig) + semantic_cache.backend = "none" + config.semantic_cache = semantic_cache + + # Mock validate method + config.validate = Mock() + + return config + + @pytest.fixture + def mock_schema_manager(self): + """Create a mock SchemaManager.""" + schema_manager = Mock() + + # Mock extraction schema with create_prompt method + extraction_schema = Mock() + extraction_schema.create_prompt = Mock(return_value="Mocked compiled prompt") + schema_manager.extraction_schema = extraction_schema + + # Mock prompt_template + schema_manager.prompt_template = "Extract the following from {text}: {fields}" + + return schema_manager + + def test_preview_prompt_with_text(self, mock_config, mock_schema_manager): + """Test preview_prompt with custom text provided.""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test with custom text + custom_text = "This is my custom text for extraction" + result = delm.preview_prompt(text=custom_text) + + # Verify create_prompt was called with the custom text + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=custom_text, + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_without_text(self, mock_config, mock_schema_manager): + """Test preview_prompt without text (should use placeholder).""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test without text (should use placeholder) + result = delm.preview_prompt() + + # Verify create_prompt was called with placeholder text + expected_placeholder = f"<{mock_config.data_preprocessing.target_column}>" + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=expected_placeholder, + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_with_none_text(self, mock_config, mock_schema_manager): + """Test preview_prompt with explicit None text.""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test with explicit None + result = delm.preview_prompt(text=None) + + # Verify create_prompt was called with placeholder text + expected_placeholder = f"<{mock_config.data_preprocessing.target_column}>" + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=expected_placeholder, + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_with_empty_string(self, mock_config, mock_schema_manager): + """Test preview_prompt with empty string (should use empty string, not placeholder).""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test with empty string + result = delm.preview_prompt(text="") + + # Verify create_prompt was called with empty string (not placeholder) + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text="", + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_with_multiline_text(self, mock_config, mock_schema_manager): + """Test preview_prompt with multiline text.""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test with multiline text + multiline_text = """This is line 1 +This is line 2 +This is line 3""" + result = delm.preview_prompt(text=multiline_text) + + # Verify create_prompt was called with multiline text + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=multiline_text, + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_with_special_characters( + self, mock_config, mock_schema_manager + ): + """Test preview_prompt with special characters in text.""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test with special characters + special_text = "Text with special chars: @#$%^&*()_+-={}[]|\\:;<>?,./~`" + result = delm.preview_prompt(text=special_text) + + # Verify create_prompt was called with special characters + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=special_text, + prompt_template=mock_schema_manager.prompt_template, + ) + + # Verify the result + assert result == "Mocked compiled prompt" + + def test_preview_prompt_uses_correct_target_column( + self, mock_config, mock_schema_manager + ): + """Test that preview_prompt uses the correct target column from config.""" + # Set a specific target column name + mock_config.data_preprocessing.target_column = "my_custom_column" + + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + # Test without text - should use custom target column in placeholder + result = delm.preview_prompt() + + # Verify create_prompt was called with correct placeholder + expected_placeholder = "" + mock_schema_manager.extraction_schema.create_prompt.assert_called_once_with( + text=expected_placeholder, + prompt_template=mock_schema_manager.prompt_template, + ) + + def test_preview_prompt_returns_string(self, mock_config, mock_schema_manager): + """Test that preview_prompt returns a string.""" + with patch("delm.delm.DataProcessor"), patch( + "delm.delm.SchemaManager", return_value=mock_schema_manager + ), patch("delm.delm.DiskExperimentManager"), patch( + "delm.delm.CostTracker" + ), patch( + "delm.delm.SemanticCacheFactory" + ), patch( + "delm.delm.ExtractionManager" + ), patch( + "delm.delm._configure_logging" + ): + + delm = DELM( + config=mock_config, + experiment_name="test_experiment", + experiment_directory=Path("/tmp/test_experiment"), + override_logging=False, + ) + + result = delm.preview_prompt(text="Test text") + + # Verify result is a string + assert isinstance(result, str) + assert result == "Mocked compiled prompt"