diff --git a/.gitignore b/.gitignore index c50630d6..9e3f32f8 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,4 @@ htmlcov/ .mypy_cache/ .dmypy.json dmypy.json -.pyre/ \ No newline at end of file +.pyre/ diff --git a/examples/session_persist/README.md b/examples/session_persist/README.md new file mode 100644 index 00000000..c73ee4e2 --- /dev/null +++ b/examples/session_persist/README.md @@ -0,0 +1,165 @@ +# Session Persistence Examples + +This directory contains examples demonstrating the `SessionPersistentClient` that provides automatic session persistence by wrapping `ClaudeSDKClient`. + +## Design Philosophy + +The `SessionPersistentClient` follows a clean wrapper pattern: + +1. **Wraps ClaudeSDKClient**: Does not modify the core client, just wraps it +2. **Uses receive_messages()**: Extracts session data from actual message responses +3. **Server-Generated IDs**: Uses session IDs from Claude's server, not client-generated UUIDs +4. **Automatic Persistence**: Saves conversation data transparently in the background + +## Examples + +### 1. `simple_persist.py` +**Purpose**: Basic demonstration of automatic session persistence + +**Key Features**: +- Shows how SessionPersistentClient wraps ClaudeSDKClient +- Demonstrates automatic session ID extraction from messages +- Shows session data inspection capabilities +- Simple conversation with automatic saving + +**Run it**: +```bash +python examples/session_persist/simple_persist.py +``` + +### 2. `multi_turn_conversation.py` +**Purpose**: Multi-turn conversation with session resumption demonstration + +**Key Features**: +- Multi-turn conversation that maintains context +- Session disconnect and resumption workflow +- Resume session using `start_or_resume_session()` +- Shows both local data loading and CLI --resume functionality +- Demonstrates conversation context continuity across disconnect/resume +- Context reference across multiple turns spanning session boundaries + +**Demo Flow**: +1. **Phase 1**: Initial conversation (Turns 1-2) with automatic persistence +2. **Disconnect**: Session is saved and connection closed +3. **Phase 2**: Resume session with new client instance +4. **Continue**: Turns 3-4 with preserved context from earlier turns + +**Run it**: +```bash +python examples/session_persist/multi_turn_conversation.py +``` + +## How SessionPersistentClient Works + +### Architecture + +```python +SessionPersistentClient +├── ClaudeSDKClient (wrapped) # Handles all Claude interactions +├── SessionPersistence # Manages file storage +└── Message Processing # Extracts session data from messages +``` + +### Key Methods + +```python +# Initialize with automatic persistence +client = SessionPersistentClient( + options=ClaudeCodeOptions(), + storage_path="./my_sessions" +) + +# All ClaudeSDKClient methods are available: +await client.connect() +await client.query("Hello") +async for message in client.receive_response(): + # Session data is automatically extracted and saved + print(message) + +# Session management: +await client.start_or_resume_session(id) # Resume existing session (local + server) +session_id = client.get_current_session_id() # Server-generated ID +sessions = await client.list_sessions() # List all saved sessions +session_data = await client.load_session(id) # Load session for inspection +await client.delete_session(session_id) # Delete a session +``` + +### Automatic Session Extraction + +The client automatically extracts session data from messages: + +1. **Session ID Detection**: Looks for `session_id` in message metadata +2. **Message Conversion**: Converts Claude messages to `ConversationMessage` format +3. **Automatic Saving**: Saves session data after each message +4. **Context Preservation**: Maintains conversation history and metadata + +### File Structure + +Sessions are saved as JSON files: +``` +~/.claude_sdk/sessions/ +├── 2aecab00-6512-4e29-9da3-9321cac6eb2.json +├── 7b8f3c45-2d19-4e7a-b6c1-f9d2e8a3c7b5.json +└── ... +``` + +Each file contains: +```json +{ + "session_id": "server-generated-uuid", + "start_time": "2025-07-30T15:25:31.069594", + "last_activity": "2025-07-30T15:27:45.123456", + "conversation_history": [...], + "working_directory": "/path/to/working/dir", + "options": {...} +} +``` + +## Benefits of This Design + +### ✅ Clean Separation +- Core `ClaudeSDKClient` remains unchanged +- Persistence is an optional wrapper layer +- No mixing of concerns + +### ✅ Server-Driven +- Uses actual session IDs from Claude's server +- No client-side UUID generation +- Matches Claude's internal session management + +### ✅ Automatic Operation +- No manual session management required +- Transparent persistence in background +- Works with all `ClaudeSDKClient` features + +### ✅ Easy Migration +- Existing `ClaudeSDKClient` code works unchanged +- Just replace `ClaudeSDKClient` with `SessionPersistentClient` +- All methods and features preserved + +## Usage Pattern + +### Old (no persistence): +```python +async with ClaudeSDKClient() as client: + await client.query("Hello") + async for message in client.receive_response(): + print(message) +``` + +### New (with automatic persistence): +```python +async with SessionPersistentClient() as client: + await client.query("Hello") + async for message in client.receive_response(): + print(message) # Same code, automatic persistence! +``` + +## Dependencies + +These examples use `trio` for async operations. Install with: +```bash +pip install trio +``` + +Or use standard `asyncio` by replacing `trio.run()` with `asyncio.run()`. \ No newline at end of file diff --git a/examples/session_persist/multi_turn_conversation.py b/examples/session_persist/multi_turn_conversation.py new file mode 100644 index 00000000..d9e216c3 --- /dev/null +++ b/examples/session_persist/multi_turn_conversation.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +Multi-turn conversation with session resumption example using SessionPersistentClient. + +This demonstrates how the SessionPersistentClient automatically captures +session context across multiple turns of conversation, and how to resume +sessions after disconnecting. The demo shows: + +1. Initial conversation (Turns 1-2) with automatic session persistence +2. Disconnect from the session +3. Resume the session using start_or_resume_session() +4. Continue conversation (Turns 3-4) with preserved context + +Key features demonstrated: +- Local session data loading from storage +- CLI --resume option for server-side session continuity +- Context retention across disconnect/resume cycles +- Seamless multi-turn conversation flow +""" + +import trio +from pathlib import Path + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + SessionPersistentClient, + TextBlock, +) + + +async def multi_turn_demo(): + """Demonstrate multi-turn conversation with session resumption.""" + print("=== Multi-Turn Conversation with Session Resumption ===") + + storage_path = Path("./conversation_sessions") + + # Phase 1: Initial conversation (Turns 1-2) + print("\n🚀 Phase 1: Initial Conversation") + session_id = None + + async with SessionPersistentClient( + options=ClaudeCodeOptions(), + storage_path=storage_path + ) as client: + + # Turn 1: Introduction + print("\n💬 Turn 1: Introduction") + intro_msg = "Hello! My name is Alex and I'm a software developer." + await client.query(intro_msg) + print(f"User: {intro_msg}") + + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Turn 2: Ask about a topic + print("\n💬 Turn 2: Technical question") + await client.query("What's the difference between async and sync programming?") + + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Capture session ID for resumption + session_id = client.get_current_session_id() + session_data = await client.load_session(session_id) + + print(f"\n📋 Session after 2 turns:") + print(f" ID: {session_id}") + print(f" Total Messages: {len(session_data.conversation_history) if session_data else 'Unknown'}") + print(f" 🔌 Disconnecting to demonstrate session resumption...") + + # Phase 2: Resume session (Turns 3-4) + print(f"\n🔄 Phase 2: Resuming Session {session_id}") + + # Create new client instance and resume the session + client = SessionPersistentClient( + options=ClaudeCodeOptions(), + storage_path=storage_path + ) + + try: + # Resume the session - this loads local data AND configures CLI --resume + await client.start_or_resume_session(session_id) + print(f"✅ Session resumed. Local data: {len(client._session_data.conversation_history) if client._session_data else 0} messages") + + # Connect and continue the conversation + await client.connect() + + # Turn 3: Follow-up question (tests context retention across disconnect/resume) + print("\n💬 Turn 3: Follow-up question") + await client.query("Can you give me a Python example of what you just explained?") + + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Turn 4: Reference earlier context + print("\n💬 Turn 4: Reference earlier context") + await client.query("Thanks! What was my name again that I mentioned at the beginning?") + + async for message in client.receive_response(): + if isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + print(f"Claude: {block.text}") + + # Show final session info + final_session_id = client.get_current_session_id() + session_data = await client.load_session(final_session_id) + + print(f"\n📋 Final Session Summary:") + print(f" Original ID: {session_id}") + print(f" Current ID: {final_session_id}") + print(f" Total Messages: {len(session_data.conversation_history) if session_data else 'Unknown'}") + print(f" Duration: {session_data.last_activity - session_data.start_time if session_data else 'Unknown'}") + print(f" Auto-saved to: {storage_path.absolute()}") + + finally: + await client.disconnect() + + +async def main(): + """Run the multi-turn conversation with session resumption demo.""" + await multi_turn_demo() + + print("\n" + "="*60) + print("✅ Session resumption demo complete!") + print("📝 Key features demonstrated:") + print(" • Local session data is loaded from storage") + print(" • CLI --resume option is set for server-side continuity") + print(" • Conversation context is maintained across disconnect/resume") + print(" • Multi-turn conversations work seamlessly across sessions") + + +if __name__ == "__main__": + trio.run(main) \ No newline at end of file diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index 8ac162e2..6e43c2c8 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -9,6 +9,7 @@ ) from .client import ClaudeSDKClient from .query import query +from .session_persistent_client import SessionPersistentClient from .types import ( AssistantMessage, ClaudeCodeOptions, @@ -30,6 +31,7 @@ # Main exports "query", "ClaudeSDKClient", + "SessionPersistentClient", # Types "PermissionMode", "McpServerConfig", diff --git a/src/claude_code_sdk/_internal/session_storage.py b/src/claude_code_sdk/_internal/session_storage.py new file mode 100644 index 00000000..109e8efc --- /dev/null +++ b/src/claude_code_sdk/_internal/session_storage.py @@ -0,0 +1,205 @@ +"""Session storage and persistence utilities.""" + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +from ..types import ClaudeCodeOptions, Message + + +@dataclass +class SessionData: + """Session data for persistence.""" + + session_id: str + start_time: datetime + last_activity: datetime + conversation_history: list[Message] = field(default_factory=list) + working_directory: str = "" + options: ClaudeCodeOptions | None = None + + def add_message(self, message: Message) -> None: + """Add a message to the conversation history.""" + self.conversation_history.append(message) + self.last_activity = datetime.now() + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + from ..types import UserMessage, AssistantMessage, SystemMessage, ResultMessage + + # Convert messages to serializable format + serialized_messages = [] + for msg in self.conversation_history: + msg_dict = { + "message_type": type(msg).__name__, + } + + if isinstance(msg, UserMessage): + msg_dict.update({ + "content": msg.content, + }) + elif isinstance(msg, AssistantMessage): + msg_dict.update({ + "content": [ + { + "type": type(block).__name__, + "text": getattr(block, 'text', None), + "id": getattr(block, 'id', None), + "name": getattr(block, 'name', None), + "input": getattr(block, 'input', None), + "tool_use_id": getattr(block, 'tool_use_id', None), + "is_error": getattr(block, 'is_error', None), + } + for block in msg.content + ], + }) + elif isinstance(msg, SystemMessage): + msg_dict.update({ + "subtype": msg.subtype, + "data": msg.data, + }) + elif isinstance(msg, ResultMessage): + msg_dict.update({ + "subtype": msg.subtype, + "duration_ms": msg.duration_ms, + "duration_api_ms": msg.duration_api_ms, + "is_error": msg.is_error, + "num_turns": msg.num_turns, + "session_id": msg.session_id, + "total_cost_usd": msg.total_cost_usd, + "usage": msg.usage, + "result": msg.result, + }) + + serialized_messages.append(msg_dict) + + return { + "session_id": self.session_id, + "start_time": self.start_time.isoformat(), + "last_activity": self.last_activity.isoformat(), + "conversation_history": serialized_messages, + "working_directory": self.working_directory, + "options": { + "model": self.options.model if self.options else None, + "allowed_tools": self.options.allowed_tools if self.options else [], + "permission_mode": self.options.permission_mode if self.options else None, + } if self.options else None, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SessionData": + """Create from dictionary (JSON deserialization).""" + from ..types import UserMessage, AssistantMessage, SystemMessage, ResultMessage, TextBlock, ToolUseBlock, ToolResultBlock + + # Parse conversation history + conversation_history = [] + for msg_data in data.get("conversation_history", []): + message_type = msg_data.get("message_type", "") + + if message_type == "UserMessage": + message = UserMessage(content=msg_data["content"]) + elif message_type == "AssistantMessage": + # Reconstruct content blocks + content_blocks = [] + for block_data in msg_data.get("content", []): + block_type = block_data.get("type", "") + if block_type == "TextBlock": + content_blocks.append(TextBlock(text=block_data.get("text", ""))) + elif block_type == "ToolUseBlock": + content_blocks.append(ToolUseBlock( + id=block_data.get("id", ""), + name=block_data.get("name", ""), + input=block_data.get("input", {}) + )) + elif block_type == "ToolResultBlock": + content_blocks.append(ToolResultBlock( + tool_use_id=block_data.get("tool_use_id", ""), + content=block_data.get("content"), + is_error=block_data.get("is_error") + )) + message = AssistantMessage(content=content_blocks) + elif message_type == "SystemMessage": + message = SystemMessage( + subtype=msg_data.get("subtype", ""), + data=msg_data.get("data", {}) + ) + elif message_type == "ResultMessage": + message = ResultMessage( + subtype=msg_data.get("subtype", ""), + duration_ms=msg_data.get("duration_ms", 0), + duration_api_ms=msg_data.get("duration_api_ms", 0), + is_error=msg_data.get("is_error", False), + num_turns=msg_data.get("num_turns", 0), + session_id=msg_data.get("session_id", ""), + total_cost_usd=msg_data.get("total_cost_usd"), + usage=msg_data.get("usage"), + result=msg_data.get("result") + ) + else: + continue # Skip unknown message types + + conversation_history.append(message) + + # Parse options + options = None + if data.get("options"): + options = ClaudeCodeOptions( + model=data["options"].get("model"), + allowed_tools=data["options"].get("allowed_tools", []), + permission_mode=data["options"].get("permission_mode"), + ) + + return cls( + session_id=data["session_id"], + start_time=datetime.fromisoformat(data["start_time"]), + last_activity=datetime.fromisoformat(data["last_activity"]), + conversation_history=conversation_history, + working_directory=data.get("working_directory", ""), + options=options, + ) + + +class SimpleSessionPersistence: + """Simple file-based session persistence.""" + + def __init__(self, storage_path: Path | str | None = None): + if storage_path is None: + storage_path = Path.home() / ".claude_sdk" / "sessions" + self._storage_path = Path(storage_path) + self._storage_path.mkdir(parents=True, exist_ok=True) + + async def save_session(self, session_data: SessionData) -> None: + """Save session data to file.""" + file_path = self._storage_path / f"{session_data.session_id}.json" + with file_path.open("w", encoding="utf-8") as f: + json.dump(session_data.to_dict(), f, indent=2, ensure_ascii=False) + + async def load_session(self, session_id: str) -> SessionData | None: + """Load session data from file.""" + file_path = self._storage_path / f"{session_id}.json" + if not file_path.exists(): + return None + + try: + with file_path.open("r", encoding="utf-8") as f: + data = json.load(f) + return SessionData.from_dict(data) + except (json.JSONDecodeError, KeyError, ValueError): + return None + + async def list_sessions(self) -> list[str]: + """List all session IDs.""" + session_ids = [] + for file_path in self._storage_path.glob("*.json"): + session_ids.append(file_path.stem) # filename without .json extension + return sorted(session_ids) + + async def delete_session(self, session_id: str) -> bool: + """Delete a session file.""" + file_path = self._storage_path / f"{session_id}.json" + if file_path.exists(): + file_path.unlink() + return True + return False \ No newline at end of file diff --git a/src/claude_code_sdk/session_persistent_client.py b/src/claude_code_sdk/session_persistent_client.py new file mode 100644 index 00000000..0050675a --- /dev/null +++ b/src/claude_code_sdk/session_persistent_client.py @@ -0,0 +1,264 @@ +"""Session Persistent Client that wraps ClaudeSDKClient for automatic session persistence.""" + +from collections.abc import AsyncIterable, AsyncIterator +from datetime import datetime +from pathlib import Path +from typing import Any + +from ._internal.session_storage import SessionData, SimpleSessionPersistence +from .client import ClaudeSDKClient +from .types import ClaudeCodeOptions, Message, ResultMessage + + +class SessionPersistentClient: + """ + A wrapper around ClaudeSDKClient that provides automatic session persistence. + + This client saves all conversation messages and metadata to files automatically, + allowing you to resume conversations and inspect session history later. + + Key features: + - Wraps ClaudeSDKClient for all Claude interactions + - Automatically extracts session IDs from received messages + - Saves conversation history to JSON files + - Provides session management (list, delete, inspect) + - Uses server-generated session IDs (no client-side UUID generation) + - Handles server-side session ID changes gracefully while preserving conversation history + + Session ID Handling: + The Claude server may change session IDs during a conversation. This client handles + such changes correctly by: + - Preserving all conversation history when session ID changes + - Moving the session data to the new session ID + - Cleaning up old session files automatically + - Maintaining session continuity and start times + + Example: + ```python + async with SessionPersistentClient() as client: + await client.query("Hello, remember my name is Alice") + + async for message in client.receive_response(): + print(message) + + # Session is automatically saved with current session ID + session_id = client.get_current_session_id() + print(f"Session saved as: {session_id}") + + # Even if server changes session ID, history is preserved + ``` + """ + + def __init__( + self, + options: ClaudeCodeOptions | None = None, + storage_path: Path | str | None = None, + ): + """ + Initialize the session persistent client. + + Args: + options: Claude Code options to pass to underlying client + storage_path: Directory to store session files. + This path is passed to SimpleSessionPersistence which creates the directory + if it doesn't exist and stores session JSON files there. + """ + self._client = ClaudeSDKClient(options) + self._persistence = SimpleSessionPersistence(storage_path) + self._current_session_id: str | None = None + self._session_data: SessionData | None = None + + @property + def client(self) -> ClaudeSDKClient: + """Access to the underlying ClaudeSDKClient.""" + return self._client + + async def connect( + self, prompt: str | AsyncIterable[dict[str, Any]] | None = None + ) -> None: + """Connect to Claude with a prompt or message stream.""" + await self._client.connect(prompt) + + async def query( + self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" + ) -> None: + """Send a new request in streaming mode.""" + await self._client.query(prompt, session_id) + + async def interrupt(self) -> None: + """Send interrupt signal (only works with streaming mode).""" + await self._client.interrupt() + + async def receive_messages(self) -> AsyncIterator[Message]: + """ + Receive all messages from Claude with automatic session persistence. + + This method wraps the underlying ClaudeSDKClient.receive_messages() and + automatically extracts session data for persistence. + """ + async for message in self._client.receive_messages(): + # Extract session data from message for persistence + await self._handle_message_persistence(message) + yield message + + async def receive_response(self) -> AsyncIterator[Message]: + """ + Receive messages from Claude until ResultMessage with automatic persistence. + + This method wraps the underlying ClaudeSDKClient.receive_response() and + automatically extracts session data for persistence. + """ + async for message in self._client.receive_response(): + # Extract session data from message for persistence + await self._handle_message_persistence(message) + yield message + + async def start_or_resume_session(self, session_id: str | None = None) -> None: + """ + Start a new session or resume an existing one. + + Args: + session_id: If provided, attempts to resume the session with this ID. + If None, starts a new session. + + Note: + This method configures the underlying ClaudeSDKClient to use --resume + when connecting to Claude CLI, which allows resuming server-side conversation state. + """ + if session_id: + # Load existing session data if available + self._session_data = await self._persistence.load_session(session_id) + if self._session_data: + self._current_session_id = session_id + + # Configure the underlying client to resume the session + if self._client.options is None: + from .types import ClaudeCodeOptions + self._client.options = ClaudeCodeOptions() + self._client.options.resume = session_id + else: + # Starting new session - clear any resume option + if self._client.options: + self._client.options.resume = None + self._current_session_id = None + self._session_data = None + + async def disconnect(self) -> None: + """Disconnect from Claude and finalize session persistence.""" + # Update final session metadata before disconnecting + if self._session_data: + self._session_data.last_activity = datetime.now() + await self._persistence.save_session(self._session_data) + + await self._client.disconnect() + + async def __aenter__(self) -> "SessionPersistentClient": + """Enter async context - automatically connects.""" + await self.connect() + return self + + async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> bool: + """Exit async context - automatically disconnects and saves session.""" + await self.disconnect() + return False + + # Session Management Methods + + def get_current_session_id(self) -> str | None: + """Get the current session ID (server-generated).""" + return self._current_session_id + + async def list_sessions(self) -> list[str]: + """List all saved session IDs.""" + return await self._persistence.list_sessions() + + async def delete_session(self, session_id: str) -> bool: + """ + Delete a saved session. + + Args: + session_id: The session ID to delete + + Returns: + bool: True if session was deleted, False if not found + """ + return await self._persistence.delete_session(session_id) + + async def load_session(self, session_id: str) -> SessionData | None: + """ + Load session data for inspection. + + Args: + session_id: The session ID to load + + Returns: + SessionData | None: Session data if found, None otherwise + """ + return await self._persistence.load_session(session_id) + + # Private Methods + + async def _handle_message_persistence(self, message: Message) -> None: + """ + Handle automatic message persistence based on received messages. + + This method handles server-side session ID changes correctly: + - When a session_id is first received, a new session is created + - When a session_id changes during an active session, the existing + conversation history is preserved and moved to the new session_id + - Old session files are cleaned up when session_id changes + + Args: + message: Message received from ClaudeSDKClient + """ + # Extract session ID from message metadata if available + session_id = getattr(message, 'session_id', None) + + # For ResultMessages, check if they have session_id + if isinstance(message, ResultMessage) and hasattr(message, 'session_id'): + session_id = message.session_id + + # Handle session ID updates from server + if session_id: + if session_id != self._current_session_id: + # Session ID changed - this could be: + # 1. First time getting a session ID (self._current_session_id is None) + # 2. Server updated the session ID for the same logical session + + old_session_id = self._current_session_id + self._current_session_id = session_id + + if self._session_data is not None: + # We have existing session data, so this is a session ID update + # Update the session ID in the existing session data + self._session_data.session_id = session_id + self._session_data.last_activity = datetime.now() + + # Save the session under the new session ID + await self._persistence.save_session(self._session_data) + + # Clean up the old session file if it exists + if old_session_id: + await self._persistence.delete_session(old_session_id) + + else: + # No existing session data - try to load session with new session_id first + # (in case this is a resumed session), otherwise create new one + self._session_data = await self._persistence.load_session(session_id) + if not self._session_data: + self._session_data = SessionData( + session_id=session_id, + start_time=datetime.now(), + last_activity=datetime.now(), + conversation_history=[], + working_directory=str(Path.cwd()), + options=self._client.options, + ) + + # Add message directly to session (Message objects are already the right type) + if self._session_data: + self._session_data.add_message(message) + self._session_data.last_activity = datetime.now() + + # Save session after each message + await self._persistence.save_session(self._session_data) \ No newline at end of file diff --git a/tests/test_session_persistent_client.py b/tests/test_session_persistent_client.py new file mode 100644 index 00000000..cd4c94b8 --- /dev/null +++ b/tests/test_session_persistent_client.py @@ -0,0 +1,640 @@ +""" +Unit tests for SessionPersistentClient. + +This test suite covers: +- SessionData class: serialization, deserialization, message management +- SimpleSessionPersistence class: file-based storage operations +- SessionPersistentClient class: wrapper functionality and automatic persistence + +Key test scenarios: +- Automatic session creation when receiving messages with session_id +- Message persistence throughout conversation flow +- Session management operations (list, load, delete) +- Proper delegation to underlying ClaudeSDKClient +- Error handling for corrupted files and missing sessions +- Context manager behavior and cleanup +""" + +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from claude_code_sdk import ( + AssistantMessage, + ClaudeCodeOptions, + ResultMessage, + SystemMessage, + TextBlock, + UserMessage, +) +from claude_code_sdk._internal.session_storage import SessionData, SimpleSessionPersistence +from claude_code_sdk.session_persistent_client import SessionPersistentClient + + +class TestSessionData: + """Test SessionData class.""" + + def test_session_data_initialization(self): + """Test SessionData initialization.""" + session_id = "test-session-123" + start_time = datetime.now() + + session_data = SessionData( + session_id=session_id, + start_time=start_time, + last_activity=start_time, + working_directory="/test/dir", + ) + + assert session_data.session_id == session_id + assert session_data.start_time == start_time + assert session_data.last_activity == start_time + assert session_data.working_directory == "/test/dir" + assert len(session_data.conversation_history) == 0 + + def test_add_message(self): + """Test adding messages to session data.""" + session_data = SessionData( + session_id="test-session", + start_time=datetime.now(), + last_activity=datetime.now(), + ) + + # Add a user message + user_msg = UserMessage(content="Hello") + session_data.add_message(user_msg) + + assert len(session_data.conversation_history) == 1 + assert session_data.conversation_history[0] == user_msg + + # Add an assistant message + assistant_msg = AssistantMessage(content=[TextBlock(text="Hi there!")]) + session_data.add_message(assistant_msg) + + assert len(session_data.conversation_history) == 2 + assert session_data.conversation_history[1] == assistant_msg + + def test_to_dict_and_from_dict(self): + """Test serialization and deserialization.""" + # Create session data with various message types + start_time = datetime(2025, 1, 1, 12, 0, 0) + last_activity = datetime(2025, 1, 1, 12, 5, 0) + + session_data = SessionData( + session_id="test-session-456", + start_time=start_time, + last_activity=last_activity, + working_directory="/test/path", + options=ClaudeCodeOptions(model="claude-3-5-sonnet-20241022"), + ) + + # Add different message types (manually to control timing) + session_data.conversation_history.append(UserMessage(content="Test user message")) + session_data.conversation_history.append(AssistantMessage(content=[TextBlock(text="Test assistant message")])) + session_data.conversation_history.append(SystemMessage(subtype="init", data={"tool": "test"})) + session_data.conversation_history.append(ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=3, + session_id="test-session-456", + result="Test completed" + )) + + # Convert to dict + data_dict = session_data.to_dict() + + # Verify dict structure + assert data_dict["session_id"] == "test-session-456" + assert data_dict["start_time"] == "2025-01-01T12:00:00" + assert data_dict["last_activity"] == "2025-01-01T12:05:00" + assert data_dict["working_directory"] == "/test/path" + assert len(data_dict["conversation_history"]) == 4 + assert data_dict["options"]["model"] == "claude-3-5-sonnet-20241022" + + # Convert back from dict + restored_session = SessionData.from_dict(data_dict) + + # Verify restoration + assert restored_session.session_id == session_data.session_id + assert restored_session.start_time == session_data.start_time + assert restored_session.last_activity == session_data.last_activity + assert restored_session.working_directory == session_data.working_directory + assert len(restored_session.conversation_history) == 4 + assert restored_session.options.model == "claude-3-5-sonnet-20241022" + + +class TestSimpleSessionPersistence: + """Test SimpleSessionPersistence class.""" + + @pytest.fixture + def temp_storage(self): + """Create temporary storage directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + @pytest.fixture + def persistence(self, temp_storage): + """Create SimpleSessionPersistence with temp storage.""" + return SimpleSessionPersistence(temp_storage) + + @pytest.fixture + def sample_session_data(self): + """Create sample session data.""" + session_data = SessionData( + session_id="test-session-789", + start_time=datetime.now(), + last_activity=datetime.now(), + working_directory="/test", + ) + session_data.add_message(UserMessage(content="Test message")) + return session_data + + async def test_save_and_load_session(self, persistence, sample_session_data): + """Test saving and loading sessions.""" + # Save session + await persistence.save_session(sample_session_data) + + # Load session + loaded_session = await persistence.load_session(sample_session_data.session_id) + + assert loaded_session is not None + assert loaded_session.session_id == sample_session_data.session_id + assert len(loaded_session.conversation_history) == 1 + + async def test_load_nonexistent_session(self, persistence): + """Test loading nonexistent session returns None.""" + result = await persistence.load_session("nonexistent-session-id") + assert result is None + + async def test_list_sessions(self, persistence, sample_session_data): + """Test listing sessions.""" + # Initially empty + sessions = await persistence.list_sessions() + assert len(sessions) == 0 + + # Save a session + await persistence.save_session(sample_session_data) + + # Should now have one session + sessions = await persistence.list_sessions() + assert len(sessions) == 1 + assert sessions[0] == sample_session_data.session_id + + async def test_delete_session(self, persistence, sample_session_data): + """Test deleting sessions.""" + # Save session first + await persistence.save_session(sample_session_data) + + # Verify it exists + sessions = await persistence.list_sessions() + assert len(sessions) == 1 + + # Delete session + deleted = await persistence.delete_session(sample_session_data.session_id) + assert deleted is True + + # Verify it's gone + sessions = await persistence.list_sessions() + assert len(sessions) == 0 + + # Try to delete again (should return False) + deleted = await persistence.delete_session(sample_session_data.session_id) + assert deleted is False + + async def test_corrupted_session_file(self, persistence, temp_storage): + """Test handling corrupted session files.""" + # Create a corrupted JSON file + corrupted_file = temp_storage / "corrupted-session.json" + corrupted_file.write_text("invalid json content") + + # Should return None for corrupted file + result = await persistence.load_session("corrupted-session") + assert result is None + + +class TestSessionPersistentClient: + """Test SessionPersistentClient class.""" + + @pytest.fixture + def temp_storage(self): + """Create temporary storage directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + @pytest.fixture + def mock_client(self): + """Create mock ClaudeSDKClient.""" + with patch('claude_code_sdk.session_persistent_client.ClaudeSDKClient') as mock: + client_instance = AsyncMock() + # Set up options as a real ClaudeCodeOptions object to avoid serialization issues + client_instance.options = ClaudeCodeOptions(model="claude-3-5-sonnet-20241022") + mock.return_value = client_instance + yield client_instance + + @pytest.fixture + def persistent_client(self, mock_client, temp_storage): + """Create SessionPersistentClient with mocked dependencies.""" + options = ClaudeCodeOptions(model="claude-3-5-sonnet-20241022") + return SessionPersistentClient(options=options, storage_path=temp_storage) + + def test_initialization(self, persistent_client, mock_client): + """Test SessionPersistentClient initialization.""" + assert persistent_client._client == mock_client + assert persistent_client._current_session_id is None + assert persistent_client._session_data is None + + def test_client_property(self, persistent_client, mock_client): + """Test client property access.""" + assert persistent_client.client == mock_client + + async def test_connect(self, persistent_client, mock_client): + """Test connect method.""" + await persistent_client.connect("test prompt") + mock_client.connect.assert_called_once_with("test prompt") + + async def test_query(self, persistent_client, mock_client): + """Test query method.""" + await persistent_client.query("test query", "test-session") + mock_client.query.assert_called_once_with("test query", "test-session") + + async def test_interrupt(self, persistent_client, mock_client): + """Test interrupt method.""" + await persistent_client.interrupt() + mock_client.interrupt.assert_called_once() + + async def test_disconnect(self, persistent_client, mock_client): + """Test disconnect method.""" + await persistent_client.disconnect() + mock_client.disconnect.assert_called_once() + + async def test_context_manager(self, persistent_client, mock_client): + """Test async context manager functionality.""" + async with persistent_client as client: + assert client == persistent_client + mock_client.connect.assert_called_once_with(None) + + mock_client.disconnect.assert_called_once() + + async def test_message_persistence(self, persistent_client): + """Test automatic message persistence.""" + # Create mock messages + system_msg = SystemMessage(subtype="init", data={}) + assistant_msg = AssistantMessage(content=[TextBlock(text="Hello!")]) + result_msg = ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=1, + session_id="test-session-123", + result="Done" + ) + + # Mock the client's receive_messages to yield these messages + async def mock_receive_messages(): + yield system_msg + yield assistant_msg + yield result_msg + + persistent_client._client.receive_messages = mock_receive_messages + + # Collect messages through persistent client + messages = [] + async for message in persistent_client.receive_messages(): + messages.append(message) + + # Verify messages were yielded + assert len(messages) == 3 + assert messages[0] == system_msg + assert messages[1] == assistant_msg + assert messages[2] == result_msg + + # Verify session was created when ResultMessage (with session_id) was processed + assert persistent_client._current_session_id == "test-session-123" + assert persistent_client._session_data is not None + # The session is created when the first message with session_id is processed + # Only that message (and subsequent ones) are saved - previous messages without session context are not + assert len(persistent_client._session_data.conversation_history) == 1 + assert persistent_client._session_data.conversation_history[0] == result_msg + + async def test_receive_response_stops_at_result(self, persistent_client): + """Test receive_response stops at ResultMessage.""" + # Create mock messages + assistant_msg = AssistantMessage(content=[TextBlock(text="Response")]) + result_msg = ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=1, + session_id="test-session-456", + result="Completed" + ) + extra_msg = SystemMessage(subtype="extra", data={}) # This shouldn't be yielded + + # Mock the client's receive_response to properly stop at ResultMessage (like the real implementation does) + async def mock_receive_response(): + yield assistant_msg + yield result_msg + # The real receive_response stops here and doesn't yield extra_msg + + persistent_client._client.receive_response = mock_receive_response + + # Collect messages + messages = [] + async for message in persistent_client.receive_response(): + messages.append(message) + + # Should stop after ResultMessage (as the underlying client's receive_response does) + assert len(messages) == 2 + assert messages[0] == assistant_msg + assert messages[1] == result_msg + + async def test_session_management(self, persistent_client): + """Test session management methods.""" + # Initially no current session + assert persistent_client.get_current_session_id() is None + + # No sessions saved + sessions = await persistent_client.list_sessions() + assert len(sessions) == 0 + + # Simulate processing a message with session ID + result_msg = ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=1, + session_id="test-session-789", + result="Test" + ) + + await persistent_client._handle_message_persistence(result_msg) + + # Should now have current session + assert persistent_client.get_current_session_id() == "test-session-789" + + # Should have one saved session + sessions = await persistent_client.list_sessions() + assert len(sessions) == 1 + assert sessions[0] == "test-session-789" + + # Should be able to load session data + session_data = await persistent_client.load_session("test-session-789") + assert session_data is not None + assert session_data.session_id == "test-session-789" + assert len(session_data.conversation_history) == 1 + + # Should be able to delete session + deleted = await persistent_client.delete_session("test-session-789") + assert deleted is True + + # Session should be gone + sessions = await persistent_client.list_sessions() + assert len(sessions) == 0 + + async def test_session_creation_and_loading(self, persistent_client): + """Test session creation vs loading existing session.""" + session_id = "test-session-999" + + # First message creates new session + first_msg = ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=1, + session_id=session_id, + result="First" + ) + + await persistent_client._handle_message_persistence(first_msg) + original_session = persistent_client._session_data + + # Reset client state to simulate new client instance + persistent_client._current_session_id = None + persistent_client._session_data = None + + # Second message should load existing session + second_msg = ResultMessage( + subtype="success", + duration_ms=1200, + duration_api_ms=900, + is_error=False, + num_turns=2, + session_id=session_id, + result="Second" + ) + + await persistent_client._handle_message_persistence(second_msg) + + # Should have loaded existing session and added new message + assert persistent_client._session_data is not None + assert persistent_client._session_data.session_id == session_id + assert len(persistent_client._session_data.conversation_history) == 2 + + async def test_message_without_session_id(self, persistent_client): + """Test handling messages without session ID.""" + # Message without session_id attribute + system_msg = SystemMessage(subtype="init", data={}) + + await persistent_client._handle_message_persistence(system_msg) + + # Should not create session data + assert persistent_client._current_session_id is None + assert persistent_client._session_data is None + + async def test_final_session_save_on_disconnect(self, persistent_client): + """Test final session save when disconnecting.""" + # Set up session data + persistent_client._session_data = SessionData( + session_id="final-test", + start_time=datetime.now(), + last_activity=datetime.now(), + ) + + # Mock the client disconnect + persistent_client._client.disconnect = AsyncMock() + + # Disconnect should save session + await persistent_client.disconnect() + + # Verify session was saved and client disconnected + persistent_client._client.disconnect.assert_called_once() + + # Session should exist in storage + sessions = await persistent_client.list_sessions() + assert "final-test" in sessions + + async def test_session_id_update_preserves_history(self, persistent_client): + """Test that session ID updates preserve conversation history.""" + # Start with first message that creates session + first_msg = ResultMessage( + subtype="success", + duration_ms=1000, + duration_api_ms=800, + is_error=False, + num_turns=1, + session_id="session-v1", + result="First response" + ) + + await persistent_client._handle_message_persistence(first_msg) + + # Verify initial session setup + assert persistent_client._current_session_id == "session-v1" + assert persistent_client._session_data is not None + assert len(persistent_client._session_data.conversation_history) == 1 + + # Add more messages to build up history + assistant_msg = AssistantMessage(content=[TextBlock(text="Building history")]) + await persistent_client._handle_message_persistence(assistant_msg) + + user_msg = UserMessage(content="More conversation") + await persistent_client._handle_message_persistence(user_msg) + + # Now we have 3 messages in history + assert len(persistent_client._session_data.conversation_history) == 3 + original_history = persistent_client._session_data.conversation_history.copy() + original_start_time = persistent_client._session_data.start_time + + # Server sends message with updated session ID (same logical session) + updated_msg = ResultMessage( + subtype="success", + duration_ms=1200, + duration_api_ms=900, + is_error=False, + num_turns=4, + session_id="session-v2", # NEW session ID from server + result="Updated session response" + ) + + await persistent_client._handle_message_persistence(updated_msg) + + # Verify session ID was updated but history preserved + assert persistent_client._current_session_id == "session-v2" + assert persistent_client._session_data.session_id == "session-v2" + + # History should be preserved + new message added + assert len(persistent_client._session_data.conversation_history) == 4 + assert persistent_client._session_data.conversation_history[:3] == original_history + assert persistent_client._session_data.conversation_history[3] == updated_msg + + # Start time should be preserved (same logical session) + assert persistent_client._session_data.start_time == original_start_time + + # Verify new session ID is saved and old one is cleaned up + sessions = await persistent_client.list_sessions() + assert "session-v2" in sessions + assert "session-v1" not in sessions # Old session should be cleaned up + + # Verify the saved session has all the history + loaded_session = await persistent_client.load_session("session-v2") + assert loaded_session is not None + assert len(loaded_session.conversation_history) == 4 + assert loaded_session.start_time == original_start_time + + async def test_multiple_session_id_updates(self, persistent_client): + """Test multiple session ID updates in sequence.""" + # Start with initial session + msg1 = ResultMessage( + subtype="success", duration_ms=1000, duration_api_ms=800, + is_error=False, num_turns=1, session_id="session-a", result="Response A" + ) + await persistent_client._handle_message_persistence(msg1) + + # First update + msg2 = ResultMessage( + subtype="success", duration_ms=1100, duration_api_ms=850, + is_error=False, num_turns=2, session_id="session-b", result="Response B" + ) + await persistent_client._handle_message_persistence(msg2) + + # Second update + msg3 = ResultMessage( + subtype="success", duration_ms=1200, duration_api_ms=900, + is_error=False, num_turns=3, session_id="session-c", result="Response C" + ) + await persistent_client._handle_message_persistence(msg3) + + # Verify final state + assert persistent_client._current_session_id == "session-c" + assert persistent_client._session_data.session_id == "session-c" + assert len(persistent_client._session_data.conversation_history) == 3 + + # Only the final session should exist + sessions = await persistent_client.list_sessions() + assert "session-c" in sessions + assert "session-a" not in sessions + assert "session-b" not in sessions + + async def test_start_or_resume_session_new(self, persistent_client): + """Test starting a new session.""" + # Start new session (no session_id provided) + await persistent_client.start_or_resume_session() + + # Should clear any existing session state + assert persistent_client._current_session_id is None + assert persistent_client._session_data is None + + # Should clear resume option in client options + if persistent_client._client.options: + assert persistent_client._client.options.resume is None + + async def test_start_or_resume_session_existing(self, persistent_client): + """Test resuming an existing session.""" + # First create a session with some data + session_id = "resume-test-session" + session_data = SessionData( + session_id=session_id, + start_time=datetime.now(), + last_activity=datetime.now(), + working_directory="/test", + ) + session_data.add_message(UserMessage(content="Previous message")) + await persistent_client._persistence.save_session(session_data) + + # Resume the session + await persistent_client.start_or_resume_session(session_id) + + # Should load the existing session data + assert persistent_client._current_session_id == session_id + assert persistent_client._session_data is not None + assert persistent_client._session_data.session_id == session_id + assert len(persistent_client._session_data.conversation_history) == 1 + + # Should set resume option in client options + assert persistent_client._client.options.resume == session_id + + async def test_start_or_resume_session_nonexistent(self, persistent_client): + """Test resuming a nonexistent session.""" + nonexistent_id = "nonexistent-session-id" + + # Try to resume nonexistent session + await persistent_client.start_or_resume_session(nonexistent_id) + + # Should still set the session_id for resume, even if no local data exists + # (the server might have the session even if we don't have local data) + assert persistent_client._current_session_id is None # No local data loaded + assert persistent_client._session_data is None + + # But should still set resume option for server-side resume + assert persistent_client._client.options.resume == nonexistent_id + + async def test_start_or_resume_creates_options_if_needed(self, persistent_client): + """Test that start_or_resume_session creates options if client has none.""" + # Ensure client has no options + persistent_client._client.options = None + + # Resume a session + await persistent_client.start_or_resume_session("test-session") + + # Should create options with resume set + assert persistent_client._client.options is not None + assert persistent_client._client.options.resume == "test-session" \ No newline at end of file