Files
claudetools/api/routers/conversation_contexts.py
Mike Swanson a534a72a0f Fix recall endpoint: Add search_term, input validation, and proper contexts array return
- Add search_term parameter with regex validation (alphanumeric + punctuation)
- Add tag validation to prevent SQL injection
- Change return format from {context: string} to {total, contexts: array}
- Use ConversationContextResponse schema for proper serialization
- Improves security and provides structured data for clients

Related: Context Recall System fixes (COMPLETE_SYSTEM_SUMMARY.md)
2026-01-18 14:08:15 -07:00

313 lines
9.6 KiB
Python

"""
ConversationContext API router for ClaudeTools.
Defines all REST API endpoints for managing conversation contexts,
including context recall functionality for Claude's memory system.
"""
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from api.database import get_db
from api.middleware.auth import get_current_user
from api.schemas.conversation_context import (
ConversationContextCreate,
ConversationContextResponse,
ConversationContextUpdate,
)
from api.services import conversation_context_service
# Create router with prefix and tags
router = APIRouter()
@router.get(
"",
response_model=dict,
summary="List all conversation contexts",
description="Retrieve a paginated list of all conversation contexts with optional filtering",
status_code=status.HTTP_200_OK,
)
def list_conversation_contexts(
skip: int = Query(
default=0,
ge=0,
description="Number of records to skip for pagination"
),
limit: int = Query(
default=100,
ge=1,
le=1000,
description="Maximum number of records to return (max 1000)"
),
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
List all conversation contexts with pagination.
Returns contexts ordered by relevance score and recency.
"""
try:
contexts, total = conversation_context_service.get_conversation_contexts(db, skip, limit)
return {
"total": total,
"skip": skip,
"limit": limit,
"contexts": [ConversationContextResponse.model_validate(ctx) for ctx in contexts]
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve conversation contexts: {str(e)}"
)
@router.get(
"/recall",
response_model=dict,
summary="Retrieve relevant contexts for injection",
description="Get token-efficient context formatted for Claude prompt injection",
status_code=status.HTTP_200_OK,
)
def recall_context(
search_term: Optional[str] = Query(
None,
max_length=200,
pattern=r'^[a-zA-Z0-9\s\-_.,!?()]+$',
description="Full-text search term (alphanumeric, spaces, and basic punctuation only)"
),
project_id: Optional[UUID] = Query(None, description="Filter by project ID"),
tags: Optional[List[str]] = Query(
None,
description="Filter by tags (OR logic)",
max_items=20
),
limit: int = Query(
default=10,
ge=1,
le=50,
description="Maximum number of contexts to retrieve (max 50)"
),
min_relevance_score: float = Query(
default=5.0,
ge=0.0,
le=10.0,
description="Minimum relevance score threshold (0.0-10.0)"
),
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Retrieve relevant contexts formatted for Claude prompt injection.
This endpoint returns contexts matching the search criteria with
properly formatted JSON response containing the contexts array.
Query Parameters:
- search_term: Full-text search across title and summary (uses FULLTEXT index)
- project_id: Filter contexts by project
- tags: Filter contexts by tags (any match)
- limit: Maximum number of contexts to retrieve
- min_relevance_score: Minimum relevance score threshold
Returns JSON with contexts array and metadata.
"""
# Validate tags to prevent SQL injection
if tags:
import re
tag_pattern = re.compile(r'^[a-zA-Z0-9\-_]+$')
for tag in tags:
if not tag_pattern.match(tag):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid tag format: '{tag}'. Tags must be alphanumeric with hyphens or underscores only."
)
try:
contexts, total = conversation_context_service.get_recall_context(
db=db,
search_term=search_term,
project_id=project_id,
tags=tags,
limit=limit,
min_relevance_score=min_relevance_score
)
return {
"total": total,
"limit": limit,
"search_term": search_term,
"project_id": str(project_id) if project_id else None,
"tags": tags,
"min_relevance_score": min_relevance_score,
"contexts": [ConversationContextResponse.model_validate(ctx) for ctx in contexts]
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve recall context: {str(e)}"
)
@router.get(
"/by-project/{project_id}",
response_model=dict,
summary="Get conversation contexts by project",
description="Retrieve all conversation contexts for a specific project",
status_code=status.HTTP_200_OK,
)
def get_conversation_contexts_by_project(
project_id: UUID,
skip: int = Query(default=0, ge=0),
limit: int = Query(default=100, ge=1, le=1000),
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Get all conversation contexts for a specific project.
"""
try:
contexts, total = conversation_context_service.get_conversation_contexts_by_project(
db, project_id, skip, limit
)
return {
"total": total,
"skip": skip,
"limit": limit,
"project_id": str(project_id),
"contexts": [ConversationContextResponse.model_validate(ctx) for ctx in contexts]
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve conversation contexts: {str(e)}"
)
@router.get(
"/by-session/{session_id}",
response_model=dict,
summary="Get conversation contexts by session",
description="Retrieve all conversation contexts for a specific session",
status_code=status.HTTP_200_OK,
)
def get_conversation_contexts_by_session(
session_id: UUID,
skip: int = Query(default=0, ge=0),
limit: int = Query(default=100, ge=1, le=1000),
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Get all conversation contexts for a specific session.
"""
try:
contexts, total = conversation_context_service.get_conversation_contexts_by_session(
db, session_id, skip, limit
)
return {
"total": total,
"skip": skip,
"limit": limit,
"session_id": str(session_id),
"contexts": [ConversationContextResponse.model_validate(ctx) for ctx in contexts]
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve conversation contexts: {str(e)}"
)
@router.get(
"/{context_id}",
response_model=ConversationContextResponse,
summary="Get conversation context by ID",
description="Retrieve a single conversation context by its unique identifier",
status_code=status.HTTP_200_OK,
)
def get_conversation_context(
context_id: UUID,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Get a specific conversation context by ID.
"""
context = conversation_context_service.get_conversation_context_by_id(db, context_id)
return ConversationContextResponse.model_validate(context)
@router.post(
"",
response_model=ConversationContextResponse,
summary="Create new conversation context",
description="Create a new conversation context with the provided details",
status_code=status.HTTP_201_CREATED,
)
def create_conversation_context(
context_data: ConversationContextCreate,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Create a new conversation context.
Requires a valid JWT token with appropriate permissions.
"""
context = conversation_context_service.create_conversation_context(db, context_data)
return ConversationContextResponse.model_validate(context)
@router.put(
"/{context_id}",
response_model=ConversationContextResponse,
summary="Update conversation context",
description="Update an existing conversation context's details",
status_code=status.HTTP_200_OK,
)
def update_conversation_context(
context_id: UUID,
context_data: ConversationContextUpdate,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Update an existing conversation context.
Only provided fields will be updated. All fields are optional.
"""
context = conversation_context_service.update_conversation_context(db, context_id, context_data)
return ConversationContextResponse.model_validate(context)
@router.delete(
"/{context_id}",
response_model=dict,
summary="Delete conversation context",
description="Delete a conversation context by its ID",
status_code=status.HTTP_200_OK,
)
def delete_conversation_context(
context_id: UUID,
db: Session = Depends(get_db),
current_user: dict = Depends(get_current_user),
):
"""
Delete a conversation context.
This is a permanent operation and cannot be undone.
"""
return conversation_context_service.delete_conversation_context(db, context_id)