feat: Major directory reorganization and cleanup
Reorganized project structure for better maintainability and reduced disk usage by 95.9% (11 GB -> 451 MB). Directory Reorganization (85% reduction in root files): - Created docs/ with subdirectories (deployment, testing, database, etc.) - Created infrastructure/vpn-configs/ for VPN scripts - Moved 90+ files from root to organized locations - Archived obsolete documentation (context system, offline mode, zombie debugging) - Moved all test files to tests/ directory - Root directory: 119 files -> 18 files Disk Cleanup (10.55 GB recovered): - Deleted Rust build artifacts: 9.6 GB (target/ directories) - Deleted Python virtual environments: 161 MB (venv/ directories) - Deleted Python cache: 50 KB (__pycache__/) New Structure: - docs/ - All documentation organized by category - docs/archives/ - Obsolete but preserved documentation - infrastructure/ - VPN configs and SSH setup - tests/ - All test files consolidated - logs/ - Ready for future logs Benefits: - Cleaner root directory (18 vs 119 files) - Logical organization of documentation - 95.9% disk space reduction - Faster navigation and discovery - Better portability (build artifacts excluded) Build artifacts can be regenerated: - Rust: cargo build --release (5-15 min per project) - Python: pip install -r requirements.txt (2-3 min) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
821
tests/test_api_endpoints.py
Normal file
821
tests/test_api_endpoints.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
Comprehensive API Endpoint Tests for ClaudeTools FastAPI Application
|
||||
|
||||
This test suite validates all 5 core API endpoints:
|
||||
- Machines
|
||||
- Clients
|
||||
- Projects
|
||||
- Sessions
|
||||
- Tags
|
||||
|
||||
Tests include:
|
||||
- API startup and health checks
|
||||
- CRUD operations for all entities
|
||||
- Authentication (with and without JWT tokens)
|
||||
- Pagination parameters
|
||||
- Error handling (404, 409, 422 responses)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import the FastAPI app and auth utilities
|
||||
from api.main import app
|
||||
from api.middleware.auth import create_access_token
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
# Test counters
|
||||
tests_passed = 0
|
||||
tests_failed = 0
|
||||
test_results = []
|
||||
|
||||
|
||||
def log_test(test_name: str, passed: bool, error_msg: str = ""):
|
||||
"""Log test result and update counters."""
|
||||
global tests_passed, tests_failed
|
||||
if passed:
|
||||
tests_passed += 1
|
||||
status = "PASS"
|
||||
symbol = "[+]"
|
||||
else:
|
||||
tests_failed += 1
|
||||
status = "FAIL"
|
||||
symbol = "[-]"
|
||||
|
||||
result = f"{symbol} {status}: {test_name}"
|
||||
if error_msg:
|
||||
result += f"\n Error: {error_msg}"
|
||||
|
||||
test_results.append((test_name, passed, error_msg))
|
||||
print(result)
|
||||
|
||||
|
||||
def create_test_token():
|
||||
"""Create a test JWT token for authentication."""
|
||||
token_data = {
|
||||
"sub": "test_user@claudetools.com",
|
||||
"scopes": ["msp:read", "msp:write", "msp:admin"]
|
||||
}
|
||||
return create_access_token(token_data, expires_delta=timedelta(hours=1))
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
"""Get authorization headers with test token."""
|
||||
token = create_test_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 1: API Health and Startup Tests
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 1: API Health and Startup Tests")
|
||||
print("="*70 + "\n")
|
||||
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint returns API status."""
|
||||
try:
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["status"] == "online", f"Expected status 'online', got {data.get('status')}"
|
||||
assert "service" in data, "Response missing 'service' field"
|
||||
assert "version" in data, "Response missing 'version' field"
|
||||
log_test("Root endpoint (/)", True)
|
||||
except Exception as e:
|
||||
log_test("Root endpoint (/)", False, str(e))
|
||||
|
||||
def test_health_endpoint():
|
||||
"""Test health check endpoint."""
|
||||
try:
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy", f"Expected status 'healthy', got {data.get('status')}"
|
||||
log_test("Health check endpoint (/health)", True)
|
||||
except Exception as e:
|
||||
log_test("Health check endpoint (/health)", False, str(e))
|
||||
|
||||
def test_jwt_token_creation():
|
||||
"""Test JWT token creation."""
|
||||
try:
|
||||
token = create_test_token()
|
||||
assert token is not None, "Token creation returned None"
|
||||
assert len(token) > 20, "Token seems too short"
|
||||
log_test("JWT token creation", True)
|
||||
except Exception as e:
|
||||
log_test("JWT token creation", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 2: Authentication Tests
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 2: Authentication Tests")
|
||||
print("="*70 + "\n")
|
||||
|
||||
def test_unauthenticated_access():
|
||||
"""Test that protected endpoints reject requests without auth."""
|
||||
try:
|
||||
response = client.get("/api/machines")
|
||||
# Can be 401 (Unauthorized) or 403 (Forbidden) depending on implementation
|
||||
assert response.status_code in [401, 403], f"Expected 401 or 403, got {response.status_code}"
|
||||
log_test("Unauthenticated access rejected", True)
|
||||
except Exception as e:
|
||||
log_test("Unauthenticated access rejected", False, str(e))
|
||||
|
||||
def test_authenticated_access():
|
||||
"""Test that protected endpoints accept valid JWT tokens."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/machines", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
log_test("Authenticated access accepted", True)
|
||||
except Exception as e:
|
||||
log_test("Authenticated access accepted", False, str(e))
|
||||
|
||||
def test_invalid_token():
|
||||
"""Test that invalid tokens are rejected."""
|
||||
try:
|
||||
headers = {"Authorization": "Bearer invalid_token_string"}
|
||||
response = client.get("/api/machines", headers=headers)
|
||||
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
|
||||
log_test("Invalid token rejected", True)
|
||||
except Exception as e:
|
||||
log_test("Invalid token rejected", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 3: Machine CRUD Operations
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 3: Machine CRUD Operations")
|
||||
print("="*70 + "\n")
|
||||
|
||||
machine_id = None
|
||||
|
||||
def test_create_machine():
|
||||
"""Test creating a new machine."""
|
||||
global machine_id
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
machine_data = {
|
||||
"hostname": f"test-machine-{uuid4().hex[:8]}",
|
||||
"friendly_name": "Test Machine",
|
||||
"machine_type": "laptop",
|
||||
"platform": "win32",
|
||||
"is_active": True
|
||||
}
|
||||
response = client.post("/api/machines", json=machine_data, headers=headers)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
|
||||
data = response.json()
|
||||
assert "id" in data, f"Response missing 'id' field. Data: {data}"
|
||||
machine_id = data["id"]
|
||||
print(f" Created machine with ID: {machine_id}")
|
||||
log_test("Create machine", True)
|
||||
except Exception as e:
|
||||
log_test("Create machine", False, str(e))
|
||||
|
||||
def test_list_machines():
|
||||
"""Test listing machines with pagination."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/machines?skip=0&limit=10", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "total" in data, "Response missing 'total' field"
|
||||
assert "machines" in data, "Response missing 'machines' field"
|
||||
assert isinstance(data["machines"], list), "machines field is not a list"
|
||||
log_test("List machines", True)
|
||||
except Exception as e:
|
||||
log_test("List machines", False, str(e))
|
||||
|
||||
def test_get_machine():
|
||||
"""Test retrieving a specific machine by ID."""
|
||||
try:
|
||||
if machine_id is None:
|
||||
raise Exception("No machine_id available (create test may have failed)")
|
||||
headers = get_auth_headers()
|
||||
print(f" Fetching machine with ID: {machine_id} (type: {type(machine_id)})")
|
||||
|
||||
# List all machines to check if our machine exists
|
||||
list_response = client.get("/api/machines", headers=headers)
|
||||
all_machines = list_response.json().get("machines", [])
|
||||
print(f" Total machines in DB: {len(all_machines)}")
|
||||
if all_machines:
|
||||
print(f" First machine ID: {all_machines[0].get('id')} (type: {type(all_machines[0].get('id'))})")
|
||||
|
||||
response = client.get(f"/api/machines/{machine_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}. Response: {response.text}"
|
||||
data = response.json()
|
||||
assert str(data["id"]) == str(machine_id), f"Expected ID {machine_id}, got {data.get('id')}"
|
||||
log_test("Get machine by ID", True)
|
||||
except Exception as e:
|
||||
log_test("Get machine by ID", False, str(e))
|
||||
|
||||
def test_update_machine():
|
||||
"""Test updating a machine."""
|
||||
try:
|
||||
if machine_id is None:
|
||||
raise Exception("No machine_id available (create test may have failed)")
|
||||
headers = get_auth_headers()
|
||||
update_data = {
|
||||
"friendly_name": "Updated Test Machine",
|
||||
"notes": "Updated during testing"
|
||||
}
|
||||
response = client.put(f"/api/machines/{machine_id}", json=update_data, headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["friendly_name"] == "Updated Test Machine", "Update not reflected"
|
||||
log_test("Update machine", True)
|
||||
except Exception as e:
|
||||
log_test("Update machine", False, str(e))
|
||||
|
||||
def test_machine_not_found():
|
||||
"""Test getting non-existent machine returns 404."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
fake_id = str(uuid4())
|
||||
response = client.get(f"/api/machines/{fake_id}", headers=headers)
|
||||
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
|
||||
log_test("Machine not found (404)", True)
|
||||
except Exception as e:
|
||||
log_test("Machine not found (404)", False, str(e))
|
||||
|
||||
def test_delete_machine():
|
||||
"""Test deleting a machine."""
|
||||
try:
|
||||
if machine_id is None:
|
||||
raise Exception("No machine_id available (create test may have failed)")
|
||||
headers = get_auth_headers()
|
||||
response = client.delete(f"/api/machines/{machine_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
log_test("Delete machine", True)
|
||||
except Exception as e:
|
||||
log_test("Delete machine", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 4: Client CRUD Operations
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 4: Client CRUD Operations")
|
||||
print("="*70 + "\n")
|
||||
|
||||
client_id = None
|
||||
|
||||
def test_create_client():
|
||||
"""Test creating a new client."""
|
||||
global client_id
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
client_data = {
|
||||
"name": f"Test Client {uuid4().hex[:8]}",
|
||||
"type": "msp_client",
|
||||
"primary_contact": "John Doe",
|
||||
"is_active": True
|
||||
}
|
||||
response = client.post("/api/clients", json=client_data, headers=headers)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
|
||||
data = response.json()
|
||||
assert "id" in data, f"Response missing 'id' field. Data: {data}"
|
||||
client_id = data["id"]
|
||||
print(f" Created client with ID: {client_id}")
|
||||
log_test("Create client", True)
|
||||
except Exception as e:
|
||||
log_test("Create client", False, str(e))
|
||||
|
||||
def test_list_clients():
|
||||
"""Test listing clients with pagination."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/clients?skip=0&limit=10", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "total" in data, "Response missing 'total' field"
|
||||
assert "clients" in data, "Response missing 'clients' field"
|
||||
log_test("List clients", True)
|
||||
except Exception as e:
|
||||
log_test("List clients", False, str(e))
|
||||
|
||||
def test_get_client():
|
||||
"""Test retrieving a specific client by ID."""
|
||||
try:
|
||||
if client_id is None:
|
||||
raise Exception("No client_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.get(f"/api/clients/{client_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["id"] == client_id, f"Expected ID {client_id}, got {data.get('id')}"
|
||||
log_test("Get client by ID", True)
|
||||
except Exception as e:
|
||||
log_test("Get client by ID", False, str(e))
|
||||
|
||||
def test_update_client():
|
||||
"""Test updating a client."""
|
||||
try:
|
||||
if client_id is None:
|
||||
raise Exception("No client_id available")
|
||||
headers = get_auth_headers()
|
||||
update_data = {
|
||||
"primary_contact": "Jane Doe",
|
||||
"notes": "Updated contact"
|
||||
}
|
||||
response = client.put(f"/api/clients/{client_id}", json=update_data, headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["primary_contact"] == "Jane Doe", "Update not reflected"
|
||||
log_test("Update client", True)
|
||||
except Exception as e:
|
||||
log_test("Update client", False, str(e))
|
||||
|
||||
def test_delete_client():
|
||||
"""Test deleting a client."""
|
||||
try:
|
||||
if client_id is None:
|
||||
raise Exception("No client_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.delete(f"/api/clients/{client_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
log_test("Delete client", True)
|
||||
except Exception as e:
|
||||
log_test("Delete client", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 5: Project CRUD Operations
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 5: Project CRUD Operations")
|
||||
print("="*70 + "\n")
|
||||
|
||||
project_id = None
|
||||
project_client_id = None
|
||||
|
||||
def test_create_project():
|
||||
"""Test creating a new project."""
|
||||
global project_id, project_client_id
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
|
||||
# First create a client for the project
|
||||
client_data = {
|
||||
"name": f"Project Test Client {uuid4().hex[:8]}",
|
||||
"type": "msp_client",
|
||||
"is_active": True
|
||||
}
|
||||
client_response = client.post("/api/clients", json=client_data, headers=headers)
|
||||
assert client_response.status_code == 201, f"Failed to create test client: {client_response.text}"
|
||||
project_client_id = client_response.json()["id"]
|
||||
|
||||
# Now create the project
|
||||
project_data = {
|
||||
"name": f"Test Project {uuid4().hex[:8]}",
|
||||
"client_id": project_client_id,
|
||||
"status": "active"
|
||||
}
|
||||
response = client.post("/api/projects", json=project_data, headers=headers)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
|
||||
data = response.json()
|
||||
assert "id" in data, f"Response missing 'id' field. Data: {data}"
|
||||
project_id = data["id"]
|
||||
print(f" Created project with ID: {project_id}")
|
||||
log_test("Create project", True)
|
||||
except Exception as e:
|
||||
log_test("Create project", False, str(e))
|
||||
|
||||
def test_list_projects():
|
||||
"""Test listing projects with pagination."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/projects?skip=0&limit=10", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "total" in data, "Response missing 'total' field"
|
||||
assert "projects" in data, "Response missing 'projects' field"
|
||||
log_test("List projects", True)
|
||||
except Exception as e:
|
||||
log_test("List projects", False, str(e))
|
||||
|
||||
def test_get_project():
|
||||
"""Test retrieving a specific project by ID."""
|
||||
try:
|
||||
if project_id is None:
|
||||
raise Exception("No project_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.get(f"/api/projects/{project_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["id"] == project_id, f"Expected ID {project_id}, got {data.get('id')}"
|
||||
log_test("Get project by ID", True)
|
||||
except Exception as e:
|
||||
log_test("Get project by ID", False, str(e))
|
||||
|
||||
def test_update_project():
|
||||
"""Test updating a project."""
|
||||
try:
|
||||
if project_id is None:
|
||||
raise Exception("No project_id available")
|
||||
headers = get_auth_headers()
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"notes": "Project completed during testing"
|
||||
}
|
||||
response = client.put(f"/api/projects/{project_id}", json=update_data, headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["status"] == "completed", "Update not reflected"
|
||||
log_test("Update project", True)
|
||||
except Exception as e:
|
||||
log_test("Update project", False, str(e))
|
||||
|
||||
def test_delete_project():
|
||||
"""Test deleting a project."""
|
||||
try:
|
||||
if project_id is None:
|
||||
raise Exception("No project_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.delete(f"/api/projects/{project_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
|
||||
# Clean up test client
|
||||
if project_client_id:
|
||||
client.delete(f"/api/clients/{project_client_id}", headers=headers)
|
||||
|
||||
log_test("Delete project", True)
|
||||
except Exception as e:
|
||||
log_test("Delete project", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 6: Session CRUD Operations
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 6: Session CRUD Operations")
|
||||
print("="*70 + "\n")
|
||||
|
||||
session_id = None
|
||||
session_client_id = None
|
||||
session_project_id = None
|
||||
|
||||
def test_create_session():
|
||||
"""Test creating a new session."""
|
||||
global session_id, session_client_id, session_project_id
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Create client for session
|
||||
client_data = {
|
||||
"name": f"Session Test Client {uuid4().hex[:8]}",
|
||||
"type": "msp_client",
|
||||
"is_active": True
|
||||
}
|
||||
client_response = client.post("/api/clients", json=client_data, headers=headers)
|
||||
assert client_response.status_code == 201, f"Failed to create test client: {client_response.text}"
|
||||
session_client_id = client_response.json()["id"]
|
||||
|
||||
# Create project for session
|
||||
project_data = {
|
||||
"name": f"Session Test Project {uuid4().hex[:8]}",
|
||||
"client_id": session_client_id,
|
||||
"status": "active"
|
||||
}
|
||||
project_response = client.post("/api/projects", json=project_data, headers=headers)
|
||||
assert project_response.status_code == 201, f"Failed to create test project: {project_response.text}"
|
||||
session_project_id = project_response.json()["id"]
|
||||
|
||||
# Create session
|
||||
from datetime import date
|
||||
session_data = {
|
||||
"session_title": f"Test Session {uuid4().hex[:8]}",
|
||||
"session_date": str(date.today()),
|
||||
"client_id": session_client_id,
|
||||
"project_id": session_project_id,
|
||||
"status": "completed"
|
||||
}
|
||||
response = client.post("/api/sessions", json=session_data, headers=headers)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
|
||||
data = response.json()
|
||||
assert "id" in data, f"Response missing 'id' field. Data: {data}"
|
||||
session_id = data["id"]
|
||||
print(f" Created session with ID: {session_id}")
|
||||
log_test("Create session", True)
|
||||
except Exception as e:
|
||||
log_test("Create session", False, str(e))
|
||||
|
||||
def test_list_sessions():
|
||||
"""Test listing sessions with pagination."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/sessions?skip=0&limit=10", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "total" in data, "Response missing 'total' field"
|
||||
assert "sessions" in data, "Response missing 'sessions' field"
|
||||
log_test("List sessions", True)
|
||||
except Exception as e:
|
||||
log_test("List sessions", False, str(e))
|
||||
|
||||
def test_get_session():
|
||||
"""Test retrieving a specific session by ID."""
|
||||
try:
|
||||
if session_id is None:
|
||||
raise Exception("No session_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.get(f"/api/sessions/{session_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["id"] == session_id, f"Expected ID {session_id}, got {data.get('id')}"
|
||||
log_test("Get session by ID", True)
|
||||
except Exception as e:
|
||||
log_test("Get session by ID", False, str(e))
|
||||
|
||||
def test_update_session():
|
||||
"""Test updating a session."""
|
||||
try:
|
||||
if session_id is None:
|
||||
raise Exception("No session_id available")
|
||||
headers = get_auth_headers()
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"summary": "Test session completed"
|
||||
}
|
||||
response = client.put(f"/api/sessions/{session_id}", json=update_data, headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["status"] == "completed", "Update not reflected"
|
||||
log_test("Update session", True)
|
||||
except Exception as e:
|
||||
log_test("Update session", False, str(e))
|
||||
|
||||
def test_delete_session():
|
||||
"""Test deleting a session."""
|
||||
try:
|
||||
if session_id is None:
|
||||
raise Exception("No session_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.delete(f"/api/sessions/{session_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
|
||||
# Clean up test data
|
||||
if session_project_id:
|
||||
client.delete(f"/api/projects/{session_project_id}", headers=headers)
|
||||
if session_client_id:
|
||||
client.delete(f"/api/clients/{session_client_id}", headers=headers)
|
||||
|
||||
log_test("Delete session", True)
|
||||
except Exception as e:
|
||||
log_test("Delete session", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 7: Tag CRUD Operations
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 7: Tag CRUD Operations")
|
||||
print("="*70 + "\n")
|
||||
|
||||
tag_id = None
|
||||
|
||||
def test_create_tag():
|
||||
"""Test creating a new tag."""
|
||||
global tag_id
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
tag_data = {
|
||||
"name": f"test-tag-{uuid4().hex[:8]}",
|
||||
"category": "technology",
|
||||
"color": "#FF5733"
|
||||
}
|
||||
response = client.post("/api/tags", json=tag_data, headers=headers)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "id" in data, "Response missing 'id' field"
|
||||
tag_id = data["id"]
|
||||
log_test("Create tag", True)
|
||||
except Exception as e:
|
||||
log_test("Create tag", False, str(e))
|
||||
|
||||
def test_list_tags():
|
||||
"""Test listing tags with pagination."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/tags?skip=0&limit=10", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert "total" in data, "Response missing 'total' field"
|
||||
assert "tags" in data, "Response missing 'tags' field"
|
||||
log_test("List tags", True)
|
||||
except Exception as e:
|
||||
log_test("List tags", False, str(e))
|
||||
|
||||
def test_get_tag():
|
||||
"""Test retrieving a specific tag by ID."""
|
||||
try:
|
||||
if tag_id is None:
|
||||
raise Exception("No tag_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.get(f"/api/tags/{tag_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["id"] == tag_id, f"Expected ID {tag_id}, got {data.get('id')}"
|
||||
log_test("Get tag by ID", True)
|
||||
except Exception as e:
|
||||
log_test("Get tag by ID", False, str(e))
|
||||
|
||||
def test_update_tag():
|
||||
"""Test updating a tag."""
|
||||
try:
|
||||
if tag_id is None:
|
||||
raise Exception("No tag_id available")
|
||||
headers = get_auth_headers()
|
||||
update_data = {
|
||||
"color": "#00FF00",
|
||||
"description": "Updated test tag"
|
||||
}
|
||||
response = client.put(f"/api/tags/{tag_id}", json=update_data, headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["color"] == "#00FF00", "Update not reflected"
|
||||
log_test("Update tag", True)
|
||||
except Exception as e:
|
||||
log_test("Update tag", False, str(e))
|
||||
|
||||
def test_tag_duplicate_name():
|
||||
"""Test creating tag with duplicate name returns 409."""
|
||||
try:
|
||||
if tag_id is None:
|
||||
raise Exception("No tag_id available")
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Get existing tag name
|
||||
existing_response = client.get(f"/api/tags/{tag_id}", headers=headers)
|
||||
existing_name = existing_response.json()["name"]
|
||||
|
||||
# Try to create duplicate
|
||||
duplicate_data = {
|
||||
"name": existing_name,
|
||||
"category": "test"
|
||||
}
|
||||
response = client.post("/api/tags", json=duplicate_data, headers=headers)
|
||||
assert response.status_code == 409, f"Expected 409, got {response.status_code}"
|
||||
log_test("Tag duplicate name (409)", True)
|
||||
except Exception as e:
|
||||
log_test("Tag duplicate name (409)", False, str(e))
|
||||
|
||||
def test_delete_tag():
|
||||
"""Test deleting a tag."""
|
||||
try:
|
||||
if tag_id is None:
|
||||
raise Exception("No tag_id available")
|
||||
headers = get_auth_headers()
|
||||
response = client.delete(f"/api/tags/{tag_id}", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
log_test("Delete tag", True)
|
||||
except Exception as e:
|
||||
log_test("Delete tag", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SECTION 8: Pagination Tests
|
||||
# ============================================================================
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SECTION 8: Pagination Tests")
|
||||
print("="*70 + "\n")
|
||||
|
||||
def test_pagination_skip_limit():
|
||||
"""Test pagination with skip and limit parameters."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
response = client.get("/api/machines?skip=0&limit=5", headers=headers)
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
||||
data = response.json()
|
||||
assert data["skip"] == 0, f"Expected skip=0, got {data.get('skip')}"
|
||||
assert data["limit"] == 5, f"Expected limit=5, got {data.get('limit')}"
|
||||
log_test("Pagination skip/limit parameters", True)
|
||||
except Exception as e:
|
||||
log_test("Pagination skip/limit parameters", False, str(e))
|
||||
|
||||
def test_pagination_max_limit():
|
||||
"""Test that pagination enforces maximum limit."""
|
||||
try:
|
||||
headers = get_auth_headers()
|
||||
# Try to request more than max (1000)
|
||||
response = client.get("/api/machines?limit=2000", headers=headers)
|
||||
# Should either return 422 or clamp to max
|
||||
assert response.status_code in [200, 422], f"Unexpected status {response.status_code}"
|
||||
log_test("Pagination max limit enforcement", True)
|
||||
except Exception as e:
|
||||
log_test("Pagination max limit enforcement", False, str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run All Tests
|
||||
# ============================================================================
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test functions."""
|
||||
print("\n" + "="*70)
|
||||
print("CLAUDETOOLS API ENDPOINT TESTS")
|
||||
print("="*70)
|
||||
|
||||
# Section 1: Health
|
||||
test_root_endpoint()
|
||||
test_health_endpoint()
|
||||
test_jwt_token_creation()
|
||||
|
||||
# Section 2: Auth
|
||||
test_unauthenticated_access()
|
||||
test_authenticated_access()
|
||||
test_invalid_token()
|
||||
|
||||
# Section 3: Machines
|
||||
test_create_machine()
|
||||
test_list_machines()
|
||||
test_get_machine()
|
||||
test_update_machine()
|
||||
test_machine_not_found()
|
||||
test_delete_machine()
|
||||
|
||||
# Section 4: Clients
|
||||
test_create_client()
|
||||
test_list_clients()
|
||||
test_get_client()
|
||||
test_update_client()
|
||||
test_delete_client()
|
||||
|
||||
# Section 5: Projects
|
||||
test_create_project()
|
||||
test_list_projects()
|
||||
test_get_project()
|
||||
test_update_project()
|
||||
test_delete_project()
|
||||
|
||||
# Section 6: Sessions
|
||||
test_create_session()
|
||||
test_list_sessions()
|
||||
test_get_session()
|
||||
test_update_session()
|
||||
test_delete_session()
|
||||
|
||||
# Section 7: Tags
|
||||
test_create_tag()
|
||||
test_list_tags()
|
||||
test_get_tag()
|
||||
test_update_tag()
|
||||
test_tag_duplicate_name()
|
||||
test_delete_tag()
|
||||
|
||||
# Section 8: Pagination
|
||||
test_pagination_skip_limit()
|
||||
test_pagination_max_limit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n>> Starting ClaudeTools API Test Suite...")
|
||||
|
||||
try:
|
||||
run_all_tests()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "="*70)
|
||||
print("TEST SUMMARY")
|
||||
print("="*70)
|
||||
print(f"\nTotal Tests: {tests_passed + tests_failed}")
|
||||
print(f"Passed: {tests_passed}")
|
||||
print(f"Failed: {tests_failed}")
|
||||
|
||||
if tests_failed > 0:
|
||||
print("\nFAILED TESTS:")
|
||||
for name, passed, error in test_results:
|
||||
if not passed:
|
||||
print(f" - {name}")
|
||||
if error:
|
||||
print(f" Error: {error}")
|
||||
|
||||
if tests_failed == 0:
|
||||
print("\n>> All tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n>> {tests_failed} test(s) failed")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n>> Fatal error running tests: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
286
tests/test_conversation_parser.py
Normal file
286
tests/test_conversation_parser.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Test script for conversation_parser.py
|
||||
|
||||
Tests all four main functions with sample data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from api.utils.conversation_parser import (
|
||||
parse_jsonl_conversation,
|
||||
categorize_conversation,
|
||||
extract_context_from_conversation,
|
||||
scan_folder_for_conversations,
|
||||
batch_process_conversations,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_jsonl_conversation():
|
||||
"""Test parsing .jsonl conversation files."""
|
||||
print("\n=== Test 1: parse_jsonl_conversation ===")
|
||||
|
||||
# Create a temporary .jsonl file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8') as f:
|
||||
# Write sample conversation data
|
||||
f.write(json.dumps({
|
||||
"role": "user",
|
||||
"content": "Build a FastAPI authentication system with PostgreSQL",
|
||||
"timestamp": 1705449600000
|
||||
}) + "\n")
|
||||
f.write(json.dumps({
|
||||
"role": "assistant",
|
||||
"content": "I'll help you build an auth system using FastAPI and PostgreSQL. Let me create the api/auth.py file.",
|
||||
"timestamp": 1705449620000
|
||||
}) + "\n")
|
||||
f.write(json.dumps({
|
||||
"role": "user",
|
||||
"content": "Also add JWT token support",
|
||||
"timestamp": 1705449640000
|
||||
}) + "\n")
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
result = parse_jsonl_conversation(temp_file)
|
||||
|
||||
print(f"Messages: {result['message_count']}")
|
||||
print(f"Duration: {result['duration_seconds']} seconds")
|
||||
print(f"File paths extracted: {result['file_paths']}")
|
||||
print(f"First message: {result['messages'][0]['content'][:50]}...")
|
||||
|
||||
assert result['message_count'] == 3, "Should have 3 messages"
|
||||
assert result['duration_seconds'] == 40, "Duration should be 40 seconds"
|
||||
assert 'api/auth.py' in result['file_paths'], "Should extract file path"
|
||||
|
||||
print("[PASS] parse_jsonl_conversation test passed!")
|
||||
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
|
||||
def test_categorize_conversation():
|
||||
"""Test conversation categorization."""
|
||||
print("\n=== Test 2: categorize_conversation ===")
|
||||
|
||||
# Test MSP conversation
|
||||
msp_messages = [
|
||||
{"role": "user", "content": "Client reported firewall blocking Office365 connection"},
|
||||
{"role": "assistant", "content": "I'll check the firewall rules for the client site"}
|
||||
]
|
||||
|
||||
msp_category = categorize_conversation(msp_messages)
|
||||
print(f"MSP category: {msp_category}")
|
||||
assert msp_category == "msp", "Should categorize as MSP"
|
||||
|
||||
# Test Development conversation
|
||||
dev_messages = [
|
||||
{"role": "user", "content": "Build API endpoint for user authentication with FastAPI"},
|
||||
{"role": "assistant", "content": "I'll create the endpoint using SQLAlchemy and implement JWT tokens"}
|
||||
]
|
||||
|
||||
dev_category = categorize_conversation(dev_messages)
|
||||
print(f"Development category: {dev_category}")
|
||||
assert dev_category == "development", "Should categorize as development"
|
||||
|
||||
# Test General conversation
|
||||
general_messages = [
|
||||
{"role": "user", "content": "What's the weather like today?"},
|
||||
{"role": "assistant", "content": "I don't have access to current weather data"}
|
||||
]
|
||||
|
||||
general_category = categorize_conversation(general_messages)
|
||||
print(f"General category: {general_category}")
|
||||
assert general_category == "general", "Should categorize as general"
|
||||
|
||||
print("[PASS] categorize_conversation test passed!")
|
||||
|
||||
|
||||
def test_extract_context_from_conversation():
|
||||
"""Test context extraction from conversation."""
|
||||
print("\n=== Test 3: extract_context_from_conversation ===")
|
||||
|
||||
# Create a sample conversation
|
||||
conversation = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Build a FastAPI REST API with PostgreSQL database",
|
||||
"timestamp": 1705449600000
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll create the API using FastAPI and SQLAlchemy. Decided to use Alembic for migrations because it integrates well with SQLAlchemy.",
|
||||
"timestamp": 1705449620000
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Add authentication with JWT tokens",
|
||||
"timestamp": 1705449640000
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"title": "Build API system",
|
||||
"model": "claude-opus-4",
|
||||
"sessionId": "test-123"
|
||||
},
|
||||
"file_paths": ["api/main.py", "api/auth.py", "api/models.py"],
|
||||
"tool_calls": [
|
||||
{"tool": "write", "count": 5},
|
||||
{"tool": "read", "count": 3}
|
||||
],
|
||||
"duration_seconds": 40,
|
||||
"message_count": 3
|
||||
}
|
||||
|
||||
context = extract_context_from_conversation(conversation)
|
||||
|
||||
print(f"Category: {context['category']}")
|
||||
print(f"Tags: {context['tags'][:5]}")
|
||||
print(f"Decisions: {len(context['decisions'])}")
|
||||
print(f"Quality score: {context['metrics']['quality_score']}/10")
|
||||
print(f"Key files: {context['key_files']}")
|
||||
|
||||
assert context['category'] in ['msp', 'development', 'general'], "Should have valid category"
|
||||
assert len(context['tags']) > 0, "Should have extracted tags"
|
||||
assert context['metrics']['message_count'] == 3, "Should have correct message count"
|
||||
|
||||
print("[PASS] extract_context_from_conversation test passed!")
|
||||
|
||||
|
||||
def test_scan_folder_for_conversations():
|
||||
"""Test scanning folder for conversation files."""
|
||||
print("\n=== Test 4: scan_folder_for_conversations ===")
|
||||
|
||||
# Create a temporary directory structure
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create some conversation files
|
||||
conv1_path = os.path.join(tmpdir, "conversation1.jsonl")
|
||||
conv2_path = os.path.join(tmpdir, "session", "conversation2.json")
|
||||
config_path = os.path.join(tmpdir, "config.json") # Should be skipped
|
||||
|
||||
os.makedirs(os.path.dirname(conv2_path), exist_ok=True)
|
||||
|
||||
# Create files
|
||||
with open(conv1_path, 'w') as f:
|
||||
f.write('{"role": "user", "content": "test"}\n')
|
||||
|
||||
with open(conv2_path, 'w') as f:
|
||||
f.write('{"role": "user", "content": "test"}')
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
f.write('{"setting": "value"}')
|
||||
|
||||
# Scan folder
|
||||
files = scan_folder_for_conversations(tmpdir)
|
||||
|
||||
print(f"Found {len(files)} conversation files")
|
||||
print(f"Files: {[os.path.basename(f) for f in files]}")
|
||||
|
||||
assert len(files) == 2, "Should find 2 conversation files"
|
||||
assert any("conversation1.jsonl" in f for f in files), "Should find jsonl file"
|
||||
assert any("conversation2.json" in f for f in files), "Should find json file"
|
||||
assert not any("config.json" in f for f in files), "Should skip config.json"
|
||||
|
||||
print("[PASS] scan_folder_for_conversations test passed!")
|
||||
|
||||
|
||||
def test_batch_process():
|
||||
"""Test batch processing of conversations."""
|
||||
print("\n=== Test 5: batch_process_conversations ===")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create sample conversations
|
||||
conv1_path = os.path.join(tmpdir, "msp_work.jsonl")
|
||||
conv2_path = os.path.join(tmpdir, "dev_work.jsonl")
|
||||
|
||||
# MSP conversation
|
||||
with open(conv1_path, 'w') as f:
|
||||
f.write(json.dumps({
|
||||
"role": "user",
|
||||
"content": "Client ticket: firewall blocking Office365",
|
||||
"timestamp": 1705449600000
|
||||
}) + "\n")
|
||||
f.write(json.dumps({
|
||||
"role": "assistant",
|
||||
"content": "I'll check the client firewall configuration",
|
||||
"timestamp": 1705449620000
|
||||
}) + "\n")
|
||||
|
||||
# Development conversation
|
||||
with open(conv2_path, 'w') as f:
|
||||
f.write(json.dumps({
|
||||
"role": "user",
|
||||
"content": "Build FastAPI endpoint for authentication",
|
||||
"timestamp": 1705449600000
|
||||
}) + "\n")
|
||||
f.write(json.dumps({
|
||||
"role": "assistant",
|
||||
"content": "Creating API endpoint with SQLAlchemy",
|
||||
"timestamp": 1705449620000
|
||||
}) + "\n")
|
||||
|
||||
# Process all conversations
|
||||
processed_count = [0]
|
||||
|
||||
def progress_callback(file_path, context):
|
||||
processed_count[0] += 1
|
||||
print(f" Processed: {os.path.basename(file_path)} -> {context['category']}")
|
||||
|
||||
contexts = batch_process_conversations(tmpdir, progress_callback)
|
||||
|
||||
print(f"\nTotal processed: {len(contexts)}")
|
||||
|
||||
assert len(contexts) == 2, "Should process 2 conversations"
|
||||
assert processed_count[0] == 2, "Callback should be called twice"
|
||||
|
||||
categories = [ctx['category'] for ctx in contexts]
|
||||
print(f"Categories: {categories}")
|
||||
|
||||
print("[PASS] batch_process_conversations test passed!")
|
||||
|
||||
|
||||
def test_real_conversation_file():
|
||||
"""Test with real conversation file if available."""
|
||||
print("\n=== Test 6: Real conversation file ===")
|
||||
|
||||
real_file = r"C:\Users\MikeSwanson\AppData\Roaming\Claude\claude-code-sessions\0c32bde5-dc29-49ac-8c80-5adeaf1cdb33\299a238a-5ebf-44f4-948b-eedfa5c1f57c\local_feb419c2-b7a6-4c31-a7ce-38f6c0ccc523.json"
|
||||
|
||||
if os.path.exists(real_file):
|
||||
try:
|
||||
conversation = parse_jsonl_conversation(real_file)
|
||||
print(f"Real file - Messages: {conversation['message_count']}")
|
||||
print(f"Real file - Metadata: {conversation['metadata'].get('title', 'No title')}")
|
||||
|
||||
if conversation['message_count'] > 0:
|
||||
context = extract_context_from_conversation(conversation)
|
||||
print(f"Real file - Category: {context['category']}")
|
||||
print(f"Real file - Quality: {context['metrics']['quality_score']}/10")
|
||||
except Exception as e:
|
||||
print(f"Note: Real file test skipped - {e}")
|
||||
else:
|
||||
print("Real conversation file not found - skipping this test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Testing conversation_parser.py")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_parse_jsonl_conversation()
|
||||
test_categorize_conversation()
|
||||
test_extract_context_from_conversation()
|
||||
test_scan_folder_for_conversations()
|
||||
test_batch_process()
|
||||
test_real_conversation_file()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All tests passed! [OK]")
|
||||
print("=" * 60)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n[FAIL] Test failed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Unexpected error: {e}")
|
||||
raise
|
||||
284
tests/test_credential_scanner.py
Normal file
284
tests/test_credential_scanner.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Test script for credential scanner and importer.
|
||||
|
||||
This script demonstrates the credential scanner functionality including:
|
||||
- Creating sample credential files
|
||||
- Scanning for credential files
|
||||
- Parsing credential data
|
||||
- Importing credentials to database
|
||||
|
||||
Usage:
|
||||
python test_credential_scanner.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from api.database import SessionLocal
|
||||
from api.utils.credential_scanner import (
|
||||
scan_for_credential_files,
|
||||
parse_credential_file,
|
||||
import_credentials_to_db,
|
||||
scan_and_import_credentials,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_sample_credential_files(temp_dir: str):
|
||||
"""Create sample credential files for testing."""
|
||||
|
||||
# Create credentials.md
|
||||
credentials_md = Path(temp_dir) / "credentials.md"
|
||||
credentials_md.write_text("""# Sample Credentials
|
||||
|
||||
## Gitea Admin
|
||||
Username: admin
|
||||
Password: GitSecurePass123!
|
||||
URL: https://git.example.com
|
||||
Notes: Main admin account
|
||||
|
||||
## Database Server
|
||||
Type: connection_string
|
||||
Connection String: mysql://dbuser:dbpass@192.168.1.50:3306/mydb
|
||||
Notes: Production database
|
||||
|
||||
## OpenAI API
|
||||
API Key: sk-1234567890abcdefghijklmnopqrstuvwxyz
|
||||
Notes: Production API key
|
||||
""")
|
||||
|
||||
# Create .env file
|
||||
env_file = Path(temp_dir) / ".env"
|
||||
env_file.write_text("""# Environment Variables
|
||||
DATABASE_URL=postgresql://user:pass@localhost:5432/testdb
|
||||
API_TOKEN=ghp_abc123def456ghi789jkl012mno345pqr678
|
||||
SECRET_KEY=super_secret_key_12345
|
||||
""")
|
||||
|
||||
# Create passwords.txt
|
||||
passwords_txt = Path(temp_dir) / "passwords.txt"
|
||||
passwords_txt.write_text("""# Server Passwords
|
||||
|
||||
## Web Server
|
||||
Username: webadmin
|
||||
Password: Web@dmin2024!
|
||||
Host: 192.168.1.100
|
||||
Port: 22
|
||||
|
||||
## Backup Server
|
||||
Username: backup
|
||||
Password: BackupSecure789
|
||||
Host: 10.0.0.50
|
||||
""")
|
||||
|
||||
logger.info(f"Created sample credential files in: {temp_dir}")
|
||||
return [str(credentials_md), str(env_file), str(passwords_txt)]
|
||||
|
||||
|
||||
def test_scan_for_credential_files():
|
||||
"""Test credential file scanning."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("TEST 1: Scan for Credential Files")
|
||||
logger.info("=" * 60)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create sample files
|
||||
create_sample_credential_files(temp_dir)
|
||||
|
||||
# Scan for files
|
||||
found_files = scan_for_credential_files(temp_dir)
|
||||
|
||||
logger.info(f"\nFound {len(found_files)} credential file(s):")
|
||||
for file_path in found_files:
|
||||
logger.info(f" - {file_path}")
|
||||
|
||||
assert len(found_files) == 3, "Should find 3 credential files"
|
||||
logger.info("\n[PASS] Test 1 passed")
|
||||
|
||||
return found_files
|
||||
|
||||
|
||||
def test_parse_credential_file():
|
||||
"""Test credential file parsing."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TEST 2: Parse Credential Files")
|
||||
logger.info("=" * 60)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create sample files
|
||||
sample_files = create_sample_credential_files(temp_dir)
|
||||
|
||||
total_credentials = 0
|
||||
|
||||
for file_path in sample_files:
|
||||
credentials = parse_credential_file(file_path)
|
||||
total_credentials += len(credentials)
|
||||
|
||||
logger.info(f"\nParsed from {Path(file_path).name}:")
|
||||
for cred in credentials:
|
||||
logger.info(f" Service: {cred.get('service_name')}")
|
||||
logger.info(f" Type: {cred.get('credential_type')}")
|
||||
if cred.get('username'):
|
||||
logger.info(f" Username: {cred.get('username')}")
|
||||
# Don't log actual passwords/keys
|
||||
if cred.get('password'):
|
||||
logger.info(f" Password: [REDACTED]")
|
||||
if cred.get('api_key'):
|
||||
logger.info(f" API Key: [REDACTED]")
|
||||
if cred.get('connection_string'):
|
||||
logger.info(f" Connection String: [REDACTED]")
|
||||
logger.info("")
|
||||
|
||||
logger.info(f"Total credentials parsed: {total_credentials}")
|
||||
assert total_credentials > 0, "Should parse at least one credential"
|
||||
logger.info("[PASS] Test 2 passed")
|
||||
|
||||
|
||||
def test_import_credentials_to_db():
|
||||
"""Test importing credentials to database."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TEST 3: Import Credentials to Database")
|
||||
logger.info("=" * 60)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create sample files
|
||||
sample_files = create_sample_credential_files(temp_dir)
|
||||
|
||||
# Parse first file
|
||||
credentials = parse_credential_file(sample_files[0])
|
||||
logger.info(f"\nParsed {len(credentials)} credential(s) from file")
|
||||
|
||||
# Import to database
|
||||
imported_count = import_credentials_to_db(
|
||||
db=db,
|
||||
credentials=credentials,
|
||||
client_id=None, # No client association for test
|
||||
user_id="test_user",
|
||||
ip_address="127.0.0.1"
|
||||
)
|
||||
|
||||
logger.info(f"\n[OK] Successfully imported {imported_count} credential(s)")
|
||||
logger.info("[PASS] Test 3 passed")
|
||||
|
||||
return imported_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Import failed: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_full_workflow():
|
||||
"""Test complete scan and import workflow."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TEST 4: Full Scan and Import Workflow")
|
||||
logger.info("=" * 60)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create sample files
|
||||
create_sample_credential_files(temp_dir)
|
||||
|
||||
# Run full workflow
|
||||
results = scan_and_import_credentials(
|
||||
base_path=temp_dir,
|
||||
db=db,
|
||||
client_id=None,
|
||||
user_id="test_user",
|
||||
ip_address="127.0.0.1"
|
||||
)
|
||||
|
||||
logger.info(f"\nWorkflow Results:")
|
||||
logger.info(f" Files found: {results['files_found']}")
|
||||
logger.info(f" Credentials parsed: {results['credentials_parsed']}")
|
||||
logger.info(f" Credentials imported: {results['credentials_imported']}")
|
||||
|
||||
assert results['files_found'] > 0, "Should find files"
|
||||
assert results['credentials_parsed'] > 0, "Should parse credentials"
|
||||
logger.info("\n[PASS] Test 4 passed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Workflow failed: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_markdown_parsing():
|
||||
"""Test markdown credential parsing with various formats."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("TEST 5: Markdown Format Variations")
|
||||
logger.info("=" * 60)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create file with various markdown formats
|
||||
test_file = Path(temp_dir) / "test_variations.md"
|
||||
test_file.write_text("""
|
||||
# Single Hash Header
|
||||
Username: user1
|
||||
Password: pass1
|
||||
|
||||
## Double Hash Header
|
||||
User: user2
|
||||
Pass: pass2
|
||||
|
||||
## API Service
|
||||
API_Key: sk-123456789
|
||||
Type: api_key
|
||||
|
||||
## Database Connection
|
||||
Connection_String: mysql://user:pass@host/db
|
||||
""")
|
||||
|
||||
credentials = parse_credential_file(str(test_file))
|
||||
|
||||
logger.info(f"\nParsed {len(credentials)} credential(s):")
|
||||
for cred in credentials:
|
||||
logger.info(f" - {cred.get('service_name')} ({cred.get('credential_type')})")
|
||||
|
||||
assert len(credentials) >= 3, "Should parse multiple variations"
|
||||
logger.info("\n[PASS] Test 5 passed")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("CREDENTIAL SCANNER TEST SUITE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Run tests
|
||||
test_scan_for_credential_files()
|
||||
test_parse_credential_file()
|
||||
test_markdown_parsing()
|
||||
test_import_credentials_to_db()
|
||||
test_full_workflow()
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("ALL TESTS PASSED!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("\n" + "=" * 60)
|
||||
logger.error("TEST FAILED!")
|
||||
logger.error("=" * 60)
|
||||
logger.error(f"Error: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
291
tests/test_credentials_api.py
Normal file
291
tests/test_credentials_api.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Test script for Credentials Management API.
|
||||
|
||||
This script tests the credentials API endpoints including encryption, decryption,
|
||||
and audit logging functionality.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from api.database import get_db
|
||||
from api.models.credential import Credential
|
||||
from api.models.credential_audit_log import CredentialAuditLog
|
||||
from api.schemas.credential import CredentialCreate, CredentialUpdate
|
||||
from api.services.credential_service import (
|
||||
create_credential,
|
||||
delete_credential,
|
||||
get_credential_by_id,
|
||||
get_credentials,
|
||||
update_credential,
|
||||
)
|
||||
from api.utils.crypto import decrypt_string, encrypt_string
|
||||
|
||||
|
||||
def test_encryption_decryption():
|
||||
"""Test basic encryption and decryption."""
|
||||
print("\n=== Testing Encryption/Decryption ===")
|
||||
|
||||
test_password = "SuperSecretPassword123!"
|
||||
print(f"Original password: {test_password}")
|
||||
|
||||
# Encrypt
|
||||
encrypted = encrypt_string(test_password)
|
||||
print(f"Encrypted (length: {len(encrypted)}): {encrypted[:50]}...")
|
||||
|
||||
# Decrypt
|
||||
decrypted = decrypt_string(encrypted)
|
||||
print(f"Decrypted: {decrypted}")
|
||||
|
||||
assert test_password == decrypted, "Encryption/decryption mismatch!"
|
||||
print("[PASS] Encryption/decryption test passed")
|
||||
|
||||
|
||||
def test_credential_lifecycle():
|
||||
"""Test the full credential lifecycle: create, read, update, delete."""
|
||||
print("\n=== Testing Credential Lifecycle ===")
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# 1. CREATE
|
||||
print("\n1. Creating credential...")
|
||||
credential_data = CredentialCreate(
|
||||
credential_type="password",
|
||||
service_name="Test Service",
|
||||
username="admin",
|
||||
password="MySecurePassword123!",
|
||||
external_url="https://test.example.com",
|
||||
requires_vpn=False,
|
||||
requires_2fa=True,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created = create_credential(
|
||||
db=db,
|
||||
credential_data=credential_data,
|
||||
user_id="test_user_123",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Script"
|
||||
)
|
||||
|
||||
print(f"[PASS] Created credential ID: {created.id}")
|
||||
print(f" Service: {created.service_name}")
|
||||
print(f" Type: {created.credential_type}")
|
||||
print(f" Password encrypted: {created.password_encrypted is not None}")
|
||||
|
||||
# Verify encryption
|
||||
if created.password_encrypted:
|
||||
decrypted_password = decrypt_string(created.password_encrypted.decode('utf-8'))
|
||||
assert decrypted_password == "MySecurePassword123!", "Password encryption failed!"
|
||||
print(f" [PASS] Password correctly encrypted and decrypted")
|
||||
|
||||
# Verify audit log was created
|
||||
audit_logs = db.query(CredentialAuditLog).filter(
|
||||
CredentialAuditLog.credential_id == str(created.id)
|
||||
).all()
|
||||
print(f" [PASS] Audit logs created: {len(audit_logs)}")
|
||||
|
||||
# 2. READ
|
||||
print("\n2. Reading credential...")
|
||||
retrieved = get_credential_by_id(db, created.id, user_id="test_user_123")
|
||||
print(f"[PASS] Retrieved credential: {retrieved.service_name}")
|
||||
|
||||
# Check audit log for view action
|
||||
audit_logs = db.query(CredentialAuditLog).filter(
|
||||
CredentialAuditLog.credential_id == str(created.id),
|
||||
CredentialAuditLog.action == "view"
|
||||
).all()
|
||||
print(f" [PASS] View action logged: {len(audit_logs) > 0}")
|
||||
|
||||
# 3. UPDATE
|
||||
print("\n3. Updating credential...")
|
||||
update_data = CredentialUpdate(
|
||||
password="NewSecurePassword456!",
|
||||
last_rotated_at=datetime.utcnow(),
|
||||
external_url="https://test-updated.example.com"
|
||||
)
|
||||
|
||||
updated = update_credential(
|
||||
db=db,
|
||||
credential_id=created.id,
|
||||
credential_data=update_data,
|
||||
user_id="test_user_123",
|
||||
ip_address="127.0.0.1"
|
||||
)
|
||||
|
||||
print(f"[PASS] Updated credential: {updated.service_name}")
|
||||
print(f" Password re-encrypted: {updated.password_encrypted is not None}")
|
||||
|
||||
# Verify new password
|
||||
if updated.password_encrypted:
|
||||
decrypted_new_password = decrypt_string(updated.password_encrypted.decode('utf-8'))
|
||||
assert decrypted_new_password == "NewSecurePassword456!", "Password update failed!"
|
||||
print(f" [PASS] New password correctly encrypted")
|
||||
|
||||
# Check audit log for update action
|
||||
audit_logs = db.query(CredentialAuditLog).filter(
|
||||
CredentialAuditLog.credential_id == str(created.id),
|
||||
CredentialAuditLog.action == "update"
|
||||
).all()
|
||||
print(f" [PASS] Update action logged: {len(audit_logs) > 0}")
|
||||
|
||||
# 4. LIST
|
||||
print("\n4. Listing credentials...")
|
||||
credentials, total = get_credentials(db, skip=0, limit=10)
|
||||
print(f"[PASS] Found {total} total credentials")
|
||||
print(f" Retrieved {len(credentials)} credentials in this page")
|
||||
|
||||
# 5. DELETE
|
||||
print("\n5. Deleting credential...")
|
||||
result = delete_credential(
|
||||
db=db,
|
||||
credential_id=created.id,
|
||||
user_id="test_user_123",
|
||||
ip_address="127.0.0.1"
|
||||
)
|
||||
|
||||
print(f"[PASS] {result['message']}")
|
||||
|
||||
# Verify deletion
|
||||
remaining = db.query(Credential).filter(Credential.id == str(created.id)).first()
|
||||
assert remaining is None, "Credential was not deleted!"
|
||||
print(f" [PASS] Credential successfully removed from database")
|
||||
|
||||
# Check audit log for delete action (should still exist due to CASCADE behavior)
|
||||
audit_logs = db.query(CredentialAuditLog).filter(
|
||||
CredentialAuditLog.credential_id == str(created.id),
|
||||
CredentialAuditLog.action == "delete"
|
||||
).all()
|
||||
print(f" [PASS] Delete action logged: {len(audit_logs) > 0}")
|
||||
|
||||
print("\n[PASS] All credential lifecycle tests passed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n[FAIL] Test failed: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_multiple_credential_types():
|
||||
"""Test creating credentials with different types and encrypted fields."""
|
||||
print("\n=== Testing Multiple Credential Types ===")
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
credential_ids = []
|
||||
|
||||
# Test API Key credential
|
||||
print("\n1. Creating API Key credential...")
|
||||
api_key_data = CredentialCreate(
|
||||
credential_type="api_key",
|
||||
service_name="GitHub API",
|
||||
api_key="ghp_abcdef1234567890",
|
||||
external_url="https://api.github.com",
|
||||
is_active=True
|
||||
)
|
||||
|
||||
api_cred = create_credential(db, api_key_data, user_id="test_user")
|
||||
print(f"[PASS] Created API Key credential: {api_cred.id}")
|
||||
credential_ids.append(api_cred.id)
|
||||
|
||||
# Verify API key encryption
|
||||
if api_cred.api_key_encrypted:
|
||||
decrypted_key = decrypt_string(api_cred.api_key_encrypted.decode('utf-8'))
|
||||
assert decrypted_key == "ghp_abcdef1234567890", "API key encryption failed!"
|
||||
print(f" [PASS] API key correctly encrypted")
|
||||
|
||||
# Test OAuth credential
|
||||
print("\n2. Creating OAuth credential...")
|
||||
oauth_data = CredentialCreate(
|
||||
credential_type="oauth",
|
||||
service_name="Microsoft 365",
|
||||
client_id_oauth="app-client-id-123",
|
||||
client_secret="secret_value_xyz789",
|
||||
tenant_id_oauth="tenant-id-456",
|
||||
is_active=True
|
||||
)
|
||||
|
||||
oauth_cred = create_credential(db, oauth_data, user_id="test_user")
|
||||
print(f"[PASS] Created OAuth credential: {oauth_cred.id}")
|
||||
credential_ids.append(oauth_cred.id)
|
||||
|
||||
# Verify client secret encryption
|
||||
if oauth_cred.client_secret_encrypted:
|
||||
decrypted_secret = decrypt_string(oauth_cred.client_secret_encrypted.decode('utf-8'))
|
||||
assert decrypted_secret == "secret_value_xyz789", "OAuth secret encryption failed!"
|
||||
print(f" [PASS] Client secret correctly encrypted")
|
||||
|
||||
# Test Connection String credential
|
||||
print("\n3. Creating Connection String credential...")
|
||||
conn_data = CredentialCreate(
|
||||
credential_type="connection_string",
|
||||
service_name="SQL Server",
|
||||
connection_string="Server=localhost;Database=TestDB;User Id=sa;Password=ComplexPass123!;",
|
||||
internal_url="sql.internal.local",
|
||||
custom_port=1433,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
conn_cred = create_credential(db, conn_data, user_id="test_user")
|
||||
print(f"[PASS] Created Connection String credential: {conn_cred.id}")
|
||||
credential_ids.append(conn_cred.id)
|
||||
|
||||
# Verify connection string encryption
|
||||
if conn_cred.connection_string_encrypted:
|
||||
decrypted_conn = decrypt_string(conn_cred.connection_string_encrypted.decode('utf-8'))
|
||||
assert "ComplexPass123!" in decrypted_conn, "Connection string encryption failed!"
|
||||
print(f" [PASS] Connection string correctly encrypted")
|
||||
|
||||
print(f"\n[PASS] Created {len(credential_ids)} different credential types")
|
||||
|
||||
# Cleanup
|
||||
print("\n4. Cleaning up test credentials...")
|
||||
for cred_id in credential_ids:
|
||||
delete_credential(db, cred_id, user_id="test_user")
|
||||
print(f"[PASS] Cleaned up {len(credential_ids)} credentials")
|
||||
|
||||
print("\n[PASS] All multi-type credential tests passed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n[FAIL] Test failed: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("CREDENTIALS API TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_encryption_decryption()
|
||||
test_credential_lifecycle()
|
||||
test_multiple_credential_types()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[PASS] ALL TESTS PASSED!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print("\n" + "=" * 60)
|
||||
print("[FAIL] TEST SUITE FAILED")
|
||||
print("=" * 60)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
490
tests/test_crud_operations.py
Normal file
490
tests/test_crud_operations.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""
|
||||
Phase 3 Test: Database CRUD Operations Validation
|
||||
|
||||
Tests CREATE, READ, UPDATE, DELETE operations on the ClaudeTools database
|
||||
with real database connections and verifies foreign key relationships.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
import random
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
|
||||
# Add api directory to path
|
||||
sys.path.insert(0, 'D:\\ClaudeTools')
|
||||
|
||||
from api.database import SessionLocal, check_db_connection
|
||||
from api.models import Client, Machine, Session, Tag, SessionTag
|
||||
|
||||
|
||||
class CRUDTester:
|
||||
"""Test harness for CRUD operations."""
|
||||
|
||||
def __init__(self):
|
||||
self.db = None
|
||||
self.test_ids = {
|
||||
'client': None,
|
||||
'machine': None,
|
||||
'session': None,
|
||||
'tag': None
|
||||
}
|
||||
self.passed = 0
|
||||
self.failed = 0
|
||||
self.errors = []
|
||||
|
||||
def connect(self):
|
||||
"""Test database connection."""
|
||||
print("=" * 80)
|
||||
print("PHASE 3: DATABASE CRUD OPERATIONS TEST")
|
||||
print("=" * 80)
|
||||
print("\n1. CONNECTION TEST")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
if not check_db_connection():
|
||||
self.fail("Connection", "check_db_connection() returned False")
|
||||
return False
|
||||
|
||||
self.db = SessionLocal()
|
||||
|
||||
# Test basic query
|
||||
result = self.db.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
self.success("Connection", f"Connected to database: {result}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.fail("Connection", str(e))
|
||||
return False
|
||||
|
||||
def test_create(self):
|
||||
"""Test INSERT operations."""
|
||||
print("\n2. CREATE TEST (INSERT)")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
# Create a client (type is required field) with unique name
|
||||
test_suffix = random.randint(1000, 9999)
|
||||
client = Client(
|
||||
name=f"Test Client Corp {test_suffix}",
|
||||
type="msp_client",
|
||||
primary_contact="test@client.com",
|
||||
is_active=True
|
||||
)
|
||||
self.db.add(client)
|
||||
self.db.commit()
|
||||
self.db.refresh(client)
|
||||
self.test_ids['client'] = client.id
|
||||
self.success("Create Client", f"Created client with ID: {client.id}")
|
||||
|
||||
# Create a machine (no client_id FK, simplified fields)
|
||||
machine = Machine(
|
||||
hostname=f"test-machine-{test_suffix}",
|
||||
machine_fingerprint=f"test-fingerprint-{test_suffix}",
|
||||
friendly_name="Test Machine",
|
||||
machine_type="laptop",
|
||||
platform="win32",
|
||||
username="testuser"
|
||||
)
|
||||
self.db.add(machine)
|
||||
self.db.commit()
|
||||
self.db.refresh(machine)
|
||||
self.test_ids['machine'] = machine.id
|
||||
self.success("Create Machine", f"Created machine with ID: {machine.id}")
|
||||
|
||||
# Create a session with required fields
|
||||
session = Session(
|
||||
client_id=client.id,
|
||||
machine_id=machine.id,
|
||||
session_date=datetime.now(timezone.utc).date(),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
status="completed",
|
||||
session_title="Test CRUD Session"
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
self.test_ids['session'] = session.id
|
||||
self.success("Create Session", f"Created session with ID: {session.id}")
|
||||
|
||||
# Create a tag
|
||||
tag = Tag(
|
||||
name=f"test-tag-{test_suffix}",
|
||||
category="testing"
|
||||
)
|
||||
self.db.add(tag)
|
||||
self.db.commit()
|
||||
self.db.refresh(tag)
|
||||
self.test_ids['tag'] = tag.id
|
||||
self.success("Create Tag", f"Created tag with ID: {tag.id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.fail("Create", str(e))
|
||||
return False
|
||||
|
||||
def test_read(self):
|
||||
"""Test SELECT operations."""
|
||||
print("\n3. READ TEST (SELECT)")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
# Query client
|
||||
client = self.db.query(Client).filter(
|
||||
Client.id == self.test_ids['client']
|
||||
).first()
|
||||
|
||||
if not client:
|
||||
self.fail("Read Client", "Client not found")
|
||||
return False
|
||||
|
||||
if not client.name.startswith("Test Client Corp"):
|
||||
self.fail("Read Client", f"Wrong name: {client.name}")
|
||||
return False
|
||||
|
||||
self.success("Read Client", f"Retrieved client: {client.name}")
|
||||
|
||||
# Query machine
|
||||
machine = self.db.query(Machine).filter(
|
||||
Machine.id == self.test_ids['machine']
|
||||
).first()
|
||||
|
||||
if not machine:
|
||||
self.fail("Read Machine", "Machine not found")
|
||||
return False
|
||||
|
||||
if not machine.hostname.startswith("test-machine"):
|
||||
self.fail("Read Machine", f"Wrong hostname: {machine.hostname}")
|
||||
return False
|
||||
|
||||
self.success("Read Machine", f"Retrieved machine: {machine.hostname}")
|
||||
|
||||
# Query session
|
||||
session = self.db.query(Session).filter(
|
||||
Session.id == self.test_ids['session']
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
self.fail("Read Session", "Session not found")
|
||||
return False
|
||||
|
||||
if session.status != "completed":
|
||||
self.fail("Read Session", f"Wrong status: {session.status}")
|
||||
return False
|
||||
|
||||
self.success("Read Session", f"Retrieved session with status: {session.status}")
|
||||
|
||||
# Query tag
|
||||
tag = self.db.query(Tag).filter(
|
||||
Tag.id == self.test_ids['tag']
|
||||
).first()
|
||||
|
||||
if not tag:
|
||||
self.fail("Read Tag", "Tag not found")
|
||||
return False
|
||||
|
||||
self.success("Read Tag", f"Retrieved tag: {tag.name}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.fail("Read", str(e))
|
||||
return False
|
||||
|
||||
def test_relationships(self):
|
||||
"""Test foreign key relationships."""
|
||||
print("\n4. RELATIONSHIP TEST (Foreign Keys)")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
# Test valid relationship: Create session_tag
|
||||
session_tag = SessionTag(
|
||||
session_id=self.test_ids['session'],
|
||||
tag_id=self.test_ids['tag']
|
||||
)
|
||||
self.db.add(session_tag)
|
||||
self.db.commit()
|
||||
self.db.refresh(session_tag)
|
||||
self.success("Valid FK", "Created session_tag with valid foreign keys")
|
||||
|
||||
# Test invalid relationship: Try to create session with non-existent machine
|
||||
try:
|
||||
invalid_session = Session(
|
||||
machine_id="non-existent-machine-id",
|
||||
client_id=self.test_ids['client'],
|
||||
session_date=datetime.now(timezone.utc).date(),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
status="running",
|
||||
session_title="Invalid Session"
|
||||
)
|
||||
self.db.add(invalid_session)
|
||||
self.db.commit()
|
||||
|
||||
# If we get here, FK constraint didn't work
|
||||
self.db.rollback()
|
||||
self.fail("Invalid FK", "Foreign key constraint not enforced!")
|
||||
return False
|
||||
|
||||
except IntegrityError:
|
||||
self.db.rollback()
|
||||
self.success("Invalid FK", "Foreign key constraint properly rejected invalid reference")
|
||||
|
||||
# Test relationship traversal
|
||||
session = self.db.query(Session).filter(
|
||||
Session.id == self.test_ids['session']
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
self.fail("Relationship Traversal", "Session not found")
|
||||
return False
|
||||
|
||||
# Access related machine through relationship
|
||||
if hasattr(session, 'machine') and session.machine:
|
||||
machine_hostname = session.machine.hostname
|
||||
self.success("Relationship Traversal",
|
||||
f"Accessed machine through session: {machine_hostname}")
|
||||
else:
|
||||
# Fallback: query machine directly
|
||||
machine = self.db.query(Machine).filter(
|
||||
Machine.machine_id == session.machine_id
|
||||
).first()
|
||||
if machine:
|
||||
self.success("Relationship Traversal",
|
||||
f"Verified machine exists: {machine.hostname}")
|
||||
else:
|
||||
self.fail("Relationship Traversal", "Could not find related machine")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
self.fail("Relationships", str(e))
|
||||
return False
|
||||
|
||||
def test_update(self):
|
||||
"""Test UPDATE operations."""
|
||||
print("\n5. UPDATE TEST")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
# Update client
|
||||
client = self.db.query(Client).filter(
|
||||
Client.id == self.test_ids['client']
|
||||
).first()
|
||||
|
||||
old_name = client.name
|
||||
new_name = "Updated Test Client Corp"
|
||||
client.name = new_name
|
||||
self.db.commit()
|
||||
self.db.refresh(client)
|
||||
|
||||
if client.name != new_name:
|
||||
self.fail("Update Client", f"Name not updated: {client.name}")
|
||||
return False
|
||||
|
||||
self.success("Update Client", f"Updated name: {old_name} -> {new_name}")
|
||||
|
||||
# Update machine
|
||||
machine = self.db.query(Machine).filter(
|
||||
Machine.id == self.test_ids['machine']
|
||||
).first()
|
||||
|
||||
old_name = machine.friendly_name
|
||||
new_name = "Updated Test Machine"
|
||||
machine.friendly_name = new_name
|
||||
self.db.commit()
|
||||
self.db.refresh(machine)
|
||||
|
||||
if machine.friendly_name != new_name:
|
||||
self.fail("Update Machine", f"Name not updated: {machine.friendly_name}")
|
||||
return False
|
||||
|
||||
self.success("Update Machine", f"Updated name: {old_name} -> {new_name}")
|
||||
|
||||
# Update session status
|
||||
session = self.db.query(Session).filter(
|
||||
Session.id == self.test_ids['session']
|
||||
).first()
|
||||
|
||||
old_status = session.status
|
||||
new_status = "in_progress"
|
||||
session.status = new_status
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
|
||||
if session.status != new_status:
|
||||
self.fail("Update Session", f"Status not updated: {session.status}")
|
||||
return False
|
||||
|
||||
self.success("Update Session", f"Updated status: {old_status} -> {new_status}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.fail("Update", str(e))
|
||||
return False
|
||||
|
||||
def test_delete(self):
|
||||
"""Test DELETE operations and cleanup."""
|
||||
print("\n6. DELETE TEST (Cleanup)")
|
||||
print("-" * 80)
|
||||
|
||||
try:
|
||||
# Delete in correct order (respecting FK constraints)
|
||||
|
||||
# Delete session_tag
|
||||
session_tag = self.db.query(SessionTag).filter(
|
||||
SessionTag.session_id == self.test_ids['session'],
|
||||
SessionTag.tag_id == self.test_ids['tag']
|
||||
).first()
|
||||
if session_tag:
|
||||
self.db.delete(session_tag)
|
||||
self.db.commit()
|
||||
self.success("Delete SessionTag", "Deleted session_tag")
|
||||
|
||||
# Delete tag
|
||||
tag = self.db.query(Tag).filter(
|
||||
Tag.id == self.test_ids['tag']
|
||||
).first()
|
||||
if tag:
|
||||
tag_name = tag.name
|
||||
self.db.delete(tag)
|
||||
self.db.commit()
|
||||
self.success("Delete Tag", f"Deleted tag: {tag_name}")
|
||||
|
||||
# Delete session
|
||||
session = self.db.query(Session).filter(
|
||||
Session.id == self.test_ids['session']
|
||||
).first()
|
||||
if session:
|
||||
session_id = session.id
|
||||
self.db.delete(session)
|
||||
self.db.commit()
|
||||
self.success("Delete Session", f"Deleted session: {session_id}")
|
||||
|
||||
# Delete machine
|
||||
machine = self.db.query(Machine).filter(
|
||||
Machine.id == self.test_ids['machine']
|
||||
).first()
|
||||
if machine:
|
||||
hostname = machine.hostname
|
||||
self.db.delete(machine)
|
||||
self.db.commit()
|
||||
self.success("Delete Machine", f"Deleted machine: {hostname}")
|
||||
|
||||
# Delete client
|
||||
client = self.db.query(Client).filter(
|
||||
Client.id == self.test_ids['client']
|
||||
).first()
|
||||
if client:
|
||||
name = client.name
|
||||
self.db.delete(client)
|
||||
self.db.commit()
|
||||
self.success("Delete Client", f"Deleted client: {name}")
|
||||
|
||||
# Verify all deleted
|
||||
remaining_client = self.db.query(Client).filter(
|
||||
Client.id == self.test_ids['client']
|
||||
).first()
|
||||
|
||||
if remaining_client:
|
||||
self.fail("Delete Verification", "Client still exists after deletion")
|
||||
return False
|
||||
|
||||
self.success("Delete Verification", "All test records successfully deleted")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.fail("Delete", str(e))
|
||||
return False
|
||||
|
||||
def success(self, operation, message):
|
||||
"""Record a successful test."""
|
||||
self.passed += 1
|
||||
print(f"[PASS] {operation} - {message}")
|
||||
|
||||
def fail(self, operation, error):
|
||||
"""Record a failed test."""
|
||||
self.failed += 1
|
||||
self.errors.append(f"{operation}: {error}")
|
||||
print(f"[FAIL] {operation} - {error}")
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary."""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Total Passed: {self.passed}")
|
||||
print(f"Total Failed: {self.failed}")
|
||||
print(f"Success Rate: {(self.passed / (self.passed + self.failed) * 100):.1f}%")
|
||||
|
||||
if self.errors:
|
||||
print("\nERRORS:")
|
||||
for error in self.errors:
|
||||
print(f" - {error}")
|
||||
|
||||
print("\nCONCLUSION:")
|
||||
if self.failed == 0:
|
||||
print("[SUCCESS] All CRUD operations working correctly!")
|
||||
print(" - Database connectivity verified")
|
||||
print(" - INSERT operations successful")
|
||||
print(" - SELECT operations successful")
|
||||
print(" - UPDATE operations successful")
|
||||
print(" - DELETE operations successful")
|
||||
print(" - Foreign key constraints enforced")
|
||||
print(" - Relationship traversal working")
|
||||
else:
|
||||
print(f"[FAILURE] {self.failed} test(s) failed - review errors above")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up database connection."""
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all CRUD tests."""
|
||||
tester = CRUDTester()
|
||||
|
||||
try:
|
||||
# Test 1: Connection
|
||||
if not tester.connect():
|
||||
print("\n[ERROR] Cannot proceed without database connection")
|
||||
return
|
||||
|
||||
# Test 2: Create
|
||||
if not tester.test_create():
|
||||
print("\n[ERROR] Cannot proceed without successful CREATE operations")
|
||||
tester.cleanup()
|
||||
return
|
||||
|
||||
# Test 3: Read
|
||||
tester.test_read()
|
||||
|
||||
# Test 4: Relationships
|
||||
tester.test_relationships()
|
||||
|
||||
# Test 5: Update
|
||||
tester.test_update()
|
||||
|
||||
# Test 6: Delete
|
||||
tester.test_delete()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[WARNING] Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n\n[ERROR] Unexpected error: {e}")
|
||||
finally:
|
||||
tester.print_summary()
|
||||
tester.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
99
tests/test_db_connection.py
Normal file
99
tests/test_db_connection.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test MariaDB connectivity from Windows"""
|
||||
import pymysql
|
||||
|
||||
# Connection details
|
||||
HOST = '172.16.3.20'
|
||||
PORT = 3306
|
||||
ROOT_PASSWORD = r'Dy8RPj-s{+=bP^(NoW"T;E~JXyBC9u|<'
|
||||
CLAUDETOOLS_PASSWORD = 'CT_e8fcd5a3952030a79ed6debae6c954ed'
|
||||
|
||||
print("Testing MariaDB connection to Jupiter (172.16.3.20)...\n")
|
||||
|
||||
# Test 1: Root connection
|
||||
try:
|
||||
print("Test 1: Connecting as root...")
|
||||
conn = pymysql.connect(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
user='root',
|
||||
password=ROOT_PASSWORD,
|
||||
connect_timeout=10
|
||||
)
|
||||
print("[OK] Root connection successful!")
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT VERSION()")
|
||||
version = cursor.fetchone()
|
||||
print(f" MariaDB Version: {version[0]}")
|
||||
|
||||
cursor.execute("SHOW DATABASES")
|
||||
databases = cursor.fetchall()
|
||||
print(f" Databases found: {len(databases)}")
|
||||
for db in databases:
|
||||
print(f" - {db[0]}")
|
||||
|
||||
# Check if claudetools database exists
|
||||
cursor.execute("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = 'claudetools'")
|
||||
claudetools_db = cursor.fetchone()
|
||||
if claudetools_db:
|
||||
print("\n[OK] 'claudetools' database exists!")
|
||||
else:
|
||||
print("\n[WARNING] 'claudetools' database does NOT exist yet")
|
||||
|
||||
# Check if claudetools user exists
|
||||
cursor.execute("SELECT User FROM mysql.user WHERE User = 'claudetools'")
|
||||
claudetools_user = cursor.fetchone()
|
||||
if claudetools_user:
|
||||
print("[OK] 'claudetools' user exists!")
|
||||
else:
|
||||
print("[WARNING] 'claudetools' user does NOT exist yet")
|
||||
|
||||
conn.close()
|
||||
print("\nTest 1 PASSED [OK]\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FAILED] Test 1: {e}\n")
|
||||
|
||||
# Test 2: Claudetools user connection (if exists)
|
||||
try:
|
||||
print("Test 2: Connecting as 'claudetools' user...")
|
||||
conn = pymysql.connect(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
user='claudetools',
|
||||
password=CLAUDETOOLS_PASSWORD,
|
||||
database='claudetools',
|
||||
connect_timeout=10
|
||||
)
|
||||
print("[OK] Claudetools user connection successful!")
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT DATABASE()")
|
||||
current_db = cursor.fetchone()
|
||||
print(f" Current database: {current_db[0]}")
|
||||
|
||||
cursor.execute("SHOW TABLES")
|
||||
tables = cursor.fetchall()
|
||||
print(f" Tables in claudetools: {len(tables)}")
|
||||
|
||||
conn.close()
|
||||
print("\nTest 2 PASSED [OK]\n")
|
||||
|
||||
except pymysql.err.OperationalError as e:
|
||||
if "Access denied" in str(e):
|
||||
print("[WARNING] Claudetools user doesn't exist or wrong password")
|
||||
elif "Unknown database" in str(e):
|
||||
print("[WARNING] Claudetools database doesn't exist yet")
|
||||
else:
|
||||
print(f"[WARNING] Test 2: {e}")
|
||||
print(" (This is expected if database/user haven't been created yet)\n")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Test 2: {e}\n")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("CONNECTION TEST SUMMARY")
|
||||
print("="*60)
|
||||
print("[OK] MariaDB is accessible from Windows on port 3306")
|
||||
print("[OK] Root authentication works")
|
||||
print("\nNext step: Create 'claudetools' database and user if they don't exist")
|
||||
88
tests/test_import_preview.py
Normal file
88
tests/test_import_preview.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick preview of what would be imported from Claude projects folder
|
||||
No API or auth required - just scans and shows what it finds
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add api directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from api.utils.conversation_parser import scan_folder_for_conversations, categorize_conversation, parse_jsonl_conversation
|
||||
from api.utils.credential_scanner import scan_for_credential_files
|
||||
|
||||
def preview_import(folder_path: str):
|
||||
"""Preview what would be imported from the given folder"""
|
||||
|
||||
print("=" * 70)
|
||||
print("CLAUDE CONTEXT IMPORT PREVIEW")
|
||||
print("=" * 70)
|
||||
print(f"\nScanning: {folder_path}\n")
|
||||
|
||||
# Scan for conversation files
|
||||
print("\n[1] Scanning for conversation files...")
|
||||
try:
|
||||
conversation_files = scan_folder_for_conversations(folder_path)
|
||||
print(f" Found {len(conversation_files)} conversation file(s)")
|
||||
|
||||
# Categorize each file
|
||||
categories = {"msp": 0, "development": 0, "general": 0}
|
||||
|
||||
for i, file_path in enumerate(conversation_files[:20]): # Limit to first 20
|
||||
try:
|
||||
conv = parse_jsonl_conversation(file_path)
|
||||
category = categorize_conversation(conv.get("messages", []))
|
||||
categories[category] += 1
|
||||
|
||||
# Show first 5
|
||||
if i < 5:
|
||||
rel_path = Path(file_path).relative_to(folder_path)
|
||||
print(f" [{category.upper()}] {rel_path}")
|
||||
except Exception as e:
|
||||
print(f" [ERROR] Failed to parse: {Path(file_path).name} - {e}")
|
||||
|
||||
if len(conversation_files) > 20:
|
||||
print(f" ... and {len(conversation_files) - 20} more files")
|
||||
|
||||
print(f"\n Category Breakdown:")
|
||||
print(f" MSP Work: {categories['msp']} files")
|
||||
print(f" Development: {categories['development']} files")
|
||||
print(f" General: {categories['general']} files")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error scanning conversations: {e}")
|
||||
|
||||
# Scan for credential files
|
||||
print("\n[2] Scanning for credential files...")
|
||||
try:
|
||||
credential_files = scan_for_credential_files(folder_path)
|
||||
print(f" Found {len(credential_files)} credential file(s)")
|
||||
|
||||
for i, file_path in enumerate(credential_files[:10]): # Limit to first 10
|
||||
rel_path = Path(file_path).relative_to(folder_path)
|
||||
print(f" {i+1}. {rel_path}")
|
||||
|
||||
if len(credential_files) > 10:
|
||||
print(f" ... and {len(credential_files) - 10} more files")
|
||||
except Exception as e:
|
||||
print(f" Error scanning credentials: {e}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("PREVIEW COMPLETE")
|
||||
print("=" * 70)
|
||||
print("\nTo actually import:")
|
||||
print(" 1. Ensure API is running: python -m api.main")
|
||||
print(" 2. Setup auth: bash scripts/setup-context-recall.sh")
|
||||
print(" 3. Run import: python scripts/import-claude-context.py --folder \"path\" --execute")
|
||||
print("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
folder = sys.argv[1]
|
||||
else:
|
||||
folder = r"C:\Users\MikeSwanson\claude-projects"
|
||||
|
||||
preview_import(folder)
|
||||
114
tests/test_import_speed.py
Normal file
114
tests/test_import_speed.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Test import speed and circular dependency detection."""
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# Set UTF-8 encoding for Windows console
|
||||
if os.name == 'nt':
|
||||
import io
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||||
|
||||
|
||||
def test_import_speed():
|
||||
"""Test how quickly the models module can be imported."""
|
||||
print("="*70)
|
||||
print("ClaudeTools - Import Speed and Dependency Test")
|
||||
print("="*70)
|
||||
|
||||
# Test 1: Cold import (first time)
|
||||
print("\n[TEST 1] Cold import (first time)...")
|
||||
start = time.time()
|
||||
import api.models
|
||||
cold_time = time.time() - start
|
||||
print(f" Time: {cold_time:.4f} seconds")
|
||||
|
||||
# Test 2: Reload (warm import)
|
||||
print("\n[TEST 2] Warm import (reload)...")
|
||||
start = time.time()
|
||||
import importlib
|
||||
importlib.reload(api.models)
|
||||
warm_time = time.time() - start
|
||||
print(f" Time: {warm_time:.4f} seconds")
|
||||
|
||||
# Test 3: Individual model imports
|
||||
print("\n[TEST 3] Individual model imports...")
|
||||
models_to_test = [
|
||||
'api.models.client',
|
||||
'api.models.session',
|
||||
'api.models.work_item',
|
||||
'api.models.credential',
|
||||
'api.models.infrastructure',
|
||||
'api.models.backup_log',
|
||||
'api.models.billable_time',
|
||||
'api.models.security_incident'
|
||||
]
|
||||
|
||||
individual_times = {}
|
||||
for module_name in models_to_test:
|
||||
start = time.time()
|
||||
module = importlib.import_module(module_name)
|
||||
elapsed = time.time() - start
|
||||
individual_times[module_name] = elapsed
|
||||
print(f" {module_name}: {elapsed:.4f}s")
|
||||
|
||||
# Test 4: Check for circular dependencies by import order
|
||||
print("\n[TEST 4] Circular dependency check...")
|
||||
print(" Importing in different orders to detect circular deps...")
|
||||
|
||||
# Try importing base first
|
||||
try:
|
||||
from api.models.base import Base, UUIDMixin, TimestampMixin
|
||||
print(" - Base classes: OK")
|
||||
except Exception as e:
|
||||
print(f" - Base classes: FAIL - {e}")
|
||||
|
||||
# Try importing models that have relationships
|
||||
try:
|
||||
from api.models.client import Client
|
||||
from api.models.session import Session
|
||||
from api.models.work_item import WorkItem
|
||||
print(" - Related models: OK")
|
||||
except Exception as e:
|
||||
print(f" - Related models: FAIL - {e}")
|
||||
|
||||
# Try importing all models at once
|
||||
try:
|
||||
from api.models import (
|
||||
Client, Session, WorkItem, Infrastructure,
|
||||
Credential, BillableTime, BackupLog, SecurityIncident
|
||||
)
|
||||
print(" - Bulk import: OK")
|
||||
except Exception as e:
|
||||
print(f" - Bulk import: FAIL - {e}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("RESULTS")
|
||||
print("="*70)
|
||||
print(f"\nImport Performance:")
|
||||
print(f" Cold import: {cold_time:.4f}s")
|
||||
print(f" Warm import: {warm_time:.4f}s")
|
||||
print(f" Average individual: {sum(individual_times.values())/len(individual_times):.4f}s")
|
||||
|
||||
# Performance assessment
|
||||
if cold_time < 1.0:
|
||||
perf_rating = "Excellent"
|
||||
elif cold_time < 2.0:
|
||||
perf_rating = "Good"
|
||||
elif cold_time < 3.0:
|
||||
perf_rating = "Acceptable"
|
||||
else:
|
||||
perf_rating = "Slow (may need optimization)"
|
||||
|
||||
print(f"\nPerformance Rating: {perf_rating}")
|
||||
|
||||
print("\nCircular Dependencies: None detected")
|
||||
print("Module Structure: Sound")
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("Import test complete!")
|
||||
print("="*70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_import_speed()
|
||||
207
tests/test_models_detailed.py
Normal file
207
tests/test_models_detailed.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Detailed structure validation for all SQLAlchemy models."""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Set UTF-8 encoding for Windows console
|
||||
if os.name == 'nt':
|
||||
import io
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||||
|
||||
import api.models
|
||||
from sqlalchemy.orm import RelationshipProperty
|
||||
from sqlalchemy.schema import ForeignKeyConstraint, CheckConstraint, Index
|
||||
|
||||
|
||||
def get_table_models():
|
||||
"""Get all table model classes (excluding base classes)."""
|
||||
base_classes = {'Base', 'TimestampMixin', 'UUIDMixin'}
|
||||
all_classes = [attr for attr in dir(api.models) if not attr.startswith('_') and attr[0].isupper()]
|
||||
return sorted([m for m in all_classes if m not in base_classes])
|
||||
|
||||
|
||||
def analyze_model(model_name):
|
||||
"""Analyze a model's structure in detail."""
|
||||
model_cls = getattr(api.models, model_name)
|
||||
|
||||
result = {
|
||||
'name': model_name,
|
||||
'table': model_cls.__tablename__,
|
||||
'has_uuid_mixin': False,
|
||||
'has_timestamp_mixin': False,
|
||||
'foreign_keys': [],
|
||||
'relationships': [],
|
||||
'indexes': [],
|
||||
'check_constraints': [],
|
||||
'columns': []
|
||||
}
|
||||
|
||||
# Check mixins
|
||||
for base in model_cls.__mro__:
|
||||
if base.__name__ == 'UUIDMixin':
|
||||
result['has_uuid_mixin'] = True
|
||||
if base.__name__ == 'TimestampMixin':
|
||||
result['has_timestamp_mixin'] = True
|
||||
|
||||
# Get table object
|
||||
if hasattr(model_cls, '__table__'):
|
||||
table = model_cls.__table__
|
||||
|
||||
# Analyze columns
|
||||
for col in table.columns:
|
||||
col_info = {
|
||||
'name': col.name,
|
||||
'type': str(col.type),
|
||||
'nullable': col.nullable,
|
||||
'primary_key': col.primary_key
|
||||
}
|
||||
result['columns'].append(col_info)
|
||||
|
||||
# Analyze foreign keys
|
||||
for fk in table.foreign_keys:
|
||||
result['foreign_keys'].append({
|
||||
'parent_column': fk.parent.name,
|
||||
'target': str(fk.target_fullname)
|
||||
})
|
||||
|
||||
# Analyze indexes
|
||||
if hasattr(table, 'indexes'):
|
||||
for idx in table.indexes:
|
||||
result['indexes'].append({
|
||||
'name': idx.name,
|
||||
'columns': [col.name for col in idx.columns]
|
||||
})
|
||||
|
||||
# Analyze check constraints
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, CheckConstraint):
|
||||
result['check_constraints'].append({
|
||||
'sqltext': str(constraint.sqltext)
|
||||
})
|
||||
|
||||
# Analyze relationships
|
||||
for attr_name in dir(model_cls):
|
||||
try:
|
||||
attr = getattr(model_cls, attr_name)
|
||||
if hasattr(attr, 'property') and isinstance(attr.property, RelationshipProperty):
|
||||
rel = attr.property
|
||||
result['relationships'].append({
|
||||
'name': attr_name,
|
||||
'target': rel.mapper.class_.__name__,
|
||||
'uselist': rel.uselist
|
||||
})
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_model_summary(result):
|
||||
"""Print a formatted summary of model structure."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Model: {result['name']} (table: {result['table']})")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Mixins
|
||||
mixins = []
|
||||
if result['has_uuid_mixin']:
|
||||
mixins.append('UUIDMixin')
|
||||
if result['has_timestamp_mixin']:
|
||||
mixins.append('TimestampMixin')
|
||||
if mixins:
|
||||
print(f"Mixins: {', '.join(mixins)}")
|
||||
|
||||
# Columns
|
||||
print(f"\nColumns ({len(result['columns'])}):")
|
||||
for col in result['columns'][:10]: # Limit to first 10 for readability
|
||||
pk = " [PK]" if col['primary_key'] else ""
|
||||
nullable = "NULL" if col['nullable'] else "NOT NULL"
|
||||
print(f" - {col['name']}: {col['type']} {nullable}{pk}")
|
||||
if len(result['columns']) > 10:
|
||||
print(f" ... and {len(result['columns']) - 10} more columns")
|
||||
|
||||
# Foreign Keys
|
||||
if result['foreign_keys']:
|
||||
print(f"\nForeign Keys ({len(result['foreign_keys'])}):")
|
||||
for fk in result['foreign_keys']:
|
||||
print(f" - {fk['parent_column']} -> {fk['target']}")
|
||||
|
||||
# Relationships
|
||||
if result['relationships']:
|
||||
print(f"\nRelationships ({len(result['relationships'])}):")
|
||||
for rel in result['relationships']:
|
||||
rel_type = "many" if rel['uselist'] else "one"
|
||||
print(f" - {rel['name']} -> {rel['target']} ({rel_type})")
|
||||
|
||||
# Indexes
|
||||
if result['indexes']:
|
||||
print(f"\nIndexes ({len(result['indexes'])}):")
|
||||
for idx in result['indexes']:
|
||||
cols = ', '.join(idx['columns'])
|
||||
print(f" - {idx['name']}: ({cols})")
|
||||
|
||||
# Check Constraints
|
||||
if result['check_constraints']:
|
||||
print(f"\nCheck Constraints ({len(result['check_constraints'])}):")
|
||||
for check in result['check_constraints']:
|
||||
print(f" - {check['sqltext']}")
|
||||
|
||||
|
||||
def main():
|
||||
print("="*70)
|
||||
print("ClaudeTools - Detailed Model Structure Analysis")
|
||||
print("="*70)
|
||||
|
||||
models = get_table_models()
|
||||
print(f"\nAnalyzing {len(models)} table models...\n")
|
||||
|
||||
all_results = []
|
||||
for model_name in models:
|
||||
try:
|
||||
result = analyze_model(model_name)
|
||||
all_results.append(result)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Error analyzing {model_name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Print summary statistics
|
||||
print("\n" + "="*70)
|
||||
print("SUMMARY STATISTICS")
|
||||
print("="*70)
|
||||
|
||||
total_models = len(all_results)
|
||||
models_with_uuid = sum(1 for r in all_results if r['has_uuid_mixin'])
|
||||
models_with_timestamp = sum(1 for r in all_results if r['has_timestamp_mixin'])
|
||||
models_with_fk = sum(1 for r in all_results if r['foreign_keys'])
|
||||
models_with_rel = sum(1 for r in all_results if r['relationships'])
|
||||
models_with_idx = sum(1 for r in all_results if r['indexes'])
|
||||
models_with_checks = sum(1 for r in all_results if r['check_constraints'])
|
||||
|
||||
total_fk = sum(len(r['foreign_keys']) for r in all_results)
|
||||
total_rel = sum(len(r['relationships']) for r in all_results)
|
||||
total_idx = sum(len(r['indexes']) for r in all_results)
|
||||
total_checks = sum(len(r['check_constraints']) for r in all_results)
|
||||
|
||||
print(f"\nTotal Models: {total_models}")
|
||||
print(f" - With UUIDMixin: {models_with_uuid}")
|
||||
print(f" - With TimestampMixin: {models_with_timestamp}")
|
||||
print(f" - With Foreign Keys: {models_with_fk} (total: {total_fk})")
|
||||
print(f" - With Relationships: {models_with_rel} (total: {total_rel})")
|
||||
print(f" - With Indexes: {models_with_idx} (total: {total_idx})")
|
||||
print(f" - With CHECK Constraints: {models_with_checks} (total: {total_checks})")
|
||||
|
||||
# Print detailed info for each model
|
||||
print("\n" + "="*70)
|
||||
print("DETAILED MODEL INFORMATION")
|
||||
print("="*70)
|
||||
|
||||
for result in all_results:
|
||||
print_model_summary(result)
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("[SUCCESS] Analysis complete!")
|
||||
print("="*70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
tests/test_models_import.py
Normal file
127
tests/test_models_import.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Test script to import and validate all SQLAlchemy models."""
|
||||
import sys
|
||||
import traceback
|
||||
import os
|
||||
|
||||
# Set UTF-8 encoding for Windows console
|
||||
if os.name == 'nt':
|
||||
import io
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||||
|
||||
def test_model_import():
|
||||
"""Test importing all models from api.models."""
|
||||
try:
|
||||
import api.models
|
||||
print("[SUCCESS] Import successful")
|
||||
|
||||
# Get all model classes (exclude private attributes and modules)
|
||||
all_classes = [attr for attr in dir(api.models) if not attr.startswith('_') and attr[0].isupper()]
|
||||
|
||||
# Filter out base classes and mixins (they don't have __tablename__)
|
||||
base_classes = {'Base', 'TimestampMixin', 'UUIDMixin'}
|
||||
models = [m for m in all_classes if m not in base_classes]
|
||||
|
||||
print(f"\nTotal classes found: {len(all_classes)}")
|
||||
print(f"Base classes/mixins: {len(base_classes)}")
|
||||
print(f"Table models: {len(models)}")
|
||||
print("\nTable Models:")
|
||||
for m in sorted(models):
|
||||
print(f" - {m}")
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Import failed: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def test_model_structure(model_name):
|
||||
"""Test individual model structure and configuration."""
|
||||
import api.models
|
||||
|
||||
try:
|
||||
model_cls = getattr(api.models, model_name)
|
||||
|
||||
# Check if it's actually a class
|
||||
if not isinstance(model_cls, type):
|
||||
return f"[FAIL] {model_name} - Not a class"
|
||||
|
||||
# Check for __tablename__
|
||||
if not hasattr(model_cls, '__tablename__'):
|
||||
return f"[FAIL] {model_name} - Missing __tablename__"
|
||||
|
||||
# Check for __table_args__ (optional but should exist if defined)
|
||||
has_table_args = hasattr(model_cls, '__table_args__')
|
||||
|
||||
# Try to instantiate (without saving to DB)
|
||||
try:
|
||||
instance = model_cls()
|
||||
can_instantiate = True
|
||||
except Exception as inst_error:
|
||||
can_instantiate = False
|
||||
inst_msg = str(inst_error)
|
||||
|
||||
# Build result message
|
||||
details = []
|
||||
details.append(f"table={model_cls.__tablename__}")
|
||||
if has_table_args:
|
||||
details.append("has_table_args")
|
||||
if can_instantiate:
|
||||
details.append("instantiable")
|
||||
else:
|
||||
details.append(f"not_instantiable({inst_msg[:50]})")
|
||||
|
||||
return f"[PASS] {model_name} - {', '.join(details)}"
|
||||
|
||||
except Exception as e:
|
||||
return f"[FAIL] {model_name} - {str(e)}"
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print("ClaudeTools - Model Import and Structure Test")
|
||||
print("=" * 70)
|
||||
|
||||
# Test 1: Import all models
|
||||
print("\n[TEST 1] Importing api.models module...")
|
||||
models = test_model_import()
|
||||
|
||||
if not models:
|
||||
print("\n[CRITICAL] Failed to import models module")
|
||||
sys.exit(1)
|
||||
|
||||
# Test 2: Validate each model structure
|
||||
print(f"\n[TEST 2] Validating structure of {len(models)} models...")
|
||||
print("-" * 70)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
results = []
|
||||
|
||||
for model_name in sorted(models):
|
||||
result = test_model_structure(model_name)
|
||||
results.append(result)
|
||||
|
||||
if result.startswith("[PASS]"):
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
# Print all results
|
||||
for result in results:
|
||||
print(result)
|
||||
|
||||
# Summary
|
||||
print("-" * 70)
|
||||
print(f"\n[SUMMARY]")
|
||||
print(f"Total models: {len(models)}")
|
||||
print(f"[PASS] Passed: {passed}")
|
||||
print(f"[FAIL] Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print(f"\n[SUCCESS] All {passed} models validated successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n[WARNING] {failed} model(s) need attention")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1597
tests/test_phase5_api_endpoints.py
Normal file
1597
tests/test_phase5_api_endpoints.py
Normal file
File diff suppressed because it is too large
Load Diff
333
tests/test_sql_injection_security.py
Normal file
333
tests/test_sql_injection_security.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
SQL Injection Security Tests for Context Recall API
|
||||
|
||||
Tests that the recall API is properly protected against SQL injection attacks.
|
||||
Validates both the input validation layer and the parameterized query layer.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import auth utilities for token creation
|
||||
from api.middleware.auth import create_access_token
|
||||
|
||||
|
||||
# Test configuration
|
||||
API_BASE_URL = "http://172.16.3.30:8001/api"
|
||||
TEST_USER_EMAIL = "admin@claudetools.local"
|
||||
|
||||
|
||||
class TestSQLInjectionSecurity(unittest.TestCase):
|
||||
"""Test suite for SQL injection attack prevention."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Create test JWT token for authentication."""
|
||||
# Create token directly without login endpoint
|
||||
cls.token = create_access_token({"sub": TEST_USER_EMAIL})
|
||||
cls.headers = {"Authorization": f"Bearer {cls.token}"}
|
||||
|
||||
# SQL Injection Test Cases for search_term parameter
|
||||
|
||||
def test_sql_injection_search_term_basic_attack(self):
|
||||
"""Test basic SQL injection attempt via search_term."""
|
||||
malicious_input = "' OR '1'='1"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation (contains single quotes)
|
||||
assert response.status_code == 422, "Failed to reject SQL injection attack"
|
||||
error_detail = response.json()["detail"]
|
||||
assert any("pattern" in str(err).lower() or "match" in str(err).lower()
|
||||
for err in error_detail if isinstance(err, dict))
|
||||
|
||||
def test_sql_injection_search_term_union_attack(self):
|
||||
"""Test UNION-based SQL injection attempt."""
|
||||
malicious_input = "' UNION SELECT * FROM users--"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation
|
||||
assert response.status_code == 422, "Failed to reject UNION attack"
|
||||
|
||||
def test_sql_injection_search_term_comment_injection(self):
|
||||
"""Test comment-based SQL injection."""
|
||||
malicious_input = "test' --"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation (contains single quote)
|
||||
assert response.status_code == 422, "Failed to reject comment injection"
|
||||
|
||||
def test_sql_injection_search_term_semicolon_attack(self):
|
||||
"""Test semicolon-based SQL injection for multiple statements."""
|
||||
malicious_input = "test'; DROP TABLE conversation_contexts;--"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation (contains semicolon and quotes)
|
||||
assert response.status_code == 422, "Failed to reject DROP TABLE attack"
|
||||
|
||||
def test_sql_injection_search_term_encoded_attack(self):
|
||||
"""Test URL-encoded SQL injection attempt."""
|
||||
# URL encoding of "' OR 1=1--"
|
||||
malicious_input = "%27%20OR%201%3D1--"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation after decoding
|
||||
assert response.status_code == 422, "Failed to reject encoded attack"
|
||||
|
||||
# SQL Injection Test Cases for tags parameter
|
||||
|
||||
def test_sql_injection_tags_basic_attack(self):
|
||||
"""Test SQL injection via tags parameter."""
|
||||
malicious_tag = "' OR '1'='1"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": [malicious_tag]},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to tag validation (contains single quotes and spaces)
|
||||
assert response.status_code == 400, "Failed to reject SQL injection via tags"
|
||||
assert "Invalid tag format" in response.json()["detail"]
|
||||
|
||||
def test_sql_injection_tags_union_attack(self):
|
||||
"""Test UNION attack via tags parameter."""
|
||||
malicious_tag = "tag' UNION SELECT password FROM users--"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": [malicious_tag]},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to tag validation
|
||||
assert response.status_code == 400, "Failed to reject UNION attack via tags"
|
||||
|
||||
def test_sql_injection_tags_multiple_malicious(self):
|
||||
"""Test multiple malicious tags."""
|
||||
malicious_tags = [
|
||||
"tag1' OR '1'='1",
|
||||
"tag2'; DROP TABLE tags;--",
|
||||
"tag3' UNION SELECT NULL--"
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": malicious_tags},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to tag validation
|
||||
assert response.status_code == 400, "Failed to reject multiple malicious tags"
|
||||
|
||||
# Valid Input Tests (should succeed)
|
||||
|
||||
def test_valid_search_term_alphanumeric(self):
|
||||
"""Test that valid alphanumeric search terms work."""
|
||||
valid_input = "API development"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": valid_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200, f"Valid input rejected: {response.text}"
|
||||
data = response.json()
|
||||
assert "contexts" in data
|
||||
assert isinstance(data["contexts"], list)
|
||||
|
||||
def test_valid_search_term_with_punctuation(self):
|
||||
"""Test valid search terms with allowed punctuation."""
|
||||
valid_input = "database-migration (phase-1)!"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": valid_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200, f"Valid input rejected: {response.text}"
|
||||
|
||||
def test_valid_tags(self):
|
||||
"""Test that valid tags work."""
|
||||
valid_tags = ["api", "database", "phase-1", "test_tag"]
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": valid_tags},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200, f"Valid tags rejected: {response.text}"
|
||||
data = response.json()
|
||||
assert "contexts" in data
|
||||
|
||||
# Boundary Tests
|
||||
|
||||
def test_search_term_max_length(self):
|
||||
"""Test search term at maximum allowed length (200 chars)."""
|
||||
valid_input = "a" * 200
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": valid_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200, "Max length valid input rejected"
|
||||
|
||||
def test_search_term_exceeds_max_length(self):
|
||||
"""Test search term exceeding maximum length."""
|
||||
invalid_input = "a" * 201
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": invalid_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject
|
||||
assert response.status_code == 422, "Overlong input not rejected"
|
||||
|
||||
def test_tags_max_items(self):
|
||||
"""Test maximum number of tags (20)."""
|
||||
valid_tags = [f"tag{i}" for i in range(20)]
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": valid_tags},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200, "Max tags rejected"
|
||||
|
||||
def test_tags_exceeds_max_items(self):
|
||||
"""Test exceeding maximum number of tags."""
|
||||
invalid_tags = [f"tag{i}" for i in range(21)]
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"tags": invalid_tags},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject
|
||||
assert response.status_code == 422, "Too many tags not rejected"
|
||||
|
||||
# Advanced SQL Injection Techniques
|
||||
|
||||
def test_sql_injection_hex_encoding(self):
|
||||
"""Test hex-encoded SQL injection."""
|
||||
malicious_input = "0x27204f522031203d2031" # Hex for "' OR 1 = 1"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Pattern allows alphanumeric, so this passes input validation
|
||||
# Should be safe due to parameterized queries
|
||||
assert response.status_code == 200, "Hex encoding caused error"
|
||||
# Verify it's treated as literal search, not executed as SQL
|
||||
data = response.json()
|
||||
assert isinstance(data["contexts"], list)
|
||||
|
||||
def test_sql_injection_time_based_blind(self):
|
||||
"""Test time-based blind SQL injection attempt."""
|
||||
malicious_input = "' AND SLEEP(5)--"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation
|
||||
assert response.status_code == 422, "Time-based attack not rejected"
|
||||
|
||||
def test_sql_injection_stacked_queries(self):
|
||||
"""Test stacked query injection."""
|
||||
malicious_input = "test; DELETE FROM conversation_contexts WHERE 1=1"
|
||||
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": malicious_input},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
# Should reject due to pattern validation (semicolon not allowed)
|
||||
assert response.status_code == 422, "Stacked query attack not rejected"
|
||||
|
||||
# Verify Database Integrity
|
||||
|
||||
def test_database_not_compromised(self):
|
||||
"""Verify database still functions after attack attempts."""
|
||||
# Simple query to verify database is intact
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"limit": 5},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200, "Database may be compromised"
|
||||
data = response.json()
|
||||
assert "contexts" in data
|
||||
assert isinstance(data["contexts"], list)
|
||||
|
||||
def test_fulltext_index_still_works(self):
|
||||
"""Verify FULLTEXT index functionality after attacks."""
|
||||
# Test normal search that should use FULLTEXT index
|
||||
response = requests.get(
|
||||
f"{API_BASE_URL}/conversation-contexts/recall",
|
||||
params={"search_term": "test"},
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200, "FULLTEXT search failed"
|
||||
data = response.json()
|
||||
assert isinstance(data["contexts"], list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 70)
|
||||
print("SQL INJECTION SECURITY TEST SUITE")
|
||||
print("=" * 70)
|
||||
print()
|
||||
print("Testing Context Recall API endpoint security...")
|
||||
print(f"Target: {API_BASE_URL}/conversation-contexts/recall")
|
||||
print()
|
||||
|
||||
# Run tests
|
||||
unittest.main(verbosity=2)
|
||||
162
tests/test_sql_injection_simple.sh
Normal file
162
tests/test_sql_injection_simple.sh
Normal file
@@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Simplified SQL Injection Security Tests
|
||||
# Tests the recall API endpoint against SQL injection attacks
|
||||
#
|
||||
|
||||
API_URL="http://172.16.3.30:8001/api"
|
||||
|
||||
# Get JWT token from setup config if it exists
|
||||
if [ -f ".claude/context-recall-config.env" ]; then
|
||||
source .claude/context-recall-config.env
|
||||
fi
|
||||
|
||||
# Test counter
|
||||
TOTAL_TESTS=0
|
||||
PASSED_TESTS=0
|
||||
FAILED_TESTS=0
|
||||
|
||||
# Color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Test function
|
||||
run_test() {
|
||||
local test_name="$1"
|
||||
local search_term="$2"
|
||||
local expected_status="$3"
|
||||
local test_type="${4:-search_term}" # search_term or tag
|
||||
|
||||
TOTAL_TESTS=$((TOTAL_TESTS + 1))
|
||||
|
||||
# Build curl command based on test type
|
||||
if [ "$test_type" = "tag" ]; then
|
||||
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?tags[]=$search_term" \
|
||||
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
|
||||
else
|
||||
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?search_term=$search_term" \
|
||||
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
|
||||
fi
|
||||
|
||||
http_code=$(echo "$response" | tail -1)
|
||||
body=$(echo "$response" | sed '$d')
|
||||
|
||||
# Check if status code matches expected
|
||||
if [ "$http_code" = "$expected_status" ]; then
|
||||
echo -e "${GREEN}[PASS]${NC} $test_name (HTTP $http_code)"
|
||||
PASSED_TESTS=$((PASSED_TESTS + 1))
|
||||
return 0
|
||||
else
|
||||
echo -e "${RED}[FAIL]${NC} $test_name"
|
||||
echo " Expected: HTTP $expected_status"
|
||||
echo " Got: HTTP $http_code"
|
||||
echo " Response: $body"
|
||||
FAILED_TESTS=$((FAILED_TESTS + 1))
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Print header
|
||||
echo "======================================================================="
|
||||
echo "SQL INJECTION SECURITY TEST SUITE - Simplified"
|
||||
echo "======================================================================="
|
||||
echo ""
|
||||
echo "Target: $API_URL/conversation-contexts/recall"
|
||||
echo ""
|
||||
|
||||
# Verify JWT token
|
||||
if [ -z "$JWT_TOKEN" ]; then
|
||||
echo -e "${RED}[ERROR]${NC} JWT_TOKEN not set. Run setup-context-recall.sh first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Testing SQL injection vulnerabilities..."
|
||||
echo ""
|
||||
|
||||
# Test 1: Basic SQL injection with single quote (should be rejected - 422)
|
||||
run_test "Basic SQL injection: ' OR '1'='1" "' OR '1'='1" "422"
|
||||
|
||||
# Test 2: UNION attack (should be rejected - 422)
|
||||
run_test "UNION attack: ' UNION SELECT * FROM users--" "' UNION SELECT * FROM users--" "422"
|
||||
|
||||
# Test 3: Comment injection (should be rejected - 422)
|
||||
run_test "Comment injection: test' --" "test' --" "422"
|
||||
|
||||
# Test 4: Semicolon attack (should be rejected - 422)
|
||||
run_test "Semicolon attack: test'; DROP TABLE conversation_contexts;--" "test'; DROP TABLE conversation_contexts;--" "422"
|
||||
|
||||
# Test 5: Time-based blind SQLi (should be rejected - 422)
|
||||
run_test "Time-based blind: ' AND SLEEP(5)--" "' AND SLEEP(5)--" "422"
|
||||
|
||||
# Test 6: Stacked queries (should be rejected - 422)
|
||||
run_test "Stacked queries: test; DELETE FROM contexts" "test; DELETE FROM contexts" "422"
|
||||
|
||||
# Test 7: SQL injection via tags (should be rejected - 400)
|
||||
run_test "Tag injection: ' OR '1'='1" "' OR '1'='1" "400" "tag"
|
||||
|
||||
# Test 8: Tag UNION attack (should be rejected - 400)
|
||||
run_test "Tag UNION: tag' UNION SELECT--" "tag' UNION SELECT--" "400" "tag"
|
||||
|
||||
# Valid inputs (should succeed - 200)
|
||||
echo ""
|
||||
echo "Testing valid inputs (should work)..."
|
||||
echo ""
|
||||
|
||||
# Test 9: Valid alphanumeric search (should succeed - 200)
|
||||
run_test "Valid search: API development" "API development" "200"
|
||||
|
||||
# Test 10: Valid search with allowed punctuation (should succeed - 200)
|
||||
run_test "Valid punctuation: database-migration (phase-1)!" "database-migration (phase-1)!" "200"
|
||||
|
||||
# Test 11: Valid tags (should succeed - 200)
|
||||
run_test "Valid tags: api-test" "api-test" "200" "tag"
|
||||
|
||||
# Test 12: Verify database still works after attacks (should succeed - 200)
|
||||
echo ""
|
||||
echo "Verifying database integrity..."
|
||||
echo ""
|
||||
|
||||
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?limit=5" \
|
||||
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
|
||||
http_code=$(echo "$response" | tail -1)
|
||||
|
||||
if [ "$http_code" = "200" ]; then
|
||||
echo -e "${GREEN}[PASS]${NC} Database integrity check (HTTP $http_code)"
|
||||
PASSED_TESTS=$((PASSED_TESTS + 1))
|
||||
TOTAL_TESTS=$((TOTAL_TESTS + 1))
|
||||
else
|
||||
echo -e "${RED}[FAIL]${NC} Database integrity check"
|
||||
echo " Expected: HTTP 200"
|
||||
echo " Got: HTTP $http_code"
|
||||
FAILED_TESTS=$((FAILED_TESTS + 1))
|
||||
TOTAL_TESTS=$((TOTAL_TESTS + 1))
|
||||
fi
|
||||
|
||||
# Print summary
|
||||
echo ""
|
||||
echo "======================================================================="
|
||||
echo "TEST SUMMARY"
|
||||
echo "======================================================================="
|
||||
echo "Total Tests: $TOTAL_TESTS"
|
||||
echo -e "${GREEN}Passed: $PASSED_TESTS${NC}"
|
||||
if [ $FAILED_TESTS -gt 0 ]; then
|
||||
echo -e "${RED}Failed: $FAILED_TESTS${NC}"
|
||||
else
|
||||
echo -e "${GREEN}Failed: $FAILED_TESTS${NC}"
|
||||
fi
|
||||
|
||||
pass_rate=$(awk "BEGIN {printf \"%.1f\", ($PASSED_TESTS/$TOTAL_TESTS)*100}")
|
||||
echo "Pass Rate: $pass_rate%"
|
||||
echo ""
|
||||
|
||||
if [ $FAILED_TESTS -eq 0 ]; then
|
||||
echo -e "${GREEN}[SUCCESS]${NC} All SQL injection tests passed!"
|
||||
echo "The API is properly protected against SQL injection attacks."
|
||||
exit 0
|
||||
else
|
||||
echo -e "${RED}[FAILURE]${NC} Some tests failed!"
|
||||
echo "Review the failed tests above for security vulnerabilities."
|
||||
exit 1
|
||||
fi
|
||||
Reference in New Issue
Block a user