feat: Major directory reorganization and cleanup

Reorganized project structure for better maintainability and reduced
disk usage by 95.9% (11 GB -> 451 MB).

Directory Reorganization (85% reduction in root files):
- Created docs/ with subdirectories (deployment, testing, database, etc.)
- Created infrastructure/vpn-configs/ for VPN scripts
- Moved 90+ files from root to organized locations
- Archived obsolete documentation (context system, offline mode, zombie debugging)
- Moved all test files to tests/ directory
- Root directory: 119 files -> 18 files

Disk Cleanup (10.55 GB recovered):
- Deleted Rust build artifacts: 9.6 GB (target/ directories)
- Deleted Python virtual environments: 161 MB (venv/ directories)
- Deleted Python cache: 50 KB (__pycache__/)

New Structure:
- docs/ - All documentation organized by category
- docs/archives/ - Obsolete but preserved documentation
- infrastructure/ - VPN configs and SSH setup
- tests/ - All test files consolidated
- logs/ - Ready for future logs

Benefits:
- Cleaner root directory (18 vs 119 files)
- Logical organization of documentation
- 95.9% disk space reduction
- Faster navigation and discovery
- Better portability (build artifacts excluded)

Build artifacts can be regenerated:
- Rust: cargo build --release (5-15 min per project)
- Python: pip install -r requirements.txt (2-3 min)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-18 20:42:28 -07:00
parent 89e5118306
commit 06f7617718
96 changed files with 54 additions and 2639 deletions

821
tests/test_api_endpoints.py Normal file
View File

@@ -0,0 +1,821 @@
"""
Comprehensive API Endpoint Tests for ClaudeTools FastAPI Application
This test suite validates all 5 core API endpoints:
- Machines
- Clients
- Projects
- Sessions
- Tags
Tests include:
- API startup and health checks
- CRUD operations for all entities
- Authentication (with and without JWT tokens)
- Pagination parameters
- Error handling (404, 409, 422 responses)
"""
import sys
from datetime import timedelta
from uuid import uuid4
from fastapi.testclient import TestClient
# Import the FastAPI app and auth utilities
from api.main import app
from api.middleware.auth import create_access_token
# Create test client
client = TestClient(app)
# Test counters
tests_passed = 0
tests_failed = 0
test_results = []
def log_test(test_name: str, passed: bool, error_msg: str = ""):
"""Log test result and update counters."""
global tests_passed, tests_failed
if passed:
tests_passed += 1
status = "PASS"
symbol = "[+]"
else:
tests_failed += 1
status = "FAIL"
symbol = "[-]"
result = f"{symbol} {status}: {test_name}"
if error_msg:
result += f"\n Error: {error_msg}"
test_results.append((test_name, passed, error_msg))
print(result)
def create_test_token():
"""Create a test JWT token for authentication."""
token_data = {
"sub": "test_user@claudetools.com",
"scopes": ["msp:read", "msp:write", "msp:admin"]
}
return create_access_token(token_data, expires_delta=timedelta(hours=1))
def get_auth_headers():
"""Get authorization headers with test token."""
token = create_test_token()
return {"Authorization": f"Bearer {token}"}
# ============================================================================
# SECTION 1: API Health and Startup Tests
# ============================================================================
print("\n" + "="*70)
print("SECTION 1: API Health and Startup Tests")
print("="*70 + "\n")
def test_root_endpoint():
"""Test root endpoint returns API status."""
try:
response = client.get("/")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["status"] == "online", f"Expected status 'online', got {data.get('status')}"
assert "service" in data, "Response missing 'service' field"
assert "version" in data, "Response missing 'version' field"
log_test("Root endpoint (/)", True)
except Exception as e:
log_test("Root endpoint (/)", False, str(e))
def test_health_endpoint():
"""Test health check endpoint."""
try:
response = client.get("/health")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["status"] == "healthy", f"Expected status 'healthy', got {data.get('status')}"
log_test("Health check endpoint (/health)", True)
except Exception as e:
log_test("Health check endpoint (/health)", False, str(e))
def test_jwt_token_creation():
"""Test JWT token creation."""
try:
token = create_test_token()
assert token is not None, "Token creation returned None"
assert len(token) > 20, "Token seems too short"
log_test("JWT token creation", True)
except Exception as e:
log_test("JWT token creation", False, str(e))
# ============================================================================
# SECTION 2: Authentication Tests
# ============================================================================
print("\n" + "="*70)
print("SECTION 2: Authentication Tests")
print("="*70 + "\n")
def test_unauthenticated_access():
"""Test that protected endpoints reject requests without auth."""
try:
response = client.get("/api/machines")
# Can be 401 (Unauthorized) or 403 (Forbidden) depending on implementation
assert response.status_code in [401, 403], f"Expected 401 or 403, got {response.status_code}"
log_test("Unauthenticated access rejected", True)
except Exception as e:
log_test("Unauthenticated access rejected", False, str(e))
def test_authenticated_access():
"""Test that protected endpoints accept valid JWT tokens."""
try:
headers = get_auth_headers()
response = client.get("/api/machines", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
log_test("Authenticated access accepted", True)
except Exception as e:
log_test("Authenticated access accepted", False, str(e))
def test_invalid_token():
"""Test that invalid tokens are rejected."""
try:
headers = {"Authorization": "Bearer invalid_token_string"}
response = client.get("/api/machines", headers=headers)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
log_test("Invalid token rejected", True)
except Exception as e:
log_test("Invalid token rejected", False, str(e))
# ============================================================================
# SECTION 3: Machine CRUD Operations
# ============================================================================
print("\n" + "="*70)
print("SECTION 3: Machine CRUD Operations")
print("="*70 + "\n")
machine_id = None
def test_create_machine():
"""Test creating a new machine."""
global machine_id
try:
headers = get_auth_headers()
machine_data = {
"hostname": f"test-machine-{uuid4().hex[:8]}",
"friendly_name": "Test Machine",
"machine_type": "laptop",
"platform": "win32",
"is_active": True
}
response = client.post("/api/machines", json=machine_data, headers=headers)
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
data = response.json()
assert "id" in data, f"Response missing 'id' field. Data: {data}"
machine_id = data["id"]
print(f" Created machine with ID: {machine_id}")
log_test("Create machine", True)
except Exception as e:
log_test("Create machine", False, str(e))
def test_list_machines():
"""Test listing machines with pagination."""
try:
headers = get_auth_headers()
response = client.get("/api/machines?skip=0&limit=10", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "total" in data, "Response missing 'total' field"
assert "machines" in data, "Response missing 'machines' field"
assert isinstance(data["machines"], list), "machines field is not a list"
log_test("List machines", True)
except Exception as e:
log_test("List machines", False, str(e))
def test_get_machine():
"""Test retrieving a specific machine by ID."""
try:
if machine_id is None:
raise Exception("No machine_id available (create test may have failed)")
headers = get_auth_headers()
print(f" Fetching machine with ID: {machine_id} (type: {type(machine_id)})")
# List all machines to check if our machine exists
list_response = client.get("/api/machines", headers=headers)
all_machines = list_response.json().get("machines", [])
print(f" Total machines in DB: {len(all_machines)}")
if all_machines:
print(f" First machine ID: {all_machines[0].get('id')} (type: {type(all_machines[0].get('id'))})")
response = client.get(f"/api/machines/{machine_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}. Response: {response.text}"
data = response.json()
assert str(data["id"]) == str(machine_id), f"Expected ID {machine_id}, got {data.get('id')}"
log_test("Get machine by ID", True)
except Exception as e:
log_test("Get machine by ID", False, str(e))
def test_update_machine():
"""Test updating a machine."""
try:
if machine_id is None:
raise Exception("No machine_id available (create test may have failed)")
headers = get_auth_headers()
update_data = {
"friendly_name": "Updated Test Machine",
"notes": "Updated during testing"
}
response = client.put(f"/api/machines/{machine_id}", json=update_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["friendly_name"] == "Updated Test Machine", "Update not reflected"
log_test("Update machine", True)
except Exception as e:
log_test("Update machine", False, str(e))
def test_machine_not_found():
"""Test getting non-existent machine returns 404."""
try:
headers = get_auth_headers()
fake_id = str(uuid4())
response = client.get(f"/api/machines/{fake_id}", headers=headers)
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
log_test("Machine not found (404)", True)
except Exception as e:
log_test("Machine not found (404)", False, str(e))
def test_delete_machine():
"""Test deleting a machine."""
try:
if machine_id is None:
raise Exception("No machine_id available (create test may have failed)")
headers = get_auth_headers()
response = client.delete(f"/api/machines/{machine_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
log_test("Delete machine", True)
except Exception as e:
log_test("Delete machine", False, str(e))
# ============================================================================
# SECTION 4: Client CRUD Operations
# ============================================================================
print("\n" + "="*70)
print("SECTION 4: Client CRUD Operations")
print("="*70 + "\n")
client_id = None
def test_create_client():
"""Test creating a new client."""
global client_id
try:
headers = get_auth_headers()
client_data = {
"name": f"Test Client {uuid4().hex[:8]}",
"type": "msp_client",
"primary_contact": "John Doe",
"is_active": True
}
response = client.post("/api/clients", json=client_data, headers=headers)
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
data = response.json()
assert "id" in data, f"Response missing 'id' field. Data: {data}"
client_id = data["id"]
print(f" Created client with ID: {client_id}")
log_test("Create client", True)
except Exception as e:
log_test("Create client", False, str(e))
def test_list_clients():
"""Test listing clients with pagination."""
try:
headers = get_auth_headers()
response = client.get("/api/clients?skip=0&limit=10", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "total" in data, "Response missing 'total' field"
assert "clients" in data, "Response missing 'clients' field"
log_test("List clients", True)
except Exception as e:
log_test("List clients", False, str(e))
def test_get_client():
"""Test retrieving a specific client by ID."""
try:
if client_id is None:
raise Exception("No client_id available")
headers = get_auth_headers()
response = client.get(f"/api/clients/{client_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["id"] == client_id, f"Expected ID {client_id}, got {data.get('id')}"
log_test("Get client by ID", True)
except Exception as e:
log_test("Get client by ID", False, str(e))
def test_update_client():
"""Test updating a client."""
try:
if client_id is None:
raise Exception("No client_id available")
headers = get_auth_headers()
update_data = {
"primary_contact": "Jane Doe",
"notes": "Updated contact"
}
response = client.put(f"/api/clients/{client_id}", json=update_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["primary_contact"] == "Jane Doe", "Update not reflected"
log_test("Update client", True)
except Exception as e:
log_test("Update client", False, str(e))
def test_delete_client():
"""Test deleting a client."""
try:
if client_id is None:
raise Exception("No client_id available")
headers = get_auth_headers()
response = client.delete(f"/api/clients/{client_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
log_test("Delete client", True)
except Exception as e:
log_test("Delete client", False, str(e))
# ============================================================================
# SECTION 5: Project CRUD Operations
# ============================================================================
print("\n" + "="*70)
print("SECTION 5: Project CRUD Operations")
print("="*70 + "\n")
project_id = None
project_client_id = None
def test_create_project():
"""Test creating a new project."""
global project_id, project_client_id
try:
headers = get_auth_headers()
# First create a client for the project
client_data = {
"name": f"Project Test Client {uuid4().hex[:8]}",
"type": "msp_client",
"is_active": True
}
client_response = client.post("/api/clients", json=client_data, headers=headers)
assert client_response.status_code == 201, f"Failed to create test client: {client_response.text}"
project_client_id = client_response.json()["id"]
# Now create the project
project_data = {
"name": f"Test Project {uuid4().hex[:8]}",
"client_id": project_client_id,
"status": "active"
}
response = client.post("/api/projects", json=project_data, headers=headers)
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
data = response.json()
assert "id" in data, f"Response missing 'id' field. Data: {data}"
project_id = data["id"]
print(f" Created project with ID: {project_id}")
log_test("Create project", True)
except Exception as e:
log_test("Create project", False, str(e))
def test_list_projects():
"""Test listing projects with pagination."""
try:
headers = get_auth_headers()
response = client.get("/api/projects?skip=0&limit=10", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "total" in data, "Response missing 'total' field"
assert "projects" in data, "Response missing 'projects' field"
log_test("List projects", True)
except Exception as e:
log_test("List projects", False, str(e))
def test_get_project():
"""Test retrieving a specific project by ID."""
try:
if project_id is None:
raise Exception("No project_id available")
headers = get_auth_headers()
response = client.get(f"/api/projects/{project_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["id"] == project_id, f"Expected ID {project_id}, got {data.get('id')}"
log_test("Get project by ID", True)
except Exception as e:
log_test("Get project by ID", False, str(e))
def test_update_project():
"""Test updating a project."""
try:
if project_id is None:
raise Exception("No project_id available")
headers = get_auth_headers()
update_data = {
"status": "completed",
"notes": "Project completed during testing"
}
response = client.put(f"/api/projects/{project_id}", json=update_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["status"] == "completed", "Update not reflected"
log_test("Update project", True)
except Exception as e:
log_test("Update project", False, str(e))
def test_delete_project():
"""Test deleting a project."""
try:
if project_id is None:
raise Exception("No project_id available")
headers = get_auth_headers()
response = client.delete(f"/api/projects/{project_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
# Clean up test client
if project_client_id:
client.delete(f"/api/clients/{project_client_id}", headers=headers)
log_test("Delete project", True)
except Exception as e:
log_test("Delete project", False, str(e))
# ============================================================================
# SECTION 6: Session CRUD Operations
# ============================================================================
print("\n" + "="*70)
print("SECTION 6: Session CRUD Operations")
print("="*70 + "\n")
session_id = None
session_client_id = None
session_project_id = None
def test_create_session():
"""Test creating a new session."""
global session_id, session_client_id, session_project_id
try:
headers = get_auth_headers()
# Create client for session
client_data = {
"name": f"Session Test Client {uuid4().hex[:8]}",
"type": "msp_client",
"is_active": True
}
client_response = client.post("/api/clients", json=client_data, headers=headers)
assert client_response.status_code == 201, f"Failed to create test client: {client_response.text}"
session_client_id = client_response.json()["id"]
# Create project for session
project_data = {
"name": f"Session Test Project {uuid4().hex[:8]}",
"client_id": session_client_id,
"status": "active"
}
project_response = client.post("/api/projects", json=project_data, headers=headers)
assert project_response.status_code == 201, f"Failed to create test project: {project_response.text}"
session_project_id = project_response.json()["id"]
# Create session
from datetime import date
session_data = {
"session_title": f"Test Session {uuid4().hex[:8]}",
"session_date": str(date.today()),
"client_id": session_client_id,
"project_id": session_project_id,
"status": "completed"
}
response = client.post("/api/sessions", json=session_data, headers=headers)
assert response.status_code == 201, f"Expected 201, got {response.status_code}. Response: {response.text}"
data = response.json()
assert "id" in data, f"Response missing 'id' field. Data: {data}"
session_id = data["id"]
print(f" Created session with ID: {session_id}")
log_test("Create session", True)
except Exception as e:
log_test("Create session", False, str(e))
def test_list_sessions():
"""Test listing sessions with pagination."""
try:
headers = get_auth_headers()
response = client.get("/api/sessions?skip=0&limit=10", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "total" in data, "Response missing 'total' field"
assert "sessions" in data, "Response missing 'sessions' field"
log_test("List sessions", True)
except Exception as e:
log_test("List sessions", False, str(e))
def test_get_session():
"""Test retrieving a specific session by ID."""
try:
if session_id is None:
raise Exception("No session_id available")
headers = get_auth_headers()
response = client.get(f"/api/sessions/{session_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["id"] == session_id, f"Expected ID {session_id}, got {data.get('id')}"
log_test("Get session by ID", True)
except Exception as e:
log_test("Get session by ID", False, str(e))
def test_update_session():
"""Test updating a session."""
try:
if session_id is None:
raise Exception("No session_id available")
headers = get_auth_headers()
update_data = {
"status": "completed",
"summary": "Test session completed"
}
response = client.put(f"/api/sessions/{session_id}", json=update_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["status"] == "completed", "Update not reflected"
log_test("Update session", True)
except Exception as e:
log_test("Update session", False, str(e))
def test_delete_session():
"""Test deleting a session."""
try:
if session_id is None:
raise Exception("No session_id available")
headers = get_auth_headers()
response = client.delete(f"/api/sessions/{session_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
# Clean up test data
if session_project_id:
client.delete(f"/api/projects/{session_project_id}", headers=headers)
if session_client_id:
client.delete(f"/api/clients/{session_client_id}", headers=headers)
log_test("Delete session", True)
except Exception as e:
log_test("Delete session", False, str(e))
# ============================================================================
# SECTION 7: Tag CRUD Operations
# ============================================================================
print("\n" + "="*70)
print("SECTION 7: Tag CRUD Operations")
print("="*70 + "\n")
tag_id = None
def test_create_tag():
"""Test creating a new tag."""
global tag_id
try:
headers = get_auth_headers()
tag_data = {
"name": f"test-tag-{uuid4().hex[:8]}",
"category": "technology",
"color": "#FF5733"
}
response = client.post("/api/tags", json=tag_data, headers=headers)
assert response.status_code == 201, f"Expected 201, got {response.status_code}"
data = response.json()
assert "id" in data, "Response missing 'id' field"
tag_id = data["id"]
log_test("Create tag", True)
except Exception as e:
log_test("Create tag", False, str(e))
def test_list_tags():
"""Test listing tags with pagination."""
try:
headers = get_auth_headers()
response = client.get("/api/tags?skip=0&limit=10", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert "total" in data, "Response missing 'total' field"
assert "tags" in data, "Response missing 'tags' field"
log_test("List tags", True)
except Exception as e:
log_test("List tags", False, str(e))
def test_get_tag():
"""Test retrieving a specific tag by ID."""
try:
if tag_id is None:
raise Exception("No tag_id available")
headers = get_auth_headers()
response = client.get(f"/api/tags/{tag_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["id"] == tag_id, f"Expected ID {tag_id}, got {data.get('id')}"
log_test("Get tag by ID", True)
except Exception as e:
log_test("Get tag by ID", False, str(e))
def test_update_tag():
"""Test updating a tag."""
try:
if tag_id is None:
raise Exception("No tag_id available")
headers = get_auth_headers()
update_data = {
"color": "#00FF00",
"description": "Updated test tag"
}
response = client.put(f"/api/tags/{tag_id}", json=update_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["color"] == "#00FF00", "Update not reflected"
log_test("Update tag", True)
except Exception as e:
log_test("Update tag", False, str(e))
def test_tag_duplicate_name():
"""Test creating tag with duplicate name returns 409."""
try:
if tag_id is None:
raise Exception("No tag_id available")
headers = get_auth_headers()
# Get existing tag name
existing_response = client.get(f"/api/tags/{tag_id}", headers=headers)
existing_name = existing_response.json()["name"]
# Try to create duplicate
duplicate_data = {
"name": existing_name,
"category": "test"
}
response = client.post("/api/tags", json=duplicate_data, headers=headers)
assert response.status_code == 409, f"Expected 409, got {response.status_code}"
log_test("Tag duplicate name (409)", True)
except Exception as e:
log_test("Tag duplicate name (409)", False, str(e))
def test_delete_tag():
"""Test deleting a tag."""
try:
if tag_id is None:
raise Exception("No tag_id available")
headers = get_auth_headers()
response = client.delete(f"/api/tags/{tag_id}", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
log_test("Delete tag", True)
except Exception as e:
log_test("Delete tag", False, str(e))
# ============================================================================
# SECTION 8: Pagination Tests
# ============================================================================
print("\n" + "="*70)
print("SECTION 8: Pagination Tests")
print("="*70 + "\n")
def test_pagination_skip_limit():
"""Test pagination with skip and limit parameters."""
try:
headers = get_auth_headers()
response = client.get("/api/machines?skip=0&limit=5", headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
data = response.json()
assert data["skip"] == 0, f"Expected skip=0, got {data.get('skip')}"
assert data["limit"] == 5, f"Expected limit=5, got {data.get('limit')}"
log_test("Pagination skip/limit parameters", True)
except Exception as e:
log_test("Pagination skip/limit parameters", False, str(e))
def test_pagination_max_limit():
"""Test that pagination enforces maximum limit."""
try:
headers = get_auth_headers()
# Try to request more than max (1000)
response = client.get("/api/machines?limit=2000", headers=headers)
# Should either return 422 or clamp to max
assert response.status_code in [200, 422], f"Unexpected status {response.status_code}"
log_test("Pagination max limit enforcement", True)
except Exception as e:
log_test("Pagination max limit enforcement", False, str(e))
# ============================================================================
# Run All Tests
# ============================================================================
def run_all_tests():
"""Run all test functions."""
print("\n" + "="*70)
print("CLAUDETOOLS API ENDPOINT TESTS")
print("="*70)
# Section 1: Health
test_root_endpoint()
test_health_endpoint()
test_jwt_token_creation()
# Section 2: Auth
test_unauthenticated_access()
test_authenticated_access()
test_invalid_token()
# Section 3: Machines
test_create_machine()
test_list_machines()
test_get_machine()
test_update_machine()
test_machine_not_found()
test_delete_machine()
# Section 4: Clients
test_create_client()
test_list_clients()
test_get_client()
test_update_client()
test_delete_client()
# Section 5: Projects
test_create_project()
test_list_projects()
test_get_project()
test_update_project()
test_delete_project()
# Section 6: Sessions
test_create_session()
test_list_sessions()
test_get_session()
test_update_session()
test_delete_session()
# Section 7: Tags
test_create_tag()
test_list_tags()
test_get_tag()
test_update_tag()
test_tag_duplicate_name()
test_delete_tag()
# Section 8: Pagination
test_pagination_skip_limit()
test_pagination_max_limit()
if __name__ == "__main__":
print("\n>> Starting ClaudeTools API Test Suite...")
try:
run_all_tests()
# Print summary
print("\n" + "="*70)
print("TEST SUMMARY")
print("="*70)
print(f"\nTotal Tests: {tests_passed + tests_failed}")
print(f"Passed: {tests_passed}")
print(f"Failed: {tests_failed}")
if tests_failed > 0:
print("\nFAILED TESTS:")
for name, passed, error in test_results:
if not passed:
print(f" - {name}")
if error:
print(f" Error: {error}")
if tests_failed == 0:
print("\n>> All tests passed!")
sys.exit(0)
else:
print(f"\n>> {tests_failed} test(s) failed")
sys.exit(1)
except Exception as e:
print(f"\n>> Fatal error running tests: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -0,0 +1,286 @@
"""
Test script for conversation_parser.py
Tests all four main functions with sample data.
"""
import json
import os
import tempfile
from api.utils.conversation_parser import (
parse_jsonl_conversation,
categorize_conversation,
extract_context_from_conversation,
scan_folder_for_conversations,
batch_process_conversations,
)
def test_parse_jsonl_conversation():
"""Test parsing .jsonl conversation files."""
print("\n=== Test 1: parse_jsonl_conversation ===")
# Create a temporary .jsonl file
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8') as f:
# Write sample conversation data
f.write(json.dumps({
"role": "user",
"content": "Build a FastAPI authentication system with PostgreSQL",
"timestamp": 1705449600000
}) + "\n")
f.write(json.dumps({
"role": "assistant",
"content": "I'll help you build an auth system using FastAPI and PostgreSQL. Let me create the api/auth.py file.",
"timestamp": 1705449620000
}) + "\n")
f.write(json.dumps({
"role": "user",
"content": "Also add JWT token support",
"timestamp": 1705449640000
}) + "\n")
temp_file = f.name
try:
result = parse_jsonl_conversation(temp_file)
print(f"Messages: {result['message_count']}")
print(f"Duration: {result['duration_seconds']} seconds")
print(f"File paths extracted: {result['file_paths']}")
print(f"First message: {result['messages'][0]['content'][:50]}...")
assert result['message_count'] == 3, "Should have 3 messages"
assert result['duration_seconds'] == 40, "Duration should be 40 seconds"
assert 'api/auth.py' in result['file_paths'], "Should extract file path"
print("[PASS] parse_jsonl_conversation test passed!")
finally:
os.unlink(temp_file)
def test_categorize_conversation():
"""Test conversation categorization."""
print("\n=== Test 2: categorize_conversation ===")
# Test MSP conversation
msp_messages = [
{"role": "user", "content": "Client reported firewall blocking Office365 connection"},
{"role": "assistant", "content": "I'll check the firewall rules for the client site"}
]
msp_category = categorize_conversation(msp_messages)
print(f"MSP category: {msp_category}")
assert msp_category == "msp", "Should categorize as MSP"
# Test Development conversation
dev_messages = [
{"role": "user", "content": "Build API endpoint for user authentication with FastAPI"},
{"role": "assistant", "content": "I'll create the endpoint using SQLAlchemy and implement JWT tokens"}
]
dev_category = categorize_conversation(dev_messages)
print(f"Development category: {dev_category}")
assert dev_category == "development", "Should categorize as development"
# Test General conversation
general_messages = [
{"role": "user", "content": "What's the weather like today?"},
{"role": "assistant", "content": "I don't have access to current weather data"}
]
general_category = categorize_conversation(general_messages)
print(f"General category: {general_category}")
assert general_category == "general", "Should categorize as general"
print("[PASS] categorize_conversation test passed!")
def test_extract_context_from_conversation():
"""Test context extraction from conversation."""
print("\n=== Test 3: extract_context_from_conversation ===")
# Create a sample conversation
conversation = {
"messages": [
{
"role": "user",
"content": "Build a FastAPI REST API with PostgreSQL database",
"timestamp": 1705449600000
},
{
"role": "assistant",
"content": "I'll create the API using FastAPI and SQLAlchemy. Decided to use Alembic for migrations because it integrates well with SQLAlchemy.",
"timestamp": 1705449620000
},
{
"role": "user",
"content": "Add authentication with JWT tokens",
"timestamp": 1705449640000
}
],
"metadata": {
"title": "Build API system",
"model": "claude-opus-4",
"sessionId": "test-123"
},
"file_paths": ["api/main.py", "api/auth.py", "api/models.py"],
"tool_calls": [
{"tool": "write", "count": 5},
{"tool": "read", "count": 3}
],
"duration_seconds": 40,
"message_count": 3
}
context = extract_context_from_conversation(conversation)
print(f"Category: {context['category']}")
print(f"Tags: {context['tags'][:5]}")
print(f"Decisions: {len(context['decisions'])}")
print(f"Quality score: {context['metrics']['quality_score']}/10")
print(f"Key files: {context['key_files']}")
assert context['category'] in ['msp', 'development', 'general'], "Should have valid category"
assert len(context['tags']) > 0, "Should have extracted tags"
assert context['metrics']['message_count'] == 3, "Should have correct message count"
print("[PASS] extract_context_from_conversation test passed!")
def test_scan_folder_for_conversations():
"""Test scanning folder for conversation files."""
print("\n=== Test 4: scan_folder_for_conversations ===")
# Create a temporary directory structure
with tempfile.TemporaryDirectory() as tmpdir:
# Create some conversation files
conv1_path = os.path.join(tmpdir, "conversation1.jsonl")
conv2_path = os.path.join(tmpdir, "session", "conversation2.json")
config_path = os.path.join(tmpdir, "config.json") # Should be skipped
os.makedirs(os.path.dirname(conv2_path), exist_ok=True)
# Create files
with open(conv1_path, 'w') as f:
f.write('{"role": "user", "content": "test"}\n')
with open(conv2_path, 'w') as f:
f.write('{"role": "user", "content": "test"}')
with open(config_path, 'w') as f:
f.write('{"setting": "value"}')
# Scan folder
files = scan_folder_for_conversations(tmpdir)
print(f"Found {len(files)} conversation files")
print(f"Files: {[os.path.basename(f) for f in files]}")
assert len(files) == 2, "Should find 2 conversation files"
assert any("conversation1.jsonl" in f for f in files), "Should find jsonl file"
assert any("conversation2.json" in f for f in files), "Should find json file"
assert not any("config.json" in f for f in files), "Should skip config.json"
print("[PASS] scan_folder_for_conversations test passed!")
def test_batch_process():
"""Test batch processing of conversations."""
print("\n=== Test 5: batch_process_conversations ===")
with tempfile.TemporaryDirectory() as tmpdir:
# Create sample conversations
conv1_path = os.path.join(tmpdir, "msp_work.jsonl")
conv2_path = os.path.join(tmpdir, "dev_work.jsonl")
# MSP conversation
with open(conv1_path, 'w') as f:
f.write(json.dumps({
"role": "user",
"content": "Client ticket: firewall blocking Office365",
"timestamp": 1705449600000
}) + "\n")
f.write(json.dumps({
"role": "assistant",
"content": "I'll check the client firewall configuration",
"timestamp": 1705449620000
}) + "\n")
# Development conversation
with open(conv2_path, 'w') as f:
f.write(json.dumps({
"role": "user",
"content": "Build FastAPI endpoint for authentication",
"timestamp": 1705449600000
}) + "\n")
f.write(json.dumps({
"role": "assistant",
"content": "Creating API endpoint with SQLAlchemy",
"timestamp": 1705449620000
}) + "\n")
# Process all conversations
processed_count = [0]
def progress_callback(file_path, context):
processed_count[0] += 1
print(f" Processed: {os.path.basename(file_path)} -> {context['category']}")
contexts = batch_process_conversations(tmpdir, progress_callback)
print(f"\nTotal processed: {len(contexts)}")
assert len(contexts) == 2, "Should process 2 conversations"
assert processed_count[0] == 2, "Callback should be called twice"
categories = [ctx['category'] for ctx in contexts]
print(f"Categories: {categories}")
print("[PASS] batch_process_conversations test passed!")
def test_real_conversation_file():
"""Test with real conversation file if available."""
print("\n=== Test 6: Real conversation file ===")
real_file = r"C:\Users\MikeSwanson\AppData\Roaming\Claude\claude-code-sessions\0c32bde5-dc29-49ac-8c80-5adeaf1cdb33\299a238a-5ebf-44f4-948b-eedfa5c1f57c\local_feb419c2-b7a6-4c31-a7ce-38f6c0ccc523.json"
if os.path.exists(real_file):
try:
conversation = parse_jsonl_conversation(real_file)
print(f"Real file - Messages: {conversation['message_count']}")
print(f"Real file - Metadata: {conversation['metadata'].get('title', 'No title')}")
if conversation['message_count'] > 0:
context = extract_context_from_conversation(conversation)
print(f"Real file - Category: {context['category']}")
print(f"Real file - Quality: {context['metrics']['quality_score']}/10")
except Exception as e:
print(f"Note: Real file test skipped - {e}")
else:
print("Real conversation file not found - skipping this test")
if __name__ == "__main__":
print("=" * 60)
print("Testing conversation_parser.py")
print("=" * 60)
try:
test_parse_jsonl_conversation()
test_categorize_conversation()
test_extract_context_from_conversation()
test_scan_folder_for_conversations()
test_batch_process()
test_real_conversation_file()
print("\n" + "=" * 60)
print("All tests passed! [OK]")
print("=" * 60)
except AssertionError as e:
print(f"\n[FAIL] Test failed: {e}")
raise
except Exception as e:
print(f"\n[ERROR] Unexpected error: {e}")
raise

View File

@@ -0,0 +1,284 @@
"""
Test script for credential scanner and importer.
This script demonstrates the credential scanner functionality including:
- Creating sample credential files
- Scanning for credential files
- Parsing credential data
- Importing credentials to database
Usage:
python test_credential_scanner.py
"""
import logging
import os
import tempfile
from pathlib import Path
from api.database import SessionLocal
from api.utils.credential_scanner import (
scan_for_credential_files,
parse_credential_file,
import_credentials_to_db,
scan_and_import_credentials,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def create_sample_credential_files(temp_dir: str):
"""Create sample credential files for testing."""
# Create credentials.md
credentials_md = Path(temp_dir) / "credentials.md"
credentials_md.write_text("""# Sample Credentials
## Gitea Admin
Username: admin
Password: GitSecurePass123!
URL: https://git.example.com
Notes: Main admin account
## Database Server
Type: connection_string
Connection String: mysql://dbuser:dbpass@192.168.1.50:3306/mydb
Notes: Production database
## OpenAI API
API Key: sk-1234567890abcdefghijklmnopqrstuvwxyz
Notes: Production API key
""")
# Create .env file
env_file = Path(temp_dir) / ".env"
env_file.write_text("""# Environment Variables
DATABASE_URL=postgresql://user:pass@localhost:5432/testdb
API_TOKEN=ghp_abc123def456ghi789jkl012mno345pqr678
SECRET_KEY=super_secret_key_12345
""")
# Create passwords.txt
passwords_txt = Path(temp_dir) / "passwords.txt"
passwords_txt.write_text("""# Server Passwords
## Web Server
Username: webadmin
Password: Web@dmin2024!
Host: 192.168.1.100
Port: 22
## Backup Server
Username: backup
Password: BackupSecure789
Host: 10.0.0.50
""")
logger.info(f"Created sample credential files in: {temp_dir}")
return [str(credentials_md), str(env_file), str(passwords_txt)]
def test_scan_for_credential_files():
"""Test credential file scanning."""
logger.info("=" * 60)
logger.info("TEST 1: Scan for Credential Files")
logger.info("=" * 60)
with tempfile.TemporaryDirectory() as temp_dir:
# Create sample files
create_sample_credential_files(temp_dir)
# Scan for files
found_files = scan_for_credential_files(temp_dir)
logger.info(f"\nFound {len(found_files)} credential file(s):")
for file_path in found_files:
logger.info(f" - {file_path}")
assert len(found_files) == 3, "Should find 3 credential files"
logger.info("\n[PASS] Test 1 passed")
return found_files
def test_parse_credential_file():
"""Test credential file parsing."""
logger.info("\n" + "=" * 60)
logger.info("TEST 2: Parse Credential Files")
logger.info("=" * 60)
with tempfile.TemporaryDirectory() as temp_dir:
# Create sample files
sample_files = create_sample_credential_files(temp_dir)
total_credentials = 0
for file_path in sample_files:
credentials = parse_credential_file(file_path)
total_credentials += len(credentials)
logger.info(f"\nParsed from {Path(file_path).name}:")
for cred in credentials:
logger.info(f" Service: {cred.get('service_name')}")
logger.info(f" Type: {cred.get('credential_type')}")
if cred.get('username'):
logger.info(f" Username: {cred.get('username')}")
# Don't log actual passwords/keys
if cred.get('password'):
logger.info(f" Password: [REDACTED]")
if cred.get('api_key'):
logger.info(f" API Key: [REDACTED]")
if cred.get('connection_string'):
logger.info(f" Connection String: [REDACTED]")
logger.info("")
logger.info(f"Total credentials parsed: {total_credentials}")
assert total_credentials > 0, "Should parse at least one credential"
logger.info("[PASS] Test 2 passed")
def test_import_credentials_to_db():
"""Test importing credentials to database."""
logger.info("\n" + "=" * 60)
logger.info("TEST 3: Import Credentials to Database")
logger.info("=" * 60)
db = SessionLocal()
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Create sample files
sample_files = create_sample_credential_files(temp_dir)
# Parse first file
credentials = parse_credential_file(sample_files[0])
logger.info(f"\nParsed {len(credentials)} credential(s) from file")
# Import to database
imported_count = import_credentials_to_db(
db=db,
credentials=credentials,
client_id=None, # No client association for test
user_id="test_user",
ip_address="127.0.0.1"
)
logger.info(f"\n[OK] Successfully imported {imported_count} credential(s)")
logger.info("[PASS] Test 3 passed")
return imported_count
except Exception as e:
logger.error(f"Import failed: {str(e)}")
raise
finally:
db.close()
def test_full_workflow():
"""Test complete scan and import workflow."""
logger.info("\n" + "=" * 60)
logger.info("TEST 4: Full Scan and Import Workflow")
logger.info("=" * 60)
db = SessionLocal()
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Create sample files
create_sample_credential_files(temp_dir)
# Run full workflow
results = scan_and_import_credentials(
base_path=temp_dir,
db=db,
client_id=None,
user_id="test_user",
ip_address="127.0.0.1"
)
logger.info(f"\nWorkflow Results:")
logger.info(f" Files found: {results['files_found']}")
logger.info(f" Credentials parsed: {results['credentials_parsed']}")
logger.info(f" Credentials imported: {results['credentials_imported']}")
assert results['files_found'] > 0, "Should find files"
assert results['credentials_parsed'] > 0, "Should parse credentials"
logger.info("\n[PASS] Test 4 passed")
except Exception as e:
logger.error(f"Workflow failed: {str(e)}")
raise
finally:
db.close()
def test_markdown_parsing():
"""Test markdown credential parsing with various formats."""
logger.info("\n" + "=" * 60)
logger.info("TEST 5: Markdown Format Variations")
logger.info("=" * 60)
with tempfile.TemporaryDirectory() as temp_dir:
# Create file with various markdown formats
test_file = Path(temp_dir) / "test_variations.md"
test_file.write_text("""
# Single Hash Header
Username: user1
Password: pass1
## Double Hash Header
User: user2
Pass: pass2
## API Service
API_Key: sk-123456789
Type: api_key
## Database Connection
Connection_String: mysql://user:pass@host/db
""")
credentials = parse_credential_file(str(test_file))
logger.info(f"\nParsed {len(credentials)} credential(s):")
for cred in credentials:
logger.info(f" - {cred.get('service_name')} ({cred.get('credential_type')})")
assert len(credentials) >= 3, "Should parse multiple variations"
logger.info("\n[PASS] Test 5 passed")
def main():
"""Run all tests."""
logger.info("\n" + "=" * 60)
logger.info("CREDENTIAL SCANNER TEST SUITE")
logger.info("=" * 60)
try:
# Run tests
test_scan_for_credential_files()
test_parse_credential_file()
test_markdown_parsing()
test_import_credentials_to_db()
test_full_workflow()
logger.info("\n" + "=" * 60)
logger.info("ALL TESTS PASSED!")
logger.info("=" * 60)
except Exception as e:
logger.error("\n" + "=" * 60)
logger.error("TEST FAILED!")
logger.error("=" * 60)
logger.error(f"Error: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,291 @@
"""
Test script for Credentials Management API.
This script tests the credentials API endpoints including encryption, decryption,
and audit logging functionality.
"""
import sys
from datetime import datetime
from uuid import uuid4
from api.database import get_db
from api.models.credential import Credential
from api.models.credential_audit_log import CredentialAuditLog
from api.schemas.credential import CredentialCreate, CredentialUpdate
from api.services.credential_service import (
create_credential,
delete_credential,
get_credential_by_id,
get_credentials,
update_credential,
)
from api.utils.crypto import decrypt_string, encrypt_string
def test_encryption_decryption():
"""Test basic encryption and decryption."""
print("\n=== Testing Encryption/Decryption ===")
test_password = "SuperSecretPassword123!"
print(f"Original password: {test_password}")
# Encrypt
encrypted = encrypt_string(test_password)
print(f"Encrypted (length: {len(encrypted)}): {encrypted[:50]}...")
# Decrypt
decrypted = decrypt_string(encrypted)
print(f"Decrypted: {decrypted}")
assert test_password == decrypted, "Encryption/decryption mismatch!"
print("[PASS] Encryption/decryption test passed")
def test_credential_lifecycle():
"""Test the full credential lifecycle: create, read, update, delete."""
print("\n=== Testing Credential Lifecycle ===")
db = next(get_db())
try:
# 1. CREATE
print("\n1. Creating credential...")
credential_data = CredentialCreate(
credential_type="password",
service_name="Test Service",
username="admin",
password="MySecurePassword123!",
external_url="https://test.example.com",
requires_vpn=False,
requires_2fa=True,
is_active=True
)
created = create_credential(
db=db,
credential_data=credential_data,
user_id="test_user_123",
ip_address="127.0.0.1",
user_agent="Test Script"
)
print(f"[PASS] Created credential ID: {created.id}")
print(f" Service: {created.service_name}")
print(f" Type: {created.credential_type}")
print(f" Password encrypted: {created.password_encrypted is not None}")
# Verify encryption
if created.password_encrypted:
decrypted_password = decrypt_string(created.password_encrypted.decode('utf-8'))
assert decrypted_password == "MySecurePassword123!", "Password encryption failed!"
print(f" [PASS] Password correctly encrypted and decrypted")
# Verify audit log was created
audit_logs = db.query(CredentialAuditLog).filter(
CredentialAuditLog.credential_id == str(created.id)
).all()
print(f" [PASS] Audit logs created: {len(audit_logs)}")
# 2. READ
print("\n2. Reading credential...")
retrieved = get_credential_by_id(db, created.id, user_id="test_user_123")
print(f"[PASS] Retrieved credential: {retrieved.service_name}")
# Check audit log for view action
audit_logs = db.query(CredentialAuditLog).filter(
CredentialAuditLog.credential_id == str(created.id),
CredentialAuditLog.action == "view"
).all()
print(f" [PASS] View action logged: {len(audit_logs) > 0}")
# 3. UPDATE
print("\n3. Updating credential...")
update_data = CredentialUpdate(
password="NewSecurePassword456!",
last_rotated_at=datetime.utcnow(),
external_url="https://test-updated.example.com"
)
updated = update_credential(
db=db,
credential_id=created.id,
credential_data=update_data,
user_id="test_user_123",
ip_address="127.0.0.1"
)
print(f"[PASS] Updated credential: {updated.service_name}")
print(f" Password re-encrypted: {updated.password_encrypted is not None}")
# Verify new password
if updated.password_encrypted:
decrypted_new_password = decrypt_string(updated.password_encrypted.decode('utf-8'))
assert decrypted_new_password == "NewSecurePassword456!", "Password update failed!"
print(f" [PASS] New password correctly encrypted")
# Check audit log for update action
audit_logs = db.query(CredentialAuditLog).filter(
CredentialAuditLog.credential_id == str(created.id),
CredentialAuditLog.action == "update"
).all()
print(f" [PASS] Update action logged: {len(audit_logs) > 0}")
# 4. LIST
print("\n4. Listing credentials...")
credentials, total = get_credentials(db, skip=0, limit=10)
print(f"[PASS] Found {total} total credentials")
print(f" Retrieved {len(credentials)} credentials in this page")
# 5. DELETE
print("\n5. Deleting credential...")
result = delete_credential(
db=db,
credential_id=created.id,
user_id="test_user_123",
ip_address="127.0.0.1"
)
print(f"[PASS] {result['message']}")
# Verify deletion
remaining = db.query(Credential).filter(Credential.id == str(created.id)).first()
assert remaining is None, "Credential was not deleted!"
print(f" [PASS] Credential successfully removed from database")
# Check audit log for delete action (should still exist due to CASCADE behavior)
audit_logs = db.query(CredentialAuditLog).filter(
CredentialAuditLog.credential_id == str(created.id),
CredentialAuditLog.action == "delete"
).all()
print(f" [PASS] Delete action logged: {len(audit_logs) > 0}")
print("\n[PASS] All credential lifecycle tests passed!")
except Exception as e:
print(f"\n[FAIL] Test failed: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
db.close()
def test_multiple_credential_types():
"""Test creating credentials with different types and encrypted fields."""
print("\n=== Testing Multiple Credential Types ===")
db = next(get_db())
try:
credential_ids = []
# Test API Key credential
print("\n1. Creating API Key credential...")
api_key_data = CredentialCreate(
credential_type="api_key",
service_name="GitHub API",
api_key="ghp_abcdef1234567890",
external_url="https://api.github.com",
is_active=True
)
api_cred = create_credential(db, api_key_data, user_id="test_user")
print(f"[PASS] Created API Key credential: {api_cred.id}")
credential_ids.append(api_cred.id)
# Verify API key encryption
if api_cred.api_key_encrypted:
decrypted_key = decrypt_string(api_cred.api_key_encrypted.decode('utf-8'))
assert decrypted_key == "ghp_abcdef1234567890", "API key encryption failed!"
print(f" [PASS] API key correctly encrypted")
# Test OAuth credential
print("\n2. Creating OAuth credential...")
oauth_data = CredentialCreate(
credential_type="oauth",
service_name="Microsoft 365",
client_id_oauth="app-client-id-123",
client_secret="secret_value_xyz789",
tenant_id_oauth="tenant-id-456",
is_active=True
)
oauth_cred = create_credential(db, oauth_data, user_id="test_user")
print(f"[PASS] Created OAuth credential: {oauth_cred.id}")
credential_ids.append(oauth_cred.id)
# Verify client secret encryption
if oauth_cred.client_secret_encrypted:
decrypted_secret = decrypt_string(oauth_cred.client_secret_encrypted.decode('utf-8'))
assert decrypted_secret == "secret_value_xyz789", "OAuth secret encryption failed!"
print(f" [PASS] Client secret correctly encrypted")
# Test Connection String credential
print("\n3. Creating Connection String credential...")
conn_data = CredentialCreate(
credential_type="connection_string",
service_name="SQL Server",
connection_string="Server=localhost;Database=TestDB;User Id=sa;Password=ComplexPass123!;",
internal_url="sql.internal.local",
custom_port=1433,
is_active=True
)
conn_cred = create_credential(db, conn_data, user_id="test_user")
print(f"[PASS] Created Connection String credential: {conn_cred.id}")
credential_ids.append(conn_cred.id)
# Verify connection string encryption
if conn_cred.connection_string_encrypted:
decrypted_conn = decrypt_string(conn_cred.connection_string_encrypted.decode('utf-8'))
assert "ComplexPass123!" in decrypted_conn, "Connection string encryption failed!"
print(f" [PASS] Connection string correctly encrypted")
print(f"\n[PASS] Created {len(credential_ids)} different credential types")
# Cleanup
print("\n4. Cleaning up test credentials...")
for cred_id in credential_ids:
delete_credential(db, cred_id, user_id="test_user")
print(f"[PASS] Cleaned up {len(credential_ids)} credentials")
print("\n[PASS] All multi-type credential tests passed!")
except Exception as e:
print(f"\n[FAIL] Test failed: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
db.close()
def main():
"""Run all tests."""
print("=" * 60)
print("CREDENTIALS API TEST SUITE")
print("=" * 60)
try:
test_encryption_decryption()
test_credential_lifecycle()
test_multiple_credential_types()
print("\n" + "=" * 60)
print("[PASS] ALL TESTS PASSED!")
print("=" * 60)
except Exception as e:
print("\n" + "=" * 60)
print("[FAIL] TEST SUITE FAILED")
print("=" * 60)
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,490 @@
"""
Phase 3 Test: Database CRUD Operations Validation
Tests CREATE, READ, UPDATE, DELETE operations on the ClaudeTools database
with real database connections and verifies foreign key relationships.
"""
import sys
from datetime import datetime, timezone
from uuid import uuid4
import random
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
# Add api directory to path
sys.path.insert(0, 'D:\\ClaudeTools')
from api.database import SessionLocal, check_db_connection
from api.models import Client, Machine, Session, Tag, SessionTag
class CRUDTester:
"""Test harness for CRUD operations."""
def __init__(self):
self.db = None
self.test_ids = {
'client': None,
'machine': None,
'session': None,
'tag': None
}
self.passed = 0
self.failed = 0
self.errors = []
def connect(self):
"""Test database connection."""
print("=" * 80)
print("PHASE 3: DATABASE CRUD OPERATIONS TEST")
print("=" * 80)
print("\n1. CONNECTION TEST")
print("-" * 80)
try:
if not check_db_connection():
self.fail("Connection", "check_db_connection() returned False")
return False
self.db = SessionLocal()
# Test basic query
result = self.db.execute(text("SELECT DATABASE()")).scalar()
self.success("Connection", f"Connected to database: {result}")
return True
except Exception as e:
self.fail("Connection", str(e))
return False
def test_create(self):
"""Test INSERT operations."""
print("\n2. CREATE TEST (INSERT)")
print("-" * 80)
try:
# Create a client (type is required field) with unique name
test_suffix = random.randint(1000, 9999)
client = Client(
name=f"Test Client Corp {test_suffix}",
type="msp_client",
primary_contact="test@client.com",
is_active=True
)
self.db.add(client)
self.db.commit()
self.db.refresh(client)
self.test_ids['client'] = client.id
self.success("Create Client", f"Created client with ID: {client.id}")
# Create a machine (no client_id FK, simplified fields)
machine = Machine(
hostname=f"test-machine-{test_suffix}",
machine_fingerprint=f"test-fingerprint-{test_suffix}",
friendly_name="Test Machine",
machine_type="laptop",
platform="win32",
username="testuser"
)
self.db.add(machine)
self.db.commit()
self.db.refresh(machine)
self.test_ids['machine'] = machine.id
self.success("Create Machine", f"Created machine with ID: {machine.id}")
# Create a session with required fields
session = Session(
client_id=client.id,
machine_id=machine.id,
session_date=datetime.now(timezone.utc).date(),
start_time=datetime.now(timezone.utc),
status="completed",
session_title="Test CRUD Session"
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
self.test_ids['session'] = session.id
self.success("Create Session", f"Created session with ID: {session.id}")
# Create a tag
tag = Tag(
name=f"test-tag-{test_suffix}",
category="testing"
)
self.db.add(tag)
self.db.commit()
self.db.refresh(tag)
self.test_ids['tag'] = tag.id
self.success("Create Tag", f"Created tag with ID: {tag.id}")
return True
except Exception as e:
self.fail("Create", str(e))
return False
def test_read(self):
"""Test SELECT operations."""
print("\n3. READ TEST (SELECT)")
print("-" * 80)
try:
# Query client
client = self.db.query(Client).filter(
Client.id == self.test_ids['client']
).first()
if not client:
self.fail("Read Client", "Client not found")
return False
if not client.name.startswith("Test Client Corp"):
self.fail("Read Client", f"Wrong name: {client.name}")
return False
self.success("Read Client", f"Retrieved client: {client.name}")
# Query machine
machine = self.db.query(Machine).filter(
Machine.id == self.test_ids['machine']
).first()
if not machine:
self.fail("Read Machine", "Machine not found")
return False
if not machine.hostname.startswith("test-machine"):
self.fail("Read Machine", f"Wrong hostname: {machine.hostname}")
return False
self.success("Read Machine", f"Retrieved machine: {machine.hostname}")
# Query session
session = self.db.query(Session).filter(
Session.id == self.test_ids['session']
).first()
if not session:
self.fail("Read Session", "Session not found")
return False
if session.status != "completed":
self.fail("Read Session", f"Wrong status: {session.status}")
return False
self.success("Read Session", f"Retrieved session with status: {session.status}")
# Query tag
tag = self.db.query(Tag).filter(
Tag.id == self.test_ids['tag']
).first()
if not tag:
self.fail("Read Tag", "Tag not found")
return False
self.success("Read Tag", f"Retrieved tag: {tag.name}")
return True
except Exception as e:
self.fail("Read", str(e))
return False
def test_relationships(self):
"""Test foreign key relationships."""
print("\n4. RELATIONSHIP TEST (Foreign Keys)")
print("-" * 80)
try:
# Test valid relationship: Create session_tag
session_tag = SessionTag(
session_id=self.test_ids['session'],
tag_id=self.test_ids['tag']
)
self.db.add(session_tag)
self.db.commit()
self.db.refresh(session_tag)
self.success("Valid FK", "Created session_tag with valid foreign keys")
# Test invalid relationship: Try to create session with non-existent machine
try:
invalid_session = Session(
machine_id="non-existent-machine-id",
client_id=self.test_ids['client'],
session_date=datetime.now(timezone.utc).date(),
start_time=datetime.now(timezone.utc),
status="running",
session_title="Invalid Session"
)
self.db.add(invalid_session)
self.db.commit()
# If we get here, FK constraint didn't work
self.db.rollback()
self.fail("Invalid FK", "Foreign key constraint not enforced!")
return False
except IntegrityError:
self.db.rollback()
self.success("Invalid FK", "Foreign key constraint properly rejected invalid reference")
# Test relationship traversal
session = self.db.query(Session).filter(
Session.id == self.test_ids['session']
).first()
if not session:
self.fail("Relationship Traversal", "Session not found")
return False
# Access related machine through relationship
if hasattr(session, 'machine') and session.machine:
machine_hostname = session.machine.hostname
self.success("Relationship Traversal",
f"Accessed machine through session: {machine_hostname}")
else:
# Fallback: query machine directly
machine = self.db.query(Machine).filter(
Machine.machine_id == session.machine_id
).first()
if machine:
self.success("Relationship Traversal",
f"Verified machine exists: {machine.hostname}")
else:
self.fail("Relationship Traversal", "Could not find related machine")
return False
return True
except Exception as e:
self.db.rollback()
self.fail("Relationships", str(e))
return False
def test_update(self):
"""Test UPDATE operations."""
print("\n5. UPDATE TEST")
print("-" * 80)
try:
# Update client
client = self.db.query(Client).filter(
Client.id == self.test_ids['client']
).first()
old_name = client.name
new_name = "Updated Test Client Corp"
client.name = new_name
self.db.commit()
self.db.refresh(client)
if client.name != new_name:
self.fail("Update Client", f"Name not updated: {client.name}")
return False
self.success("Update Client", f"Updated name: {old_name} -> {new_name}")
# Update machine
machine = self.db.query(Machine).filter(
Machine.id == self.test_ids['machine']
).first()
old_name = machine.friendly_name
new_name = "Updated Test Machine"
machine.friendly_name = new_name
self.db.commit()
self.db.refresh(machine)
if machine.friendly_name != new_name:
self.fail("Update Machine", f"Name not updated: {machine.friendly_name}")
return False
self.success("Update Machine", f"Updated name: {old_name} -> {new_name}")
# Update session status
session = self.db.query(Session).filter(
Session.id == self.test_ids['session']
).first()
old_status = session.status
new_status = "in_progress"
session.status = new_status
self.db.commit()
self.db.refresh(session)
if session.status != new_status:
self.fail("Update Session", f"Status not updated: {session.status}")
return False
self.success("Update Session", f"Updated status: {old_status} -> {new_status}")
return True
except Exception as e:
self.fail("Update", str(e))
return False
def test_delete(self):
"""Test DELETE operations and cleanup."""
print("\n6. DELETE TEST (Cleanup)")
print("-" * 80)
try:
# Delete in correct order (respecting FK constraints)
# Delete session_tag
session_tag = self.db.query(SessionTag).filter(
SessionTag.session_id == self.test_ids['session'],
SessionTag.tag_id == self.test_ids['tag']
).first()
if session_tag:
self.db.delete(session_tag)
self.db.commit()
self.success("Delete SessionTag", "Deleted session_tag")
# Delete tag
tag = self.db.query(Tag).filter(
Tag.id == self.test_ids['tag']
).first()
if tag:
tag_name = tag.name
self.db.delete(tag)
self.db.commit()
self.success("Delete Tag", f"Deleted tag: {tag_name}")
# Delete session
session = self.db.query(Session).filter(
Session.id == self.test_ids['session']
).first()
if session:
session_id = session.id
self.db.delete(session)
self.db.commit()
self.success("Delete Session", f"Deleted session: {session_id}")
# Delete machine
machine = self.db.query(Machine).filter(
Machine.id == self.test_ids['machine']
).first()
if machine:
hostname = machine.hostname
self.db.delete(machine)
self.db.commit()
self.success("Delete Machine", f"Deleted machine: {hostname}")
# Delete client
client = self.db.query(Client).filter(
Client.id == self.test_ids['client']
).first()
if client:
name = client.name
self.db.delete(client)
self.db.commit()
self.success("Delete Client", f"Deleted client: {name}")
# Verify all deleted
remaining_client = self.db.query(Client).filter(
Client.id == self.test_ids['client']
).first()
if remaining_client:
self.fail("Delete Verification", "Client still exists after deletion")
return False
self.success("Delete Verification", "All test records successfully deleted")
return True
except Exception as e:
self.fail("Delete", str(e))
return False
def success(self, operation, message):
"""Record a successful test."""
self.passed += 1
print(f"[PASS] {operation} - {message}")
def fail(self, operation, error):
"""Record a failed test."""
self.failed += 1
self.errors.append(f"{operation}: {error}")
print(f"[FAIL] {operation} - {error}")
def print_summary(self):
"""Print test summary."""
print("\n" + "=" * 80)
print("TEST SUMMARY")
print("=" * 80)
print(f"Total Passed: {self.passed}")
print(f"Total Failed: {self.failed}")
print(f"Success Rate: {(self.passed / (self.passed + self.failed) * 100):.1f}%")
if self.errors:
print("\nERRORS:")
for error in self.errors:
print(f" - {error}")
print("\nCONCLUSION:")
if self.failed == 0:
print("[SUCCESS] All CRUD operations working correctly!")
print(" - Database connectivity verified")
print(" - INSERT operations successful")
print(" - SELECT operations successful")
print(" - UPDATE operations successful")
print(" - DELETE operations successful")
print(" - Foreign key constraints enforced")
print(" - Relationship traversal working")
else:
print(f"[FAILURE] {self.failed} test(s) failed - review errors above")
print("=" * 80)
def cleanup(self):
"""Clean up database connection."""
if self.db:
self.db.close()
def main():
"""Run all CRUD tests."""
tester = CRUDTester()
try:
# Test 1: Connection
if not tester.connect():
print("\n[ERROR] Cannot proceed without database connection")
return
# Test 2: Create
if not tester.test_create():
print("\n[ERROR] Cannot proceed without successful CREATE operations")
tester.cleanup()
return
# Test 3: Read
tester.test_read()
# Test 4: Relationships
tester.test_relationships()
# Test 5: Update
tester.test_update()
# Test 6: Delete
tester.test_delete()
except KeyboardInterrupt:
print("\n\n[WARNING] Test interrupted by user")
except Exception as e:
print(f"\n\n[ERROR] Unexpected error: {e}")
finally:
tester.print_summary()
tester.cleanup()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python3
"""Test MariaDB connectivity from Windows"""
import pymysql
# Connection details
HOST = '172.16.3.20'
PORT = 3306
ROOT_PASSWORD = r'Dy8RPj-s{+=bP^(NoW"T;E~JXyBC9u|<'
CLAUDETOOLS_PASSWORD = 'CT_e8fcd5a3952030a79ed6debae6c954ed'
print("Testing MariaDB connection to Jupiter (172.16.3.20)...\n")
# Test 1: Root connection
try:
print("Test 1: Connecting as root...")
conn = pymysql.connect(
host=HOST,
port=PORT,
user='root',
password=ROOT_PASSWORD,
connect_timeout=10
)
print("[OK] Root connection successful!")
cursor = conn.cursor()
cursor.execute("SELECT VERSION()")
version = cursor.fetchone()
print(f" MariaDB Version: {version[0]}")
cursor.execute("SHOW DATABASES")
databases = cursor.fetchall()
print(f" Databases found: {len(databases)}")
for db in databases:
print(f" - {db[0]}")
# Check if claudetools database exists
cursor.execute("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = 'claudetools'")
claudetools_db = cursor.fetchone()
if claudetools_db:
print("\n[OK] 'claudetools' database exists!")
else:
print("\n[WARNING] 'claudetools' database does NOT exist yet")
# Check if claudetools user exists
cursor.execute("SELECT User FROM mysql.user WHERE User = 'claudetools'")
claudetools_user = cursor.fetchone()
if claudetools_user:
print("[OK] 'claudetools' user exists!")
else:
print("[WARNING] 'claudetools' user does NOT exist yet")
conn.close()
print("\nTest 1 PASSED [OK]\n")
except Exception as e:
print(f"[FAILED] Test 1: {e}\n")
# Test 2: Claudetools user connection (if exists)
try:
print("Test 2: Connecting as 'claudetools' user...")
conn = pymysql.connect(
host=HOST,
port=PORT,
user='claudetools',
password=CLAUDETOOLS_PASSWORD,
database='claudetools',
connect_timeout=10
)
print("[OK] Claudetools user connection successful!")
cursor = conn.cursor()
cursor.execute("SELECT DATABASE()")
current_db = cursor.fetchone()
print(f" Current database: {current_db[0]}")
cursor.execute("SHOW TABLES")
tables = cursor.fetchall()
print(f" Tables in claudetools: {len(tables)}")
conn.close()
print("\nTest 2 PASSED [OK]\n")
except pymysql.err.OperationalError as e:
if "Access denied" in str(e):
print("[WARNING] Claudetools user doesn't exist or wrong password")
elif "Unknown database" in str(e):
print("[WARNING] Claudetools database doesn't exist yet")
else:
print(f"[WARNING] Test 2: {e}")
print(" (This is expected if database/user haven't been created yet)\n")
except Exception as e:
print(f"[WARNING] Test 2: {e}\n")
print("\n" + "="*60)
print("CONNECTION TEST SUMMARY")
print("="*60)
print("[OK] MariaDB is accessible from Windows on port 3306")
print("[OK] Root authentication works")
print("\nNext step: Create 'claudetools' database and user if they don't exist")

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""
Quick preview of what would be imported from Claude projects folder
No API or auth required - just scans and shows what it finds
"""
import sys
from pathlib import Path
# Add api directory to path
sys.path.insert(0, str(Path(__file__).parent))
from api.utils.conversation_parser import scan_folder_for_conversations, categorize_conversation, parse_jsonl_conversation
from api.utils.credential_scanner import scan_for_credential_files
def preview_import(folder_path: str):
"""Preview what would be imported from the given folder"""
print("=" * 70)
print("CLAUDE CONTEXT IMPORT PREVIEW")
print("=" * 70)
print(f"\nScanning: {folder_path}\n")
# Scan for conversation files
print("\n[1] Scanning for conversation files...")
try:
conversation_files = scan_folder_for_conversations(folder_path)
print(f" Found {len(conversation_files)} conversation file(s)")
# Categorize each file
categories = {"msp": 0, "development": 0, "general": 0}
for i, file_path in enumerate(conversation_files[:20]): # Limit to first 20
try:
conv = parse_jsonl_conversation(file_path)
category = categorize_conversation(conv.get("messages", []))
categories[category] += 1
# Show first 5
if i < 5:
rel_path = Path(file_path).relative_to(folder_path)
print(f" [{category.upper()}] {rel_path}")
except Exception as e:
print(f" [ERROR] Failed to parse: {Path(file_path).name} - {e}")
if len(conversation_files) > 20:
print(f" ... and {len(conversation_files) - 20} more files")
print(f"\n Category Breakdown:")
print(f" MSP Work: {categories['msp']} files")
print(f" Development: {categories['development']} files")
print(f" General: {categories['general']} files")
except Exception as e:
print(f" Error scanning conversations: {e}")
# Scan for credential files
print("\n[2] Scanning for credential files...")
try:
credential_files = scan_for_credential_files(folder_path)
print(f" Found {len(credential_files)} credential file(s)")
for i, file_path in enumerate(credential_files[:10]): # Limit to first 10
rel_path = Path(file_path).relative_to(folder_path)
print(f" {i+1}. {rel_path}")
if len(credential_files) > 10:
print(f" ... and {len(credential_files) - 10} more files")
except Exception as e:
print(f" Error scanning credentials: {e}")
print("\n" + "=" * 70)
print("PREVIEW COMPLETE")
print("=" * 70)
print("\nTo actually import:")
print(" 1. Ensure API is running: python -m api.main")
print(" 2. Setup auth: bash scripts/setup-context-recall.sh")
print(" 3. Run import: python scripts/import-claude-context.py --folder \"path\" --execute")
print("\n")
if __name__ == "__main__":
if len(sys.argv) > 1:
folder = sys.argv[1]
else:
folder = r"C:\Users\MikeSwanson\claude-projects"
preview_import(folder)

114
tests/test_import_speed.py Normal file
View File

@@ -0,0 +1,114 @@
"""Test import speed and circular dependency detection."""
import sys
import os
import time
# Set UTF-8 encoding for Windows console
if os.name == 'nt':
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
def test_import_speed():
"""Test how quickly the models module can be imported."""
print("="*70)
print("ClaudeTools - Import Speed and Dependency Test")
print("="*70)
# Test 1: Cold import (first time)
print("\n[TEST 1] Cold import (first time)...")
start = time.time()
import api.models
cold_time = time.time() - start
print(f" Time: {cold_time:.4f} seconds")
# Test 2: Reload (warm import)
print("\n[TEST 2] Warm import (reload)...")
start = time.time()
import importlib
importlib.reload(api.models)
warm_time = time.time() - start
print(f" Time: {warm_time:.4f} seconds")
# Test 3: Individual model imports
print("\n[TEST 3] Individual model imports...")
models_to_test = [
'api.models.client',
'api.models.session',
'api.models.work_item',
'api.models.credential',
'api.models.infrastructure',
'api.models.backup_log',
'api.models.billable_time',
'api.models.security_incident'
]
individual_times = {}
for module_name in models_to_test:
start = time.time()
module = importlib.import_module(module_name)
elapsed = time.time() - start
individual_times[module_name] = elapsed
print(f" {module_name}: {elapsed:.4f}s")
# Test 4: Check for circular dependencies by import order
print("\n[TEST 4] Circular dependency check...")
print(" Importing in different orders to detect circular deps...")
# Try importing base first
try:
from api.models.base import Base, UUIDMixin, TimestampMixin
print(" - Base classes: OK")
except Exception as e:
print(f" - Base classes: FAIL - {e}")
# Try importing models that have relationships
try:
from api.models.client import Client
from api.models.session import Session
from api.models.work_item import WorkItem
print(" - Related models: OK")
except Exception as e:
print(f" - Related models: FAIL - {e}")
# Try importing all models at once
try:
from api.models import (
Client, Session, WorkItem, Infrastructure,
Credential, BillableTime, BackupLog, SecurityIncident
)
print(" - Bulk import: OK")
except Exception as e:
print(f" - Bulk import: FAIL - {e}")
# Summary
print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"\nImport Performance:")
print(f" Cold import: {cold_time:.4f}s")
print(f" Warm import: {warm_time:.4f}s")
print(f" Average individual: {sum(individual_times.values())/len(individual_times):.4f}s")
# Performance assessment
if cold_time < 1.0:
perf_rating = "Excellent"
elif cold_time < 2.0:
perf_rating = "Good"
elif cold_time < 3.0:
perf_rating = "Acceptable"
else:
perf_rating = "Slow (may need optimization)"
print(f"\nPerformance Rating: {perf_rating}")
print("\nCircular Dependencies: None detected")
print("Module Structure: Sound")
print("\n" + "="*70)
print("Import test complete!")
print("="*70)
if __name__ == "__main__":
test_import_speed()

View File

@@ -0,0 +1,207 @@
"""Detailed structure validation for all SQLAlchemy models."""
import sys
import os
# Set UTF-8 encoding for Windows console
if os.name == 'nt':
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
import api.models
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.schema import ForeignKeyConstraint, CheckConstraint, Index
def get_table_models():
"""Get all table model classes (excluding base classes)."""
base_classes = {'Base', 'TimestampMixin', 'UUIDMixin'}
all_classes = [attr for attr in dir(api.models) if not attr.startswith('_') and attr[0].isupper()]
return sorted([m for m in all_classes if m not in base_classes])
def analyze_model(model_name):
"""Analyze a model's structure in detail."""
model_cls = getattr(api.models, model_name)
result = {
'name': model_name,
'table': model_cls.__tablename__,
'has_uuid_mixin': False,
'has_timestamp_mixin': False,
'foreign_keys': [],
'relationships': [],
'indexes': [],
'check_constraints': [],
'columns': []
}
# Check mixins
for base in model_cls.__mro__:
if base.__name__ == 'UUIDMixin':
result['has_uuid_mixin'] = True
if base.__name__ == 'TimestampMixin':
result['has_timestamp_mixin'] = True
# Get table object
if hasattr(model_cls, '__table__'):
table = model_cls.__table__
# Analyze columns
for col in table.columns:
col_info = {
'name': col.name,
'type': str(col.type),
'nullable': col.nullable,
'primary_key': col.primary_key
}
result['columns'].append(col_info)
# Analyze foreign keys
for fk in table.foreign_keys:
result['foreign_keys'].append({
'parent_column': fk.parent.name,
'target': str(fk.target_fullname)
})
# Analyze indexes
if hasattr(table, 'indexes'):
for idx in table.indexes:
result['indexes'].append({
'name': idx.name,
'columns': [col.name for col in idx.columns]
})
# Analyze check constraints
for constraint in table.constraints:
if isinstance(constraint, CheckConstraint):
result['check_constraints'].append({
'sqltext': str(constraint.sqltext)
})
# Analyze relationships
for attr_name in dir(model_cls):
try:
attr = getattr(model_cls, attr_name)
if hasattr(attr, 'property') and isinstance(attr.property, RelationshipProperty):
rel = attr.property
result['relationships'].append({
'name': attr_name,
'target': rel.mapper.class_.__name__,
'uselist': rel.uselist
})
except (AttributeError, TypeError):
continue
return result
def print_model_summary(result):
"""Print a formatted summary of model structure."""
print(f"\n{'='*70}")
print(f"Model: {result['name']} (table: {result['table']})")
print(f"{'='*70}")
# Mixins
mixins = []
if result['has_uuid_mixin']:
mixins.append('UUIDMixin')
if result['has_timestamp_mixin']:
mixins.append('TimestampMixin')
if mixins:
print(f"Mixins: {', '.join(mixins)}")
# Columns
print(f"\nColumns ({len(result['columns'])}):")
for col in result['columns'][:10]: # Limit to first 10 for readability
pk = " [PK]" if col['primary_key'] else ""
nullable = "NULL" if col['nullable'] else "NOT NULL"
print(f" - {col['name']}: {col['type']} {nullable}{pk}")
if len(result['columns']) > 10:
print(f" ... and {len(result['columns']) - 10} more columns")
# Foreign Keys
if result['foreign_keys']:
print(f"\nForeign Keys ({len(result['foreign_keys'])}):")
for fk in result['foreign_keys']:
print(f" - {fk['parent_column']} -> {fk['target']}")
# Relationships
if result['relationships']:
print(f"\nRelationships ({len(result['relationships'])}):")
for rel in result['relationships']:
rel_type = "many" if rel['uselist'] else "one"
print(f" - {rel['name']} -> {rel['target']} ({rel_type})")
# Indexes
if result['indexes']:
print(f"\nIndexes ({len(result['indexes'])}):")
for idx in result['indexes']:
cols = ', '.join(idx['columns'])
print(f" - {idx['name']}: ({cols})")
# Check Constraints
if result['check_constraints']:
print(f"\nCheck Constraints ({len(result['check_constraints'])}):")
for check in result['check_constraints']:
print(f" - {check['sqltext']}")
def main():
print("="*70)
print("ClaudeTools - Detailed Model Structure Analysis")
print("="*70)
models = get_table_models()
print(f"\nAnalyzing {len(models)} table models...\n")
all_results = []
for model_name in models:
try:
result = analyze_model(model_name)
all_results.append(result)
except Exception as e:
print(f"[ERROR] Error analyzing {model_name}: {e}")
import traceback
traceback.print_exc()
# Print summary statistics
print("\n" + "="*70)
print("SUMMARY STATISTICS")
print("="*70)
total_models = len(all_results)
models_with_uuid = sum(1 for r in all_results if r['has_uuid_mixin'])
models_with_timestamp = sum(1 for r in all_results if r['has_timestamp_mixin'])
models_with_fk = sum(1 for r in all_results if r['foreign_keys'])
models_with_rel = sum(1 for r in all_results if r['relationships'])
models_with_idx = sum(1 for r in all_results if r['indexes'])
models_with_checks = sum(1 for r in all_results if r['check_constraints'])
total_fk = sum(len(r['foreign_keys']) for r in all_results)
total_rel = sum(len(r['relationships']) for r in all_results)
total_idx = sum(len(r['indexes']) for r in all_results)
total_checks = sum(len(r['check_constraints']) for r in all_results)
print(f"\nTotal Models: {total_models}")
print(f" - With UUIDMixin: {models_with_uuid}")
print(f" - With TimestampMixin: {models_with_timestamp}")
print(f" - With Foreign Keys: {models_with_fk} (total: {total_fk})")
print(f" - With Relationships: {models_with_rel} (total: {total_rel})")
print(f" - With Indexes: {models_with_idx} (total: {total_idx})")
print(f" - With CHECK Constraints: {models_with_checks} (total: {total_checks})")
# Print detailed info for each model
print("\n" + "="*70)
print("DETAILED MODEL INFORMATION")
print("="*70)
for result in all_results:
print_model_summary(result)
print("\n" + "="*70)
print("[SUCCESS] Analysis complete!")
print("="*70)
if __name__ == "__main__":
main()

127
tests/test_models_import.py Normal file
View File

@@ -0,0 +1,127 @@
"""Test script to import and validate all SQLAlchemy models."""
import sys
import traceback
import os
# Set UTF-8 encoding for Windows console
if os.name == 'nt':
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
def test_model_import():
"""Test importing all models from api.models."""
try:
import api.models
print("[SUCCESS] Import successful")
# Get all model classes (exclude private attributes and modules)
all_classes = [attr for attr in dir(api.models) if not attr.startswith('_') and attr[0].isupper()]
# Filter out base classes and mixins (they don't have __tablename__)
base_classes = {'Base', 'TimestampMixin', 'UUIDMixin'}
models = [m for m in all_classes if m not in base_classes]
print(f"\nTotal classes found: {len(all_classes)}")
print(f"Base classes/mixins: {len(base_classes)}")
print(f"Table models: {len(models)}")
print("\nTable Models:")
for m in sorted(models):
print(f" - {m}")
return models
except Exception as e:
print(f"[ERROR] Import failed: {e}")
traceback.print_exc()
return []
def test_model_structure(model_name):
"""Test individual model structure and configuration."""
import api.models
try:
model_cls = getattr(api.models, model_name)
# Check if it's actually a class
if not isinstance(model_cls, type):
return f"[FAIL] {model_name} - Not a class"
# Check for __tablename__
if not hasattr(model_cls, '__tablename__'):
return f"[FAIL] {model_name} - Missing __tablename__"
# Check for __table_args__ (optional but should exist if defined)
has_table_args = hasattr(model_cls, '__table_args__')
# Try to instantiate (without saving to DB)
try:
instance = model_cls()
can_instantiate = True
except Exception as inst_error:
can_instantiate = False
inst_msg = str(inst_error)
# Build result message
details = []
details.append(f"table={model_cls.__tablename__}")
if has_table_args:
details.append("has_table_args")
if can_instantiate:
details.append("instantiable")
else:
details.append(f"not_instantiable({inst_msg[:50]})")
return f"[PASS] {model_name} - {', '.join(details)}"
except Exception as e:
return f"[FAIL] {model_name} - {str(e)}"
def main():
print("=" * 70)
print("ClaudeTools - Model Import and Structure Test")
print("=" * 70)
# Test 1: Import all models
print("\n[TEST 1] Importing api.models module...")
models = test_model_import()
if not models:
print("\n[CRITICAL] Failed to import models module")
sys.exit(1)
# Test 2: Validate each model structure
print(f"\n[TEST 2] Validating structure of {len(models)} models...")
print("-" * 70)
passed = 0
failed = 0
results = []
for model_name in sorted(models):
result = test_model_structure(model_name)
results.append(result)
if result.startswith("[PASS]"):
passed += 1
else:
failed += 1
# Print all results
for result in results:
print(result)
# Summary
print("-" * 70)
print(f"\n[SUMMARY]")
print(f"Total models: {len(models)}")
print(f"[PASS] Passed: {passed}")
print(f"[FAIL] Failed: {failed}")
if failed == 0:
print(f"\n[SUCCESS] All {passed} models validated successfully!")
sys.exit(0)
else:
print(f"\n[WARNING] {failed} model(s) need attention")
sys.exit(1)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
"""
SQL Injection Security Tests for Context Recall API
Tests that the recall API is properly protected against SQL injection attacks.
Validates both the input validation layer and the parameterized query layer.
"""
import unittest
import requests
from typing import Dict, Any
# Import auth utilities for token creation
from api.middleware.auth import create_access_token
# Test configuration
API_BASE_URL = "http://172.16.3.30:8001/api"
TEST_USER_EMAIL = "admin@claudetools.local"
class TestSQLInjectionSecurity(unittest.TestCase):
"""Test suite for SQL injection attack prevention."""
@classmethod
def setUpClass(cls):
"""Create test JWT token for authentication."""
# Create token directly without login endpoint
cls.token = create_access_token({"sub": TEST_USER_EMAIL})
cls.headers = {"Authorization": f"Bearer {cls.token}"}
# SQL Injection Test Cases for search_term parameter
def test_sql_injection_search_term_basic_attack(self):
"""Test basic SQL injection attempt via search_term."""
malicious_input = "' OR '1'='1"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation (contains single quotes)
assert response.status_code == 422, "Failed to reject SQL injection attack"
error_detail = response.json()["detail"]
assert any("pattern" in str(err).lower() or "match" in str(err).lower()
for err in error_detail if isinstance(err, dict))
def test_sql_injection_search_term_union_attack(self):
"""Test UNION-based SQL injection attempt."""
malicious_input = "' UNION SELECT * FROM users--"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation
assert response.status_code == 422, "Failed to reject UNION attack"
def test_sql_injection_search_term_comment_injection(self):
"""Test comment-based SQL injection."""
malicious_input = "test' --"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation (contains single quote)
assert response.status_code == 422, "Failed to reject comment injection"
def test_sql_injection_search_term_semicolon_attack(self):
"""Test semicolon-based SQL injection for multiple statements."""
malicious_input = "test'; DROP TABLE conversation_contexts;--"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation (contains semicolon and quotes)
assert response.status_code == 422, "Failed to reject DROP TABLE attack"
def test_sql_injection_search_term_encoded_attack(self):
"""Test URL-encoded SQL injection attempt."""
# URL encoding of "' OR 1=1--"
malicious_input = "%27%20OR%201%3D1--"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation after decoding
assert response.status_code == 422, "Failed to reject encoded attack"
# SQL Injection Test Cases for tags parameter
def test_sql_injection_tags_basic_attack(self):
"""Test SQL injection via tags parameter."""
malicious_tag = "' OR '1'='1"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": [malicious_tag]},
headers=self.headers
)
# Should reject due to tag validation (contains single quotes and spaces)
assert response.status_code == 400, "Failed to reject SQL injection via tags"
assert "Invalid tag format" in response.json()["detail"]
def test_sql_injection_tags_union_attack(self):
"""Test UNION attack via tags parameter."""
malicious_tag = "tag' UNION SELECT password FROM users--"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": [malicious_tag]},
headers=self.headers
)
# Should reject due to tag validation
assert response.status_code == 400, "Failed to reject UNION attack via tags"
def test_sql_injection_tags_multiple_malicious(self):
"""Test multiple malicious tags."""
malicious_tags = [
"tag1' OR '1'='1",
"tag2'; DROP TABLE tags;--",
"tag3' UNION SELECT NULL--"
]
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": malicious_tags},
headers=self.headers
)
# Should reject due to tag validation
assert response.status_code == 400, "Failed to reject multiple malicious tags"
# Valid Input Tests (should succeed)
def test_valid_search_term_alphanumeric(self):
"""Test that valid alphanumeric search terms work."""
valid_input = "API development"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": valid_input},
headers=self.headers
)
# Should succeed
assert response.status_code == 200, f"Valid input rejected: {response.text}"
data = response.json()
assert "contexts" in data
assert isinstance(data["contexts"], list)
def test_valid_search_term_with_punctuation(self):
"""Test valid search terms with allowed punctuation."""
valid_input = "database-migration (phase-1)!"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": valid_input},
headers=self.headers
)
# Should succeed
assert response.status_code == 200, f"Valid input rejected: {response.text}"
def test_valid_tags(self):
"""Test that valid tags work."""
valid_tags = ["api", "database", "phase-1", "test_tag"]
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": valid_tags},
headers=self.headers
)
# Should succeed
assert response.status_code == 200, f"Valid tags rejected: {response.text}"
data = response.json()
assert "contexts" in data
# Boundary Tests
def test_search_term_max_length(self):
"""Test search term at maximum allowed length (200 chars)."""
valid_input = "a" * 200
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": valid_input},
headers=self.headers
)
# Should succeed
assert response.status_code == 200, "Max length valid input rejected"
def test_search_term_exceeds_max_length(self):
"""Test search term exceeding maximum length."""
invalid_input = "a" * 201
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": invalid_input},
headers=self.headers
)
# Should reject
assert response.status_code == 422, "Overlong input not rejected"
def test_tags_max_items(self):
"""Test maximum number of tags (20)."""
valid_tags = [f"tag{i}" for i in range(20)]
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": valid_tags},
headers=self.headers
)
# Should succeed
assert response.status_code == 200, "Max tags rejected"
def test_tags_exceeds_max_items(self):
"""Test exceeding maximum number of tags."""
invalid_tags = [f"tag{i}" for i in range(21)]
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"tags": invalid_tags},
headers=self.headers
)
# Should reject
assert response.status_code == 422, "Too many tags not rejected"
# Advanced SQL Injection Techniques
def test_sql_injection_hex_encoding(self):
"""Test hex-encoded SQL injection."""
malicious_input = "0x27204f522031203d2031" # Hex for "' OR 1 = 1"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Pattern allows alphanumeric, so this passes input validation
# Should be safe due to parameterized queries
assert response.status_code == 200, "Hex encoding caused error"
# Verify it's treated as literal search, not executed as SQL
data = response.json()
assert isinstance(data["contexts"], list)
def test_sql_injection_time_based_blind(self):
"""Test time-based blind SQL injection attempt."""
malicious_input = "' AND SLEEP(5)--"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation
assert response.status_code == 422, "Time-based attack not rejected"
def test_sql_injection_stacked_queries(self):
"""Test stacked query injection."""
malicious_input = "test; DELETE FROM conversation_contexts WHERE 1=1"
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": malicious_input},
headers=self.headers
)
# Should reject due to pattern validation (semicolon not allowed)
assert response.status_code == 422, "Stacked query attack not rejected"
# Verify Database Integrity
def test_database_not_compromised(self):
"""Verify database still functions after attack attempts."""
# Simple query to verify database is intact
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"limit": 5},
headers=self.headers
)
assert response.status_code == 200, "Database may be compromised"
data = response.json()
assert "contexts" in data
assert isinstance(data["contexts"], list)
def test_fulltext_index_still_works(self):
"""Verify FULLTEXT index functionality after attacks."""
# Test normal search that should use FULLTEXT index
response = requests.get(
f"{API_BASE_URL}/conversation-contexts/recall",
params={"search_term": "test"},
headers=self.headers
)
assert response.status_code == 200, "FULLTEXT search failed"
data = response.json()
assert isinstance(data["contexts"], list)
if __name__ == "__main__":
print("=" * 70)
print("SQL INJECTION SECURITY TEST SUITE")
print("=" * 70)
print()
print("Testing Context Recall API endpoint security...")
print(f"Target: {API_BASE_URL}/conversation-contexts/recall")
print()
# Run tests
unittest.main(verbosity=2)

View File

@@ -0,0 +1,162 @@
#!/bin/bash
#
# Simplified SQL Injection Security Tests
# Tests the recall API endpoint against SQL injection attacks
#
API_URL="http://172.16.3.30:8001/api"
# Get JWT token from setup config if it exists
if [ -f ".claude/context-recall-config.env" ]; then
source .claude/context-recall-config.env
fi
# Test counter
TOTAL_TESTS=0
PASSED_TESTS=0
FAILED_TESTS=0
# Color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Test function
run_test() {
local test_name="$1"
local search_term="$2"
local expected_status="$3"
local test_type="${4:-search_term}" # search_term or tag
TOTAL_TESTS=$((TOTAL_TESTS + 1))
# Build curl command based on test type
if [ "$test_type" = "tag" ]; then
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?tags[]=$search_term" \
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
else
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?search_term=$search_term" \
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
fi
http_code=$(echo "$response" | tail -1)
body=$(echo "$response" | sed '$d')
# Check if status code matches expected
if [ "$http_code" = "$expected_status" ]; then
echo -e "${GREEN}[PASS]${NC} $test_name (HTTP $http_code)"
PASSED_TESTS=$((PASSED_TESTS + 1))
return 0
else
echo -e "${RED}[FAIL]${NC} $test_name"
echo " Expected: HTTP $expected_status"
echo " Got: HTTP $http_code"
echo " Response: $body"
FAILED_TESTS=$((FAILED_TESTS + 1))
return 1
fi
}
# Print header
echo "======================================================================="
echo "SQL INJECTION SECURITY TEST SUITE - Simplified"
echo "======================================================================="
echo ""
echo "Target: $API_URL/conversation-contexts/recall"
echo ""
# Verify JWT token
if [ -z "$JWT_TOKEN" ]; then
echo -e "${RED}[ERROR]${NC} JWT_TOKEN not set. Run setup-context-recall.sh first."
exit 1
fi
echo "Testing SQL injection vulnerabilities..."
echo ""
# Test 1: Basic SQL injection with single quote (should be rejected - 422)
run_test "Basic SQL injection: ' OR '1'='1" "' OR '1'='1" "422"
# Test 2: UNION attack (should be rejected - 422)
run_test "UNION attack: ' UNION SELECT * FROM users--" "' UNION SELECT * FROM users--" "422"
# Test 3: Comment injection (should be rejected - 422)
run_test "Comment injection: test' --" "test' --" "422"
# Test 4: Semicolon attack (should be rejected - 422)
run_test "Semicolon attack: test'; DROP TABLE conversation_contexts;--" "test'; DROP TABLE conversation_contexts;--" "422"
# Test 5: Time-based blind SQLi (should be rejected - 422)
run_test "Time-based blind: ' AND SLEEP(5)--" "' AND SLEEP(5)--" "422"
# Test 6: Stacked queries (should be rejected - 422)
run_test "Stacked queries: test; DELETE FROM contexts" "test; DELETE FROM contexts" "422"
# Test 7: SQL injection via tags (should be rejected - 400)
run_test "Tag injection: ' OR '1'='1" "' OR '1'='1" "400" "tag"
# Test 8: Tag UNION attack (should be rejected - 400)
run_test "Tag UNION: tag' UNION SELECT--" "tag' UNION SELECT--" "400" "tag"
# Valid inputs (should succeed - 200)
echo ""
echo "Testing valid inputs (should work)..."
echo ""
# Test 9: Valid alphanumeric search (should succeed - 200)
run_test "Valid search: API development" "API development" "200"
# Test 10: Valid search with allowed punctuation (should succeed - 200)
run_test "Valid punctuation: database-migration (phase-1)!" "database-migration (phase-1)!" "200"
# Test 11: Valid tags (should succeed - 200)
run_test "Valid tags: api-test" "api-test" "200" "tag"
# Test 12: Verify database still works after attacks (should succeed - 200)
echo ""
echo "Verifying database integrity..."
echo ""
response=$(curl -s -w "\n%{http_code}" -X GET "$API_URL/conversation-contexts/recall?limit=5" \
-H "Authorization: Bearer $JWT_TOKEN" 2>&1)
http_code=$(echo "$response" | tail -1)
if [ "$http_code" = "200" ]; then
echo -e "${GREEN}[PASS]${NC} Database integrity check (HTTP $http_code)"
PASSED_TESTS=$((PASSED_TESTS + 1))
TOTAL_TESTS=$((TOTAL_TESTS + 1))
else
echo -e "${RED}[FAIL]${NC} Database integrity check"
echo " Expected: HTTP 200"
echo " Got: HTTP $http_code"
FAILED_TESTS=$((FAILED_TESTS + 1))
TOTAL_TESTS=$((TOTAL_TESTS + 1))
fi
# Print summary
echo ""
echo "======================================================================="
echo "TEST SUMMARY"
echo "======================================================================="
echo "Total Tests: $TOTAL_TESTS"
echo -e "${GREEN}Passed: $PASSED_TESTS${NC}"
if [ $FAILED_TESTS -gt 0 ]; then
echo -e "${RED}Failed: $FAILED_TESTS${NC}"
else
echo -e "${GREEN}Failed: $FAILED_TESTS${NC}"
fi
pass_rate=$(awk "BEGIN {printf \"%.1f\", ($PASSED_TESTS/$TOTAL_TESTS)*100}")
echo "Pass Rate: $pass_rate%"
echo ""
if [ $FAILED_TESTS -eq 0 ]; then
echo -e "${GREEN}[SUCCESS]${NC} All SQL injection tests passed!"
echo "The API is properly protected against SQL injection attacks."
exit 0
else
echo -e "${RED}[FAILURE]${NC} Some tests failed!"
echo "Review the failed tests above for security vulnerabilities."
exit 1
fi