diff --git a/wren/src/wren/cli.py b/wren/src/wren/cli.py index 08a9edd84..ebe765a62 100644 --- a/wren/src/wren/cli.py +++ b/wren/src/wren/cli.py @@ -9,6 +9,8 @@ import typer +from wren.context_cli import context_app + app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False) _WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren")).expanduser() @@ -43,6 +45,9 @@ def _require_mdl(mdl: str | None) -> str: def _load_manifest(mdl: str) -> str: """Load MDL from a file path or treat as base64 string directly.""" path = Path(mdl).expanduser() + if path.suffix.lower() == ".json" and not path.exists(): + typer.echo(f"Error: MDL file not found: {path}", err=True) + raise typer.Exit(1) if path.exists(): import base64 # noqa: PLC0415 @@ -540,11 +545,11 @@ def docs_connection_info( app.add_typer(docs_app) -from wren.context_cli import context_app # noqa: E402, PLC0415 from wren.utils_cli import utils_app # noqa: E402, PLC0415 app.add_typer(context_app) app.add_typer(utils_app) +app.add_typer(context_app) try: import lancedb # noqa: PLC0415, F401 diff --git a/wren/src/wren/context.py b/wren/src/wren/context.py index e2dd43d39..c2b99d5cf 100644 --- a/wren/src/wren/context.py +++ b/wren/src/wren/context.py @@ -543,3 +543,112 @@ def validate_project(project_path: Path) -> list[ValidationError]: ) return errors + + +# ── Semantic validation (view dry-plan + description completeness) ───────── + +_VALID_LEVELS = frozenset({"error", "warning", "strict"}) + +_VALID_LEVELS = frozenset({"error", "warning", "strict"}) + + +def _prop_description(item: dict) -> str | None: + return (item.get("properties") or {}).get("description") + + +def _check_descriptions(manifest: dict, *, strict: bool = False) -> list[str]: + warnings: list[str] = [] + + for model in manifest.get("models", []): + name = model.get("name", "") + if not _prop_description(model): + warnings.append( + f"Model '{name}' has no description — " + "add properties.description to improve memory search and agent comprehension" + ) + if strict: + for col in model.get("columns", []): + col_name = col.get("name", "") + if not _prop_description(col): + warnings.append( + f"Column '{col_name}' in model '{name}' has no description" + ) + + for view in manifest.get("views", []): + view_name = view.get("name", "") + if not _prop_description(view): + warnings.append( + f"View '{view_name}' has no description — " + "views with descriptions are indexed as NL-SQL examples in memory" + ) + + return warnings + + +def validate_manifest( + manifest_str: str, + data_source: str, + *, + level: str = "warning", +) -> dict: + """Semantic validation of a compiled MDL manifest (base64-encoded JSON). + + Args: + manifest_str: Base64-encoded MDL JSON. + data_source: Target data source (used for view dry-plan dialect). + level: Validation level. + "error" — view SQL dry-plan only (CI/CD) + "warning" — + model/view missing description (default) + "strict" — + column missing description + + Returns: + Dict with "errors" (list) and "warnings" (list). + """ + import base64 as _base64 # noqa: PLC0415 + + from wren.engine import WrenEngine # noqa: PLC0415 + from wren.model.data_source import DataSource # noqa: PLC0415 + + errors: list[str] = [] + warnings: list[str] = [] + + if level not in _VALID_LEVELS: + errors.append( + f"Invalid level '{level}' — must be one of: {', '.join(sorted(_VALID_LEVELS))}" + ) + return {"errors": errors, "warnings": warnings} + + try: + manifest = json.loads(_base64.b64decode(manifest_str)) + except Exception as e: + errors.append(f"Failed to decode manifest: {e}") + return {"errors": errors, "warnings": warnings} + + # View SQL dry-plan — always checked (failures are errors) + views = manifest.get("views", []) + if views: + if isinstance(data_source, str): + try: + data_source = DataSource(data_source) + except ValueError: + errors.append(f"Invalid datasource '{data_source}'") + return {"errors": errors, "warnings": warnings} + with WrenEngine( + manifest_str=manifest_str, data_source=data_source, connection_info={} + ) as engine: + for view in views: + name = view.get("name", "") + stmt = (view.get("statement") or "").strip() + if not stmt: + errors.append(f"View '{name}': empty statement") + continue + try: + engine.dry_plan(stmt) + except Exception as e: + errors.append(f"View '{name}': dry-plan failed — {e}") + + # Description checks — only at warning/strict level + if level in ("warning", "strict"): + warnings.extend(_check_descriptions(manifest, strict=(level == "strict"))) + + return {"errors": errors, "warnings": warnings} diff --git a/wren/src/wren/context_cli.py b/wren/src/wren/context_cli.py index a879f612b..11b69ac62 100644 --- a/wren/src/wren/context_cli.py +++ b/wren/src/wren/context_cli.py @@ -134,17 +134,25 @@ def validate( bool, typer.Option("--strict", help="Treat warnings as errors."), ] = False, + level: Annotated[ + str, + typer.Option( + "--level", + help="Semantic check depth: error (dry-plan only), warning (+ descriptions), strict (+ columns).", + ), + ] = "warning", ) -> None: - """Validate MDL project structure (no database required). + """Validate MDL project: YAML structure + view SQL dry-plan + description checks.""" + import base64 as _b64 # noqa: PLC0415 - Checks wren_project.yml, model/view definitions, column types, - relationship integrity, and naming uniqueness. - """ from wren.context import ( # noqa: PLC0415 + build_json, discover_project_path, load_models, + load_project_config, load_relationships, load_views, + validate_manifest, validate_project, ) @@ -154,27 +162,58 @@ def validate( typer.echo(str(e), err=True) raise typer.Exit(1) - errors = validate_project(project_path) + # ── Structural validation ──────────────────────────────────────────── + struct_errors = validate_project(project_path) + struct_warnings = [e for e in struct_errors if e.level == "warning"] + struct_hard = [e for e in struct_errors if e.level == "error"] + + if struct_errors: + for e in struct_errors: + typer.echo(str(e), err=True) + + # ── Semantic validation (dry-plan + description checks) ────────────── + sem_errors: list[str] = [] + sem_warnings: list[str] = [] + try: + config = load_project_config(project_path) + ds_str = config.get("data_source", "") + manifest_json = build_json(project_path) + manifest_str = _b64.b64encode( + json.dumps(manifest_json, ensure_ascii=False).encode() + ).decode() + sem_result = validate_manifest(manifest_str, ds_str, level=level) + sem_errors = sem_result["errors"] + sem_warnings = sem_result["warnings"] + except Exception as e: + sem_errors = [f"Semantic validation failed: {e}"] + + if sem_errors: + typer.echo("\nSemantic errors:") + for msg in sem_errors: + typer.echo(f" \u2717 {msg}", err=True) + + if sem_warnings: + typer.echo("\nWarnings:") + for msg in sem_warnings: + typer.echo(f" \u26a0 {msg}") + + # ── Exit logic ──────────────────────────────────────────────────────── + has_hard_error = bool(struct_hard or sem_errors) + has_warning = bool(struct_warnings or sem_warnings) + + if has_hard_error or (strict and has_warning): + raise typer.Exit(1) - if not errors: + if not struct_errors and not sem_errors and not sem_warnings: models = load_models(project_path) views = load_views(project_path) rels = load_relationships(project_path) typer.echo( f"Valid — {len(models)} models, {len(views)} views, {len(rels)} relationships." ) - return - - warnings = [e for e in errors if e.level == "warning"] - hard_errors = [e for e in errors if e.level == "error"] - - for e in errors: - typer.echo(str(e), err=True) - - if hard_errors or (strict and warnings): - raise typer.Exit(1) - - typer.echo(f"\n{len(warnings)} warning(s), 0 errors.") + elif has_warning: + n_warn = len(struct_warnings) + len(sem_warnings) + typer.echo(f"\n{n_warn} warning(s), 0 errors.") @context_app.command() diff --git a/wren/tests/unit/test_context.py b/wren/tests/unit/test_context.py index 20948b256..044bba3ce 100644 --- a/wren/tests/unit/test_context.py +++ b/wren/tests/unit/test_context.py @@ -538,3 +538,135 @@ def test_discover_via_config(tmp_path, monkeypatch): importlib.reload(ctx) result = ctx.discover_project_path() assert result == project_dir + + +# ── Semantic validation tests (view dry-plan + description checks) ───────── + +import base64 +import json as _json + +import orjson +import pytest + +from wren.context import validate_manifest +from wren.model.data_source import DataSource + + +def _b64(manifest: dict) -> str: + return base64.b64encode(orjson.dumps(manifest)).decode() + + +_SEM_MODEL_WITH_DESC = { + "name": "orders", + "tableReference": {"schema": "main", "table": "orders"}, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + ], + "primaryKey": "o_orderkey", + "properties": {"description": "Orders model"}, +} + +_SEM_MODEL_WITHOUT_DESC = { + "name": "accounts", + "tableReference": {"schema": "main", "table": "accounts"}, + "columns": [ + {"name": "acct_id", "type": "integer"}, + {"name": "plan_cd", "type": "varchar"}, + ], + "primaryKey": "acct_id", +} + +_VALID_VIEW = { + "name": "valid_view", + "statement": 'SELECT o_orderkey FROM "orders"', + "properties": {"description": "A valid view"}, +} + +_VIEW_WITHOUT_DESC = { + "name": "daily_usage", + "statement": 'SELECT o_orderkey FROM "orders"', +} + +_BROKEN_VIEW = { + "name": "stale_report", + "statement": 'SELECT * FROM "deleted_model"', +} + +_EMPTY_STMT_VIEW = { + "name": "empty_view", + "statement": "", +} + +_SEM_BASE_MANIFEST = { + "catalog": "wren", + "schema": "public", + "models": [_SEM_MODEL_WITH_DESC], +} + + +@pytest.mark.unit +def test_validate_manifest_view_pass(): + manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]} + result = validate_manifest(_b64(manifest), DataSource.duckdb) + assert result["errors"] == [] + + +@pytest.mark.unit +def test_validate_manifest_view_dry_plan_error(): + manifest = {**_SEM_BASE_MANIFEST, "views": [_BROKEN_VIEW]} + result = validate_manifest(_b64(manifest), DataSource.duckdb) + assert len(result["errors"]) == 1 + assert "stale_report" in result["errors"][0] + + +@pytest.mark.unit +def test_validate_manifest_empty_statement(): + manifest = {**_SEM_BASE_MANIFEST, "views": [_EMPTY_STMT_VIEW]} + result = validate_manifest(_b64(manifest), DataSource.duckdb) + assert any("empty statement" in e for e in result["errors"]) + + +@pytest.mark.unit +def test_validate_manifest_model_no_description(): + manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]} + result = validate_manifest(_b64(manifest), DataSource.duckdb) + assert result["errors"] == [] + assert any("accounts" in w for w in result["warnings"]) + + +@pytest.mark.unit +def test_validate_manifest_view_no_description(): + manifest = {**_SEM_BASE_MANIFEST, "views": [_VIEW_WITHOUT_DESC]} + result = validate_manifest(_b64(manifest), DataSource.duckdb) + assert result["errors"] == [] + assert any("daily_usage" in w for w in result["warnings"]) + + +@pytest.mark.unit +def test_validate_manifest_level_error_suppresses_warnings(): + manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]} + result = validate_manifest(_b64(manifest), DataSource.duckdb, level="error") + assert result["warnings"] == [] + + +@pytest.mark.unit +def test_validate_manifest_strict_column_warnings(): + manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]} + result = validate_manifest(_b64(manifest), DataSource.duckdb, level="strict") + text = " ".join(result["warnings"]) + assert "plan_cd" in text + assert "acct_id" in text + + +@pytest.mark.unit +def test_validate_manifest_invalid_level(): + result = validate_manifest(_b64(_SEM_BASE_MANIFEST), DataSource.duckdb, level="nope") + assert any("nope" in e for e in result["errors"]) + + +@pytest.mark.unit +def test_validate_manifest_invalid_datasource(): + manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]} + result = validate_manifest(_b64(manifest), "not-a-datasource") + assert len(result["errors"]) == 1 diff --git a/wren/tests/unit/test_memory.py b/wren/tests/unit/test_memory.py index 4ea33cba2..0918c2309 100644 --- a/wren/tests/unit/test_memory.py +++ b/wren/tests/unit/test_memory.py @@ -355,7 +355,8 @@ def test_full_lifecycle(self, wren_memory): sql_query="SELECT * FROM orders WHERE o_totalprice > 1000", ) recalled = wren_memory.recall_queries("costly orders") - assert len(recalled) == 1 + assert len(recalled) >= 1 + assert any(r["nl_query"] == "find expensive orders" for r in recalled) assert wren_memory.schema_is_current(_MANIFEST)