Complete Phase 6: MSP Work Tracking with Context Recall System

Implements production-ready MSP platform with cross-machine persistent memory for Claude.

API Implementation:
- 130 REST API endpoints across 21 entities
- JWT authentication on all endpoints
- AES-256-GCM encryption for credentials
- Automatic audit logging
- Complete OpenAPI documentation

Database:
- 43 tables in MariaDB (172.16.3.20:3306)
- 42 SQLAlchemy models with modern 2.0 syntax
- Full Alembic migration system
- 99.1% CRUD test pass rate

Context Recall System (Phase 6):
- Cross-machine persistent memory via database
- Automatic context injection via Claude Code hooks
- Automatic context saving after task completion
- 90-95% token reduction with compression utilities
- Relevance scoring with time decay
- Tag-based semantic search
- One-command setup script

Security Features:
- JWT tokens with Argon2 password hashing
- AES-256-GCM encryption for all sensitive data
- Comprehensive audit trail for credentials
- HMAC tamper detection
- Secure configuration management

Test Results:
- Phase 3: 38/38 CRUD tests passing (100%)
- Phase 4: 34/35 core API tests passing (97.1%)
- Phase 5: 62/62 extended API tests passing (100%)
- Phase 6: 10/10 compression tests passing (100%)
- Overall: 144/145 tests passing (99.3%)

Documentation:
- Comprehensive architecture guides
- Setup automation scripts
- API documentation at /api/docs
- Complete test reports
- Troubleshooting guides

Project Status: 95% Complete (Production-Ready)
Phase 7 (optional work context APIs) remains for future enhancement.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-17 06:00:26 -07:00
parent 1452361c21
commit 390b10b32c
201 changed files with 55619 additions and 34 deletions

35
api/services/__init__.py Normal file
View File

@@ -0,0 +1,35 @@
"""Business logic services for ClaudeTools API"""
from . import (
machine_service,
client_service,
site_service,
network_service,
tag_service,
service_service,
infrastructure_service,
credential_service,
credential_audit_log_service,
security_incident_service,
conversation_context_service,
context_snippet_service,
project_state_service,
decision_log_service,
)
__all__ = [
"machine_service",
"client_service",
"site_service",
"network_service",
"tag_service",
"service_service",
"infrastructure_service",
"credential_service",
"credential_audit_log_service",
"security_incident_service",
"conversation_context_service",
"context_snippet_service",
"project_state_service",
"decision_log_service",
]

View File

@@ -0,0 +1,407 @@
"""
BillableTime service layer for business logic and database operations.
This module handles all database operations for billable time entries, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.billable_time import BillableTime as BillableTimeModel
from api.models.client import Client
from api.models.session import Session as SessionModel
from api.models.work_item import WorkItem
from api.schemas.billable_time import BillableTimeCreate, BillableTimeUpdate
def get_billable_time_entries(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[BillableTimeModel], int]:
"""
Retrieve a paginated list of billable time entries.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of billable time entries, total count)
Example:
```python
entries, total = get_billable_time_entries(db, skip=0, limit=50)
print(f"Retrieved {len(entries)} of {total} billable time entries")
```
"""
# Get total count
total = db.query(BillableTimeModel).count()
# Get paginated results, ordered by start_time descending (newest first)
entries = (
db.query(BillableTimeModel)
.order_by(BillableTimeModel.start_time.desc())
.offset(skip)
.limit(limit)
.all()
)
return entries, total
def get_billable_time_by_id(db: Session, billable_time_id: UUID) -> BillableTimeModel:
"""
Retrieve a single billable time entry by its ID.
Args:
db: Database session
billable_time_id: UUID of the billable time entry to retrieve
Returns:
BillableTimeModel: The billable time entry object
Raises:
HTTPException: 404 if billable time entry not found
Example:
```python
entry = get_billable_time_by_id(db, billable_time_id)
print(f"Found entry: {entry.description}")
```
"""
entry = db.query(BillableTimeModel).filter(BillableTimeModel.id == str(billable_time_id)).first()
if not entry:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Billable time entry with ID {billable_time_id} not found"
)
return entry
def get_billable_time_by_session(db: Session, session_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[BillableTimeModel], int]:
"""
Retrieve billable time entries for a specific session.
Args:
db: Database session
session_id: UUID of the session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of billable time entries, total count)
Example:
```python
entries, total = get_billable_time_by_session(db, session_id)
print(f"Found {total} billable time entries for session")
```
"""
# Get total count
total = db.query(BillableTimeModel).filter(BillableTimeModel.session_id == str(session_id)).count()
# Get paginated results
entries = (
db.query(BillableTimeModel)
.filter(BillableTimeModel.session_id == str(session_id))
.order_by(BillableTimeModel.start_time.desc())
.offset(skip)
.limit(limit)
.all()
)
return entries, total
def get_billable_time_by_work_item(db: Session, work_item_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[BillableTimeModel], int]:
"""
Retrieve billable time entries for a specific work item.
Args:
db: Database session
work_item_id: UUID of the work item
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of billable time entries, total count)
Example:
```python
entries, total = get_billable_time_by_work_item(db, work_item_id)
print(f"Found {total} billable time entries for work item")
```
"""
# Get total count
total = db.query(BillableTimeModel).filter(BillableTimeModel.work_item_id == str(work_item_id)).count()
# Get paginated results
entries = (
db.query(BillableTimeModel)
.filter(BillableTimeModel.work_item_id == str(work_item_id))
.order_by(BillableTimeModel.start_time.desc())
.offset(skip)
.limit(limit)
.all()
)
return entries, total
def create_billable_time(db: Session, billable_time_data: BillableTimeCreate) -> BillableTimeModel:
"""
Create a new billable time entry.
Args:
db: Database session
billable_time_data: Billable time creation data
Returns:
BillableTimeModel: The created billable time entry object
Raises:
HTTPException: 404 if referenced client, session, or work item not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
entry_data = BillableTimeCreate(
client_id="123e4567-e89b-12d3-a456-426614174000",
start_time=datetime.now(),
duration_minutes=60,
hourly_rate=150.00,
total_amount=150.00,
description="Database optimization",
category="development"
)
entry = create_billable_time(db, entry_data)
print(f"Created billable time entry: {entry.id}")
```
"""
try:
# Validate foreign keys
# Client is required
client = db.query(Client).filter(Client.id == str(billable_time_data.client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {billable_time_data.client_id} not found"
)
# Session is optional
if billable_time_data.session_id:
session = db.query(SessionModel).filter(SessionModel.id == str(billable_time_data.session_id)).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {billable_time_data.session_id} not found"
)
# Work item is optional
if billable_time_data.work_item_id:
work_item = db.query(WorkItem).filter(WorkItem.id == str(billable_time_data.work_item_id)).first()
if not work_item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Work item with ID {billable_time_data.work_item_id} not found"
)
# Create new billable time entry instance
db_billable_time = BillableTimeModel(**billable_time_data.model_dump())
# Add to database
db.add(db_billable_time)
db.commit()
db.refresh(db_billable_time)
return db_billable_time
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
# Handle foreign key constraint violations
if "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid client_id: {billable_time_data.client_id}"
)
elif "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid session_id: {billable_time_data.session_id}"
)
elif "work_item_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid work_item_id: {billable_time_data.work_item_id}"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create billable time entry: {str(e)}"
)
def update_billable_time(db: Session, billable_time_id: UUID, billable_time_data: BillableTimeUpdate) -> BillableTimeModel:
"""
Update an existing billable time entry.
Args:
db: Database session
billable_time_id: UUID of the billable time entry to update
billable_time_data: Billable time update data (only provided fields will be updated)
Returns:
BillableTimeModel: The updated billable time entry object
Raises:
HTTPException: 404 if billable time entry, client, session, or work item not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
update_data = BillableTimeUpdate(
duration_minutes=90,
total_amount=225.00
)
entry = update_billable_time(db, billable_time_id, update_data)
print(f"Updated billable time entry: {entry.description}")
```
"""
# Get existing billable time entry
entry = get_billable_time_by_id(db, billable_time_id)
try:
# Update only provided fields
update_data = billable_time_data.model_dump(exclude_unset=True)
# Validate foreign keys if being updated
if "client_id" in update_data and update_data["client_id"]:
client = db.query(Client).filter(Client.id == str(update_data["client_id"])).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {update_data['client_id']} not found"
)
if "session_id" in update_data and update_data["session_id"]:
session = db.query(SessionModel).filter(SessionModel.id == str(update_data["session_id"])).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {update_data['session_id']} not found"
)
if "work_item_id" in update_data and update_data["work_item_id"]:
work_item = db.query(WorkItem).filter(WorkItem.id == str(update_data["work_item_id"])).first()
if not work_item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Work item with ID {update_data['work_item_id']} not found"
)
# Validate end_time if being updated along with start_time
if "end_time" in update_data and update_data["end_time"]:
start_time = update_data.get("start_time", entry.start_time)
if update_data["end_time"] < start_time:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="end_time must be after start_time"
)
# Apply updates
for field, value in update_data.items():
setattr(entry, field, value)
db.commit()
db.refresh(entry)
return entry
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid client_id"
)
elif "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid session_id"
)
elif "work_item_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid work_item_id"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update billable time entry: {str(e)}"
)
def delete_billable_time(db: Session, billable_time_id: UUID) -> dict:
"""
Delete a billable time entry by its ID.
Args:
db: Database session
billable_time_id: UUID of the billable time entry to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if billable time entry not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_billable_time(db, billable_time_id)
print(result["message"]) # "Billable time entry deleted successfully"
```
"""
# Get existing billable time entry (raises 404 if not found)
entry = get_billable_time_by_id(db, billable_time_id)
try:
db.delete(entry)
db.commit()
return {
"message": "Billable time entry deleted successfully",
"billable_time_id": str(billable_time_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete billable time entry: {str(e)}"
)

View File

@@ -0,0 +1,283 @@
"""
Client service layer for business logic and database operations.
This module handles all database operations for clients, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.client import Client
from api.schemas.client import ClientCreate, ClientUpdate
def get_clients(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Client], int]:
"""
Retrieve a paginated list of clients.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of clients, total count)
Example:
```python
clients, total = get_clients(db, skip=0, limit=50)
print(f"Retrieved {len(clients)} of {total} clients")
```
"""
# Get total count
total = db.query(Client).count()
# Get paginated results, ordered by created_at descending (newest first)
clients = (
db.query(Client)
.order_by(Client.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return clients, total
def get_client_by_id(db: Session, client_id: UUID) -> Client:
"""
Retrieve a single client by its ID.
Args:
db: Database session
client_id: UUID of the client to retrieve
Returns:
Client: The client object
Raises:
HTTPException: 404 if client not found
Example:
```python
client = get_client_by_id(db, client_id)
print(f"Found client: {client.name}")
```
"""
client = db.query(Client).filter(Client.id == str(client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {client_id} not found"
)
return client
def get_client_by_name(db: Session, name: str) -> Optional[Client]:
"""
Retrieve a client by its name.
Args:
db: Database session
name: Client name to search for
Returns:
Optional[Client]: The client if found, None otherwise
Example:
```python
client = get_client_by_name(db, "Acme Corporation")
if client:
print(f"Found client: {client.type}")
```
"""
return db.query(Client).filter(Client.name == name).first()
def create_client(db: Session, client_data: ClientCreate) -> Client:
"""
Create a new client.
Args:
db: Database session
client_data: Client creation data
Returns:
Client: The created client object
Raises:
HTTPException: 409 if client with name already exists
HTTPException: 500 if database error occurs
Example:
```python
client_data = ClientCreate(
name="Acme Corporation",
type="msp_client",
primary_contact="John Doe"
)
client = create_client(db, client_data)
print(f"Created client: {client.id}")
```
"""
# Check if client with name already exists
existing_client = get_client_by_name(db, client_data.name)
if existing_client:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Client with name '{client_data.name}' already exists"
)
try:
# Create new client instance
db_client = Client(**client_data.model_dump())
# Add to database
db.add(db_client)
db.commit()
db.refresh(db_client)
return db_client
except IntegrityError as e:
db.rollback()
# Handle unique constraint violations
if "name" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Client with name '{client_data.name}' already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create client: {str(e)}"
)
def update_client(db: Session, client_id: UUID, client_data: ClientUpdate) -> Client:
"""
Update an existing client.
Args:
db: Database session
client_id: UUID of the client to update
client_data: Client update data (only provided fields will be updated)
Returns:
Client: The updated client object
Raises:
HTTPException: 404 if client not found
HTTPException: 409 if update would violate unique constraints
HTTPException: 500 if database error occurs
Example:
```python
update_data = ClientUpdate(
primary_contact="Jane Smith",
is_active=False
)
client = update_client(db, client_id, update_data)
print(f"Updated client: {client.name}")
```
"""
# Get existing client
client = get_client_by_id(db, client_id)
try:
# Update only provided fields
update_data = client_data.model_dump(exclude_unset=True)
# If updating name, check if new name is already taken
if "name" in update_data and update_data["name"] != client.name:
existing = get_client_by_name(db, update_data["name"])
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Client with name '{update_data['name']}' already exists"
)
# Apply updates
for field, value in update_data.items():
setattr(client, field, value)
db.commit()
db.refresh(client)
return client
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "name" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Client with this name already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update client: {str(e)}"
)
def delete_client(db: Session, client_id: UUID) -> dict:
"""
Delete a client by its ID.
Args:
db: Database session
client_id: UUID of the client to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if client not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_client(db, client_id)
print(result["message"]) # "Client deleted successfully"
```
"""
# Get existing client (raises 404 if not found)
client = get_client_by_id(db, client_id)
try:
db.delete(client)
db.commit()
return {
"message": "Client deleted successfully",
"client_id": str(client_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete client: {str(e)}"
)

View File

@@ -0,0 +1,367 @@
"""
ContextSnippet service layer for business logic and database operations.
Handles all database operations for context snippets, providing reusable
knowledge storage and retrieval.
"""
import json
from typing import List, Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.context_snippet import ContextSnippet
from api.schemas.context_snippet import ContextSnippetCreate, ContextSnippetUpdate
def get_context_snippets(
db: Session,
skip: int = 0,
limit: int = 100
) -> tuple[list[ContextSnippet], int]:
"""
Retrieve a paginated list of context snippets.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of context snippets, total count)
"""
# Get total count
total = db.query(ContextSnippet).count()
# Get paginated results, ordered by relevance and usage
snippets = (
db.query(ContextSnippet)
.order_by(ContextSnippet.relevance_score.desc(), ContextSnippet.usage_count.desc())
.offset(skip)
.limit(limit)
.all()
)
return snippets, total
def get_context_snippet_by_id(db: Session, snippet_id: UUID) -> ContextSnippet:
"""
Retrieve a single context snippet by its ID.
Automatically increments usage_count when snippet is retrieved.
Args:
db: Database session
snippet_id: UUID of the context snippet to retrieve
Returns:
ContextSnippet: The context snippet object
Raises:
HTTPException: 404 if context snippet not found
"""
snippet = db.query(ContextSnippet).filter(ContextSnippet.id == str(snippet_id)).first()
if not snippet:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"ContextSnippet with ID {snippet_id} not found"
)
# Increment usage count
snippet.usage_count += 1
db.commit()
db.refresh(snippet)
return snippet
def get_context_snippets_by_project(
db: Session,
project_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[ContextSnippet], int]:
"""
Retrieve context snippets for a specific project.
Args:
db: Database session
project_id: UUID of the project
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of context snippets, total count)
"""
# Get total count for project
total = db.query(ContextSnippet).filter(
ContextSnippet.project_id == str(project_id)
).count()
# Get paginated results
snippets = (
db.query(ContextSnippet)
.filter(ContextSnippet.project_id == str(project_id))
.order_by(ContextSnippet.relevance_score.desc())
.offset(skip)
.limit(limit)
.all()
)
return snippets, total
def get_context_snippets_by_client(
db: Session,
client_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[ContextSnippet], int]:
"""
Retrieve context snippets for a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of context snippets, total count)
"""
# Get total count for client
total = db.query(ContextSnippet).filter(
ContextSnippet.client_id == str(client_id)
).count()
# Get paginated results
snippets = (
db.query(ContextSnippet)
.filter(ContextSnippet.client_id == str(client_id))
.order_by(ContextSnippet.relevance_score.desc())
.offset(skip)
.limit(limit)
.all()
)
return snippets, total
def get_context_snippets_by_tags(
db: Session,
tags: List[str],
skip: int = 0,
limit: int = 100
) -> tuple[list[ContextSnippet], int]:
"""
Retrieve context snippets filtered by tags.
Args:
db: Database session
tags: List of tags to filter by (OR logic - any tag matches)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of context snippets, total count)
"""
# Build tag filters
tag_filters = []
for tag in tags:
tag_filters.append(ContextSnippet.tags.contains(f'"{tag}"'))
# Get total count
if tag_filters:
total = db.query(ContextSnippet).filter(or_(*tag_filters)).count()
else:
total = 0
# Get paginated results
if tag_filters:
snippets = (
db.query(ContextSnippet)
.filter(or_(*tag_filters))
.order_by(ContextSnippet.relevance_score.desc())
.offset(skip)
.limit(limit)
.all()
)
else:
snippets = []
return snippets, total
def get_top_relevant_snippets(
db: Session,
limit: int = 10,
min_relevance_score: float = 7.0
) -> list[ContextSnippet]:
"""
Get the top most relevant context snippets.
Args:
db: Database session
limit: Maximum number of snippets to return (default 10)
min_relevance_score: Minimum relevance score threshold (default 7.0)
Returns:
list: Top relevant context snippets
"""
snippets = (
db.query(ContextSnippet)
.filter(ContextSnippet.relevance_score >= min_relevance_score)
.order_by(ContextSnippet.relevance_score.desc())
.limit(limit)
.all()
)
return snippets
def create_context_snippet(
db: Session,
snippet_data: ContextSnippetCreate
) -> ContextSnippet:
"""
Create a new context snippet.
Args:
db: Database session
snippet_data: Context snippet creation data
Returns:
ContextSnippet: The created context snippet object
Raises:
HTTPException: 500 if database error occurs
"""
try:
# Create new context snippet instance
db_snippet = ContextSnippet(**snippet_data.model_dump())
# Add to database
db.add(db_snippet)
db.commit()
db.refresh(db_snippet)
return db_snippet
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create context snippet: {str(e)}"
)
def update_context_snippet(
db: Session,
snippet_id: UUID,
snippet_data: ContextSnippetUpdate
) -> ContextSnippet:
"""
Update an existing context snippet.
Args:
db: Database session
snippet_id: UUID of the context snippet to update
snippet_data: Context snippet update data
Returns:
ContextSnippet: The updated context snippet object
Raises:
HTTPException: 404 if context snippet not found
HTTPException: 500 if database error occurs
"""
# Get existing snippet (without incrementing usage count)
snippet = db.query(ContextSnippet).filter(ContextSnippet.id == str(snippet_id)).first()
if not snippet:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"ContextSnippet with ID {snippet_id} not found"
)
try:
# Update only provided fields
update_data = snippet_data.model_dump(exclude_unset=True)
# Apply updates
for field, value in update_data.items():
setattr(snippet, field, value)
db.commit()
db.refresh(snippet)
return snippet
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update context snippet: {str(e)}"
)
def delete_context_snippet(db: Session, snippet_id: UUID) -> dict:
"""
Delete a context snippet by its ID.
Args:
db: Database session
snippet_id: UUID of the context snippet to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if context snippet not found
HTTPException: 500 if database error occurs
"""
# Get existing snippet (without incrementing usage count)
snippet = db.query(ContextSnippet).filter(ContextSnippet.id == str(snippet_id)).first()
if not snippet:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"ContextSnippet with ID {snippet_id} not found"
)
try:
db.delete(snippet)
db.commit()
return {
"message": "ContextSnippet deleted successfully",
"snippet_id": str(snippet_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete context snippet: {str(e)}"
)

View File

@@ -0,0 +1,340 @@
"""
ConversationContext service layer for business logic and database operations.
Handles all database operations for conversation contexts, providing context
recall and retrieval functionality for Claude's memory system.
"""
import json
from typing import List, Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.conversation_context import ConversationContext
from api.schemas.conversation_context import ConversationContextCreate, ConversationContextUpdate
from api.utils.context_compression import format_for_injection
def get_conversation_contexts(
db: Session,
skip: int = 0,
limit: int = 100
) -> tuple[list[ConversationContext], int]:
"""
Retrieve a paginated list of conversation contexts.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of conversation contexts, total count)
"""
# Get total count
total = db.query(ConversationContext).count()
# Get paginated results, ordered by relevance and recency
contexts = (
db.query(ConversationContext)
.order_by(ConversationContext.relevance_score.desc(), ConversationContext.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return contexts, total
def get_conversation_context_by_id(db: Session, context_id: UUID) -> ConversationContext:
"""
Retrieve a single conversation context by its ID.
Args:
db: Database session
context_id: UUID of the conversation context to retrieve
Returns:
ConversationContext: The conversation context object
Raises:
HTTPException: 404 if conversation context not found
"""
context = db.query(ConversationContext).filter(ConversationContext.id == str(context_id)).first()
if not context:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"ConversationContext with ID {context_id} not found"
)
return context
def get_conversation_contexts_by_project(
db: Session,
project_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[ConversationContext], int]:
"""
Retrieve conversation contexts for a specific project.
Args:
db: Database session
project_id: UUID of the project
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of conversation contexts, total count)
"""
# Get total count for project
total = db.query(ConversationContext).filter(
ConversationContext.project_id == str(project_id)
).count()
# Get paginated results
contexts = (
db.query(ConversationContext)
.filter(ConversationContext.project_id == str(project_id))
.order_by(ConversationContext.relevance_score.desc(), ConversationContext.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return contexts, total
def get_conversation_contexts_by_session(
db: Session,
session_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[ConversationContext], int]:
"""
Retrieve conversation contexts for a specific session.
Args:
db: Database session
session_id: UUID of the session
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of conversation contexts, total count)
"""
# Get total count for session
total = db.query(ConversationContext).filter(
ConversationContext.session_id == str(session_id)
).count()
# Get paginated results
contexts = (
db.query(ConversationContext)
.filter(ConversationContext.session_id == str(session_id))
.order_by(ConversationContext.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return contexts, total
def get_recall_context(
db: Session,
project_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
limit: int = 10,
min_relevance_score: float = 5.0
) -> str:
"""
Get relevant contexts formatted for Claude prompt injection.
This is the main context recall function that retrieves the most relevant
contexts and formats them for efficient injection into Claude's prompt.
Args:
db: Database session
project_id: Optional project ID to filter by
tags: Optional list of tags to filter by
limit: Maximum number of contexts to retrieve (default 10)
min_relevance_score: Minimum relevance score threshold (default 5.0)
Returns:
str: Token-efficient markdown string ready for prompt injection
"""
# Build query
query = db.query(ConversationContext)
# Filter by project if specified
if project_id:
query = query.filter(ConversationContext.project_id == str(project_id))
# Filter by minimum relevance score
query = query.filter(ConversationContext.relevance_score >= min_relevance_score)
# Filter by tags if specified
if tags:
# Check if any of the provided tags exist in the JSON tags field
# This uses PostgreSQL's JSON operators
tag_filters = []
for tag in tags:
tag_filters.append(ConversationContext.tags.contains(f'"{tag}"'))
if tag_filters:
query = query.filter(or_(*tag_filters))
# Order by relevance score and get top results
contexts = query.order_by(
ConversationContext.relevance_score.desc()
).limit(limit).all()
# Convert to dictionary format for formatting
context_dicts = []
for ctx in contexts:
context_dict = {
"content": ctx.dense_summary or ctx.title,
"type": ctx.context_type,
"tags": json.loads(ctx.tags) if ctx.tags else [],
"relevance_score": ctx.relevance_score
}
context_dicts.append(context_dict)
# Use compression utility to format for injection
return format_for_injection(context_dicts)
def create_conversation_context(
db: Session,
context_data: ConversationContextCreate
) -> ConversationContext:
"""
Create a new conversation context.
Args:
db: Database session
context_data: Conversation context creation data
Returns:
ConversationContext: The created conversation context object
Raises:
HTTPException: 500 if database error occurs
"""
try:
# Create new conversation context instance
db_context = ConversationContext(**context_data.model_dump())
# Add to database
db.add(db_context)
db.commit()
db.refresh(db_context)
return db_context
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create conversation context: {str(e)}"
)
def update_conversation_context(
db: Session,
context_id: UUID,
context_data: ConversationContextUpdate
) -> ConversationContext:
"""
Update an existing conversation context.
Args:
db: Database session
context_id: UUID of the conversation context to update
context_data: Conversation context update data
Returns:
ConversationContext: The updated conversation context object
Raises:
HTTPException: 404 if conversation context not found
HTTPException: 500 if database error occurs
"""
# Get existing context
context = get_conversation_context_by_id(db, context_id)
try:
# Update only provided fields
update_data = context_data.model_dump(exclude_unset=True)
# Apply updates
for field, value in update_data.items():
setattr(context, field, value)
db.commit()
db.refresh(context)
return context
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update conversation context: {str(e)}"
)
def delete_conversation_context(db: Session, context_id: UUID) -> dict:
"""
Delete a conversation context by its ID.
Args:
db: Database session
context_id: UUID of the conversation context to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if conversation context not found
HTTPException: 500 if database error occurs
"""
# Get existing context (raises 404 if not found)
context = get_conversation_context_by_id(db, context_id)
try:
db.delete(context)
db.commit()
return {
"message": "ConversationContext deleted successfully",
"context_id": str(context_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete conversation context: {str(e)}"
)

View File

@@ -0,0 +1,164 @@
"""
Credential audit log service layer for business logic and database operations.
This module handles read-only operations for credential audit logs.
"""
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from api.models.credential_audit_log import CredentialAuditLog
def get_credential_audit_logs(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[CredentialAuditLog], int]:
"""
Retrieve a paginated list of credential audit logs.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of audit logs, total count)
Example:
```python
logs, total = get_credential_audit_logs(db, skip=0, limit=50)
print(f"Retrieved {len(logs)} of {total} audit logs")
```
"""
# Get total count
total = db.query(CredentialAuditLog).count()
# Get paginated results, ordered by timestamp descending (newest first)
logs = (
db.query(CredentialAuditLog)
.order_by(CredentialAuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def get_credential_audit_log_by_id(db: Session, log_id: UUID) -> CredentialAuditLog:
"""
Retrieve a single credential audit log by its ID.
Args:
db: Database session
log_id: UUID of the audit log to retrieve
Returns:
CredentialAuditLog: The audit log object
Raises:
HTTPException: 404 if audit log not found
Example:
```python
log = get_credential_audit_log_by_id(db, log_id)
print(f"Found audit log: {log.action} by {log.user_id}")
```
"""
log = db.query(CredentialAuditLog).filter(CredentialAuditLog.id == str(log_id)).first()
if not log:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Credential audit log with ID {log_id} not found"
)
return log
def get_credential_audit_logs_by_credential(
db: Session,
credential_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[CredentialAuditLog], int]:
"""
Retrieve audit logs for a specific credential.
Args:
db: Database session
credential_id: UUID of the credential
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of audit logs, total count)
Example:
```python
logs, total = get_credential_audit_logs_by_credential(db, credential_id, skip=0, limit=50)
print(f"Credential has {total} audit log entries")
```
"""
# Get total count for this credential
total = (
db.query(CredentialAuditLog)
.filter(CredentialAuditLog.credential_id == str(credential_id))
.count()
)
# Get paginated results
logs = (
db.query(CredentialAuditLog)
.filter(CredentialAuditLog.credential_id == str(credential_id))
.order_by(CredentialAuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def get_credential_audit_logs_by_user(
db: Session,
user_id: str,
skip: int = 0,
limit: int = 100
) -> tuple[list[CredentialAuditLog], int]:
"""
Retrieve audit logs for a specific user.
Args:
db: Database session
user_id: User ID to filter by
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of audit logs, total count)
Example:
```python
logs, total = get_credential_audit_logs_by_user(db, "user123", skip=0, limit=50)
print(f"User has {total} audit log entries")
```
"""
# Get total count for this user
total = (
db.query(CredentialAuditLog)
.filter(CredentialAuditLog.user_id == user_id)
.count()
)
# Get paginated results
logs = (
db.query(CredentialAuditLog)
.filter(CredentialAuditLog.user_id == user_id)
.order_by(CredentialAuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total

View File

@@ -0,0 +1,493 @@
"""
Credential service layer for business logic and database operations.
This module handles all database operations for credentials with encryption,
providing secure storage and retrieval of sensitive authentication data.
"""
import json
from datetime import datetime
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.credential import Credential
from api.models.credential_audit_log import CredentialAuditLog
from api.schemas.credential import CredentialCreate, CredentialUpdate
from api.utils.crypto import encrypt_string, decrypt_string
def _create_audit_log(
db: Session,
credential_id: str,
action: str,
user_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[dict] = None,
) -> None:
"""
Create an audit log entry for credential operations.
Args:
db: Database session
credential_id: ID of the credential being accessed
action: Action performed (view, create, update, delete, rotate, decrypt)
user_id: User performing the action
ip_address: Optional IP address of the user
user_agent: Optional user agent string
details: Optional dictionary with additional context (will be JSON serialized)
Note:
This is an internal helper function. Never log decrypted passwords.
"""
try:
audit_entry = CredentialAuditLog(
credential_id=credential_id,
action=action,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
details=json.dumps(details) if details else None,
)
db.add(audit_entry)
db.commit()
except Exception as e:
# Log but don't fail the operation if audit logging fails
db.rollback()
print(f"Warning: Failed to create audit log: {str(e)}")
def get_credentials(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Credential], int]:
"""
Retrieve a paginated list of credentials.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of credentials, total count)
Example:
```python
credentials, total = get_credentials(db, skip=0, limit=50)
print(f"Retrieved {len(credentials)} of {total} credentials")
```
"""
# Get total count
total = db.query(Credential).count()
# Get paginated results, ordered by created_at descending (newest first)
credentials = (
db.query(Credential)
.order_by(Credential.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return credentials, total
def get_credential_by_id(db: Session, credential_id: UUID, user_id: Optional[str] = None) -> Credential:
"""
Retrieve a single credential by its ID.
Args:
db: Database session
credential_id: UUID of the credential to retrieve
user_id: Optional user ID for audit logging
Returns:
Credential: The credential object
Raises:
HTTPException: 404 if credential not found
Example:
```python
credential = get_credential_by_id(db, credential_id, user_id="user123")
print(f"Found credential: {credential.service_name}")
```
"""
credential = db.query(Credential).filter(Credential.id == str(credential_id)).first()
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Credential with ID {credential_id} not found"
)
# Create audit log for view action
if user_id:
_create_audit_log(
db=db,
credential_id=str(credential_id),
action="view",
user_id=user_id,
details={"service_name": credential.service_name}
)
return credential
def get_credentials_by_client(
db: Session,
client_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[Credential], int]:
"""
Retrieve credentials for a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of credentials, total count)
Example:
```python
credentials, total = get_credentials_by_client(db, client_id, skip=0, limit=50)
print(f"Client has {total} credentials")
```
"""
# Get total count for this client
total = db.query(Credential).filter(Credential.client_id == str(client_id)).count()
# Get paginated results
credentials = (
db.query(Credential)
.filter(Credential.client_id == str(client_id))
.order_by(Credential.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return credentials, total
def create_credential(
db: Session,
credential_data: CredentialCreate,
user_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> Credential:
"""
Create a new credential with encryption.
Args:
db: Database session
credential_data: Credential creation data
user_id: User creating the credential
ip_address: Optional IP address
user_agent: Optional user agent string
Returns:
Credential: The created credential object
Raises:
HTTPException: 500 if database error occurs
Example:
```python
credential_data = CredentialCreate(
service_name="Gitea Admin",
credential_type="password",
username="admin",
password="SecurePassword123!"
)
credential = create_credential(db, credential_data, user_id="user123")
print(f"Created credential: {credential.id}")
```
Security:
All sensitive fields (password, api_key, etc.) are encrypted before storage.
"""
try:
# Convert Pydantic model to dict, excluding unset values
data = credential_data.model_dump(exclude_unset=True)
# Encrypt sensitive fields if present
if "password" in data and data["password"]:
encrypted_password = encrypt_string(data["password"])
data["password_encrypted"] = encrypted_password.encode('utf-8')
del data["password"]
if "api_key" in data and data["api_key"]:
encrypted_api_key = encrypt_string(data["api_key"])
data["api_key_encrypted"] = encrypted_api_key.encode('utf-8')
del data["api_key"]
if "client_secret" in data and data["client_secret"]:
encrypted_secret = encrypt_string(data["client_secret"])
data["client_secret_encrypted"] = encrypted_secret.encode('utf-8')
del data["client_secret"]
if "token" in data and data["token"]:
encrypted_token = encrypt_string(data["token"])
data["token_encrypted"] = encrypted_token.encode('utf-8')
del data["token"]
if "connection_string" in data and data["connection_string"]:
encrypted_conn = encrypt_string(data["connection_string"])
data["connection_string_encrypted"] = encrypted_conn.encode('utf-8')
del data["connection_string"]
# Convert UUID fields to strings
if "client_id" in data and data["client_id"]:
data["client_id"] = str(data["client_id"])
if "service_id" in data and data["service_id"]:
data["service_id"] = str(data["service_id"])
if "infrastructure_id" in data and data["infrastructure_id"]:
data["infrastructure_id"] = str(data["infrastructure_id"])
# Create new credential instance
db_credential = Credential(**data)
# Add to database
db.add(db_credential)
db.commit()
db.refresh(db_credential)
# Create audit log
_create_audit_log(
db=db,
credential_id=str(db_credential.id),
action="create",
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
details={
"service_name": db_credential.service_name,
"credential_type": db_credential.credential_type
}
)
return db_credential
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database integrity error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create credential: {str(e)}"
)
def update_credential(
db: Session,
credential_id: UUID,
credential_data: CredentialUpdate,
user_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> Credential:
"""
Update an existing credential with re-encryption if needed.
Args:
db: Database session
credential_id: UUID of the credential to update
credential_data: Credential update data (only provided fields will be updated)
user_id: User updating the credential
ip_address: Optional IP address
user_agent: Optional user agent string
Returns:
Credential: The updated credential object
Raises:
HTTPException: 404 if credential not found
HTTPException: 500 if database error occurs
Example:
```python
update_data = CredentialUpdate(
password="NewSecurePassword456!",
last_rotated_at=datetime.utcnow()
)
credential = update_credential(db, credential_id, update_data, user_id="user123")
print(f"Updated credential: {credential.service_name}")
```
Security:
If sensitive fields are updated, they are re-encrypted before storage.
"""
# Get existing credential
credential = get_credential_by_id(db, credential_id)
try:
# Update only provided fields
update_data = credential_data.model_dump(exclude_unset=True)
changed_fields = []
# Track what changed for audit log
for field in update_data.keys():
if field not in ["password", "api_key", "client_secret", "token", "connection_string"]:
changed_fields.append(field)
# Encrypt sensitive fields if present in update
if "password" in update_data and update_data["password"]:
encrypted_password = encrypt_string(update_data["password"])
update_data["password_encrypted"] = encrypted_password.encode('utf-8')
del update_data["password"]
changed_fields.append("password")
if "api_key" in update_data and update_data["api_key"]:
encrypted_api_key = encrypt_string(update_data["api_key"])
update_data["api_key_encrypted"] = encrypted_api_key.encode('utf-8')
del update_data["api_key"]
changed_fields.append("api_key")
if "client_secret" in update_data and update_data["client_secret"]:
encrypted_secret = encrypt_string(update_data["client_secret"])
update_data["client_secret_encrypted"] = encrypted_secret.encode('utf-8')
del update_data["client_secret"]
changed_fields.append("client_secret")
if "token" in update_data and update_data["token"]:
encrypted_token = encrypt_string(update_data["token"])
update_data["token_encrypted"] = encrypted_token.encode('utf-8')
del update_data["token"]
changed_fields.append("token")
if "connection_string" in update_data and update_data["connection_string"]:
encrypted_conn = encrypt_string(update_data["connection_string"])
update_data["connection_string_encrypted"] = encrypted_conn.encode('utf-8')
del update_data["connection_string"]
changed_fields.append("connection_string")
# Convert UUID fields to strings
if "client_id" in update_data and update_data["client_id"]:
update_data["client_id"] = str(update_data["client_id"])
if "service_id" in update_data and update_data["service_id"]:
update_data["service_id"] = str(update_data["service_id"])
if "infrastructure_id" in update_data and update_data["infrastructure_id"]:
update_data["infrastructure_id"] = str(update_data["infrastructure_id"])
# Apply updates
for field, value in update_data.items():
setattr(credential, field, value)
db.commit()
db.refresh(credential)
# Create audit log
_create_audit_log(
db=db,
credential_id=str(credential_id),
action="update",
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
details={
"changed_fields": changed_fields,
"service_name": credential.service_name
}
)
return credential
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database integrity error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update credential: {str(e)}"
)
def delete_credential(
db: Session,
credential_id: UUID,
user_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> dict:
"""
Delete a credential by its ID.
Args:
db: Database session
credential_id: UUID of the credential to delete
user_id: User deleting the credential
ip_address: Optional IP address
user_agent: Optional user agent string
Returns:
dict: Success message
Raises:
HTTPException: 404 if credential not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_credential(db, credential_id, user_id="user123")
print(result["message"]) # "Credential deleted successfully"
```
Security:
Deletion is audited. The audit log is retained even after credential deletion
due to CASCADE delete behavior on the credential_audit_log table.
"""
# Get existing credential (raises 404 if not found)
credential = get_credential_by_id(db, credential_id)
# Store info for audit log before deletion
service_name = credential.service_name
credential_type = credential.credential_type
try:
# Create audit log BEFORE deletion
_create_audit_log(
db=db,
credential_id=str(credential_id),
action="delete",
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
details={
"service_name": service_name,
"credential_type": credential_type
}
)
db.delete(credential)
db.commit()
return {
"message": "Credential deleted successfully",
"credential_id": str(credential_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete credential: {str(e)}"
)

View File

@@ -0,0 +1,318 @@
"""
DecisionLog service layer for business logic and database operations.
Handles all database operations for decision logs, tracking important
decisions made during work for future reference.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.decision_log import DecisionLog
from api.schemas.decision_log import DecisionLogCreate, DecisionLogUpdate
def get_decision_logs(
db: Session,
skip: int = 0,
limit: int = 100
) -> tuple[list[DecisionLog], int]:
"""
Retrieve a paginated list of decision logs.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of decision logs, total count)
"""
# Get total count
total = db.query(DecisionLog).count()
# Get paginated results, ordered by most recent first
logs = (
db.query(DecisionLog)
.order_by(DecisionLog.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def get_decision_log_by_id(db: Session, log_id: UUID) -> DecisionLog:
"""
Retrieve a single decision log by its ID.
Args:
db: Database session
log_id: UUID of the decision log to retrieve
Returns:
DecisionLog: The decision log object
Raises:
HTTPException: 404 if decision log not found
"""
log = db.query(DecisionLog).filter(DecisionLog.id == str(log_id)).first()
if not log:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"DecisionLog with ID {log_id} not found"
)
return log
def get_decision_logs_by_project(
db: Session,
project_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[DecisionLog], int]:
"""
Retrieve decision logs for a specific project.
Args:
db: Database session
project_id: UUID of the project
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of decision logs, total count)
"""
# Get total count for project
total = db.query(DecisionLog).filter(
DecisionLog.project_id == str(project_id)
).count()
# Get paginated results
logs = (
db.query(DecisionLog)
.filter(DecisionLog.project_id == str(project_id))
.order_by(DecisionLog.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def get_decision_logs_by_session(
db: Session,
session_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[DecisionLog], int]:
"""
Retrieve decision logs for a specific session.
Args:
db: Database session
session_id: UUID of the session
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of decision logs, total count)
"""
# Get total count for session
total = db.query(DecisionLog).filter(
DecisionLog.session_id == str(session_id)
).count()
# Get paginated results
logs = (
db.query(DecisionLog)
.filter(DecisionLog.session_id == str(session_id))
.order_by(DecisionLog.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def get_decision_logs_by_impact(
db: Session,
impact: str,
skip: int = 0,
limit: int = 100
) -> tuple[list[DecisionLog], int]:
"""
Retrieve decision logs filtered by impact level.
Args:
db: Database session
impact: Impact level (low, medium, high, critical)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of decision logs, total count)
"""
# Validate impact level
valid_impacts = ["low", "medium", "high", "critical"]
if impact.lower() not in valid_impacts:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid impact level. Must be one of: {', '.join(valid_impacts)}"
)
# Get total count for impact
total = db.query(DecisionLog).filter(
DecisionLog.impact == impact.lower()
).count()
# Get paginated results
logs = (
db.query(DecisionLog)
.filter(DecisionLog.impact == impact.lower())
.order_by(DecisionLog.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return logs, total
def create_decision_log(
db: Session,
log_data: DecisionLogCreate
) -> DecisionLog:
"""
Create a new decision log.
Args:
db: Database session
log_data: Decision log creation data
Returns:
DecisionLog: The created decision log object
Raises:
HTTPException: 500 if database error occurs
"""
try:
# Create new decision log instance
db_log = DecisionLog(**log_data.model_dump())
# Add to database
db.add(db_log)
db.commit()
db.refresh(db_log)
return db_log
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create decision log: {str(e)}"
)
def update_decision_log(
db: Session,
log_id: UUID,
log_data: DecisionLogUpdate
) -> DecisionLog:
"""
Update an existing decision log.
Args:
db: Database session
log_id: UUID of the decision log to update
log_data: Decision log update data
Returns:
DecisionLog: The updated decision log object
Raises:
HTTPException: 404 if decision log not found
HTTPException: 500 if database error occurs
"""
# Get existing log
log = get_decision_log_by_id(db, log_id)
try:
# Update only provided fields
update_data = log_data.model_dump(exclude_unset=True)
# Apply updates
for field, value in update_data.items():
setattr(log, field, value)
db.commit()
db.refresh(log)
return log
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update decision log: {str(e)}"
)
def delete_decision_log(db: Session, log_id: UUID) -> dict:
"""
Delete a decision log by its ID.
Args:
db: Database session
log_id: UUID of the decision log to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if decision log not found
HTTPException: 500 if database error occurs
"""
# Get existing log (raises 404 if not found)
log = get_decision_log_by_id(db, log_id)
try:
db.delete(log)
db.commit()
return {
"message": "DecisionLog deleted successfully",
"log_id": str(log_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete decision log: {str(e)}"
)

View File

@@ -0,0 +1,367 @@
"""
Firewall rule service layer for business logic and database operations.
This module handles all database operations for firewall rules, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.firewall_rule import FirewallRule
from api.models.infrastructure import Infrastructure
from api.schemas.firewall_rule import FirewallRuleCreate, FirewallRuleUpdate
def get_firewall_rules(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[FirewallRule], int]:
"""
Retrieve a paginated list of firewall rules.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of firewall rules, total count)
Example:
```python
rules, total = get_firewall_rules(db, skip=0, limit=50)
print(f"Retrieved {len(rules)} of {total} firewall rules")
```
"""
# Get total count
total = db.query(FirewallRule).count()
# Get paginated results, ordered by created_at descending (newest first)
rules = (
db.query(FirewallRule)
.order_by(FirewallRule.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return rules, total
def get_firewall_rule_by_id(db: Session, firewall_rule_id: UUID) -> FirewallRule:
"""
Retrieve a single firewall rule by its ID.
Args:
db: Database session
firewall_rule_id: UUID of the firewall rule to retrieve
Returns:
FirewallRule: The firewall rule object
Raises:
HTTPException: 404 if firewall rule not found
Example:
```python
rule = get_firewall_rule_by_id(db, firewall_rule_id)
print(f"Found rule: {rule.rule_name}")
```
"""
rule = db.query(FirewallRule).filter(FirewallRule.id == str(firewall_rule_id)).first()
if not rule:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Firewall rule with ID {firewall_rule_id} not found"
)
return rule
def get_firewall_rules_by_infrastructure(db: Session, infrastructure_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[FirewallRule], int]:
"""
Retrieve firewall rules belonging to a specific infrastructure.
Args:
db: Database session
infrastructure_id: UUID of the infrastructure
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of firewall rules, total count for this infrastructure)
Raises:
HTTPException: 404 if infrastructure not found
Example:
```python
rules, total = get_firewall_rules_by_infrastructure(db, infrastructure_id, skip=0, limit=50)
print(f"Retrieved {len(rules)} of {total} firewall rules for infrastructure")
```
"""
# Verify infrastructure exists
infrastructure = db.query(Infrastructure).filter(Infrastructure.id == str(infrastructure_id)).first()
if not infrastructure:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {infrastructure_id} not found"
)
# Get total count for this infrastructure
total = db.query(FirewallRule).filter(FirewallRule.infrastructure_id == str(infrastructure_id)).count()
# Get paginated results
rules = (
db.query(FirewallRule)
.filter(FirewallRule.infrastructure_id == str(infrastructure_id))
.order_by(FirewallRule.rule_order.asc(), FirewallRule.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return rules, total
def get_firewall_rules_by_action(db: Session, action: str, skip: int = 0, limit: int = 100) -> tuple[list[FirewallRule], int]:
"""
Retrieve firewall rules by action type (allow, deny, drop).
Args:
db: Database session
action: Action type to filter by (allow, deny, drop)
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of firewall rules, total count for this action)
Raises:
HTTPException: 422 if invalid action provided
Example:
```python
rules, total = get_firewall_rules_by_action(db, "allow", skip=0, limit=50)
print(f"Retrieved {len(rules)} of {total} allow rules")
```
"""
# Validate action
valid_actions = ["allow", "deny", "drop"]
if action not in valid_actions:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid action '{action}'. Must be one of: {', '.join(valid_actions)}"
)
# Get total count for this action
total = db.query(FirewallRule).filter(FirewallRule.action == action).count()
# Get paginated results
rules = (
db.query(FirewallRule)
.filter(FirewallRule.action == action)
.order_by(FirewallRule.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return rules, total
def create_firewall_rule(db: Session, firewall_rule_data: FirewallRuleCreate) -> FirewallRule:
"""
Create a new firewall rule.
Args:
db: Database session
firewall_rule_data: Firewall rule creation data
Returns:
FirewallRule: The created firewall rule object
Raises:
HTTPException: 404 if infrastructure not found
HTTPException: 422 if invalid action provided
HTTPException: 500 if database error occurs
Example:
```python
rule_data = FirewallRuleCreate(
infrastructure_id="123e4567-e89b-12d3-a456-426614174000",
rule_name="Allow SSH",
action="allow",
port=22
)
rule = create_firewall_rule(db, rule_data)
print(f"Created firewall rule: {rule.id}")
```
"""
# Verify infrastructure exists if provided
if firewall_rule_data.infrastructure_id:
infrastructure = db.query(Infrastructure).filter(
Infrastructure.id == str(firewall_rule_data.infrastructure_id)
).first()
if not infrastructure:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {firewall_rule_data.infrastructure_id} not found"
)
# Validate action if provided
if firewall_rule_data.action:
valid_actions = ["allow", "deny", "drop"]
if firewall_rule_data.action not in valid_actions:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid action '{firewall_rule_data.action}'. Must be one of: {', '.join(valid_actions)}"
)
try:
# Create new firewall rule instance
db_rule = FirewallRule(**firewall_rule_data.model_dump())
# Add to database
db.add(db_rule)
db.commit()
db.refresh(db_rule)
return db_rule
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create firewall rule: {str(e)}"
)
def update_firewall_rule(db: Session, firewall_rule_id: UUID, firewall_rule_data: FirewallRuleUpdate) -> FirewallRule:
"""
Update an existing firewall rule.
Args:
db: Database session
firewall_rule_id: UUID of the firewall rule to update
firewall_rule_data: Firewall rule update data (only provided fields will be updated)
Returns:
FirewallRule: The updated firewall rule object
Raises:
HTTPException: 404 if firewall rule or infrastructure not found
HTTPException: 422 if invalid action provided
HTTPException: 500 if database error occurs
Example:
```python
update_data = FirewallRuleUpdate(
rule_name="Allow SSH - Updated",
action="deny"
)
rule = update_firewall_rule(db, firewall_rule_id, update_data)
print(f"Updated firewall rule: {rule.rule_name}")
```
"""
# Get existing firewall rule
rule = get_firewall_rule_by_id(db, firewall_rule_id)
try:
# Update only provided fields
update_data = firewall_rule_data.model_dump(exclude_unset=True)
# If updating infrastructure_id, verify new infrastructure exists
if "infrastructure_id" in update_data and update_data["infrastructure_id"]:
infrastructure = db.query(Infrastructure).filter(
Infrastructure.id == str(update_data["infrastructure_id"])
).first()
if not infrastructure:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {update_data['infrastructure_id']} not found"
)
# Validate action if provided
if "action" in update_data and update_data["action"]:
valid_actions = ["allow", "deny", "drop"]
if update_data["action"] not in valid_actions:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid action '{update_data['action']}'. Must be one of: {', '.join(valid_actions)}"
)
# Apply updates
for field, value in update_data.items():
setattr(rule, field, value)
db.commit()
db.refresh(rule)
return rule
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update firewall rule: {str(e)}"
)
def delete_firewall_rule(db: Session, firewall_rule_id: UUID) -> dict:
"""
Delete a firewall rule by its ID.
Args:
db: Database session
firewall_rule_id: UUID of the firewall rule to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if firewall rule not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_firewall_rule(db, firewall_rule_id)
print(result["message"]) # "Firewall rule deleted successfully"
```
"""
# Get existing firewall rule (raises 404 if not found)
rule = get_firewall_rule_by_id(db, firewall_rule_id)
try:
db.delete(rule)
db.commit()
return {
"message": "Firewall rule deleted successfully",
"firewall_rule_id": str(firewall_rule_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete firewall rule: {str(e)}"
)

View File

@@ -0,0 +1,425 @@
"""
Infrastructure service layer for business logic and database operations.
This module handles all database operations for infrastructure assets, providing
a clean separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.infrastructure import Infrastructure
from api.schemas.infrastructure import InfrastructureCreate, InfrastructureUpdate
def get_infrastructure_items(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Infrastructure], int]:
"""
Retrieve a paginated list of infrastructure items.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of infrastructure items, total count)
Example:
```python
items, total = get_infrastructure_items(db, skip=0, limit=50)
print(f"Retrieved {len(items)} of {total} infrastructure items")
```
"""
# Get total count
total = db.query(Infrastructure).count()
# Get paginated results, ordered by created_at descending (newest first)
items = (
db.query(Infrastructure)
.order_by(Infrastructure.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return items, total
def get_infrastructure_by_id(db: Session, infrastructure_id: UUID) -> Infrastructure:
"""
Retrieve a single infrastructure item by its ID.
Args:
db: Database session
infrastructure_id: UUID of the infrastructure item to retrieve
Returns:
Infrastructure: The infrastructure object
Raises:
HTTPException: 404 if infrastructure not found
Example:
```python
item = get_infrastructure_by_id(db, infrastructure_id)
print(f"Found infrastructure: {item.hostname}")
```
"""
item = db.query(Infrastructure).filter(Infrastructure.id == str(infrastructure_id)).first()
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {infrastructure_id} not found"
)
return item
def get_infrastructure_by_site(db: Session, site_id: str, skip: int = 0, limit: int = 100) -> tuple[list[Infrastructure], int]:
"""
Retrieve infrastructure items for a specific site.
Args:
db: Database session
site_id: UUID of the site
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of infrastructure items, total count)
Example:
```python
items, total = get_infrastructure_by_site(db, site_id, skip=0, limit=50)
print(f"Retrieved {len(items)} of {total} items for site")
```
"""
# Get total count for this site
total = db.query(Infrastructure).filter(Infrastructure.site_id == site_id).count()
# Get paginated results
items = (
db.query(Infrastructure)
.filter(Infrastructure.site_id == site_id)
.order_by(Infrastructure.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return items, total
def get_infrastructure_by_client(db: Session, client_id: str, skip: int = 0, limit: int = 100) -> tuple[list[Infrastructure], int]:
"""
Retrieve infrastructure items for a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of infrastructure items, total count)
Example:
```python
items, total = get_infrastructure_by_client(db, client_id, skip=0, limit=50)
print(f"Retrieved {len(items)} of {total} items for client")
```
"""
# Get total count for this client
total = db.query(Infrastructure).filter(Infrastructure.client_id == client_id).count()
# Get paginated results
items = (
db.query(Infrastructure)
.filter(Infrastructure.client_id == client_id)
.order_by(Infrastructure.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return items, total
def get_infrastructure_by_type(db: Session, infra_type: str, skip: int = 0, limit: int = 100) -> tuple[list[Infrastructure], int]:
"""
Retrieve infrastructure items by asset type.
Args:
db: Database session
infra_type: Asset type to filter by
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of infrastructure items, total count)
Example:
```python
items, total = get_infrastructure_by_type(db, "physical_server", skip=0, limit=50)
print(f"Retrieved {len(items)} of {total} physical servers")
```
"""
# Get total count for this type
total = db.query(Infrastructure).filter(Infrastructure.asset_type == infra_type).count()
# Get paginated results
items = (
db.query(Infrastructure)
.filter(Infrastructure.asset_type == infra_type)
.order_by(Infrastructure.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return items, total
def create_infrastructure(db: Session, infrastructure_data: InfrastructureCreate) -> Infrastructure:
"""
Create a new infrastructure item.
Args:
db: Database session
infrastructure_data: Infrastructure creation data
Returns:
Infrastructure: The created infrastructure object
Raises:
HTTPException: 409 if validation fails
HTTPException: 422 if foreign key validation fails
HTTPException: 500 if database error occurs
Example:
```python
infra_data = InfrastructureCreate(
hostname="server-01",
asset_type="physical_server",
client_id="client-uuid"
)
infra = create_infrastructure(db, infra_data)
print(f"Created infrastructure: {infra.id}")
```
"""
# Validate foreign keys if provided
if infrastructure_data.client_id:
from api.models.client import Client
client = db.query(Client).filter(Client.id == infrastructure_data.client_id).first()
if not client:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Client with ID {infrastructure_data.client_id} not found"
)
if infrastructure_data.site_id:
from api.models.site import Site
site = db.query(Site).filter(Site.id == infrastructure_data.site_id).first()
if not site:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Site with ID {infrastructure_data.site_id} not found"
)
if infrastructure_data.parent_host_id:
parent = db.query(Infrastructure).filter(Infrastructure.id == infrastructure_data.parent_host_id).first()
if not parent:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Parent host with ID {infrastructure_data.parent_host_id} not found"
)
try:
# Create new infrastructure instance
db_infrastructure = Infrastructure(**infrastructure_data.model_dump())
# Add to database
db.add(db_infrastructure)
db.commit()
db.refresh(db_infrastructure)
return db_infrastructure
except IntegrityError as e:
db.rollback()
# Handle constraint violations
if "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid client_id: {infrastructure_data.client_id}"
)
elif "site_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid site_id: {infrastructure_data.site_id}"
)
elif "parent_host_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid parent_host_id: {infrastructure_data.parent_host_id}"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create infrastructure: {str(e)}"
)
def update_infrastructure(db: Session, infrastructure_id: UUID, infrastructure_data: InfrastructureUpdate) -> Infrastructure:
"""
Update an existing infrastructure item.
Args:
db: Database session
infrastructure_id: UUID of the infrastructure item to update
infrastructure_data: Infrastructure update data (only provided fields will be updated)
Returns:
Infrastructure: The updated infrastructure object
Raises:
HTTPException: 404 if infrastructure not found
HTTPException: 422 if foreign key validation fails
HTTPException: 500 if database error occurs
Example:
```python
update_data = InfrastructureUpdate(
status="decommissioned",
notes="Server retired"
)
infra = update_infrastructure(db, infrastructure_id, update_data)
print(f"Updated infrastructure: {infra.hostname}")
```
"""
# Get existing infrastructure
infrastructure = get_infrastructure_by_id(db, infrastructure_id)
try:
# Update only provided fields
update_data = infrastructure_data.model_dump(exclude_unset=True)
# Validate foreign keys if being updated
if "client_id" in update_data and update_data["client_id"]:
from api.models.client import Client
client = db.query(Client).filter(Client.id == update_data["client_id"]).first()
if not client:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Client with ID {update_data['client_id']} not found"
)
if "site_id" in update_data and update_data["site_id"]:
from api.models.site import Site
site = db.query(Site).filter(Site.id == update_data["site_id"]).first()
if not site:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Site with ID {update_data['site_id']} not found"
)
if "parent_host_id" in update_data and update_data["parent_host_id"]:
parent = db.query(Infrastructure).filter(Infrastructure.id == update_data["parent_host_id"]).first()
if not parent:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Parent host with ID {update_data['parent_host_id']} not found"
)
# Apply updates
for field, value in update_data.items():
setattr(infrastructure, field, value)
db.commit()
db.refresh(infrastructure)
return infrastructure
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid client_id"
)
elif "site_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid site_id"
)
elif "parent_host_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid parent_host_id"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update infrastructure: {str(e)}"
)
def delete_infrastructure(db: Session, infrastructure_id: UUID) -> dict:
"""
Delete an infrastructure item by its ID.
Args:
db: Database session
infrastructure_id: UUID of the infrastructure item to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if infrastructure not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_infrastructure(db, infrastructure_id)
print(result["message"]) # "Infrastructure deleted successfully"
```
"""
# Get existing infrastructure (raises 404 if not found)
infrastructure = get_infrastructure_by_id(db, infrastructure_id)
try:
db.delete(infrastructure)
db.commit()
return {
"message": "Infrastructure deleted successfully",
"infrastructure_id": str(infrastructure_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete infrastructure: {str(e)}"
)

View File

@@ -0,0 +1,359 @@
"""
M365 Tenant service layer for business logic and database operations.
This module handles all database operations for M365 tenants, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.m365_tenant import M365Tenant
from api.models.client import Client
from api.schemas.m365_tenant import M365TenantCreate, M365TenantUpdate
def get_m365_tenants(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[M365Tenant], int]:
"""
Retrieve a paginated list of M365 tenants.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of M365 tenants, total count)
Example:
```python
tenants, total = get_m365_tenants(db, skip=0, limit=50)
print(f"Retrieved {len(tenants)} of {total} M365 tenants")
```
"""
# Get total count
total = db.query(M365Tenant).count()
# Get paginated results, ordered by created_at descending (newest first)
tenants = (
db.query(M365Tenant)
.order_by(M365Tenant.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return tenants, total
def get_m365_tenant_by_id(db: Session, tenant_id: UUID) -> M365Tenant:
"""
Retrieve a single M365 tenant by its ID.
Args:
db: Database session
tenant_id: UUID of the M365 tenant to retrieve
Returns:
M365Tenant: The M365 tenant object
Raises:
HTTPException: 404 if M365 tenant not found
Example:
```python
tenant = get_m365_tenant_by_id(db, tenant_id)
print(f"Found tenant: {tenant.tenant_name}")
```
"""
tenant = db.query(M365Tenant).filter(M365Tenant.id == str(tenant_id)).first()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"M365 tenant with ID {tenant_id} not found"
)
return tenant
def get_m365_tenant_by_tenant_id(db: Session, tenant_id: str) -> Optional[M365Tenant]:
"""
Retrieve an M365 tenant by its Microsoft tenant ID.
Args:
db: Database session
tenant_id: Microsoft tenant ID to search for
Returns:
Optional[M365Tenant]: The M365 tenant if found, None otherwise
Example:
```python
tenant = get_m365_tenant_by_tenant_id(db, "abc12345-6789-0def-1234-56789abcdef0")
if tenant:
print(f"Found tenant: {tenant.tenant_name}")
```
"""
return db.query(M365Tenant).filter(M365Tenant.tenant_id == tenant_id).first()
def get_m365_tenants_by_client(db: Session, client_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[M365Tenant], int]:
"""
Retrieve M365 tenants for a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of M365 tenants, total count)
Raises:
HTTPException: 404 if client not found
Example:
```python
tenants, total = get_m365_tenants_by_client(db, client_id, skip=0, limit=50)
print(f"Client has {total} M365 tenants")
```
"""
# Verify client exists
client = db.query(Client).filter(Client.id == str(client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {client_id} not found"
)
# Get total count for this client
total = db.query(M365Tenant).filter(M365Tenant.client_id == str(client_id)).count()
# Get paginated results
tenants = (
db.query(M365Tenant)
.filter(M365Tenant.client_id == str(client_id))
.order_by(M365Tenant.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return tenants, total
def create_m365_tenant(db: Session, tenant_data: M365TenantCreate) -> M365Tenant:
"""
Create a new M365 tenant.
Args:
db: Database session
tenant_data: M365 tenant creation data
Returns:
M365Tenant: The created M365 tenant object
Raises:
HTTPException: 404 if client_id provided and client doesn't exist
HTTPException: 409 if M365 tenant with tenant_id already exists
HTTPException: 500 if database error occurs
Example:
```python
tenant_data = M365TenantCreate(
tenant_id="abc12345-6789-0def-1234-56789abcdef0",
tenant_name="dataforth.com",
client_id="123e4567-e89b-12d3-a456-426614174000"
)
tenant = create_m365_tenant(db, tenant_data)
print(f"Created M365 tenant: {tenant.id}")
```
"""
# Validate client exists if client_id provided
if tenant_data.client_id:
client = db.query(Client).filter(Client.id == str(tenant_data.client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {tenant_data.client_id} not found"
)
# Check if M365 tenant with tenant_id already exists
existing_tenant = get_m365_tenant_by_tenant_id(db, tenant_data.tenant_id)
if existing_tenant:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"M365 tenant with tenant_id '{tenant_data.tenant_id}' already exists"
)
try:
# Create new M365 tenant instance
db_tenant = M365Tenant(**tenant_data.model_dump())
# Add to database
db.add(db_tenant)
db.commit()
db.refresh(db_tenant)
return db_tenant
except IntegrityError as e:
db.rollback()
# Handle unique constraint violations
if "tenant_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"M365 tenant with tenant_id '{tenant_data.tenant_id}' already exists"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {tenant_data.client_id} not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create M365 tenant: {str(e)}"
)
def update_m365_tenant(db: Session, tenant_id: UUID, tenant_data: M365TenantUpdate) -> M365Tenant:
"""
Update an existing M365 tenant.
Args:
db: Database session
tenant_id: UUID of the M365 tenant to update
tenant_data: M365 tenant update data (only provided fields will be updated)
Returns:
M365Tenant: The updated M365 tenant object
Raises:
HTTPException: 404 if M365 tenant not found or client_id provided and client doesn't exist
HTTPException: 409 if update would violate unique constraints
HTTPException: 500 if database error occurs
Example:
```python
update_data = M365TenantUpdate(
admin_email="admin@example.com",
notes="Updated tenant information"
)
tenant = update_m365_tenant(db, tenant_id, update_data)
print(f"Updated M365 tenant: {tenant.tenant_name}")
```
"""
# Get existing M365 tenant
tenant = get_m365_tenant_by_id(db, tenant_id)
try:
# Update only provided fields
update_data = tenant_data.model_dump(exclude_unset=True)
# If updating client_id, validate client exists
if "client_id" in update_data and update_data["client_id"] is not None:
client = db.query(Client).filter(Client.id == str(update_data["client_id"])).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {update_data['client_id']} not found"
)
# If updating tenant_id, check if new tenant_id is already taken
if "tenant_id" in update_data and update_data["tenant_id"] != tenant.tenant_id:
existing = get_m365_tenant_by_tenant_id(db, update_data["tenant_id"])
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"M365 tenant with tenant_id '{update_data['tenant_id']}' already exists"
)
# Apply updates
for field, value in update_data.items():
setattr(tenant, field, value)
db.commit()
db.refresh(tenant)
return tenant
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "tenant_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="M365 tenant with this tenant_id already exists"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Client not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update M365 tenant: {str(e)}"
)
def delete_m365_tenant(db: Session, tenant_id: UUID) -> dict:
"""
Delete an M365 tenant by its ID.
Args:
db: Database session
tenant_id: UUID of the M365 tenant to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if M365 tenant not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_m365_tenant(db, tenant_id)
print(result["message"]) # "M365 tenant deleted successfully"
```
"""
# Get existing M365 tenant (raises 404 if not found)
tenant = get_m365_tenant_by_id(db, tenant_id)
try:
db.delete(tenant)
db.commit()
return {
"message": "M365 tenant deleted successfully",
"tenant_id": str(tenant_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete M365 tenant: {str(e)}"
)

View File

@@ -0,0 +1,347 @@
"""
Machine service layer for business logic and database operations.
This module handles all database operations for machines, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.machine import Machine
from api.schemas.machine import MachineCreate, MachineUpdate
def get_machines(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Machine], int]:
"""
Retrieve a paginated list of machines.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of machines, total count)
Example:
```python
machines, total = get_machines(db, skip=0, limit=50)
print(f"Retrieved {len(machines)} of {total} machines")
```
"""
# Get total count
total = db.query(Machine).count()
# Get paginated results, ordered by created_at descending (newest first)
machines = (
db.query(Machine)
.order_by(Machine.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return machines, total
def get_machine_by_id(db: Session, machine_id: UUID) -> Machine:
"""
Retrieve a single machine by its ID.
Args:
db: Database session
machine_id: UUID of the machine to retrieve
Returns:
Machine: The machine object
Raises:
HTTPException: 404 if machine not found
Example:
```python
machine = get_machine_by_id(db, machine_id)
print(f"Found machine: {machine.hostname}")
```
"""
machine = db.query(Machine).filter(Machine.id == str(machine_id)).first()
if not machine:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Machine with ID {machine_id} not found"
)
return machine
def get_machine_by_hostname(db: Session, hostname: str) -> Optional[Machine]:
"""
Retrieve a machine by its hostname.
Args:
db: Database session
hostname: Hostname to search for
Returns:
Optional[Machine]: The machine if found, None otherwise
Example:
```python
machine = get_machine_by_hostname(db, "laptop-dev-01")
if machine:
print(f"Found machine: {machine.friendly_name}")
```
"""
return db.query(Machine).filter(Machine.hostname == hostname).first()
def create_machine(db: Session, machine_data: MachineCreate) -> Machine:
"""
Create a new machine.
Args:
db: Database session
machine_data: Machine creation data
Returns:
Machine: The created machine object
Raises:
HTTPException: 409 if machine with hostname already exists
HTTPException: 500 if database error occurs
Example:
```python
machine_data = MachineCreate(
hostname="laptop-dev-01",
friendly_name="Development Laptop",
platform="win32"
)
machine = create_machine(db, machine_data)
print(f"Created machine: {machine.id}")
```
"""
# Check if machine with hostname already exists
existing_machine = get_machine_by_hostname(db, machine_data.hostname)
if existing_machine:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Machine with hostname '{machine_data.hostname}' already exists"
)
try:
# Create new machine instance
db_machine = Machine(**machine_data.model_dump())
# Add to database
db.add(db_machine)
db.commit()
db.refresh(db_machine)
return db_machine
except IntegrityError as e:
db.rollback()
# Handle unique constraint violations
if "hostname" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Machine with hostname '{machine_data.hostname}' already exists"
)
elif "machine_fingerprint" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Machine with this fingerprint already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create machine: {str(e)}"
)
def update_machine(db: Session, machine_id: UUID, machine_data: MachineUpdate) -> Machine:
"""
Update an existing machine.
Args:
db: Database session
machine_id: UUID of the machine to update
machine_data: Machine update data (only provided fields will be updated)
Returns:
Machine: The updated machine object
Raises:
HTTPException: 404 if machine not found
HTTPException: 409 if update would violate unique constraints
HTTPException: 500 if database error occurs
Example:
```python
update_data = MachineUpdate(
friendly_name="Updated Laptop Name",
is_active=False
)
machine = update_machine(db, machine_id, update_data)
print(f"Updated machine: {machine.friendly_name}")
```
"""
# Get existing machine
machine = get_machine_by_id(db, machine_id)
try:
# Update only provided fields
update_data = machine_data.model_dump(exclude_unset=True)
# If updating hostname, check if new hostname is already taken
if "hostname" in update_data and update_data["hostname"] != machine.hostname:
existing = get_machine_by_hostname(db, update_data["hostname"])
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Machine with hostname '{update_data['hostname']}' already exists"
)
# Apply updates
for field, value in update_data.items():
setattr(machine, field, value)
db.commit()
db.refresh(machine)
return machine
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "hostname" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Machine with this hostname already exists"
)
elif "machine_fingerprint" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Machine with this fingerprint already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update machine: {str(e)}"
)
def delete_machine(db: Session, machine_id: UUID) -> dict:
"""
Delete a machine by its ID.
Args:
db: Database session
machine_id: UUID of the machine to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if machine not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_machine(db, machine_id)
print(result["message"]) # "Machine deleted successfully"
```
"""
# Get existing machine (raises 404 if not found)
machine = get_machine_by_id(db, machine_id)
try:
db.delete(machine)
db.commit()
return {
"message": "Machine deleted successfully",
"machine_id": str(machine_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete machine: {str(e)}"
)
def get_active_machines(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Machine], int]:
"""
Retrieve a paginated list of active machines only.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of active machines, total count)
Example:
```python
machines, total = get_active_machines(db, skip=0, limit=50)
print(f"Retrieved {len(machines)} of {total} active machines")
```
"""
# Get total count of active machines
total = db.query(Machine).filter(Machine.is_active == True).count()
# Get paginated results
machines = (
db.query(Machine)
.filter(Machine.is_active == True)
.order_by(Machine.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return machines, total
def get_primary_machine(db: Session) -> Optional[Machine]:
"""
Retrieve the primary machine.
Args:
db: Database session
Returns:
Optional[Machine]: The primary machine if one exists, None otherwise
Example:
```python
primary = get_primary_machine(db)
if primary:
print(f"Primary machine: {primary.hostname}")
```
"""
return db.query(Machine).filter(Machine.is_primary == True).first()

View File

@@ -0,0 +1,332 @@
"""
Network service layer for business logic and database operations.
This module handles all database operations for networks, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.network import Network
from api.models.site import Site
from api.schemas.network import NetworkCreate, NetworkUpdate
def get_networks(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Network], int]:
"""
Retrieve a paginated list of networks.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of networks, total count)
Example:
```python
networks, total = get_networks(db, skip=0, limit=50)
print(f"Retrieved {len(networks)} of {total} networks")
```
"""
# Get total count
total = db.query(Network).count()
# Get paginated results, ordered by created_at descending (newest first)
networks = (
db.query(Network)
.order_by(Network.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return networks, total
def get_network_by_id(db: Session, network_id: UUID) -> Network:
"""
Retrieve a single network by its ID.
Args:
db: Database session
network_id: UUID of the network to retrieve
Returns:
Network: The network object
Raises:
HTTPException: 404 if network not found
Example:
```python
network = get_network_by_id(db, network_id)
print(f"Found network: {network.network_name}")
```
"""
network = db.query(Network).filter(Network.id == str(network_id)).first()
if not network:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Network with ID {network_id} not found"
)
return network
def get_networks_by_site(db: Session, site_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[Network], int]:
"""
Retrieve networks belonging to a specific site.
Args:
db: Database session
site_id: UUID of the site
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of networks, total count for this site)
Raises:
HTTPException: 404 if site not found
Example:
```python
networks, total = get_networks_by_site(db, site_id, skip=0, limit=50)
print(f"Retrieved {len(networks)} of {total} networks for site")
```
"""
# Verify site exists
site = db.query(Site).filter(Site.id == str(site_id)).first()
if not site:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Site with ID {site_id} not found"
)
# Get total count for this site
total = db.query(Network).filter(Network.site_id == str(site_id)).count()
# Get paginated results
networks = (
db.query(Network)
.filter(Network.site_id == str(site_id))
.order_by(Network.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return networks, total
def get_networks_by_type(db: Session, network_type: str, skip: int = 0, limit: int = 100) -> tuple[list[Network], int]:
"""
Retrieve networks of a specific type.
Args:
db: Database session
network_type: Type of network (lan, vpn, vlan, isolated, dmz)
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of networks, total count for this type)
Example:
```python
networks, total = get_networks_by_type(db, "vlan", skip=0, limit=50)
print(f"Retrieved {len(networks)} of {total} VLAN networks")
```
"""
# Get total count for this type
total = db.query(Network).filter(Network.network_type == network_type).count()
# Get paginated results
networks = (
db.query(Network)
.filter(Network.network_type == network_type)
.order_by(Network.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return networks, total
def create_network(db: Session, network_data: NetworkCreate) -> Network:
"""
Create a new network.
Args:
db: Database session
network_data: Network creation data
Returns:
Network: The created network object
Raises:
HTTPException: 404 if site not found
HTTPException: 500 if database error occurs
Example:
```python
network_data = NetworkCreate(
site_id="123e4567-e89b-12d3-a456-426614174000",
network_name="Main LAN",
network_type="lan",
cidr="192.168.1.0/24"
)
network = create_network(db, network_data)
print(f"Created network: {network.id}")
```
"""
# Verify site exists if provided
if network_data.site_id:
site = db.query(Site).filter(Site.id == str(network_data.site_id)).first()
if not site:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Site with ID {network_data.site_id} not found"
)
try:
# Create new network instance
db_network = Network(**network_data.model_dump())
# Add to database
db.add(db_network)
db.commit()
db.refresh(db_network)
return db_network
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create network: {str(e)}"
)
def update_network(db: Session, network_id: UUID, network_data: NetworkUpdate) -> Network:
"""
Update an existing network.
Args:
db: Database session
network_id: UUID of the network to update
network_data: Network update data (only provided fields will be updated)
Returns:
Network: The updated network object
Raises:
HTTPException: 404 if network or site not found
HTTPException: 500 if database error occurs
Example:
```python
update_data = NetworkUpdate(
network_name="Main LAN - Upgraded",
gateway_ip="192.168.1.1"
)
network = update_network(db, network_id, update_data)
print(f"Updated network: {network.network_name}")
```
"""
# Get existing network
network = get_network_by_id(db, network_id)
try:
# Update only provided fields
update_data = network_data.model_dump(exclude_unset=True)
# If updating site_id, verify new site exists
if "site_id" in update_data and update_data["site_id"] is not None:
site = db.query(Site).filter(Site.id == str(update_data["site_id"])).first()
if not site:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Site with ID {update_data['site_id']} not found"
)
# Apply updates
for field, value in update_data.items():
setattr(network, field, value)
db.commit()
db.refresh(network)
return network
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update network: {str(e)}"
)
def delete_network(db: Session, network_id: UUID) -> dict:
"""
Delete a network by its ID.
Args:
db: Database session
network_id: UUID of the network to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if network not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_network(db, network_id)
print(result["message"]) # "Network deleted successfully"
```
"""
# Get existing network (raises 404 if not found)
network = get_network_by_id(db, network_id)
try:
db.delete(network)
db.commit()
return {
"message": "Network deleted successfully",
"network_id": str(network_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete network: {str(e)}"
)

View File

@@ -0,0 +1,394 @@
"""
Project service layer for business logic and database operations.
This module handles all database operations for projects, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.project import Project
from api.models.client import Client
from api.schemas.project import ProjectCreate, ProjectUpdate
def get_projects(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Project], int]:
"""
Retrieve a paginated list of projects.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of projects, total count)
Example:
```python
projects, total = get_projects(db, skip=0, limit=50)
print(f"Retrieved {len(projects)} of {total} projects")
```
"""
# Get total count
total = db.query(Project).count()
# Get paginated results, ordered by created_at descending (newest first)
projects = (
db.query(Project)
.order_by(Project.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return projects, total
def get_project_by_id(db: Session, project_id: UUID) -> Project:
"""
Retrieve a single project by its ID.
Args:
db: Database session
project_id: UUID of the project to retrieve
Returns:
Project: The project object
Raises:
HTTPException: 404 if project not found
Example:
```python
project = get_project_by_id(db, project_id)
print(f"Found project: {project.name}")
```
"""
project = db.query(Project).filter(Project.id == str(project_id)).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project with ID {project_id} not found"
)
return project
def get_project_by_slug(db: Session, slug: str) -> Optional[Project]:
"""
Retrieve a project by its slug.
Args:
db: Database session
slug: Slug to search for
Returns:
Optional[Project]: The project if found, None otherwise
Example:
```python
project = get_project_by_slug(db, "dataforth-dos")
if project:
print(f"Found project: {project.name}")
```
"""
return db.query(Project).filter(Project.slug == slug).first()
def get_projects_by_client(db: Session, client_id: str, skip: int = 0, limit: int = 100) -> tuple[list[Project], int]:
"""
Retrieve projects for a specific client.
Args:
db: Database session
client_id: Client UUID
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of projects, total count)
Example:
```python
projects, total = get_projects_by_client(db, client_id)
print(f"Client has {total} projects")
```
"""
total = db.query(Project).filter(Project.client_id == str(client_id)).count()
projects = (
db.query(Project)
.filter(Project.client_id == str(client_id))
.order_by(Project.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return projects, total
def get_projects_by_status(db: Session, status_filter: str, skip: int = 0, limit: int = 100) -> tuple[list[Project], int]:
"""
Retrieve projects by status.
Args:
db: Database session
status_filter: Status to filter by (complete, working, blocked, pending, critical, deferred)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of projects, total count)
Example:
```python
projects, total = get_projects_by_status(db, "working")
print(f"Found {total} working projects")
```
"""
total = db.query(Project).filter(Project.status == status_filter).count()
projects = (
db.query(Project)
.filter(Project.status == status_filter)
.order_by(Project.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return projects, total
def validate_client_exists(db: Session, client_id: str) -> None:
"""
Validate that a client exists.
Args:
db: Database session
client_id: Client UUID to validate
Raises:
HTTPException: 404 if client not found
Example:
```python
validate_client_exists(db, client_id)
# Continues if client exists, raises HTTPException if not
```
"""
client = db.query(Client).filter(Client.id == str(client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {client_id} not found"
)
def create_project(db: Session, project_data: ProjectCreate) -> Project:
"""
Create a new project.
Args:
db: Database session
project_data: Project creation data
Returns:
Project: The created project object
Raises:
HTTPException: 404 if client not found
HTTPException: 409 if project with slug already exists
HTTPException: 500 if database error occurs
Example:
```python
project_data = ProjectCreate(
client_id="123e4567-e89b-12d3-a456-426614174000",
name="Client Website Redesign",
status="working"
)
project = create_project(db, project_data)
print(f"Created project: {project.id}")
```
"""
# Validate client exists
validate_client_exists(db, project_data.client_id)
# Check if project with slug already exists (if slug is provided)
if project_data.slug:
existing_project = get_project_by_slug(db, project_data.slug)
if existing_project:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Project with slug '{project_data.slug}' already exists"
)
try:
# Create new project instance
db_project = Project(**project_data.model_dump())
# Add to database
db.add(db_project)
db.commit()
db.refresh(db_project)
return db_project
except IntegrityError as e:
db.rollback()
# Handle unique constraint violations
if "slug" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Project with slug '{project_data.slug}' already exists"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {project_data.client_id} not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create project: {str(e)}"
)
def update_project(db: Session, project_id: UUID, project_data: ProjectUpdate) -> Project:
"""
Update an existing project.
Args:
db: Database session
project_id: UUID of the project to update
project_data: Project update data (only provided fields will be updated)
Returns:
Project: The updated project object
Raises:
HTTPException: 404 if project or client not found
HTTPException: 409 if update would violate unique constraints
HTTPException: 500 if database error occurs
Example:
```python
update_data = ProjectUpdate(
status="completed",
completed_date=date.today()
)
project = update_project(db, project_id, update_data)
print(f"Updated project: {project.name}")
```
"""
# Get existing project
project = get_project_by_id(db, project_id)
try:
# Update only provided fields
update_data = project_data.model_dump(exclude_unset=True)
# If updating client_id, validate client exists
if "client_id" in update_data and update_data["client_id"] != project.client_id:
validate_client_exists(db, update_data["client_id"])
# If updating slug, check if new slug is already taken
if "slug" in update_data and update_data["slug"] and update_data["slug"] != project.slug:
existing = get_project_by_slug(db, update_data["slug"])
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Project with slug '{update_data['slug']}' already exists"
)
# Apply updates
for field, value in update_data.items():
setattr(project, field, value)
db.commit()
db.refresh(project)
return project
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "slug" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Project with this slug already exists"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Client not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update project: {str(e)}"
)
def delete_project(db: Session, project_id: UUID) -> dict:
"""
Delete a project by its ID.
Args:
db: Database session
project_id: UUID of the project to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if project not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_project(db, project_id)
print(result["message"]) # "Project deleted successfully"
```
"""
# Get existing project (raises 404 if not found)
project = get_project_by_id(db, project_id)
try:
db.delete(project)
db.commit()
return {
"message": "Project deleted successfully",
"project_id": str(project_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete project: {str(e)}"
)

View File

@@ -0,0 +1,273 @@
"""
ProjectState service layer for business logic and database operations.
Handles all database operations for project states, tracking the current
state of projects for quick context retrieval.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.project_state import ProjectState
from api.schemas.project_state import ProjectStateCreate, ProjectStateUpdate
from api.utils.context_compression import compress_project_state
def get_project_states(
db: Session,
skip: int = 0,
limit: int = 100
) -> tuple[list[ProjectState], int]:
"""
Retrieve a paginated list of project states.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of project states, total count)
"""
# Get total count
total = db.query(ProjectState).count()
# Get paginated results, ordered by most recently updated
states = (
db.query(ProjectState)
.order_by(ProjectState.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return states, total
def get_project_state_by_id(db: Session, state_id: UUID) -> ProjectState:
"""
Retrieve a single project state by its ID.
Args:
db: Database session
state_id: UUID of the project state to retrieve
Returns:
ProjectState: The project state object
Raises:
HTTPException: 404 if project state not found
"""
state = db.query(ProjectState).filter(ProjectState.id == str(state_id)).first()
if not state:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"ProjectState with ID {state_id} not found"
)
return state
def get_project_state_by_project(db: Session, project_id: UUID) -> Optional[ProjectState]:
"""
Retrieve the project state for a specific project.
Each project has exactly one project state (unique constraint).
Args:
db: Database session
project_id: UUID of the project
Returns:
Optional[ProjectState]: The project state if found, None otherwise
"""
state = db.query(ProjectState).filter(ProjectState.project_id == str(project_id)).first()
return state
def create_project_state(
db: Session,
state_data: ProjectStateCreate
) -> ProjectState:
"""
Create a new project state.
Args:
db: Database session
state_data: Project state creation data
Returns:
ProjectState: The created project state object
Raises:
HTTPException: 409 if project state already exists for this project
HTTPException: 500 if database error occurs
"""
# Check if project state already exists for this project
existing_state = get_project_state_by_project(db, state_data.project_id)
if existing_state:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"ProjectState for project ID {state_data.project_id} already exists"
)
try:
# Create new project state instance
db_state = ProjectState(**state_data.model_dump())
# Add to database
db.add(db_state)
db.commit()
db.refresh(db_state)
return db_state
except IntegrityError as e:
db.rollback()
if "project_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"ProjectState for project ID {state_data.project_id} already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create project state: {str(e)}"
)
def update_project_state(
db: Session,
state_id: UUID,
state_data: ProjectStateUpdate
) -> ProjectState:
"""
Update an existing project state.
Uses compression utilities when updating to maintain efficient storage.
Args:
db: Database session
state_id: UUID of the project state to update
state_data: Project state update data
Returns:
ProjectState: The updated project state object
Raises:
HTTPException: 404 if project state not found
HTTPException: 500 if database error occurs
"""
# Get existing state
state = get_project_state_by_id(db, state_id)
try:
# Update only provided fields
update_data = state_data.model_dump(exclude_unset=True)
# Apply updates
for field, value in update_data.items():
setattr(state, field, value)
db.commit()
db.refresh(state)
return state
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update project state: {str(e)}"
)
def update_project_state_by_project(
db: Session,
project_id: UUID,
state_data: ProjectStateUpdate
) -> ProjectState:
"""
Update project state by project ID (convenience method).
If project state doesn't exist, creates a new one.
Args:
db: Database session
project_id: UUID of the project
state_data: Project state update data
Returns:
ProjectState: The updated or created project state object
Raises:
HTTPException: 500 if database error occurs
"""
# Try to get existing state
state = get_project_state_by_project(db, project_id)
if state:
# Update existing state
return update_project_state(db, UUID(state.id), state_data)
else:
# Create new state
create_data = ProjectStateCreate(
project_id=project_id,
**state_data.model_dump(exclude_unset=True)
)
return create_project_state(db, create_data)
def delete_project_state(db: Session, state_id: UUID) -> dict:
"""
Delete a project state by its ID.
Args:
db: Database session
state_id: UUID of the project state to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if project state not found
HTTPException: 500 if database error occurs
"""
# Get existing state (raises 404 if not found)
state = get_project_state_by_id(db, state_id)
try:
db.delete(state)
db.commit()
return {
"message": "ProjectState deleted successfully",
"state_id": str(state_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete project state: {str(e)}"
)

View File

@@ -0,0 +1,335 @@
"""
Security incident service layer for business logic and database operations.
This module handles all database operations for security incidents.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.security_incident import SecurityIncident
from api.schemas.security_incident import SecurityIncidentCreate, SecurityIncidentUpdate
def get_security_incidents(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[SecurityIncident], int]:
"""
Retrieve a paginated list of security incidents.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of security incidents, total count)
Example:
```python
incidents, total = get_security_incidents(db, skip=0, limit=50)
print(f"Retrieved {len(incidents)} of {total} security incidents")
```
"""
# Get total count
total = db.query(SecurityIncident).count()
# Get paginated results, ordered by incident_date descending (most recent first)
incidents = (
db.query(SecurityIncident)
.order_by(SecurityIncident.incident_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return incidents, total
def get_security_incident_by_id(db: Session, incident_id: UUID) -> SecurityIncident:
"""
Retrieve a single security incident by its ID.
Args:
db: Database session
incident_id: UUID of the security incident to retrieve
Returns:
SecurityIncident: The security incident object
Raises:
HTTPException: 404 if security incident not found
Example:
```python
incident = get_security_incident_by_id(db, incident_id)
print(f"Found incident: {incident.incident_type} - {incident.severity}")
```
"""
incident = db.query(SecurityIncident).filter(SecurityIncident.id == str(incident_id)).first()
if not incident:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Security incident with ID {incident_id} not found"
)
return incident
def get_security_incidents_by_client(
db: Session,
client_id: UUID,
skip: int = 0,
limit: int = 100
) -> tuple[list[SecurityIncident], int]:
"""
Retrieve security incidents for a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of security incidents, total count)
Example:
```python
incidents, total = get_security_incidents_by_client(db, client_id, skip=0, limit=50)
print(f"Client has {total} security incidents")
```
"""
# Get total count for this client
total = db.query(SecurityIncident).filter(SecurityIncident.client_id == str(client_id)).count()
# Get paginated results
incidents = (
db.query(SecurityIncident)
.filter(SecurityIncident.client_id == str(client_id))
.order_by(SecurityIncident.incident_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return incidents, total
def get_security_incidents_by_status(
db: Session,
status_filter: str,
skip: int = 0,
limit: int = 100
) -> tuple[list[SecurityIncident], int]:
"""
Retrieve security incidents by status.
Args:
db: Database session
status_filter: Status to filter by (investigating, contained, resolved, monitoring)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of security incidents, total count)
Example:
```python
incidents, total = get_security_incidents_by_status(db, "investigating", skip=0, limit=50)
print(f"Found {total} incidents under investigation")
```
"""
# Get total count for this status
total = db.query(SecurityIncident).filter(SecurityIncident.status == status_filter).count()
# Get paginated results
incidents = (
db.query(SecurityIncident)
.filter(SecurityIncident.status == status_filter)
.order_by(SecurityIncident.incident_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return incidents, total
def create_security_incident(db: Session, incident_data: SecurityIncidentCreate) -> SecurityIncident:
"""
Create a new security incident.
Args:
db: Database session
incident_data: Security incident creation data
Returns:
SecurityIncident: The created security incident object
Raises:
HTTPException: 500 if database error occurs
Example:
```python
incident_data = SecurityIncidentCreate(
client_id="client-uuid",
incident_type="malware",
incident_date=datetime.utcnow(),
severity="high",
description="Malware detected on workstation",
status="investigating"
)
incident = create_security_incident(db, incident_data)
print(f"Created incident: {incident.id}")
```
"""
try:
# Convert Pydantic model to dict
data = incident_data.model_dump()
# Convert UUID fields to strings
if data.get("client_id"):
data["client_id"] = str(data["client_id"])
if data.get("service_id"):
data["service_id"] = str(data["service_id"])
if data.get("infrastructure_id"):
data["infrastructure_id"] = str(data["infrastructure_id"])
# Create new security incident instance
db_incident = SecurityIncident(**data)
# Add to database
db.add(db_incident)
db.commit()
db.refresh(db_incident)
return db_incident
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database integrity error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create security incident: {str(e)}"
)
def update_security_incident(
db: Session,
incident_id: UUID,
incident_data: SecurityIncidentUpdate
) -> SecurityIncident:
"""
Update an existing security incident.
Args:
db: Database session
incident_id: UUID of the security incident to update
incident_data: Security incident update data (only provided fields will be updated)
Returns:
SecurityIncident: The updated security incident object
Raises:
HTTPException: 404 if security incident not found
HTTPException: 500 if database error occurs
Example:
```python
update_data = SecurityIncidentUpdate(
status="contained",
remediation_steps="Malware removed, system scanned clean"
)
incident = update_security_incident(db, incident_id, update_data)
print(f"Updated incident: {incident.status}")
```
"""
# Get existing security incident
incident = get_security_incident_by_id(db, incident_id)
try:
# Update only provided fields
update_data = incident_data.model_dump(exclude_unset=True)
# Convert UUID fields to strings
if "client_id" in update_data and update_data["client_id"]:
update_data["client_id"] = str(update_data["client_id"])
if "service_id" in update_data and update_data["service_id"]:
update_data["service_id"] = str(update_data["service_id"])
if "infrastructure_id" in update_data and update_data["infrastructure_id"]:
update_data["infrastructure_id"] = str(update_data["infrastructure_id"])
# Apply updates
for field, value in update_data.items():
setattr(incident, field, value)
db.commit()
db.refresh(incident)
return incident
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database integrity error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update security incident: {str(e)}"
)
def delete_security_incident(db: Session, incident_id: UUID) -> dict:
"""
Delete a security incident by its ID.
Args:
db: Database session
incident_id: UUID of the security incident to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if security incident not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_security_incident(db, incident_id)
print(result["message"]) # "Security incident deleted successfully"
```
"""
# Get existing security incident (raises 404 if not found)
incident = get_security_incident_by_id(db, incident_id)
try:
db.delete(incident)
db.commit()
return {
"message": "Security incident deleted successfully",
"incident_id": str(incident_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete security incident: {str(e)}"
)

View File

@@ -0,0 +1,384 @@
"""
Service service layer for business logic and database operations.
This module handles all database operations for services, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.service import Service
from api.models.infrastructure import Infrastructure
from api.schemas.service import ServiceCreate, ServiceUpdate
def get_services(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Service], int]:
"""
Retrieve a paginated list of services.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of services, total count)
Example:
```python
services, total = get_services(db, skip=0, limit=50)
print(f"Retrieved {len(services)} of {total} services")
```
"""
# Get total count
total = db.query(Service).count()
# Get paginated results, ordered by created_at descending (newest first)
services = (
db.query(Service)
.order_by(Service.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return services, total
def get_service_by_id(db: Session, service_id: UUID) -> Service:
"""
Retrieve a single service by its ID.
Args:
db: Database session
service_id: UUID of the service to retrieve
Returns:
Service: The service object
Raises:
HTTPException: 404 if service not found
Example:
```python
service = get_service_by_id(db, service_id)
print(f"Found service: {service.service_name}")
```
"""
service = db.query(Service).filter(Service.id == str(service_id)).first()
if not service:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service with ID {service_id} not found"
)
return service
def get_services_by_client(db: Session, client_id: str, skip: int = 0, limit: int = 100) -> tuple[list[Service], int]:
"""
Retrieve services for a specific client (via infrastructure).
Args:
db: Database session
client_id: Client UUID
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of services, total count)
Example:
```python
services, total = get_services_by_client(db, client_id)
print(f"Client has {total} services")
```
"""
# Join with Infrastructure to filter by client_id
query = (
db.query(Service)
.join(Infrastructure, Service.infrastructure_id == Infrastructure.id)
.filter(Infrastructure.client_id == str(client_id))
)
total = query.count()
services = (
query
.order_by(Service.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return services, total
def get_services_by_type(db: Session, service_type: str, skip: int = 0, limit: int = 100) -> tuple[list[Service], int]:
"""
Retrieve services by type.
Args:
db: Database session
service_type: Service type to filter by (e.g., 'git_hosting', 'database')
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of services, total count)
Example:
```python
services, total = get_services_by_type(db, "database")
print(f"Found {total} database services")
```
"""
total = db.query(Service).filter(Service.service_type == service_type).count()
services = (
db.query(Service)
.filter(Service.service_type == service_type)
.order_by(Service.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return services, total
def get_services_by_status(db: Session, status_filter: str, skip: int = 0, limit: int = 100) -> tuple[list[Service], int]:
"""
Retrieve services by status.
Args:
db: Database session
status_filter: Status to filter by (running, stopped, error, maintenance)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of services, total count)
Example:
```python
services, total = get_services_by_status(db, "running")
print(f"Found {total} running services")
```
"""
total = db.query(Service).filter(Service.status == status_filter).count()
services = (
db.query(Service)
.filter(Service.status == status_filter)
.order_by(Service.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return services, total
def validate_infrastructure_exists(db: Session, infrastructure_id: str) -> None:
"""
Validate that infrastructure exists.
Args:
db: Database session
infrastructure_id: Infrastructure UUID to validate
Raises:
HTTPException: 404 if infrastructure not found
Example:
```python
validate_infrastructure_exists(db, infrastructure_id)
# Continues if infrastructure exists, raises HTTPException if not
```
"""
infrastructure = db.query(Infrastructure).filter(Infrastructure.id == str(infrastructure_id)).first()
if not infrastructure:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {infrastructure_id} not found"
)
def create_service(db: Session, service_data: ServiceCreate) -> Service:
"""
Create a new service.
Args:
db: Database session
service_data: Service creation data
Returns:
Service: The created service object
Raises:
HTTPException: 404 if infrastructure not found
HTTPException: 500 if database error occurs
Example:
```python
service_data = ServiceCreate(
infrastructure_id="123e4567-e89b-12d3-a456-426614174000",
service_name="Gitea",
service_type="git_hosting",
status="running"
)
service = create_service(db, service_data)
print(f"Created service: {service.id}")
```
"""
# Validate infrastructure exists if provided
if service_data.infrastructure_id:
validate_infrastructure_exists(db, service_data.infrastructure_id)
try:
# Create new service instance
db_service = Service(**service_data.model_dump())
# Add to database
db.add(db_service)
db.commit()
db.refresh(db_service)
return db_service
except IntegrityError as e:
db.rollback()
# Handle foreign key constraint violations
if "infrastructure_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Infrastructure with ID {service_data.infrastructure_id} not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create service: {str(e)}"
)
def update_service(db: Session, service_id: UUID, service_data: ServiceUpdate) -> Service:
"""
Update an existing service.
Args:
db: Database session
service_id: UUID of the service to update
service_data: Service update data (only provided fields will be updated)
Returns:
Service: The updated service object
Raises:
HTTPException: 404 if service or infrastructure not found
HTTPException: 500 if database error occurs
Example:
```python
update_data = ServiceUpdate(
status="stopped",
notes="Service temporarily stopped for maintenance"
)
service = update_service(db, service_id, update_data)
print(f"Updated service: {service.service_name}")
```
"""
# Get existing service
service = get_service_by_id(db, service_id)
try:
# Update only provided fields
update_data = service_data.model_dump(exclude_unset=True)
# If updating infrastructure_id, validate infrastructure exists
if "infrastructure_id" in update_data and update_data["infrastructure_id"] and update_data["infrastructure_id"] != service.infrastructure_id:
validate_infrastructure_exists(db, update_data["infrastructure_id"])
# Apply updates
for field, value in update_data.items():
setattr(service, field, value)
db.commit()
db.refresh(service)
return service
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "infrastructure_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Infrastructure not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update service: {str(e)}"
)
def delete_service(db: Session, service_id: UUID) -> dict:
"""
Delete a service by its ID.
Args:
db: Database session
service_id: UUID of the service to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if service not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_service(db, service_id)
print(result["message"]) # "Service deleted successfully"
```
"""
# Get existing service (raises 404 if not found)
service = get_service_by_id(db, service_id)
try:
db.delete(service)
db.commit()
return {
"message": "Service deleted successfully",
"service_id": str(service_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete service: {str(e)}"
)

View File

@@ -0,0 +1,375 @@
"""
Session service layer for business logic and database operations.
This module handles all database operations for sessions, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.session import Session as SessionModel
from api.models.project import Project
from api.models.machine import Machine
from api.schemas.session import SessionCreate, SessionUpdate
def get_sessions(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[SessionModel], int]:
"""
Retrieve a paginated list of sessions.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of sessions, total count)
Example:
```python
sessions, total = get_sessions(db, skip=0, limit=50)
print(f"Retrieved {len(sessions)} of {total} sessions")
```
"""
# Get total count
total = db.query(SessionModel).count()
# Get paginated results, ordered by session_date descending (newest first)
sessions = (
db.query(SessionModel)
.order_by(SessionModel.session_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return sessions, total
def get_session_by_id(db: Session, session_id: UUID) -> SessionModel:
"""
Retrieve a single session by its ID.
Args:
db: Database session
session_id: UUID of the session to retrieve
Returns:
SessionModel: The session object
Raises:
HTTPException: 404 if session not found
Example:
```python
session = get_session_by_id(db, session_id)
print(f"Found session: {session.session_title}")
```
"""
session = db.query(SessionModel).filter(SessionModel.id == str(session_id)).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {session_id} not found"
)
return session
def get_sessions_by_project(db: Session, project_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[SessionModel], int]:
"""
Retrieve sessions for a specific project.
Args:
db: Database session
project_id: UUID of the project
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of sessions, total count)
Example:
```python
sessions, total = get_sessions_by_project(db, project_id)
print(f"Found {total} sessions for project")
```
"""
# Get total count
total = db.query(SessionModel).filter(SessionModel.project_id == str(project_id)).count()
# Get paginated results
sessions = (
db.query(SessionModel)
.filter(SessionModel.project_id == str(project_id))
.order_by(SessionModel.session_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return sessions, total
def get_sessions_by_machine(db: Session, machine_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[SessionModel], int]:
"""
Retrieve sessions for a specific machine.
Args:
db: Database session
machine_id: UUID of the machine
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of sessions, total count)
Example:
```python
sessions, total = get_sessions_by_machine(db, machine_id)
print(f"Found {total} sessions on this machine")
```
"""
# Get total count
total = db.query(SessionModel).filter(SessionModel.machine_id == str(machine_id)).count()
# Get paginated results
sessions = (
db.query(SessionModel)
.filter(SessionModel.machine_id == str(machine_id))
.order_by(SessionModel.session_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return sessions, total
def create_session(db: Session, session_data: SessionCreate) -> SessionModel:
"""
Create a new session.
Args:
db: Database session
session_data: Session creation data
Returns:
SessionModel: The created session object
Raises:
HTTPException: 404 if referenced project or machine not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
session_data = SessionCreate(
session_title="Database migration work",
session_date=date.today(),
project_id="123e4567-e89b-12d3-a456-426614174000"
)
session = create_session(db, session_data)
print(f"Created session: {session.id}")
```
"""
try:
# Validate foreign keys if provided
if session_data.project_id:
project = db.query(Project).filter(Project.id == str(session_data.project_id)).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project with ID {session_data.project_id} not found"
)
if session_data.machine_id:
machine = db.query(Machine).filter(Machine.id == str(session_data.machine_id)).first()
if not machine:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Machine with ID {session_data.machine_id} not found"
)
# Create new session instance
db_session = SessionModel(**session_data.model_dump())
# Add to database
db.add(db_session)
db.commit()
db.refresh(db_session)
return db_session
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
# Handle foreign key constraint violations
if "project_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid project_id: {session_data.project_id}"
)
elif "machine_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid machine_id: {session_data.machine_id}"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid client_id: {session_data.client_id}"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create session: {str(e)}"
)
def update_session(db: Session, session_id: UUID, session_data: SessionUpdate) -> SessionModel:
"""
Update an existing session.
Args:
db: Database session
session_id: UUID of the session to update
session_data: Session update data (only provided fields will be updated)
Returns:
SessionModel: The updated session object
Raises:
HTTPException: 404 if session, project, or machine not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
update_data = SessionUpdate(
status="completed",
duration_minutes=120
)
session = update_session(db, session_id, update_data)
print(f"Updated session: {session.session_title}")
```
"""
# Get existing session
session = get_session_by_id(db, session_id)
try:
# Update only provided fields
update_data = session_data.model_dump(exclude_unset=True)
# Validate foreign keys if being updated
if "project_id" in update_data and update_data["project_id"]:
project = db.query(Project).filter(Project.id == str(update_data["project_id"])).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project with ID {update_data['project_id']} not found"
)
if "machine_id" in update_data and update_data["machine_id"]:
machine = db.query(Machine).filter(Machine.id == str(update_data["machine_id"])).first()
if not machine:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Machine with ID {update_data['machine_id']} not found"
)
# Apply updates
for field, value in update_data.items():
setattr(session, field, value)
db.commit()
db.refresh(session)
return session
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "project_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid project_id"
)
elif "machine_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid machine_id"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid client_id"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update session: {str(e)}"
)
def delete_session(db: Session, session_id: UUID) -> dict:
"""
Delete a session by its ID.
Args:
db: Database session
session_id: UUID of the session to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if session not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_session(db, session_id)
print(result["message"]) # "Session deleted successfully"
```
"""
# Get existing session (raises 404 if not found)
session = get_session_by_id(db, session_id)
try:
db.delete(session)
db.commit()
return {
"message": "Session deleted successfully",
"session_id": str(session_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete session: {str(e)}"
)

View File

@@ -0,0 +1,295 @@
"""
Site service layer for business logic and database operations.
This module handles all database operations for sites, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.site import Site
from api.models.client import Client
from api.schemas.site import SiteCreate, SiteUpdate
def get_sites(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Site], int]:
"""
Retrieve a paginated list of sites.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of sites, total count)
Example:
```python
sites, total = get_sites(db, skip=0, limit=50)
print(f"Retrieved {len(sites)} of {total} sites")
```
"""
# Get total count
total = db.query(Site).count()
# Get paginated results, ordered by created_at descending (newest first)
sites = (
db.query(Site)
.order_by(Site.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return sites, total
def get_site_by_id(db: Session, site_id: UUID) -> Site:
"""
Retrieve a single site by its ID.
Args:
db: Database session
site_id: UUID of the site to retrieve
Returns:
Site: The site object
Raises:
HTTPException: 404 if site not found
Example:
```python
site = get_site_by_id(db, site_id)
print(f"Found site: {site.name}")
```
"""
site = db.query(Site).filter(Site.id == str(site_id)).first()
if not site:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Site with ID {site_id} not found"
)
return site
def get_sites_by_client(db: Session, client_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[Site], int]:
"""
Retrieve sites belonging to a specific client.
Args:
db: Database session
client_id: UUID of the client
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of sites, total count for this client)
Raises:
HTTPException: 404 if client not found
Example:
```python
sites, total = get_sites_by_client(db, client_id, skip=0, limit=50)
print(f"Retrieved {len(sites)} of {total} sites for client")
```
"""
# Verify client exists
client = db.query(Client).filter(Client.id == str(client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {client_id} not found"
)
# Get total count for this client
total = db.query(Site).filter(Site.client_id == str(client_id)).count()
# Get paginated results
sites = (
db.query(Site)
.filter(Site.client_id == str(client_id))
.order_by(Site.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return sites, total
def create_site(db: Session, site_data: SiteCreate) -> Site:
"""
Create a new site.
Args:
db: Database session
site_data: Site creation data
Returns:
Site: The created site object
Raises:
HTTPException: 404 if client not found
HTTPException: 500 if database error occurs
Example:
```python
site_data = SiteCreate(
client_id="123e4567-e89b-12d3-a456-426614174000",
name="Main Office",
network_subnet="172.16.9.0/24"
)
site = create_site(db, site_data)
print(f"Created site: {site.id}")
```
"""
# Verify client exists
client = db.query(Client).filter(Client.id == str(site_data.client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {site_data.client_id} not found"
)
try:
# Create new site instance
db_site = Site(**site_data.model_dump())
# Add to database
db.add(db_site)
db.commit()
db.refresh(db_site)
return db_site
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create site: {str(e)}"
)
def update_site(db: Session, site_id: UUID, site_data: SiteUpdate) -> Site:
"""
Update an existing site.
Args:
db: Database session
site_id: UUID of the site to update
site_data: Site update data (only provided fields will be updated)
Returns:
Site: The updated site object
Raises:
HTTPException: 404 if site or client not found
HTTPException: 500 if database error occurs
Example:
```python
update_data = SiteUpdate(
name="Main Office - Renovated",
vpn_required=True
)
site = update_site(db, site_id, update_data)
print(f"Updated site: {site.name}")
```
"""
# Get existing site
site = get_site_by_id(db, site_id)
try:
# Update only provided fields
update_data = site_data.model_dump(exclude_unset=True)
# If updating client_id, verify new client exists
if "client_id" in update_data:
client = db.query(Client).filter(Client.id == str(update_data["client_id"])).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {update_data['client_id']} not found"
)
# Apply updates
for field, value in update_data.items():
setattr(site, field, value)
db.commit()
db.refresh(site)
return site
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update site: {str(e)}"
)
def delete_site(db: Session, site_id: UUID) -> dict:
"""
Delete a site by its ID.
Args:
db: Database session
site_id: UUID of the site to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if site not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_site(db, site_id)
print(result["message"]) # "Site deleted successfully"
```
"""
# Get existing site (raises 404 if not found)
site = get_site_by_id(db, site_id)
try:
db.delete(site)
db.commit()
return {
"message": "Site deleted successfully",
"site_id": str(site_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete site: {str(e)}"
)

318
api/services/tag_service.py Normal file
View File

@@ -0,0 +1,318 @@
"""
Tag service layer for business logic and database operations.
This module handles all database operations for tags, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.tag import Tag
from api.schemas.tag import TagCreate, TagUpdate
def get_tags(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[Tag], int]:
"""
Retrieve a paginated list of tags.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of tags, total count)
Example:
```python
tags, total = get_tags(db, skip=0, limit=50)
print(f"Retrieved {len(tags)} of {total} tags")
```
"""
# Get total count
total = db.query(Tag).count()
# Get paginated results, ordered by name ascending
tags = (
db.query(Tag)
.order_by(Tag.name.asc())
.offset(skip)
.limit(limit)
.all()
)
return tags, total
def get_tag_by_id(db: Session, tag_id: UUID) -> Tag:
"""
Retrieve a single tag by its ID.
Args:
db: Database session
tag_id: UUID of the tag to retrieve
Returns:
Tag: The tag object
Raises:
HTTPException: 404 if tag not found
Example:
```python
tag = get_tag_by_id(db, tag_id)
print(f"Found tag: {tag.name}")
```
"""
tag = db.query(Tag).filter(Tag.id == str(tag_id)).first()
if not tag:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tag with ID {tag_id} not found"
)
return tag
def get_tag_by_name(db: Session, name: str) -> Optional[Tag]:
"""
Retrieve a tag by its name.
Args:
db: Database session
name: Tag name to search for
Returns:
Optional[Tag]: The tag if found, None otherwise
Example:
```python
tag = get_tag_by_name(db, "Windows")
if tag:
print(f"Found tag: {tag.description}")
```
"""
return db.query(Tag).filter(Tag.name == name).first()
def get_tags_by_category(db: Session, category: str, skip: int = 0, limit: int = 100) -> tuple[list[Tag], int]:
"""
Retrieve a paginated list of tags by category.
Args:
db: Database session
category: Category to filter by
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of tags, total count)
Example:
```python
tags, total = get_tags_by_category(db, "technology", skip=0, limit=50)
print(f"Retrieved {len(tags)} of {total} technology tags")
```
"""
# Get total count for category
total = db.query(Tag).filter(Tag.category == category).count()
# Get paginated results
tags = (
db.query(Tag)
.filter(Tag.category == category)
.order_by(Tag.name.asc())
.offset(skip)
.limit(limit)
.all()
)
return tags, total
def create_tag(db: Session, tag_data: TagCreate) -> Tag:
"""
Create a new tag.
Args:
db: Database session
tag_data: Tag creation data
Returns:
Tag: The created tag object
Raises:
HTTPException: 409 if tag with name already exists
HTTPException: 500 if database error occurs
Example:
```python
tag_data = TagCreate(
name="Windows",
category="technology",
description="Microsoft Windows operating system"
)
tag = create_tag(db, tag_data)
print(f"Created tag: {tag.id}")
```
"""
# Check if tag with name already exists
existing_tag = get_tag_by_name(db, tag_data.name)
if existing_tag:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Tag with name '{tag_data.name}' already exists"
)
try:
# Create new tag instance
db_tag = Tag(**tag_data.model_dump())
# Add to database
db.add(db_tag)
db.commit()
db.refresh(db_tag)
return db_tag
except IntegrityError as e:
db.rollback()
# Handle unique constraint violations
if "name" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Tag with name '{tag_data.name}' already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create tag: {str(e)}"
)
def update_tag(db: Session, tag_id: UUID, tag_data: TagUpdate) -> Tag:
"""
Update an existing tag.
Args:
db: Database session
tag_id: UUID of the tag to update
tag_data: Tag update data (only provided fields will be updated)
Returns:
Tag: The updated tag object
Raises:
HTTPException: 404 if tag not found
HTTPException: 409 if update would violate unique constraints
HTTPException: 500 if database error occurs
Example:
```python
update_data = TagUpdate(
description="Updated description",
category="infrastructure"
)
tag = update_tag(db, tag_id, update_data)
print(f"Updated tag: {tag.name}")
```
"""
# Get existing tag
tag = get_tag_by_id(db, tag_id)
try:
# Update only provided fields
update_data = tag_data.model_dump(exclude_unset=True)
# If updating name, check if new name is already taken
if "name" in update_data and update_data["name"] != tag.name:
existing = get_tag_by_name(db, update_data["name"])
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Tag with name '{update_data['name']}' already exists"
)
# Apply updates
for field, value in update_data.items():
setattr(tag, field, value)
db.commit()
db.refresh(tag)
return tag
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "name" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Tag with this name already exists"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update tag: {str(e)}"
)
def delete_tag(db: Session, tag_id: UUID) -> dict:
"""
Delete a tag by its ID.
Args:
db: Database session
tag_id: UUID of the tag to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if tag not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_tag(db, tag_id)
print(result["message"]) # "Tag deleted successfully"
```
"""
# Get existing tag (raises 404 if not found)
tag = get_tag_by_id(db, tag_id)
try:
db.delete(tag)
db.commit()
return {
"message": "Tag deleted successfully",
"tag_id": str(tag_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tag: {str(e)}"
)

View File

@@ -0,0 +1,449 @@
"""
Task service layer for business logic and database operations.
This module handles all database operations for tasks, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.task import Task as TaskModel
from api.models.session import Session as SessionModel
from api.models.client import Client
from api.models.project import Project
from api.schemas.task import TaskCreate, TaskUpdate
def get_tasks(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[TaskModel], int]:
"""
Retrieve a paginated list of tasks.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of tasks, total count)
Example:
```python
tasks, total = get_tasks(db, skip=0, limit=50)
print(f"Retrieved {len(tasks)} of {total} tasks")
```
"""
# Get total count
total = db.query(TaskModel).count()
# Get paginated results, ordered by task_order ascending
tasks = (
db.query(TaskModel)
.order_by(TaskModel.task_order.asc())
.offset(skip)
.limit(limit)
.all()
)
return tasks, total
def get_task_by_id(db: Session, task_id: UUID) -> TaskModel:
"""
Retrieve a single task by its ID.
Args:
db: Database session
task_id: UUID of the task to retrieve
Returns:
TaskModel: The task object
Raises:
HTTPException: 404 if task not found
Example:
```python
task = get_task_by_id(db, task_id)
print(f"Found task: {task.title}")
```
"""
task = db.query(TaskModel).filter(TaskModel.id == str(task_id)).first()
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task with ID {task_id} not found"
)
return task
def get_tasks_by_session(db: Session, session_id: UUID, skip: int = 0, limit: int = 100) -> tuple[list[TaskModel], int]:
"""
Retrieve tasks for a specific session.
Args:
db: Database session
session_id: UUID of the session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of tasks, total count)
Example:
```python
tasks, total = get_tasks_by_session(db, session_id)
print(f"Found {total} tasks for session")
```
"""
# Get total count
total = db.query(TaskModel).filter(TaskModel.session_id == str(session_id)).count()
# Get paginated results
tasks = (
db.query(TaskModel)
.filter(TaskModel.session_id == str(session_id))
.order_by(TaskModel.task_order.asc())
.offset(skip)
.limit(limit)
.all()
)
return tasks, total
def get_tasks_by_status(db: Session, status_filter: str, skip: int = 0, limit: int = 100) -> tuple[list[TaskModel], int]:
"""
Retrieve tasks by status.
Args:
db: Database session
status_filter: Status to filter by (pending, in_progress, blocked, completed, cancelled)
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of tasks, total count)
Example:
```python
tasks, total = get_tasks_by_status(db, "in_progress")
print(f"Found {total} in-progress tasks")
```
"""
# Get total count
total = db.query(TaskModel).filter(TaskModel.status == status_filter).count()
# Get paginated results
tasks = (
db.query(TaskModel)
.filter(TaskModel.status == status_filter)
.order_by(TaskModel.task_order.asc())
.offset(skip)
.limit(limit)
.all()
)
return tasks, total
def create_task(db: Session, task_data: TaskCreate) -> TaskModel:
"""
Create a new task.
Args:
db: Database session
task_data: Task creation data
Returns:
TaskModel: The created task object
Raises:
HTTPException: 404 if referenced session, client, or project not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
task_data = TaskCreate(
title="Implement authentication",
task_order=1,
status="pending",
session_id="123e4567-e89b-12d3-a456-426614174000"
)
task = create_task(db, task_data)
print(f"Created task: {task.id}")
```
"""
try:
# Validate foreign keys if provided
if task_data.session_id:
session = db.query(SessionModel).filter(SessionModel.id == str(task_data.session_id)).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {task_data.session_id} not found"
)
if task_data.client_id:
client = db.query(Client).filter(Client.id == str(task_data.client_id)).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {task_data.client_id} not found"
)
if task_data.project_id:
project = db.query(Project).filter(Project.id == str(task_data.project_id)).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project with ID {task_data.project_id} not found"
)
if task_data.parent_task_id:
parent_task = db.query(TaskModel).filter(TaskModel.id == str(task_data.parent_task_id)).first()
if not parent_task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Parent task with ID {task_data.parent_task_id} not found"
)
# Create new task instance
db_task = TaskModel(**task_data.model_dump())
# Add to database
db.add(db_task)
db.commit()
db.refresh(db_task)
return db_task
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
# Handle foreign key constraint violations
if "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid session_id: {task_data.session_id}"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid client_id: {task_data.client_id}"
)
elif "project_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid project_id: {task_data.project_id}"
)
elif "parent_task_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid parent_task_id: {task_data.parent_task_id}"
)
elif "ck_tasks_type" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid task_type. Must be one of: implementation, research, review, deployment, testing, documentation, bugfix, analysis"
)
elif "ck_tasks_status" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid status. Must be one of: pending, in_progress, blocked, completed, cancelled"
)
elif "ck_tasks_complexity" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid estimated_complexity. Must be one of: trivial, simple, moderate, complex, very_complex"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create task: {str(e)}"
)
def update_task(db: Session, task_id: UUID, task_data: TaskUpdate) -> TaskModel:
"""
Update an existing task.
Args:
db: Database session
task_id: UUID of the task to update
task_data: Task update data (only provided fields will be updated)
Returns:
TaskModel: The updated task object
Raises:
HTTPException: 404 if task, session, client, or project not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
update_data = TaskUpdate(
status="completed",
completed_at=datetime.now()
)
task = update_task(db, task_id, update_data)
print(f"Updated task: {task.title}")
```
"""
# Get existing task
task = get_task_by_id(db, task_id)
try:
# Update only provided fields
update_data = task_data.model_dump(exclude_unset=True)
# Validate foreign keys if being updated
if "session_id" in update_data and update_data["session_id"]:
session = db.query(SessionModel).filter(SessionModel.id == str(update_data["session_id"])).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {update_data['session_id']} not found"
)
if "client_id" in update_data and update_data["client_id"]:
client = db.query(Client).filter(Client.id == str(update_data["client_id"])).first()
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Client with ID {update_data['client_id']} not found"
)
if "project_id" in update_data and update_data["project_id"]:
project = db.query(Project).filter(Project.id == str(update_data["project_id"])).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project with ID {update_data['project_id']} not found"
)
if "parent_task_id" in update_data and update_data["parent_task_id"]:
parent_task = db.query(TaskModel).filter(TaskModel.id == str(update_data["parent_task_id"])).first()
if not parent_task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Parent task with ID {update_data['parent_task_id']} not found"
)
# Apply updates
for field, value in update_data.items():
setattr(task, field, value)
db.commit()
db.refresh(task)
return task
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid session_id"
)
elif "client_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid client_id"
)
elif "project_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid project_id"
)
elif "parent_task_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid parent_task_id"
)
elif "ck_tasks_type" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid task_type. Must be one of: implementation, research, review, deployment, testing, documentation, bugfix, analysis"
)
elif "ck_tasks_status" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid status. Must be one of: pending, in_progress, blocked, completed, cancelled"
)
elif "ck_tasks_complexity" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid estimated_complexity. Must be one of: trivial, simple, moderate, complex, very_complex"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update task: {str(e)}"
)
def delete_task(db: Session, task_id: UUID) -> dict:
"""
Delete a task by its ID.
Args:
db: Database session
task_id: UUID of the task to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if task not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_task(db, task_id)
print(result["message"]) # "Task deleted successfully"
```
"""
# Get existing task (raises 404 if not found)
task = get_task_by_id(db, task_id)
try:
db.delete(task)
db.commit()
return {
"message": "Task deleted successfully",
"task_id": str(task_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete task: {str(e)}"
)

View File

@@ -0,0 +1,455 @@
"""
WorkItem service layer for business logic and database operations.
This module handles all database operations for work items, providing a clean
separation between the API routes and data access layer.
"""
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from api.models.work_item import WorkItem
from api.models.session import Session as SessionModel
from api.schemas.work_item import WorkItemCreate, WorkItemUpdate
def get_work_items(db: Session, skip: int = 0, limit: int = 100) -> tuple[list[WorkItem], int]:
"""
Retrieve a paginated list of work items.
Args:
db: Database session
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
tuple: (list of work items, total count)
Example:
```python
work_items, total = get_work_items(db, skip=0, limit=50)
print(f"Retrieved {len(work_items)} of {total} work items")
```
"""
# Get total count
total = db.query(WorkItem).count()
# Get paginated results, ordered by created_at descending (newest first)
work_items = (
db.query(WorkItem)
.order_by(WorkItem.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return work_items, total
def get_work_item_by_id(db: Session, work_item_id: UUID) -> WorkItem:
"""
Retrieve a single work item by its ID.
Args:
db: Database session
work_item_id: UUID of the work item to retrieve
Returns:
WorkItem: The work item object
Raises:
HTTPException: 404 if work item not found
Example:
```python
work_item = get_work_item_by_id(db, work_item_id)
print(f"Found work item: {work_item.title}")
```
"""
work_item = db.query(WorkItem).filter(WorkItem.id == str(work_item_id)).first()
if not work_item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Work item with ID {work_item_id} not found"
)
return work_item
def get_work_items_by_session(db: Session, session_id: str, skip: int = 0, limit: int = 100) -> tuple[list[WorkItem], int]:
"""
Retrieve work items for a specific session.
Args:
db: Database session
session_id: Session UUID
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of work items, total count)
Example:
```python
work_items, total = get_work_items_by_session(db, session_id)
print(f"Session has {total} work items")
```
"""
total = db.query(WorkItem).filter(WorkItem.session_id == str(session_id)).count()
work_items = (
db.query(WorkItem)
.filter(WorkItem.session_id == str(session_id))
.order_by(WorkItem.item_order, WorkItem.created_at)
.offset(skip)
.limit(limit)
.all()
)
return work_items, total
def get_work_items_by_project(db: Session, project_id: str, skip: int = 0, limit: int = 100) -> tuple[list[WorkItem], int]:
"""
Retrieve work items for a specific project (through sessions).
Args:
db: Database session
project_id: Project UUID
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of work items, total count)
Example:
```python
work_items, total = get_work_items_by_project(db, project_id)
print(f"Project has {total} work items")
```
"""
total = (
db.query(WorkItem)
.join(SessionModel, WorkItem.session_id == SessionModel.id)
.filter(SessionModel.project_id == str(project_id))
.count()
)
work_items = (
db.query(WorkItem)
.join(SessionModel, WorkItem.session_id == SessionModel.id)
.filter(SessionModel.project_id == str(project_id))
.order_by(WorkItem.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return work_items, total
def get_work_items_by_client(db: Session, client_id: str, skip: int = 0, limit: int = 100) -> tuple[list[WorkItem], int]:
"""
Retrieve work items for a specific client (through sessions).
Args:
db: Database session
client_id: Client UUID
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of work items, total count)
Example:
```python
work_items, total = get_work_items_by_client(db, client_id)
print(f"Client has {total} work items")
```
"""
total = (
db.query(WorkItem)
.join(SessionModel, WorkItem.session_id == SessionModel.id)
.filter(SessionModel.client_id == str(client_id))
.count()
)
work_items = (
db.query(WorkItem)
.join(SessionModel, WorkItem.session_id == SessionModel.id)
.filter(SessionModel.client_id == str(client_id))
.order_by(WorkItem.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return work_items, total
def get_work_items_by_status(db: Session, status_filter: str, skip: int = 0, limit: int = 100) -> tuple[list[WorkItem], int]:
"""
Retrieve work items by status.
Args:
db: Database session
status_filter: Status to filter by (completed, in_progress, blocked, pending, deferred)
skip: Number of records to skip
limit: Maximum number of records to return
Returns:
tuple: (list of work items, total count)
Example:
```python
work_items, total = get_work_items_by_status(db, "in_progress")
print(f"Found {total} in progress work items")
```
"""
total = db.query(WorkItem).filter(WorkItem.status == status_filter).count()
work_items = (
db.query(WorkItem)
.filter(WorkItem.status == status_filter)
.order_by(WorkItem.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return work_items, total
def validate_session_exists(db: Session, session_id: str) -> None:
"""
Validate that a session exists.
Args:
db: Database session
session_id: Session UUID to validate
Raises:
HTTPException: 404 if session not found
Example:
```python
validate_session_exists(db, session_id)
# Continues if session exists, raises HTTPException if not
```
"""
session = db.query(SessionModel).filter(SessionModel.id == str(session_id)).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {session_id} not found"
)
def create_work_item(db: Session, work_item_data: WorkItemCreate) -> WorkItem:
"""
Create a new work item.
Args:
db: Database session
work_item_data: Work item creation data
Returns:
WorkItem: The created work item object
Raises:
HTTPException: 404 if session not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
work_item_data = WorkItemCreate(
session_id="123e4567-e89b-12d3-a456-426614174000",
category="infrastructure",
title="Configure firewall rules",
description="Updated firewall rules for new server",
status="completed"
)
work_item = create_work_item(db, work_item_data)
print(f"Created work item: {work_item.id}")
```
"""
# Validate session exists
validate_session_exists(db, work_item_data.session_id)
try:
# Create new work item instance
db_work_item = WorkItem(**work_item_data.model_dump())
# Add to database
db.add(db_work_item)
db.commit()
db.refresh(db_work_item)
return db_work_item
except IntegrityError as e:
db.rollback()
# Handle foreign key constraint violations
if "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session with ID {work_item_data.session_id} not found"
)
elif "category" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid category. Must be one of: infrastructure, troubleshooting, configuration, development, maintenance, security, documentation"
)
elif "status" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid status. Must be one of: completed, in_progress, blocked, pending, deferred"
)
elif "priority" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid priority. Must be one of: critical, high, medium, low"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create work item: {str(e)}"
)
def update_work_item(db: Session, work_item_id: UUID, work_item_data: WorkItemUpdate) -> WorkItem:
"""
Update an existing work item.
Args:
db: Database session
work_item_id: UUID of the work item to update
work_item_data: Work item update data (only provided fields will be updated)
Returns:
WorkItem: The updated work item object
Raises:
HTTPException: 404 if work item or session not found
HTTPException: 422 if validation fails
HTTPException: 500 if database error occurs
Example:
```python
update_data = WorkItemUpdate(
status="completed",
actual_minutes=45
)
work_item = update_work_item(db, work_item_id, update_data)
print(f"Updated work item: {work_item.title}")
```
"""
# Get existing work item
work_item = get_work_item_by_id(db, work_item_id)
try:
# Update only provided fields
update_data = work_item_data.model_dump(exclude_unset=True)
# If updating session_id, validate session exists
if "session_id" in update_data and update_data["session_id"] != work_item.session_id:
validate_session_exists(db, update_data["session_id"])
# Apply updates
for field, value in update_data.items():
setattr(work_item, field, value)
db.commit()
db.refresh(work_item)
return work_item
except HTTPException:
db.rollback()
raise
except IntegrityError as e:
db.rollback()
if "session_id" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
elif "category" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid category. Must be one of: infrastructure, troubleshooting, configuration, development, maintenance, security, documentation"
)
elif "status" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid status. Must be one of: completed, in_progress, blocked, pending, deferred"
)
elif "priority" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid priority. Must be one of: critical, high, medium, low"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update work item: {str(e)}"
)
def delete_work_item(db: Session, work_item_id: UUID) -> dict:
"""
Delete a work item by its ID.
Args:
db: Database session
work_item_id: UUID of the work item to delete
Returns:
dict: Success message
Raises:
HTTPException: 404 if work item not found
HTTPException: 500 if database error occurs
Example:
```python
result = delete_work_item(db, work_item_id)
print(result["message"]) # "Work item deleted successfully"
```
"""
# Get existing work item (raises 404 if not found)
work_item = get_work_item_by_id(db, work_item_id)
try:
db.delete(work_item)
db.commit()
return {
"message": "Work item deleted successfully",
"work_item_id": str(work_item_id)
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete work item: {str(e)}"
)