"""Claude Agent SDK wrapper for per-thread Discord conversations.""" from __future__ import annotations import logging from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, Optional from claude_agent_sdk import ( AssistantMessage, ClaudeAgentOptions, ClaudeSDKClient, ResultMessage, TextBlock, ToolUseBlock, ) from bot.config import settings logger = logging.getLogger(__name__) def _load_system_prompt() -> str: claude_md = settings.claudetools_root / ".claude" / "CLAUDE.md" return claude_md.read_text(encoding="utf-8") class ThreadAgent: """One persistent Claude Code session bound to a Discord thread.""" def __init__(self, system_prompt: str, cwd: Path, model: str) -> None: self._options = ClaudeAgentOptions( system_prompt=system_prompt, cwd=str(cwd), model=model, ) self._client: Optional[ClaudeSDKClient] = None async def start(self) -> None: self._client = ClaudeSDKClient(options=self._options) await self._client.connect() async def stop(self) -> None: if self._client is not None: await self._client.disconnect() self._client = None async def send( self, user_message: str, on_text: Callable[[str], Awaitable[None]], on_tool_use: Optional[Callable[[str], Awaitable[None]]] = None, ) -> str: if self._client is None: raise RuntimeError("ThreadAgent.send() called before start()") await self._client.query(user_message) full_text = "" async for message in self._client.receive_response(): if isinstance(message, AssistantMessage): for block in message.content: if isinstance(block, TextBlock): full_text += block.text await on_text(block.text) elif isinstance(block, ToolUseBlock) and on_tool_use is not None: await on_tool_use(block.name) elif isinstance(message, ResultMessage): break return full_text class ClaudeAgentManager: """Owns one ThreadAgent per Discord thread id.""" def __init__(self) -> None: self._system_prompt = _load_system_prompt() self._cwd = settings.claudetools_root self._model = settings.claude_model self._agents: dict[int, ThreadAgent] = {} async def get_or_create(self, thread_id: int) -> ThreadAgent: agent = self._agents.get(thread_id) if agent is None: logger.info("[INFO] Starting new agent session for thread %d", thread_id) agent = ThreadAgent(self._system_prompt, self._cwd, self._model) await agent.start() self._agents[thread_id] = agent return agent async def shutdown(self) -> None: for thread_id, agent in list(self._agents.items()): try: await agent.stop() except Exception as e: logger.warning("[WARNING] Failed to stop agent %d: %s", thread_id, e) self._agents.clear()