"""Claude Agent SDK wrapper for per-thread Discord conversations.""" from __future__ import annotations import logging from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, Optional from claude_agent_sdk import ( AssistantMessage, ClaudeAgentOptions, ClaudeSDKClient, ResultMessage, TextBlock, ToolUseBlock, ) from bot.config import settings logger = logging.getLogger(__name__) def _load_system_prompt() -> str: prompt_path = settings.claudetools_root / settings.discord_system_prompt if prompt_path.exists(): return prompt_path.read_text(encoding="utf-8") logger.warning( "[WARNING] Discord system prompt not found at %s — falling back to CLAUDE.md", prompt_path, ) return (settings.claudetools_root / ".claude" / "CLAUDE.md").read_text(encoding="utf-8") class ThreadAgent: """One persistent Claude Code session bound to a Discord thread.""" def __init__( self, system_prompt: str, cwd: Path, model: str, env: Optional[dict[str, str]] = None, ) -> None: # `env` is per-session (per Discord thread), so concurrent threads carry # their own requester attribution without colliding. It reaches the # Bash tool (and thus whoami-block.sh / sync.sh) via the SDK subprocess. self._options = ClaudeAgentOptions( system_prompt=system_prompt, cwd=str(cwd), model=model, env=env or {}, ) self._client: Optional[ClaudeSDKClient] = None async def start(self) -> None: self._client = ClaudeSDKClient(options=self._options) await self._client.connect() async def stop(self) -> None: if self._client is not None: await self._client.disconnect() self._client = None async def send( self, user_message: str, on_text: Callable[[str], Awaitable[None]], on_tool_use: Optional[Callable[[str], Awaitable[None]]] = None, ) -> str: if self._client is None: raise RuntimeError("ThreadAgent.send() called before start()") try: return await self._query_once(user_message, on_text, on_tool_use) except Exception as e: # noqa: BLE001 — recover a dead SDK session, then retry once if "session is closed" in str(e).lower() or "session closed" in str(e).lower(): logger.warning( "[WARNING] SDK session was closed; reconnecting and retrying once" ) await self._reconnect() return await self._query_once(user_message, on_text, on_tool_use) raise async def _reconnect(self) -> None: """Tear down and re-establish the SDK session (it can close on idle).""" try: if self._client is not None: await self._client.disconnect() except Exception as e: # noqa: BLE001 logger.warning("[WARNING] disconnect during reconnect failed: %s", e) self._client = ClaudeSDKClient(options=self._options) await self._client.connect() async def _query_once( self, user_message: str, on_text: Callable[[str], Awaitable[None]], on_tool_use: Optional[Callable[[str], Awaitable[None]]], ) -> str: assert self._client is not None await self._client.query(user_message) full_text = "" result_text: Optional[str] = None result_subtype: Optional[str] = None async for message in self._client.receive_response(): if isinstance(message, AssistantMessage): for block in message.content: if isinstance(block, TextBlock): full_text += block.text await on_text(block.text) elif isinstance(block, ToolUseBlock) and on_tool_use is not None: await on_tool_use(block.name) elif isinstance(message, ResultMessage): # The SDK delivers the final answer here; capture it as the # fallback when no TextBlock streamed (the cause of "(no response)"). result_text = message.result result_subtype = message.subtype break if full_text.strip(): return full_text if result_text and result_text.strip(): return result_text # Genuinely nothing — never leave the user with a blank "no response": # explain why so it's actionable. logger.warning( "[WARNING] empty turn: no text blocks and no result (subtype=%s)", result_subtype, ) return ( f"[INFO] I finished without a text reply (subtype={result_subtype}). " "I may have only run tools or hit a turn limit — ask me to summarize " "what I found, or rephrase the question." ) class ClaudeAgentManager: """Owns one ThreadAgent per Discord thread id.""" def __init__(self) -> None: self._system_prompt = _load_system_prompt() self._cwd = settings.claudetools_root self._model = settings.claude_model self._agents: dict[int, ThreadAgent] = {} async def get_or_create( self, thread_id: int, env: Optional[dict[str, str]] = None ) -> ThreadAgent: # `env` is applied only when the thread's session is first created, so # attribution pins to the thread opener (the SDK bakes env at session # spawn and cannot change it per turn without losing context). Follow-up # turns reuse the opener's attribution by design. agent = self._agents.get(thread_id) if agent is None: logger.info("[INFO] Starting new agent session for thread %d", thread_id) agent = ThreadAgent(self._system_prompt, self._cwd, self._model, env=env) await agent.start() self._agents[thread_id] = agent return agent async def shutdown(self) -> None: for thread_id, agent in list(self._agents.items()): try: await agent.stop() except Exception as e: logger.warning("[WARNING] Failed to stop agent %d: %s", thread_id, e) self._agents.clear()