Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion wren/src/wren/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import typer

from wren.context_cli import context_app

app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False)

_WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren")).expanduser()
Expand Down Expand Up @@ -43,6 +45,9 @@ def _require_mdl(mdl: str | None) -> str:
def _load_manifest(mdl: str) -> str:
"""Load MDL from a file path or treat as base64 string directly."""
path = Path(mdl).expanduser()
if path.suffix.lower() == ".json" and not path.exists():
typer.echo(f"Error: MDL file not found: {path}", err=True)
raise typer.Exit(1)
if path.exists():
import base64 # noqa: PLC0415

Expand Down Expand Up @@ -540,11 +545,11 @@ def docs_connection_info(

app.add_typer(docs_app)

from wren.context_cli import context_app # noqa: E402, PLC0415
from wren.utils_cli import utils_app # noqa: E402, PLC0415

app.add_typer(context_app)
app.add_typer(utils_app)
app.add_typer(context_app)

try:
import lancedb # noqa: PLC0415, F401
Expand Down
109 changes: 109 additions & 0 deletions wren/src/wren/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,112 @@ def validate_project(project_path: Path) -> list[ValidationError]:
)

return errors


# ── Semantic validation (view dry-plan + description completeness) ─────────

_VALID_LEVELS = frozenset({"error", "warning", "strict"})

_VALID_LEVELS = frozenset({"error", "warning", "strict"})


def _prop_description(item: dict) -> str | None:
return (item.get("properties") or {}).get("description")


def _check_descriptions(manifest: dict, *, strict: bool = False) -> list[str]:
warnings: list[str] = []

for model in manifest.get("models", []):
name = model.get("name", "<unknown>")
if not _prop_description(model):
warnings.append(
f"Model '{name}' has no description — "
"add properties.description to improve memory search and agent comprehension"
)
if strict:
for col in model.get("columns", []):
col_name = col.get("name", "<unknown>")
if not _prop_description(col):
warnings.append(
f"Column '{col_name}' in model '{name}' has no description"
)

for view in manifest.get("views", []):
view_name = view.get("name", "<unknown>")
if not _prop_description(view):
warnings.append(
f"View '{view_name}' has no description — "
"views with descriptions are indexed as NL-SQL examples in memory"
)

return warnings


def validate_manifest(
manifest_str: str,
data_source: str,
*,
level: str = "warning",
) -> dict:
"""Semantic validation of a compiled MDL manifest (base64-encoded JSON).

Args:
manifest_str: Base64-encoded MDL JSON.
data_source: Target data source (used for view dry-plan dialect).
level: Validation level.
"error" — view SQL dry-plan only (CI/CD)
"warning" — + model/view missing description (default)
"strict" — + column missing description

Returns:
Dict with "errors" (list) and "warnings" (list).
"""
import base64 as _base64 # noqa: PLC0415

from wren.engine import WrenEngine # noqa: PLC0415
from wren.model.data_source import DataSource # noqa: PLC0415

errors: list[str] = []
warnings: list[str] = []

if level not in _VALID_LEVELS:
errors.append(
f"Invalid level '{level}' — must be one of: {', '.join(sorted(_VALID_LEVELS))}"
)
return {"errors": errors, "warnings": warnings}

try:
manifest = json.loads(_base64.b64decode(manifest_str))
except Exception as e:
errors.append(f"Failed to decode manifest: {e}")
return {"errors": errors, "warnings": warnings}

# View SQL dry-plan — always checked (failures are errors)
views = manifest.get("views", [])
if views:
if isinstance(data_source, str):
try:
data_source = DataSource(data_source)
except ValueError:
errors.append(f"Invalid datasource '{data_source}'")
return {"errors": errors, "warnings": warnings}
with WrenEngine(
manifest_str=manifest_str, data_source=data_source, connection_info={}
) as engine:
for view in views:
name = view.get("name", "<unknown>")
stmt = (view.get("statement") or "").strip()
if not stmt:
errors.append(f"View '{name}': empty statement")
continue
try:
engine.dry_plan(stmt)
except Exception as e:
errors.append(f"View '{name}': dry-plan failed — {e}")

# Description checks — only at warning/strict level
if level in ("warning", "strict"):
warnings.extend(_check_descriptions(manifest, strict=(level == "strict")))

return {"errors": errors, "warnings": warnings}
75 changes: 57 additions & 18 deletions wren/src/wren/context_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,25 @@ def validate(
bool,
typer.Option("--strict", help="Treat warnings as errors."),
] = False,
level: Annotated[
str,
typer.Option(
"--level",
help="Semantic check depth: error (dry-plan only), warning (+ descriptions), strict (+ columns).",
),
] = "warning",
) -> None:
"""Validate MDL project structure (no database required).
"""Validate MDL project: YAML structure + view SQL dry-plan + description checks."""
import base64 as _b64 # noqa: PLC0415

Checks wren_project.yml, model/view definitions, column types,
relationship integrity, and naming uniqueness.
"""
from wren.context import ( # noqa: PLC0415
build_json,
discover_project_path,
load_models,
load_project_config,
load_relationships,
load_views,
validate_manifest,
validate_project,
)

Expand All @@ -154,27 +162,58 @@ def validate(
typer.echo(str(e), err=True)
raise typer.Exit(1)

errors = validate_project(project_path)
# ── Structural validation ────────────────────────────────────────────
struct_errors = validate_project(project_path)
struct_warnings = [e for e in struct_errors if e.level == "warning"]
struct_hard = [e for e in struct_errors if e.level == "error"]

if struct_errors:
for e in struct_errors:
typer.echo(str(e), err=True)

# ── Semantic validation (dry-plan + description checks) ──────────────
sem_errors: list[str] = []
sem_warnings: list[str] = []
try:
config = load_project_config(project_path)
ds_str = config.get("data_source", "")
manifest_json = build_json(project_path)
manifest_str = _b64.b64encode(
json.dumps(manifest_json, ensure_ascii=False).encode()
).decode()
sem_result = validate_manifest(manifest_str, ds_str, level=level)
sem_errors = sem_result["errors"]
sem_warnings = sem_result["warnings"]
except Exception as e:
sem_errors = [f"Semantic validation failed: {e}"]

if sem_errors:
typer.echo("\nSemantic errors:")
for msg in sem_errors:
typer.echo(f" \u2717 {msg}", err=True)

if sem_warnings:
typer.echo("\nWarnings:")
for msg in sem_warnings:
typer.echo(f" \u26a0 {msg}")

# ── Exit logic ────────────────────────────────────────────────────────
has_hard_error = bool(struct_hard or sem_errors)
has_warning = bool(struct_warnings or sem_warnings)

if has_hard_error or (strict and has_warning):
raise typer.Exit(1)

if not errors:
if not struct_errors and not sem_errors and not sem_warnings:
models = load_models(project_path)
views = load_views(project_path)
rels = load_relationships(project_path)
typer.echo(
f"Valid — {len(models)} models, {len(views)} views, {len(rels)} relationships."
)
return

warnings = [e for e in errors if e.level == "warning"]
hard_errors = [e for e in errors if e.level == "error"]

for e in errors:
typer.echo(str(e), err=True)

if hard_errors or (strict and warnings):
raise typer.Exit(1)

typer.echo(f"\n{len(warnings)} warning(s), 0 errors.")
elif has_warning:
n_warn = len(struct_warnings) + len(sem_warnings)
typer.echo(f"\n{n_warn} warning(s), 0 errors.")


@context_app.command()
Expand Down
132 changes: 132 additions & 0 deletions wren/tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,135 @@ def test_discover_via_config(tmp_path, monkeypatch):
importlib.reload(ctx)
result = ctx.discover_project_path()
assert result == project_dir


# ── Semantic validation tests (view dry-plan + description checks) ─────────

import base64
import json as _json

import orjson
import pytest

from wren.context import validate_manifest
from wren.model.data_source import DataSource


def _b64(manifest: dict) -> str:
return base64.b64encode(orjson.dumps(manifest)).decode()


_SEM_MODEL_WITH_DESC = {
"name": "orders",
"tableReference": {"schema": "main", "table": "orders"},
"columns": [
{"name": "o_orderkey", "type": "integer"},
{"name": "o_custkey", "type": "integer"},
],
"primaryKey": "o_orderkey",
"properties": {"description": "Orders model"},
}

_SEM_MODEL_WITHOUT_DESC = {
"name": "accounts",
"tableReference": {"schema": "main", "table": "accounts"},
"columns": [
{"name": "acct_id", "type": "integer"},
{"name": "plan_cd", "type": "varchar"},
],
"primaryKey": "acct_id",
}

_VALID_VIEW = {
"name": "valid_view",
"statement": 'SELECT o_orderkey FROM "orders"',
"properties": {"description": "A valid view"},
}

_VIEW_WITHOUT_DESC = {
"name": "daily_usage",
"statement": 'SELECT o_orderkey FROM "orders"',
}

_BROKEN_VIEW = {
"name": "stale_report",
"statement": 'SELECT * FROM "deleted_model"',
}

_EMPTY_STMT_VIEW = {
"name": "empty_view",
"statement": "",
}

_SEM_BASE_MANIFEST = {
"catalog": "wren",
"schema": "public",
"models": [_SEM_MODEL_WITH_DESC],
}


@pytest.mark.unit
def test_validate_manifest_view_pass():
manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]}
result = validate_manifest(_b64(manifest), DataSource.duckdb)
assert result["errors"] == []


@pytest.mark.unit
def test_validate_manifest_view_dry_plan_error():
manifest = {**_SEM_BASE_MANIFEST, "views": [_BROKEN_VIEW]}
result = validate_manifest(_b64(manifest), DataSource.duckdb)
assert len(result["errors"]) == 1
assert "stale_report" in result["errors"][0]


@pytest.mark.unit
def test_validate_manifest_empty_statement():
manifest = {**_SEM_BASE_MANIFEST, "views": [_EMPTY_STMT_VIEW]}
result = validate_manifest(_b64(manifest), DataSource.duckdb)
assert any("empty statement" in e for e in result["errors"])


@pytest.mark.unit
def test_validate_manifest_model_no_description():
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
result = validate_manifest(_b64(manifest), DataSource.duckdb)
assert result["errors"] == []
assert any("accounts" in w for w in result["warnings"])


@pytest.mark.unit
def test_validate_manifest_view_no_description():
manifest = {**_SEM_BASE_MANIFEST, "views": [_VIEW_WITHOUT_DESC]}
result = validate_manifest(_b64(manifest), DataSource.duckdb)
assert result["errors"] == []
assert any("daily_usage" in w for w in result["warnings"])


@pytest.mark.unit
def test_validate_manifest_level_error_suppresses_warnings():
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
result = validate_manifest(_b64(manifest), DataSource.duckdb, level="error")
assert result["warnings"] == []


@pytest.mark.unit
def test_validate_manifest_strict_column_warnings():
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
result = validate_manifest(_b64(manifest), DataSource.duckdb, level="strict")
text = " ".join(result["warnings"])
assert "plan_cd" in text
assert "acct_id" in text


@pytest.mark.unit
def test_validate_manifest_invalid_level():
result = validate_manifest(_b64(_SEM_BASE_MANIFEST), DataSource.duckdb, level="nope")
assert any("nope" in e for e in result["errors"])


@pytest.mark.unit
def test_validate_manifest_invalid_datasource():
manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]}
result = validate_manifest(_b64(manifest), "not-a-datasource")
assert len(result["errors"]) == 1
3 changes: 2 additions & 1 deletion wren/tests/unit/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def test_full_lifecycle(self, wren_memory):
sql_query="SELECT * FROM orders WHERE o_totalprice > 1000",
)
recalled = wren_memory.recall_queries("costly orders")
assert len(recalled) == 1
assert len(recalled) >= 1
assert any(r["nl_query"] == "find expensive orders" for r in recalled)

assert wren_memory.schema_is_current(_MANIFEST)

Expand Down
Loading