From 202b9468b54cf9fffce7e8ee1be233514a305495 Mon Sep 17 00:00:00 2001 From: Colin Son Date: Wed, 3 Jun 2026 20:40:36 -0500 Subject: [PATCH] Add blueprint repair loop --- .../generation/blueprint_repair.py | 331 ++++++++++++++++++ src/casecrawler/storage/dataset_store.py | 38 ++ tests/test_blueprint_repair_loop.py | 226 ++++++++++++ tests/test_blueprint_storage.py | 22 ++ 4 files changed, 617 insertions(+) create mode 100644 src/casecrawler/generation/blueprint_repair.py create mode 100644 tests/test_blueprint_repair_loop.py diff --git a/src/casecrawler/generation/blueprint_repair.py b/src/casecrawler/generation/blueprint_repair.py new file mode 100644 index 0000000..9859159 --- /dev/null +++ b/src/casecrawler/generation/blueprint_repair.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import hashlib +import json +import logging +from collections.abc import Callable +from uuid import uuid4 + +from pydantic import BaseModel + +from casecrawler.generation.blueprint_judge import BlueprintJudge +from casecrawler.llm.base import BaseLLMProvider +from casecrawler.llm.factory import get_provider +from casecrawler.models.blueprint import ( + BlueprintGenerationRequest, + ClinicalBlueprint, + GenerationAttempt, + GenerationAttemptStatus, + GenerationRole, + GenerationRolePolicy, + JudgeReport, +) +from casecrawler.storage.dataset_store import DatasetStore + + +ProviderFactory = Callable[[str, str], BaseLLMProvider] +logger = logging.getLogger(__name__) + + +class BlueprintRepairResult(BaseModel): + original_blueprint: ClinicalBlueprint + final_blueprint: ClinicalBlueprint + judge_reports: list[JudgeReport] + repaired_blueprints: list[ClinicalBlueprint] + repair_rounds: int + passed: bool + + +class BlueprintRepairLoop: + def __init__( + self, + *, + provider_factory: ProviderFactory = get_provider, + judge: BlueprintJudge | None = None, + ) -> None: + self._provider_factory = provider_factory + self._judge = judge or BlueprintJudge(provider_factory=provider_factory) + + async def run( + self, + request: BlueprintGenerationRequest, + blueprint: ClinicalBlueprint, + *, + store: DatasetStore | None = None, + ) -> BlueprintRepairResult: + current = blueprint + judge_reports: list[JudgeReport] = [] + repaired_blueprints: list[ClinicalBlueprint] = [] + + for round_index in range(request.max_repair_rounds + 1): + judge_report = await self._judge.evaluate(request, current, store=store) + judge_reports.append(judge_report) + if judge_report.passed: + return BlueprintRepairResult( + original_blueprint=blueprint, + final_blueprint=current, + judge_reports=judge_reports, + repaired_blueprints=repaired_blueprints, + repair_rounds=len(repaired_blueprints), + passed=True, + ) + if round_index >= request.max_repair_rounds: + break + + repair_policy = request.policy_for(GenerationRole.REPAIR) + if repair_policy is None: + raise ValueError( + "A repair role policy is required when judge repair is needed." + ) + + repair_round = round_index + 1 + repair_request_attempt = self._repair_requested_attempt( + blueprint=current, + policy=repair_policy, + judge_report=judge_report, + repair_round=repair_round, + ) + if store is not None: + store.save_generation_attempt(repair_request_attempt) + + current = await self._repair_blueprint( + request, + current, + judge_report, + policy=repair_policy, + repair_round=repair_round, + store=store, + ) + repaired_blueprints.append(current) + + return BlueprintRepairResult( + original_blueprint=blueprint, + final_blueprint=current, + judge_reports=judge_reports, + repaired_blueprints=repaired_blueprints, + repair_rounds=len(repaired_blueprints), + passed=False, + ) + + async def _repair_blueprint( + self, + request: BlueprintGenerationRequest, + blueprint: ClinicalBlueprint, + judge_report: JudgeReport, + *, + policy: GenerationRolePolicy, + repair_round: int, + store: DatasetStore | None, + ) -> ClinicalBlueprint: + provider = self._provider_factory(policy.provider, policy.model) + prompt = self._build_repair_prompt( + request, + blueprint=blueprint, + judge_report=judge_report, + repair_round=repair_round, + ) + prompt_hash = self._prompt_hash(prompt, policy) + + try: + result = await provider.generate_structured( + prompt, + ClinicalBlueprint, + system=_REPAIR_SYSTEM_PROMPT, + temperature=policy.temperature, + ) + repaired = self._canonicalize_repair( + ClinicalBlueprint.model_validate(result.data), + source=blueprint, + judge_report=judge_report, + repair_round=repair_round, + ) + except Exception as err: + if store is not None: + self._save_failed_attempt_best_effort( + store, + self._repair_attempt( + blueprint=blueprint, + policy=policy, + status=GenerationAttemptStatus.FAILED, + prompt_hash=prompt_hash, + repair_round=repair_round, + judge_report=judge_report, + errors=[str(err)], + ), + ) + raise + + if store is not None: + store.save_blueprint_with_attempt( + repaired, + self._repair_attempt( + blueprint=repaired, + policy=policy, + status=GenerationAttemptStatus.SUCCEEDED, + prompt_hash=prompt_hash, + repair_round=repair_round, + judge_report=judge_report, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + ) + ) + return repaired + + def _canonicalize_repair( + self, + raw_blueprint: ClinicalBlueprint, + *, + source: ClinicalBlueprint, + judge_report: JudgeReport, + repair_round: int, + ) -> ClinicalBlueprint: + metadata = { + **raw_blueprint.metadata, + "parent_blueprint_id": source.blueprint_id, + "repair_round": repair_round, + "judge_report_id": judge_report.report_id, + } + return ClinicalBlueprint.model_validate( + { + **raw_blueprint.model_dump(), + "blueprint_id": f"bp-{uuid4()}", + "dataset_id": source.dataset_id, + "cohort_plan_id": source.cohort_plan_id, + "archetype_name": source.archetype_name, + "organ_system": source.organ_system, + "setting": source.setting, + "metadata": metadata, + } + ) + + def _build_repair_prompt( + self, + request: BlueprintGenerationRequest, + *, + blueprint: ClinicalBlueprint, + judge_report: JudgeReport, + repair_round: int, + ) -> str: + blueprint_json = json.dumps( + blueprint.model_dump(mode="json"), + sort_keys=True, + separators=(",", ":"), + ) + report_json = json.dumps( + judge_report.model_dump(mode="json"), + sort_keys=True, + separators=(",", ":"), + ) + return "\n".join( + [ + "Repair this clinical blueprint so it can pass independent review.", + f"User request: {request.request}", + f"Repair round: {repair_round}", + f"Blueprint id: {blueprint.blueprint_id}", + f"Judge report id: {judge_report.report_id}", + ( + "Preserve the clinical intent, dataset lineage, and task target. " + "Change only fields needed to address judge findings." + ), + f"Blueprint JSON: {blueprint_json}", + f"JudgeReport JSON: {report_json}", + ] + ) + + def _prompt_hash( + self, + prompt: str, + policy: GenerationRolePolicy, + ) -> str: + payload = { + "model": policy.model, + "provider": policy.provider, + "schema": ClinicalBlueprint.__name__, + "system": _REPAIR_SYSTEM_PROMPT, + "temperature": policy.temperature, + "user": prompt, + } + serialized = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(serialized.encode("utf-8")).hexdigest() + + def _repair_requested_attempt( + self, + *, + blueprint: ClinicalBlueprint, + policy: GenerationRolePolicy, + judge_report: JudgeReport, + repair_round: int, + ) -> GenerationAttempt: + payload = { + "artifact_id": blueprint.blueprint_id, + "judge_report_id": judge_report.report_id, + "repair_round": repair_round, + "status": GenerationAttemptStatus.REPAIR_REQUESTED.value, + } + return GenerationAttempt( + attempt_id=f"attempt-{uuid4()}", + dataset_id=blueprint.dataset_id, + role=GenerationRole.REPAIR, + status=GenerationAttemptStatus.REPAIR_REQUESTED, + provider=policy.provider, + model=policy.model, + prompt_hash=_hash_payload(payload), + artifact_id=blueprint.blueprint_id, + metadata={ + "judge_report_id": judge_report.report_id, + "repair_round": repair_round, + }, + ) + + def _repair_attempt( + self, + *, + blueprint: ClinicalBlueprint, + policy: GenerationRolePolicy, + status: GenerationAttemptStatus, + prompt_hash: str, + repair_round: int, + judge_report: JudgeReport, + input_tokens: int = 0, + output_tokens: int = 0, + errors: list[str] | None = None, + ) -> GenerationAttempt: + return GenerationAttempt( + attempt_id=f"attempt-{uuid4()}", + dataset_id=blueprint.dataset_id, + role=GenerationRole.REPAIR, + status=status, + provider=policy.provider, + model=policy.model, + prompt_hash=prompt_hash, + input_tokens=input_tokens, + output_tokens=output_tokens, + errors=errors or [], + artifact_id=blueprint.blueprint_id, + metadata={ + "judge_report_id": judge_report.report_id, + "repair_round": repair_round, + }, + ) + + def _save_failed_attempt_best_effort( + self, + store: DatasetStore, + attempt: GenerationAttempt, + ) -> None: + try: + store.save_generation_attempt(attempt) + except Exception: + logger.exception("Failed to persist blueprint repair failure audit.") + + +def _hash_payload(payload: dict) -> str: + serialized = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(serialized.encode("utf-8")).hexdigest() + + +_REPAIR_SYSTEM_PROMPT = ( + "You are a clinical blueprint repair model. Return a corrected " + "ClinicalBlueprint as structured data only. Do not create final synthetic " + "case text, and do not add unsupported patient facts." +) diff --git a/src/casecrawler/storage/dataset_store.py b/src/casecrawler/storage/dataset_store.py index 893f12c..25bce96 100644 --- a/src/casecrawler/storage/dataset_store.py +++ b/src/casecrawler/storage/dataset_store.py @@ -294,6 +294,44 @@ def save_blueprint(self, blueprint: ClinicalBlueprint) -> None: ) self._conn.commit() + def save_blueprint_with_attempt( + self, + blueprint: ClinicalBlueprint, + attempt: GenerationAttempt, + ) -> None: + with self._write_lock: + try: + self._conn.execute( + """INSERT OR REPLACE INTO clinical_blueprints + (blueprint_id, dataset_id, cohort_plan_id, archetype_name, + blueprint_json) + VALUES (?, ?, ?, ?, ?)""", + ( + blueprint.blueprint_id, + blueprint.dataset_id, + blueprint.cohort_plan_id, + blueprint.archetype_name, + blueprint.model_dump_json(), + ), + ) + self._conn.execute( + """INSERT OR REPLACE INTO generation_attempts + (attempt_id, dataset_id, role, status, artifact_id, attempt_json) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + attempt.attempt_id, + attempt.dataset_id, + attempt.role.value, + attempt.status.value, + attempt.artifact_id, + attempt.model_dump_json(), + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + def get_blueprint(self, blueprint_id: str) -> ClinicalBlueprint | None: row = self._conn.execute( "SELECT blueprint_json FROM clinical_blueprints WHERE blueprint_id = ?", diff --git a/tests/test_blueprint_repair_loop.py b/tests/test_blueprint_repair_loop.py new file mode 100644 index 0000000..6a889c6 --- /dev/null +++ b/tests/test_blueprint_repair_loop.py @@ -0,0 +1,226 @@ +from collections import Counter + +import pytest + +from casecrawler.llm.base import StructuredGenerationResult +from casecrawler.models.blueprint import ( + BlueprintEvidence, + BlueprintGenerationRequest, + ClinicalBlueprint, + GenerationAttemptStatus, + GenerationRole, + GenerationRolePolicy, + JudgeReport, +) +from casecrawler.storage.dataset_store import DatasetStore + + +class SequenceProvider: + def __init__(self, results) -> None: + self.results = list(results) + self.calls = [] + + async def generate_structured(self, prompt, schema, system="", **kwargs): + self.calls.append( + { + "prompt": prompt, + "schema": schema, + "system": system, + "kwargs": kwargs, + } + ) + data = self.results.pop(0) + return StructuredGenerationResult( + data=data, + input_tokens=50, + output_tokens=25, + model="fake", + ) + + +class RoutingProviderFactory: + def __init__(self, *, judge_results, repair_results=()) -> None: + self.judge_provider = SequenceProvider(judge_results) + self.repair_provider = SequenceProvider(repair_results) + + def __call__(self, provider_name, model): + if model == "judge-model": + return self.judge_provider + if model == "repair-model": + return self.repair_provider + raise ValueError(f"Unexpected model {model}") + + +def _request(max_repair_rounds: int = 1) -> BlueprintGenerationRequest: + return BlueprintGenerationRequest( + request="Judge and repair outpatient anticoagulation decision blueprints.", + target_count=1, + max_repair_rounds=max_repair_rounds, + role_policies=[ + GenerationRolePolicy( + role=GenerationRole.JUDGE, + provider="openai", + model="judge-model", + ), + GenerationRolePolicy( + role=GenerationRole.REPAIR, + provider="openai", + model="repair-model", + temperature=0.1, + ), + ], + ) + + +def _blueprint(**overrides) -> ClinicalBlueprint: + payload = { + "blueprint_id": "bp-1", + "dataset_id": "ds-1", + "cohort_plan_id": "plan-1", + "archetype_name": "anticoagulation decision", + "organ_system": "cardiovascular", + "setting": "outpatient", + "patient": {"age": 72, "sex": "female"}, + "chief_concern": "Atrial fibrillation anticoagulation follow-up.", + "diagnoses": [ + { + "name": "atrial fibrillation", + "supporting_findings": ["ECG confirms AF"], + } + ], + "clinical_reasoning_targets": ["Review renal dosing and bleeding risk."], + "evidence": BlueprintEvidence( + supported_claims=["AF anticoagulation requires renal-dose review."], + citations=[{"source": "dailymed", "claim": "renal-dose review"}], + ), + } + payload.update(overrides) + return ClinicalBlueprint(**payload) + + +def _judge_report(*, passed: bool, score: float = 0.8) -> JudgeReport: + return JudgeReport( + report_id="model-controlled", + dataset_id="wrong-dataset", + artifact_id="wrong-artifact", + role=GenerationRole.REPAIR, + score=score, + passed=passed, + rubric="blueprint_plausibility", + findings=[{"criterion": "diagnostic_support", "passed": passed}], + ) + + +@pytest.mark.asyncio +async def test_blueprint_repair_loop_repairs_failed_judge_report(tmp_path): + from casecrawler.generation.blueprint_repair import BlueprintRepairLoop + + store = DatasetStore(db_path=str(tmp_path / "datasets.db")) + original = _blueprint() + repaired_raw = _blueprint( + blueprint_id="model-repaired-id", + dataset_id="wrong-dataset", + chief_concern="Atrial fibrillation follow-up with renal-dose review.", + evidence=BlueprintEvidence( + supported_claims=[ + "AF anticoagulation requires renal-dose review.", + "Renal function changes anticoagulant dosing.", + ], + citations=[ + {"source": "dailymed", "claim": "renal-dose review"}, + {"source": "dailymed", "claim": "renal dosing"}, + ], + ), + ) + factory = RoutingProviderFactory( + judge_results=[ + _judge_report(passed=False, score=0.42), + _judge_report(passed=True, score=0.93), + ], + repair_results=[repaired_raw], + ) + + result = await BlueprintRepairLoop(provider_factory=factory).run( + _request(), + original, + store=store, + ) + + assert result.passed is True + assert result.repair_rounds == 1 + assert result.original_blueprint == original + assert result.final_blueprint.blueprint_id.startswith("bp-") + assert result.final_blueprint.blueprint_id != original.blueprint_id + assert result.final_blueprint.dataset_id == "ds-1" + assert result.final_blueprint.metadata["parent_blueprint_id"] == "bp-1" + assert result.final_blueprint.metadata["repair_round"] == 1 + assert len(result.judge_reports) == 2 + assert len(result.repaired_blueprints) == 1 + assert store.get_blueprint(result.final_blueprint.blueprint_id) == ( + result.final_blueprint + ) + attempts = store.list_generation_attempts(dataset_id="ds-1") + assert Counter((attempt.role, attempt.status) for attempt in attempts) == Counter( + { + (GenerationRole.JUDGE, GenerationAttemptStatus.SUCCEEDED): 2, + (GenerationRole.REPAIR, GenerationAttemptStatus.REPAIR_REQUESTED): 1, + (GenerationRole.REPAIR, GenerationAttemptStatus.SUCCEEDED): 1, + } + ) + repair_attempts = [ + attempt for attempt in attempts if attempt.role == GenerationRole.REPAIR + ] + assert {attempt.metadata["repair_round"] for attempt in repair_attempts} == {1} + assert factory.repair_provider.calls[0]["kwargs"]["temperature"] == 0.1 + assert "bp-1" in factory.repair_provider.calls[0]["prompt"] + + +@pytest.mark.asyncio +async def test_blueprint_repair_loop_does_not_repair_passing_blueprint(tmp_path): + from casecrawler.generation.blueprint_repair import BlueprintRepairLoop + + store = DatasetStore(db_path=str(tmp_path / "datasets.db")) + original = _blueprint() + factory = RoutingProviderFactory( + judge_results=[_judge_report(passed=True, score=0.95)] + ) + + result = await BlueprintRepairLoop(provider_factory=factory).run( + _request(), + original, + store=store, + ) + + assert result.passed is True + assert result.repair_rounds == 0 + assert result.final_blueprint == original + assert result.repaired_blueprints == [] + assert factory.repair_provider.calls == [] + assert len(store.list_judge_reports(artifact_id="bp-1")) == 1 + assert len(store.list_generation_attempts(dataset_id="ds-1")) == 1 + + +@pytest.mark.asyncio +async def test_blueprint_repair_loop_respects_zero_repair_rounds(tmp_path): + from casecrawler.generation.blueprint_repair import BlueprintRepairLoop + + store = DatasetStore(db_path=str(tmp_path / "datasets.db")) + original = _blueprint() + factory = RoutingProviderFactory( + judge_results=[_judge_report(passed=False, score=0.35)] + ) + + result = await BlueprintRepairLoop(provider_factory=factory).run( + _request(max_repair_rounds=0), + original, + store=store, + ) + + assert result.passed is False + assert result.repair_rounds == 0 + assert result.final_blueprint == original + assert result.repaired_blueprints == [] + assert factory.repair_provider.calls == [] + attempts = store.list_generation_attempts(dataset_id="ds-1") + assert len(attempts) == 1 + assert attempts[0].role == GenerationRole.JUDGE diff --git a/tests/test_blueprint_storage.py b/tests/test_blueprint_storage.py index 4b7f5c4..c4cf1fd 100644 --- a/tests/test_blueprint_storage.py +++ b/tests/test_blueprint_storage.py @@ -177,6 +177,28 @@ def test_dataset_store_saves_judge_report_with_attempt(tmp_path): assert store.get_generation_attempt("attempt-1") == attempt +def test_dataset_store_saves_blueprint_with_attempt(tmp_path): + store = DatasetStore(db_path=str(tmp_path / "datasets.db")) + blueprint = _blueprint() + attempt = GenerationAttempt( + attempt_id="attempt-1", + dataset_id="ds-1", + role=GenerationRole.REPAIR, + status=GenerationAttemptStatus.SUCCEEDED, + provider="openai", + model="gpt-4.1-mini", + prompt_hash="abc123", + input_tokens=100, + output_tokens=75, + artifact_id="bp-1", + ) + + store.save_blueprint_with_attempt(blueprint, attempt) + + assert store.get_blueprint("bp-1") == blueprint + assert store.get_generation_attempt("attempt-1") == attempt + + def test_dataset_manifest_includes_blueprint_persistence_counts(tmp_path): store = DatasetStore(db_path=str(tmp_path / "datasets.db")) store.save_record(_record())