""" Test script for Credentials Management API. This script tests the credentials API endpoints including encryption, decryption, and audit logging functionality. """ import sys from datetime import datetime from uuid import uuid4 from api.database import get_db from api.models.credential import Credential from api.models.credential_audit_log import CredentialAuditLog from api.schemas.credential import CredentialCreate, CredentialUpdate from api.services.credential_service import ( create_credential, delete_credential, get_credential_by_id, get_credentials, update_credential, ) from api.utils.crypto import decrypt_string, encrypt_string def test_encryption_decryption(): """Test basic encryption and decryption.""" print("\n=== Testing Encryption/Decryption ===") test_password = "SuperSecretPassword123!" print(f"Original password: {test_password}") # Encrypt encrypted = encrypt_string(test_password) print(f"Encrypted (length: {len(encrypted)}): {encrypted[:50]}...") # Decrypt decrypted = decrypt_string(encrypted) print(f"Decrypted: {decrypted}") assert test_password == decrypted, "Encryption/decryption mismatch!" print("[PASS] Encryption/decryption test passed") def test_credential_lifecycle(): """Test the full credential lifecycle: create, read, update, delete.""" print("\n=== Testing Credential Lifecycle ===") db = next(get_db()) try: # 1. CREATE print("\n1. Creating credential...") credential_data = CredentialCreate( credential_type="password", service_name="Test Service", username="admin", password="MySecurePassword123!", external_url="https://test.example.com", requires_vpn=False, requires_2fa=True, is_active=True ) created = create_credential( db=db, credential_data=credential_data, user_id="test_user_123", ip_address="127.0.0.1", user_agent="Test Script" ) print(f"[PASS] Created credential ID: {created.id}") print(f" Service: {created.service_name}") print(f" Type: {created.credential_type}") print(f" Password encrypted: {created.password_encrypted is not None}") # Verify encryption if created.password_encrypted: decrypted_password = decrypt_string(created.password_encrypted.decode('utf-8')) assert decrypted_password == "MySecurePassword123!", "Password encryption failed!" print(f" [PASS] Password correctly encrypted and decrypted") # Verify audit log was created audit_logs = db.query(CredentialAuditLog).filter( CredentialAuditLog.credential_id == str(created.id) ).all() print(f" [PASS] Audit logs created: {len(audit_logs)}") # 2. READ print("\n2. Reading credential...") retrieved = get_credential_by_id(db, created.id, user_id="test_user_123") print(f"[PASS] Retrieved credential: {retrieved.service_name}") # Check audit log for view action audit_logs = db.query(CredentialAuditLog).filter( CredentialAuditLog.credential_id == str(created.id), CredentialAuditLog.action == "view" ).all() print(f" [PASS] View action logged: {len(audit_logs) > 0}") # 3. UPDATE print("\n3. Updating credential...") update_data = CredentialUpdate( password="NewSecurePassword456!", last_rotated_at=datetime.utcnow(), external_url="https://test-updated.example.com" ) updated = update_credential( db=db, credential_id=created.id, credential_data=update_data, user_id="test_user_123", ip_address="127.0.0.1" ) print(f"[PASS] Updated credential: {updated.service_name}") print(f" Password re-encrypted: {updated.password_encrypted is not None}") # Verify new password if updated.password_encrypted: decrypted_new_password = decrypt_string(updated.password_encrypted.decode('utf-8')) assert decrypted_new_password == "NewSecurePassword456!", "Password update failed!" print(f" [PASS] New password correctly encrypted") # Check audit log for update action audit_logs = db.query(CredentialAuditLog).filter( CredentialAuditLog.credential_id == str(created.id), CredentialAuditLog.action == "update" ).all() print(f" [PASS] Update action logged: {len(audit_logs) > 0}") # 4. LIST print("\n4. Listing credentials...") credentials, total = get_credentials(db, skip=0, limit=10) print(f"[PASS] Found {total} total credentials") print(f" Retrieved {len(credentials)} credentials in this page") # 5. DELETE print("\n5. Deleting credential...") result = delete_credential( db=db, credential_id=created.id, user_id="test_user_123", ip_address="127.0.0.1" ) print(f"[PASS] {result['message']}") # Verify deletion remaining = db.query(Credential).filter(Credential.id == str(created.id)).first() assert remaining is None, "Credential was not deleted!" print(f" [PASS] Credential successfully removed from database") # Check audit log for delete action (should still exist due to CASCADE behavior) audit_logs = db.query(CredentialAuditLog).filter( CredentialAuditLog.credential_id == str(created.id), CredentialAuditLog.action == "delete" ).all() print(f" [PASS] Delete action logged: {len(audit_logs) > 0}") print("\n[PASS] All credential lifecycle tests passed!") except Exception as e: print(f"\n[FAIL] Test failed: {str(e)}") import traceback traceback.print_exc() sys.exit(1) finally: db.close() def test_multiple_credential_types(): """Test creating credentials with different types and encrypted fields.""" print("\n=== Testing Multiple Credential Types ===") db = next(get_db()) try: credential_ids = [] # Test API Key credential print("\n1. Creating API Key credential...") api_key_data = CredentialCreate( credential_type="api_key", service_name="GitHub API", api_key="ghp_abcdef1234567890", external_url="https://api.github.com", is_active=True ) api_cred = create_credential(db, api_key_data, user_id="test_user") print(f"[PASS] Created API Key credential: {api_cred.id}") credential_ids.append(api_cred.id) # Verify API key encryption if api_cred.api_key_encrypted: decrypted_key = decrypt_string(api_cred.api_key_encrypted.decode('utf-8')) assert decrypted_key == "ghp_abcdef1234567890", "API key encryption failed!" print(f" [PASS] API key correctly encrypted") # Test OAuth credential print("\n2. Creating OAuth credential...") oauth_data = CredentialCreate( credential_type="oauth", service_name="Microsoft 365", client_id_oauth="app-client-id-123", client_secret="secret_value_xyz789", tenant_id_oauth="tenant-id-456", is_active=True ) oauth_cred = create_credential(db, oauth_data, user_id="test_user") print(f"[PASS] Created OAuth credential: {oauth_cred.id}") credential_ids.append(oauth_cred.id) # Verify client secret encryption if oauth_cred.client_secret_encrypted: decrypted_secret = decrypt_string(oauth_cred.client_secret_encrypted.decode('utf-8')) assert decrypted_secret == "secret_value_xyz789", "OAuth secret encryption failed!" print(f" [PASS] Client secret correctly encrypted") # Test Connection String credential print("\n3. Creating Connection String credential...") conn_data = CredentialCreate( credential_type="connection_string", service_name="SQL Server", connection_string="Server=localhost;Database=TestDB;User Id=sa;Password=ComplexPass123!;", internal_url="sql.internal.local", custom_port=1433, is_active=True ) conn_cred = create_credential(db, conn_data, user_id="test_user") print(f"[PASS] Created Connection String credential: {conn_cred.id}") credential_ids.append(conn_cred.id) # Verify connection string encryption if conn_cred.connection_string_encrypted: decrypted_conn = decrypt_string(conn_cred.connection_string_encrypted.decode('utf-8')) assert "ComplexPass123!" in decrypted_conn, "Connection string encryption failed!" print(f" [PASS] Connection string correctly encrypted") print(f"\n[PASS] Created {len(credential_ids)} different credential types") # Cleanup print("\n4. Cleaning up test credentials...") for cred_id in credential_ids: delete_credential(db, cred_id, user_id="test_user") print(f"[PASS] Cleaned up {len(credential_ids)} credentials") print("\n[PASS] All multi-type credential tests passed!") except Exception as e: print(f"\n[FAIL] Test failed: {str(e)}") import traceback traceback.print_exc() sys.exit(1) finally: db.close() def main(): """Run all tests.""" print("=" * 60) print("CREDENTIALS API TEST SUITE") print("=" * 60) try: test_encryption_decryption() test_credential_lifecycle() test_multiple_credential_types() print("\n" + "=" * 60) print("[PASS] ALL TESTS PASSED!") print("=" * 60) except Exception as e: print("\n" + "=" * 60) print("[FAIL] TEST SUITE FAILED") print("=" * 60) import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()