98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
"""Claude Agent SDK wrapper for per-thread Discord conversations."""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import AsyncIterator, Awaitable, Callable, Optional
|
|
|
|
from claude_agent_sdk import (
|
|
AssistantMessage,
|
|
ClaudeAgentOptions,
|
|
ClaudeSDKClient,
|
|
ResultMessage,
|
|
TextBlock,
|
|
ToolUseBlock,
|
|
)
|
|
|
|
from bot.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_system_prompt() -> str:
|
|
claude_md = settings.claudetools_root / ".claude" / "CLAUDE.md"
|
|
return claude_md.read_text(encoding="utf-8")
|
|
|
|
|
|
class ThreadAgent:
|
|
"""One persistent Claude Code session bound to a Discord thread."""
|
|
|
|
def __init__(self, system_prompt: str, cwd: Path, model: str) -> None:
|
|
self._options = ClaudeAgentOptions(
|
|
system_prompt=system_prompt,
|
|
cwd=str(cwd),
|
|
model=model,
|
|
)
|
|
self._client: Optional[ClaudeSDKClient] = None
|
|
|
|
async def start(self) -> None:
|
|
self._client = ClaudeSDKClient(options=self._options)
|
|
await self._client.connect()
|
|
|
|
async def stop(self) -> None:
|
|
if self._client is not None:
|
|
await self._client.disconnect()
|
|
self._client = None
|
|
|
|
async def send(
|
|
self,
|
|
user_message: str,
|
|
on_text: Callable[[str], Awaitable[None]],
|
|
on_tool_use: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
) -> str:
|
|
if self._client is None:
|
|
raise RuntimeError("ThreadAgent.send() called before start()")
|
|
|
|
await self._client.query(user_message)
|
|
|
|
full_text = ""
|
|
async for message in self._client.receive_response():
|
|
if isinstance(message, AssistantMessage):
|
|
for block in message.content:
|
|
if isinstance(block, TextBlock):
|
|
full_text += block.text
|
|
await on_text(block.text)
|
|
elif isinstance(block, ToolUseBlock) and on_tool_use is not None:
|
|
await on_tool_use(block.name)
|
|
elif isinstance(message, ResultMessage):
|
|
break
|
|
|
|
return full_text
|
|
|
|
|
|
class ClaudeAgentManager:
|
|
"""Owns one ThreadAgent per Discord thread id."""
|
|
|
|
def __init__(self) -> None:
|
|
self._system_prompt = _load_system_prompt()
|
|
self._cwd = settings.claudetools_root
|
|
self._model = settings.claude_model
|
|
self._agents: dict[int, ThreadAgent] = {}
|
|
|
|
async def get_or_create(self, thread_id: int) -> ThreadAgent:
|
|
agent = self._agents.get(thread_id)
|
|
if agent is None:
|
|
logger.info("[INFO] Starting new agent session for thread %d", thread_id)
|
|
agent = ThreadAgent(self._system_prompt, self._cwd, self._model)
|
|
await agent.start()
|
|
self._agents[thread_id] = agent
|
|
return agent
|
|
|
|
async def shutdown(self) -> None:
|
|
for thread_id, agent in list(self._agents.items()):
|
|
try:
|
|
await agent.stop()
|
|
except Exception as e:
|
|
logger.warning("[WARNING] Failed to stop agent %d: %s", thread_id, e)
|
|
self._agents.clear()
|