Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions src/casecrawler/generation/blueprint_repair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
from __future__ import annotations

import hashlib
import json
import logging
from collections.abc import Callable
from uuid import uuid4

from pydantic import BaseModel

from casecrawler.generation.blueprint_judge import BlueprintJudge
from casecrawler.llm.base import BaseLLMProvider
from casecrawler.llm.factory import get_provider
from casecrawler.models.blueprint import (
BlueprintGenerationRequest,
ClinicalBlueprint,
GenerationAttempt,
GenerationAttemptStatus,
GenerationRole,
GenerationRolePolicy,
JudgeReport,
)
from casecrawler.storage.dataset_store import DatasetStore


ProviderFactory = Callable[[str, str], BaseLLMProvider]
logger = logging.getLogger(__name__)


class BlueprintRepairResult(BaseModel):
original_blueprint: ClinicalBlueprint
final_blueprint: ClinicalBlueprint
judge_reports: list[JudgeReport]
repaired_blueprints: list[ClinicalBlueprint]
repair_rounds: int
passed: bool


class BlueprintRepairLoop:
def __init__(
self,
*,
provider_factory: ProviderFactory = get_provider,
judge: BlueprintJudge | None = None,
) -> None:
self._provider_factory = provider_factory
self._judge = judge or BlueprintJudge(provider_factory=provider_factory)

async def run(
self,
request: BlueprintGenerationRequest,
blueprint: ClinicalBlueprint,
*,
store: DatasetStore | None = None,
) -> BlueprintRepairResult:
current = blueprint
judge_reports: list[JudgeReport] = []
repaired_blueprints: list[ClinicalBlueprint] = []

for round_index in range(request.max_repair_rounds + 1):
judge_report = await self._judge.evaluate(request, current, store=store)
judge_reports.append(judge_report)
if judge_report.passed:
return BlueprintRepairResult(
original_blueprint=blueprint,
final_blueprint=current,
judge_reports=judge_reports,
repaired_blueprints=repaired_blueprints,
repair_rounds=len(repaired_blueprints),
passed=True,
)
if round_index >= request.max_repair_rounds:
break

repair_policy = request.policy_for(GenerationRole.REPAIR)
if repair_policy is None:
raise ValueError(
"A repair role policy is required when judge repair is needed."
)

repair_round = round_index + 1
repair_request_attempt = self._repair_requested_attempt(
blueprint=current,
policy=repair_policy,
judge_report=judge_report,
repair_round=repair_round,
)
if store is not None:
store.save_generation_attempt(repair_request_attempt)

current = await self._repair_blueprint(
request,
current,
judge_report,
policy=repair_policy,
repair_round=repair_round,
store=store,
)
repaired_blueprints.append(current)

return BlueprintRepairResult(
original_blueprint=blueprint,
final_blueprint=current,
judge_reports=judge_reports,
repaired_blueprints=repaired_blueprints,
repair_rounds=len(repaired_blueprints),
passed=False,
)

async def _repair_blueprint(
self,
request: BlueprintGenerationRequest,
blueprint: ClinicalBlueprint,
judge_report: JudgeReport,
*,
policy: GenerationRolePolicy,
repair_round: int,
store: DatasetStore | None,
) -> ClinicalBlueprint:
provider = self._provider_factory(policy.provider, policy.model)
prompt = self._build_repair_prompt(
request,
blueprint=blueprint,
judge_report=judge_report,
repair_round=repair_round,
)
prompt_hash = self._prompt_hash(prompt, policy)

try:
result = await provider.generate_structured(
prompt,
ClinicalBlueprint,
system=_REPAIR_SYSTEM_PROMPT,
temperature=policy.temperature,
)
repaired = self._canonicalize_repair(
ClinicalBlueprint.model_validate(result.data),
source=blueprint,
judge_report=judge_report,
repair_round=repair_round,
)
except Exception as err:
if store is not None:
self._save_failed_attempt_best_effort(
store,
self._repair_attempt(
blueprint=blueprint,
policy=policy,
status=GenerationAttemptStatus.FAILED,
prompt_hash=prompt_hash,
repair_round=repair_round,
judge_report=judge_report,
errors=[str(err)],
),
)
raise

if store is not None:
store.save_blueprint_with_attempt(
repaired,
self._repair_attempt(
blueprint=repaired,
policy=policy,
status=GenerationAttemptStatus.SUCCEEDED,
prompt_hash=prompt_hash,
repair_round=repair_round,
judge_report=judge_report,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens,
)
)
return repaired

def _canonicalize_repair(
self,
raw_blueprint: ClinicalBlueprint,
*,
source: ClinicalBlueprint,
judge_report: JudgeReport,
repair_round: int,
) -> ClinicalBlueprint:
metadata = {
**raw_blueprint.metadata,
"parent_blueprint_id": source.blueprint_id,
"repair_round": repair_round,
"judge_report_id": judge_report.report_id,
}
return ClinicalBlueprint.model_validate(
{
**raw_blueprint.model_dump(),
"blueprint_id": f"bp-{uuid4()}",
"dataset_id": source.dataset_id,
"cohort_plan_id": source.cohort_plan_id,
"archetype_name": source.archetype_name,
"organ_system": source.organ_system,
"setting": source.setting,
"metadata": metadata,
}
)

def _build_repair_prompt(
self,
request: BlueprintGenerationRequest,
*,
blueprint: ClinicalBlueprint,
judge_report: JudgeReport,
repair_round: int,
) -> str:
blueprint_json = json.dumps(
blueprint.model_dump(mode="json"),
sort_keys=True,
separators=(",", ":"),
)
report_json = json.dumps(
judge_report.model_dump(mode="json"),
sort_keys=True,
separators=(",", ":"),
)
return "\n".join(
[
"Repair this clinical blueprint so it can pass independent review.",
f"User request: {request.request}",
f"Repair round: {repair_round}",
f"Blueprint id: {blueprint.blueprint_id}",
f"Judge report id: {judge_report.report_id}",
(
"Preserve the clinical intent, dataset lineage, and task target. "
"Change only fields needed to address judge findings."
),
f"Blueprint JSON: {blueprint_json}",
f"JudgeReport JSON: {report_json}",
]
)

def _prompt_hash(
self,
prompt: str,
policy: GenerationRolePolicy,
) -> str:
payload = {
"model": policy.model,
"provider": policy.provider,
"schema": ClinicalBlueprint.__name__,
"system": _REPAIR_SYSTEM_PROMPT,
"temperature": policy.temperature,
"user": prompt,
}
serialized = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()

def _repair_requested_attempt(
self,
*,
blueprint: ClinicalBlueprint,
policy: GenerationRolePolicy,
judge_report: JudgeReport,
repair_round: int,
) -> GenerationAttempt:
payload = {
"artifact_id": blueprint.blueprint_id,
"judge_report_id": judge_report.report_id,
"repair_round": repair_round,
"status": GenerationAttemptStatus.REPAIR_REQUESTED.value,
}
return GenerationAttempt(
attempt_id=f"attempt-{uuid4()}",
dataset_id=blueprint.dataset_id,
role=GenerationRole.REPAIR,
status=GenerationAttemptStatus.REPAIR_REQUESTED,
provider=policy.provider,
model=policy.model,
prompt_hash=_hash_payload(payload),
artifact_id=blueprint.blueprint_id,
metadata={
"judge_report_id": judge_report.report_id,
"repair_round": repair_round,
},
)
Comment on lines +251 to +278

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use the real repair prompt hash for REPAIR_REQUESTED.

_repair_requested_attempt() hashes a local payload instead of the actual repair prompt, while the later SUCCEEDED/FAILED attempt for the same round uses _prompt_hash(prompt, policy). That makes the preflight record impossible to correlate with its terminal record via prompt_hash, which weakens the attempt lineage this PR is adding. Compute the prompt/hash once before persisting the request row, then pass that same hash through both attempt builders.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/casecrawler/generation/blueprint_repair.py` around lines 251 - 278,
_repair_requested_attempt currently computes prompt_hash from a local payload
using _hash_payload, which diverges from the terminal attempts that use
_prompt_hash(prompt, policy); change the flow so the repair prompt is hashed
once (via _prompt_hash with the actual repair prompt and policy) before
creating/persisting the REPAIR_REQUESTED GenerationAttempt and pass that same
prompt_hash into _repair_requested_attempt (or add a prompt_hash parameter) so
both the REPAIR_REQUESTED and subsequent SUCCEEDED/FAILED attempts use the
identical prompt_hash value.


def _repair_attempt(
self,
*,
blueprint: ClinicalBlueprint,
policy: GenerationRolePolicy,
status: GenerationAttemptStatus,
prompt_hash: str,
repair_round: int,
judge_report: JudgeReport,
input_tokens: int = 0,
output_tokens: int = 0,
errors: list[str] | None = None,
) -> GenerationAttempt:
return GenerationAttempt(
attempt_id=f"attempt-{uuid4()}",
dataset_id=blueprint.dataset_id,
role=GenerationRole.REPAIR,
status=status,
provider=policy.provider,
model=policy.model,
prompt_hash=prompt_hash,
input_tokens=input_tokens,
output_tokens=output_tokens,
errors=errors or [],
artifact_id=blueprint.blueprint_id,
metadata={
"judge_report_id": judge_report.report_id,
"repair_round": repair_round,
},
)

def _save_failed_attempt_best_effort(
self,
store: DatasetStore,
attempt: GenerationAttempt,
) -> None:
try:
store.save_generation_attempt(attempt)
except Exception:
logger.exception("Failed to persist blueprint repair failure audit.")


def _hash_payload(payload: dict) -> str:
serialized = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()


_REPAIR_SYSTEM_PROMPT = (
"You are a clinical blueprint repair model. Return a corrected "
"ClinicalBlueprint as structured data only. Do not create final synthetic "
"case text, and do not add unsupported patient facts."
)
38 changes: 38 additions & 0 deletions src/casecrawler/storage/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,44 @@ def save_blueprint(self, blueprint: ClinicalBlueprint) -> None:
)
self._conn.commit()

def save_blueprint_with_attempt(
self,
blueprint: ClinicalBlueprint,
attempt: GenerationAttempt,
) -> None:
with self._write_lock:
try:
self._conn.execute(
"""INSERT OR REPLACE INTO clinical_blueprints
(blueprint_id, dataset_id, cohort_plan_id, archetype_name,
blueprint_json)
VALUES (?, ?, ?, ?, ?)""",
(
blueprint.blueprint_id,
blueprint.dataset_id,
blueprint.cohort_plan_id,
blueprint.archetype_name,
blueprint.model_dump_json(),
),
)
self._conn.execute(
"""INSERT OR REPLACE INTO generation_attempts
(attempt_id, dataset_id, role, status, artifact_id, attempt_json)
VALUES (?, ?, ?, ?, ?, ?)""",
(
attempt.attempt_id,
attempt.dataset_id,
attempt.role.value,
attempt.status.value,
attempt.artifact_id,
attempt.model_dump_json(),
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
Comment on lines +297 to +333

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject mismatched blueprint/attempt pairs before the transaction.

This method never validates that the attempt belongs to the blueprint being saved. A caller can currently commit a ClinicalBlueprint for one artifact and a GenerationAttempt for another, which silently corrupts repair lineage despite the write being "atomic".

Suggested fix
 def save_blueprint_with_attempt(
     self,
     blueprint: ClinicalBlueprint,
     attempt: GenerationAttempt,
 ) -> None:
+    if attempt.dataset_id != blueprint.dataset_id:
+        raise ValueError("Attempt dataset_id must match blueprint dataset_id.")
+    if attempt.artifact_id != blueprint.blueprint_id:
+        raise ValueError("Attempt artifact_id must match blueprint blueprint_id.")
     with self._write_lock:
         try:
             self._conn.execute(
                 """INSERT OR REPLACE INTO clinical_blueprints
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_blueprint_with_attempt(
self,
blueprint: ClinicalBlueprint,
attempt: GenerationAttempt,
) -> None:
with self._write_lock:
try:
self._conn.execute(
"""INSERT OR REPLACE INTO clinical_blueprints
(blueprint_id, dataset_id, cohort_plan_id, archetype_name,
blueprint_json)
VALUES (?, ?, ?, ?, ?)""",
(
blueprint.blueprint_id,
blueprint.dataset_id,
blueprint.cohort_plan_id,
blueprint.archetype_name,
blueprint.model_dump_json(),
),
)
self._conn.execute(
"""INSERT OR REPLACE INTO generation_attempts
(attempt_id, dataset_id, role, status, artifact_id, attempt_json)
VALUES (?, ?, ?, ?, ?, ?)""",
(
attempt.attempt_id,
attempt.dataset_id,
attempt.role.value,
attempt.status.value,
attempt.artifact_id,
attempt.model_dump_json(),
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def save_blueprint_with_attempt(
self,
blueprint: ClinicalBlueprint,
attempt: GenerationAttempt,
) -> None:
if attempt.dataset_id != blueprint.dataset_id:
raise ValueError("Attempt dataset_id must match blueprint dataset_id.")
if attempt.artifact_id != blueprint.blueprint_id:
raise ValueError("Attempt artifact_id must match blueprint blueprint_id.")
with self._write_lock:
try:
self._conn.execute(
"""INSERT OR REPLACE INTO clinical_blueprints
(blueprint_id, dataset_id, cohort_plan_id, archetype_name,
blueprint_json)
VALUES (?, ?, ?, ?, ?)""",
(
blueprint.blueprint_id,
blueprint.dataset_id,
blueprint.cohort_plan_id,
blueprint.archetype_name,
blueprint.model_dump_json(),
),
)
self._conn.execute(
"""INSERT OR REPLACE INTO generation_attempts
(attempt_id, dataset_id, role, status, artifact_id, attempt_json)
VALUES (?, ?, ?, ?, ?, ?)""",
(
attempt.attempt_id,
attempt.dataset_id,
attempt.role.value,
attempt.status.value,
attempt.artifact_id,
attempt.model_dump_json(),
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/casecrawler/storage/dataset_store.py` around lines 297 - 333, Before
performing the DB writes in save_blueprint_with_attempt, validate that the
provided GenerationAttempt actually belongs to the ClinicalBlueprint: check
attempt.artifact_id == blueprint.blueprint_id and attempt.dataset_id ==
blueprint.dataset_id (or any other domain-specific linkage between attempt and
blueprint), and raise a clear exception (e.g., ValueError) if they do not match;
perform this validation before acquiring the write lock / before executing the
INSERTs so mismatched pairs are rejected prior to the transaction.


def get_blueprint(self, blueprint_id: str) -> ClinicalBlueprint | None:
row = self._conn.execute(
"SELECT blueprint_json FROM clinical_blueprints WHERE blueprint_id = ?",
Expand Down
Loading
Loading