""" Test script for credential scanner and importer. This script demonstrates the credential scanner functionality including: - Creating sample credential files - Scanning for credential files - Parsing credential data - Importing credentials to database Usage: python test_credential_scanner.py """ import logging import os import tempfile from pathlib import Path from api.database import SessionLocal from api.utils.credential_scanner import ( scan_for_credential_files, parse_credential_file, import_credentials_to_db, scan_and_import_credentials, ) # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def create_sample_credential_files(temp_dir: str): """Create sample credential files for testing.""" # Create credentials.md credentials_md = Path(temp_dir) / "credentials.md" credentials_md.write_text("""# Sample Credentials ## Gitea Admin Username: admin Password: GitSecurePass123! URL: https://git.example.com Notes: Main admin account ## Database Server Type: connection_string Connection String: mysql://dbuser:dbpass@192.168.1.50:3306/mydb Notes: Production database ## OpenAI API API Key: sk-1234567890abcdefghijklmnopqrstuvwxyz Notes: Production API key """) # Create .env file env_file = Path(temp_dir) / ".env" env_file.write_text("""# Environment Variables DATABASE_URL=postgresql://user:pass@localhost:5432/testdb API_TOKEN=ghp_abc123def456ghi789jkl012mno345pqr678 SECRET_KEY=super_secret_key_12345 """) # Create passwords.txt passwords_txt = Path(temp_dir) / "passwords.txt" passwords_txt.write_text("""# Server Passwords ## Web Server Username: webadmin Password: Web@dmin2024! Host: 192.168.1.100 Port: 22 ## Backup Server Username: backup Password: BackupSecure789 Host: 10.0.0.50 """) logger.info(f"Created sample credential files in: {temp_dir}") return [str(credentials_md), str(env_file), str(passwords_txt)] def test_scan_for_credential_files(): """Test credential file scanning.""" logger.info("=" * 60) logger.info("TEST 1: Scan for Credential Files") logger.info("=" * 60) with tempfile.TemporaryDirectory() as temp_dir: # Create sample files create_sample_credential_files(temp_dir) # Scan for files found_files = scan_for_credential_files(temp_dir) logger.info(f"\nFound {len(found_files)} credential file(s):") for file_path in found_files: logger.info(f" - {file_path}") assert len(found_files) == 3, "Should find 3 credential files" logger.info("\n[PASS] Test 1 passed") return found_files def test_parse_credential_file(): """Test credential file parsing.""" logger.info("\n" + "=" * 60) logger.info("TEST 2: Parse Credential Files") logger.info("=" * 60) with tempfile.TemporaryDirectory() as temp_dir: # Create sample files sample_files = create_sample_credential_files(temp_dir) total_credentials = 0 for file_path in sample_files: credentials = parse_credential_file(file_path) total_credentials += len(credentials) logger.info(f"\nParsed from {Path(file_path).name}:") for cred in credentials: logger.info(f" Service: {cred.get('service_name')}") logger.info(f" Type: {cred.get('credential_type')}") if cred.get('username'): logger.info(f" Username: {cred.get('username')}") # Don't log actual passwords/keys if cred.get('password'): logger.info(f" Password: [REDACTED]") if cred.get('api_key'): logger.info(f" API Key: [REDACTED]") if cred.get('connection_string'): logger.info(f" Connection String: [REDACTED]") logger.info("") logger.info(f"Total credentials parsed: {total_credentials}") assert total_credentials > 0, "Should parse at least one credential" logger.info("[PASS] Test 2 passed") def test_import_credentials_to_db(): """Test importing credentials to database.""" logger.info("\n" + "=" * 60) logger.info("TEST 3: Import Credentials to Database") logger.info("=" * 60) db = SessionLocal() try: with tempfile.TemporaryDirectory() as temp_dir: # Create sample files sample_files = create_sample_credential_files(temp_dir) # Parse first file credentials = parse_credential_file(sample_files[0]) logger.info(f"\nParsed {len(credentials)} credential(s) from file") # Import to database imported_count = import_credentials_to_db( db=db, credentials=credentials, client_id=None, # No client association for test user_id="test_user", ip_address="127.0.0.1" ) logger.info(f"\n[OK] Successfully imported {imported_count} credential(s)") logger.info("[PASS] Test 3 passed") return imported_count except Exception as e: logger.error(f"Import failed: {str(e)}") raise finally: db.close() def test_full_workflow(): """Test complete scan and import workflow.""" logger.info("\n" + "=" * 60) logger.info("TEST 4: Full Scan and Import Workflow") logger.info("=" * 60) db = SessionLocal() try: with tempfile.TemporaryDirectory() as temp_dir: # Create sample files create_sample_credential_files(temp_dir) # Run full workflow results = scan_and_import_credentials( base_path=temp_dir, db=db, client_id=None, user_id="test_user", ip_address="127.0.0.1" ) logger.info(f"\nWorkflow Results:") logger.info(f" Files found: {results['files_found']}") logger.info(f" Credentials parsed: {results['credentials_parsed']}") logger.info(f" Credentials imported: {results['credentials_imported']}") assert results['files_found'] > 0, "Should find files" assert results['credentials_parsed'] > 0, "Should parse credentials" logger.info("\n[PASS] Test 4 passed") except Exception as e: logger.error(f"Workflow failed: {str(e)}") raise finally: db.close() def test_markdown_parsing(): """Test markdown credential parsing with various formats.""" logger.info("\n" + "=" * 60) logger.info("TEST 5: Markdown Format Variations") logger.info("=" * 60) with tempfile.TemporaryDirectory() as temp_dir: # Create file with various markdown formats test_file = Path(temp_dir) / "test_variations.md" test_file.write_text(""" # Single Hash Header Username: user1 Password: pass1 ## Double Hash Header User: user2 Pass: pass2 ## API Service API_Key: sk-123456789 Type: api_key ## Database Connection Connection_String: mysql://user:pass@host/db """) credentials = parse_credential_file(str(test_file)) logger.info(f"\nParsed {len(credentials)} credential(s):") for cred in credentials: logger.info(f" - {cred.get('service_name')} ({cred.get('credential_type')})") assert len(credentials) >= 3, "Should parse multiple variations" logger.info("\n[PASS] Test 5 passed") def main(): """Run all tests.""" logger.info("\n" + "=" * 60) logger.info("CREDENTIAL SCANNER TEST SUITE") logger.info("=" * 60) try: # Run tests test_scan_for_credential_files() test_parse_credential_file() test_markdown_parsing() test_import_credentials_to_db() test_full_workflow() logger.info("\n" + "=" * 60) logger.info("ALL TESTS PASSED!") logger.info("=" * 60) except Exception as e: logger.error("\n" + "=" * 60) logger.error("TEST FAILED!") logger.error("=" * 60) logger.error(f"Error: {str(e)}", exc_info=True) raise if __name__ == "__main__": main()