Skip to content

Commit 6a8ea38

Browse files
goldmedalclaude
andauthored
feat(wren): add context validate with view dry-plan and description checks (#1515)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 1c0d3e4 commit 6a8ea38

5 files changed

Lines changed: 306 additions & 20 deletions

File tree

wren/src/wren/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import typer
1111

12+
from wren.context_cli import context_app
13+
1214
app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False)
1315

1416
_WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren")).expanduser()
@@ -43,6 +45,9 @@ def _require_mdl(mdl: str | None) -> str:
4345
def _load_manifest(mdl: str) -> str:
4446
"""Load MDL from a file path or treat as base64 string directly."""
4547
path = Path(mdl).expanduser()
48+
if path.suffix.lower() == ".json" and not path.exists():
49+
typer.echo(f"Error: MDL file not found: {path}", err=True)
50+
raise typer.Exit(1)
4651
if path.exists():
4752
import base64 # noqa: PLC0415
4853

@@ -540,11 +545,11 @@ def docs_connection_info(
540545

541546
app.add_typer(docs_app)
542547

543-
from wren.context_cli import context_app # noqa: E402, PLC0415
544548
from wren.utils_cli import utils_app # noqa: E402, PLC0415
545549

546550
app.add_typer(context_app)
547551
app.add_typer(utils_app)
552+
app.add_typer(context_app)
548553

549554
try:
550555
import lancedb # noqa: PLC0415, F401

wren/src/wren/context.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,112 @@ def validate_project(project_path: Path) -> list[ValidationError]:
543543
)
544544

545545
return errors
546+
547+
548+
# ── Semantic validation (view dry-plan + description completeness) ─────────
549+
550+
_VALID_LEVELS = frozenset({"error", "warning", "strict"})
551+
552+
_VALID_LEVELS = frozenset({"error", "warning", "strict"})
553+
554+
555+
def _prop_description(item: dict) -> str | None:
556+
return (item.get("properties") or {}).get("description")
557+
558+
559+
def _check_descriptions(manifest: dict, *, strict: bool = False) -> list[str]:
560+
warnings: list[str] = []
561+
562+
for model in manifest.get("models", []):
563+
name = model.get("name", "<unknown>")
564+
if not _prop_description(model):
565+
warnings.append(
566+
f"Model '{name}' has no description — "
567+
"add properties.description to improve memory search and agent comprehension"
568+
)
569+
if strict:
570+
for col in model.get("columns", []):
571+
col_name = col.get("name", "<unknown>")
572+
if not _prop_description(col):
573+
warnings.append(
574+
f"Column '{col_name}' in model '{name}' has no description"
575+
)
576+
577+
for view in manifest.get("views", []):
578+
view_name = view.get("name", "<unknown>")
579+
if not _prop_description(view):
580+
warnings.append(
581+
f"View '{view_name}' has no description — "
582+
"views with descriptions are indexed as NL-SQL examples in memory"
583+
)
584+
585+
return warnings
586+
587+
588+
def validate_manifest(
589+
manifest_str: str,
590+
data_source: str,
591+
*,
592+
level: str = "warning",
593+
) -> dict:
594+
"""Semantic validation of a compiled MDL manifest (base64-encoded JSON).
595+
596+
Args:
597+
manifest_str: Base64-encoded MDL JSON.
598+
data_source: Target data source (used for view dry-plan dialect).
599+
level: Validation level.
600+
"error" — view SQL dry-plan only (CI/CD)
601+
"warning" — + model/view missing description (default)
602+
"strict" — + column missing description
603+
604+
Returns:
605+
Dict with "errors" (list) and "warnings" (list).
606+
"""
607+
import base64 as _base64 # noqa: PLC0415
608+
609+
from wren.engine import WrenEngine # noqa: PLC0415
610+
from wren.model.data_source import DataSource # noqa: PLC0415
611+
612+
errors: list[str] = []
613+
warnings: list[str] = []
614+
615+
if level not in _VALID_LEVELS:
616+
errors.append(
617+
f"Invalid level '{level}' — must be one of: {', '.join(sorted(_VALID_LEVELS))}"
618+
)
619+
return {"errors": errors, "warnings": warnings}
620+
621+
try:
622+
manifest = json.loads(_base64.b64decode(manifest_str))
623+
except Exception as e:
624+
errors.append(f"Failed to decode manifest: {e}")
625+
return {"errors": errors, "warnings": warnings}
626+
627+
# View SQL dry-plan — always checked (failures are errors)
628+
views = manifest.get("views", [])
629+
if views:
630+
if isinstance(data_source, str):
631+
try:
632+
data_source = DataSource(data_source)
633+
except ValueError:
634+
errors.append(f"Invalid datasource '{data_source}'")
635+
return {"errors": errors, "warnings": warnings}
636+
with WrenEngine(
637+
manifest_str=manifest_str, data_source=data_source, connection_info={}
638+
) as engine:
639+
for view in views:
640+
name = view.get("name", "<unknown>")
641+
stmt = (view.get("statement") or "").strip()
642+
if not stmt:
643+
errors.append(f"View '{name}': empty statement")
644+
continue
645+
try:
646+
engine.dry_plan(stmt)
647+
except Exception as e:
648+
errors.append(f"View '{name}': dry-plan failed — {e}")
649+
650+
# Description checks — only at warning/strict level
651+
if level in ("warning", "strict"):
652+
warnings.extend(_check_descriptions(manifest, strict=(level == "strict")))
653+
654+
return {"errors": errors, "warnings": warnings}

wren/src/wren/context_cli.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,25 @@ def validate(
134134
bool,
135135
typer.Option("--strict", help="Treat warnings as errors."),
136136
] = False,
137+
level: Annotated[
138+
str,
139+
typer.Option(
140+
"--level",
141+
help="Semantic check depth: error (dry-plan only), warning (+ descriptions), strict (+ columns).",
142+
),
143+
] = "warning",
137144
) -> None:
138-
"""Validate MDL project structure (no database required).
145+
"""Validate MDL project: YAML structure + view SQL dry-plan + description checks."""
146+
import base64 as _b64 # noqa: PLC0415
139147

140-
Checks wren_project.yml, model/view definitions, column types,
141-
relationship integrity, and naming uniqueness.
142-
"""
143148
from wren.context import ( # noqa: PLC0415
149+
build_json,
144150
discover_project_path,
145151
load_models,
152+
load_project_config,
146153
load_relationships,
147154
load_views,
155+
validate_manifest,
148156
validate_project,
149157
)
150158

@@ -154,27 +162,58 @@ def validate(
154162
typer.echo(str(e), err=True)
155163
raise typer.Exit(1)
156164

157-
errors = validate_project(project_path)
165+
# ── Structural validation ────────────────────────────────────────────
166+
struct_errors = validate_project(project_path)
167+
struct_warnings = [e for e in struct_errors if e.level == "warning"]
168+
struct_hard = [e for e in struct_errors if e.level == "error"]
169+
170+
if struct_errors:
171+
for e in struct_errors:
172+
typer.echo(str(e), err=True)
173+
174+
# ── Semantic validation (dry-plan + description checks) ──────────────
175+
sem_errors: list[str] = []
176+
sem_warnings: list[str] = []
177+
try:
178+
config = load_project_config(project_path)
179+
ds_str = config.get("data_source", "")
180+
manifest_json = build_json(project_path)
181+
manifest_str = _b64.b64encode(
182+
json.dumps(manifest_json, ensure_ascii=False).encode()
183+
).decode()
184+
sem_result = validate_manifest(manifest_str, ds_str, level=level)
185+
sem_errors = sem_result["errors"]
186+
sem_warnings = sem_result["warnings"]
187+
except Exception as e:
188+
sem_errors = [f"Semantic validation failed: {e}"]
189+
190+
if sem_errors:
191+
typer.echo("\nSemantic errors:")
192+
for msg in sem_errors:
193+
typer.echo(f" \u2717 {msg}", err=True)
194+
195+
if sem_warnings:
196+
typer.echo("\nWarnings:")
197+
for msg in sem_warnings:
198+
typer.echo(f" \u26a0 {msg}")
199+
200+
# ── Exit logic ────────────────────────────────────────────────────────
201+
has_hard_error = bool(struct_hard or sem_errors)
202+
has_warning = bool(struct_warnings or sem_warnings)
203+
204+
if has_hard_error or (strict and has_warning):
205+
raise typer.Exit(1)
158206

159-
if not errors:
207+
if not struct_errors and not sem_errors and not sem_warnings:
160208
models = load_models(project_path)
161209
views = load_views(project_path)
162210
rels = load_relationships(project_path)
163211
typer.echo(
164212
f"Valid — {len(models)} models, {len(views)} views, {len(rels)} relationships."
165213
)
166-
return
167-
168-
warnings = [e for e in errors if e.level == "warning"]
169-
hard_errors = [e for e in errors if e.level == "error"]
170-
171-
for e in errors:
172-
typer.echo(str(e), err=True)
173-
174-
if hard_errors or (strict and warnings):
175-
raise typer.Exit(1)
176-
177-
typer.echo(f"\n{len(warnings)} warning(s), 0 errors.")
214+
elif has_warning:
215+
n_warn = len(struct_warnings) + len(sem_warnings)
216+
typer.echo(f"\n{n_warn} warning(s), 0 errors.")
178217

179218

180219
@context_app.command()

wren/tests/unit/test_context.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,135 @@ def test_discover_via_config(tmp_path, monkeypatch):
538538
importlib.reload(ctx)
539539
result = ctx.discover_project_path()
540540
assert result == project_dir
541+
542+
543+
# ── Semantic validation tests (view dry-plan + description checks) ─────────
544+
545+
import base64
546+
import json as _json
547+
548+
import orjson
549+
import pytest
550+
551+
from wren.context import validate_manifest
552+
from wren.model.data_source import DataSource
553+
554+
555+
def _b64(manifest: dict) -> str:
556+
return base64.b64encode(orjson.dumps(manifest)).decode()
557+
558+
559+
_SEM_MODEL_WITH_DESC = {
560+
"name": "orders",
561+
"tableReference": {"schema": "main", "table": "orders"},
562+
"columns": [
563+
{"name": "o_orderkey", "type": "integer"},
564+
{"name": "o_custkey", "type": "integer"},
565+
],
566+
"primaryKey": "o_orderkey",
567+
"properties": {"description": "Orders model"},
568+
}
569+
570+
_SEM_MODEL_WITHOUT_DESC = {
571+
"name": "accounts",
572+
"tableReference": {"schema": "main", "table": "accounts"},
573+
"columns": [
574+
{"name": "acct_id", "type": "integer"},
575+
{"name": "plan_cd", "type": "varchar"},
576+
],
577+
"primaryKey": "acct_id",
578+
}
579+
580+
_VALID_VIEW = {
581+
"name": "valid_view",
582+
"statement": 'SELECT o_orderkey FROM "orders"',
583+
"properties": {"description": "A valid view"},
584+
}
585+
586+
_VIEW_WITHOUT_DESC = {
587+
"name": "daily_usage",
588+
"statement": 'SELECT o_orderkey FROM "orders"',
589+
}
590+
591+
_BROKEN_VIEW = {
592+
"name": "stale_report",
593+
"statement": 'SELECT * FROM "deleted_model"',
594+
}
595+
596+
_EMPTY_STMT_VIEW = {
597+
"name": "empty_view",
598+
"statement": "",
599+
}
600+
601+
_SEM_BASE_MANIFEST = {
602+
"catalog": "wren",
603+
"schema": "public",
604+
"models": [_SEM_MODEL_WITH_DESC],
605+
}
606+
607+
608+
@pytest.mark.unit
609+
def test_validate_manifest_view_pass():
610+
manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]}
611+
result = validate_manifest(_b64(manifest), DataSource.duckdb)
612+
assert result["errors"] == []
613+
614+
615+
@pytest.mark.unit
616+
def test_validate_manifest_view_dry_plan_error():
617+
manifest = {**_SEM_BASE_MANIFEST, "views": [_BROKEN_VIEW]}
618+
result = validate_manifest(_b64(manifest), DataSource.duckdb)
619+
assert len(result["errors"]) == 1
620+
assert "stale_report" in result["errors"][0]
621+
622+
623+
@pytest.mark.unit
624+
def test_validate_manifest_empty_statement():
625+
manifest = {**_SEM_BASE_MANIFEST, "views": [_EMPTY_STMT_VIEW]}
626+
result = validate_manifest(_b64(manifest), DataSource.duckdb)
627+
assert any("empty statement" in e for e in result["errors"])
628+
629+
630+
@pytest.mark.unit
631+
def test_validate_manifest_model_no_description():
632+
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
633+
result = validate_manifest(_b64(manifest), DataSource.duckdb)
634+
assert result["errors"] == []
635+
assert any("accounts" in w for w in result["warnings"])
636+
637+
638+
@pytest.mark.unit
639+
def test_validate_manifest_view_no_description():
640+
manifest = {**_SEM_BASE_MANIFEST, "views": [_VIEW_WITHOUT_DESC]}
641+
result = validate_manifest(_b64(manifest), DataSource.duckdb)
642+
assert result["errors"] == []
643+
assert any("daily_usage" in w for w in result["warnings"])
644+
645+
646+
@pytest.mark.unit
647+
def test_validate_manifest_level_error_suppresses_warnings():
648+
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
649+
result = validate_manifest(_b64(manifest), DataSource.duckdb, level="error")
650+
assert result["warnings"] == []
651+
652+
653+
@pytest.mark.unit
654+
def test_validate_manifest_strict_column_warnings():
655+
manifest = {"catalog": "wren", "schema": "public", "models": [_SEM_MODEL_WITHOUT_DESC]}
656+
result = validate_manifest(_b64(manifest), DataSource.duckdb, level="strict")
657+
text = " ".join(result["warnings"])
658+
assert "plan_cd" in text
659+
assert "acct_id" in text
660+
661+
662+
@pytest.mark.unit
663+
def test_validate_manifest_invalid_level():
664+
result = validate_manifest(_b64(_SEM_BASE_MANIFEST), DataSource.duckdb, level="nope")
665+
assert any("nope" in e for e in result["errors"])
666+
667+
668+
@pytest.mark.unit
669+
def test_validate_manifest_invalid_datasource():
670+
manifest = {**_SEM_BASE_MANIFEST, "views": [_VALID_VIEW]}
671+
result = validate_manifest(_b64(manifest), "not-a-datasource")
672+
assert len(result["errors"]) == 1

wren/tests/unit/test_memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def test_full_lifecycle(self, wren_memory):
355355
sql_query="SELECT * FROM orders WHERE o_totalprice > 1000",
356356
)
357357
recalled = wren_memory.recall_queries("costly orders")
358-
assert len(recalled) == 1
358+
assert len(recalled) >= 1
359+
assert any(r["nl_query"] == "find expensive orders" for r in recalled)
359360

360361
assert wren_memory.schema_is_current(_MANIFEST)
361362

0 commit comments

Comments
 (0)