"""Service layer for CoordSessionLock.""" from datetime import datetime, timedelta, timezone from typing import Optional from uuid import UUID from fastapi import HTTPException, status from sqlalchemy import and_, or_ from sqlalchemy.orm import Session from api.models.coord_session_lock import CoordSessionLock from api.schemas.coord_session_lock import CoordSessionLockCreate def _active_filter(q): """Apply the 'lock is currently active' predicate to a query.""" now = datetime.now(timezone.utc).replace(tzinfo=None) return q.filter( CoordSessionLock.released_at.is_(None), or_( CoordSessionLock.expires_at.is_(None), CoordSessionLock.expires_at > now, ), ) def get_active_locks( db: Session, project_key: Optional[str] = None, session_id: Optional[str] = None, skip: int = 0, limit: int = 100, ) -> tuple[list[CoordSessionLock], int]: """Return currently active locks with optional filters.""" q = db.query(CoordSessionLock) if project_key: q = q.filter(CoordSessionLock.project_key == project_key) if session_id: q = q.filter(CoordSessionLock.session_id == session_id) q = _active_filter(q) total = q.count() locks = q.order_by(CoordSessionLock.acquired_at.desc()).offset(skip).limit(limit).all() return locks, total def check_resource_locked( db: Session, project_key: str, resource: str ) -> Optional[CoordSessionLock]: """Return the active lock on a resource, or None if unlocked.""" q = db.query(CoordSessionLock).filter( CoordSessionLock.project_key == project_key, CoordSessionLock.resource == resource, ) return _active_filter(q).first() def claim_lock(db: Session, data: CoordSessionLockCreate) -> CoordSessionLock: """Claim a resource lock, computing expires_at from ttl_hours.""" expires_at: Optional[datetime] = None if data.ttl_hours > 0: expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=data.ttl_hours) try: lock = CoordSessionLock( project_key=data.project_key, session_id=data.session_id, resource=data.resource, description=data.description, acquired_at=datetime.now(timezone.utc).replace(tzinfo=None), expires_at=expires_at, released_at=None, ) db.add(lock) db.commit() db.refresh(lock) return lock except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to claim lock: {e}" ) def release_lock(db: Session, lock_id: UUID, session_id: str) -> CoordSessionLock: """Release a specific lock; only the owning session may release it.""" lock = db.query(CoordSessionLock).filter(CoordSessionLock.id == str(lock_id)).first() if not lock: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Lock {lock_id} not found" ) if lock.session_id != session_id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only the session that claimed this lock may release it" ) try: lock.released_at = datetime.now(timezone.utc).replace(tzinfo=None) db.commit() db.refresh(lock) return lock except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to release lock: {e}" ) def release_all_session_locks(db: Session, session_id: str) -> dict: """Release all active locks held by a session (cleanup on session end).""" now = datetime.now(timezone.utc).replace(tzinfo=None) try: q = db.query(CoordSessionLock).filter( CoordSessionLock.session_id == session_id, CoordSessionLock.released_at.is_(None), ) count = q.count() q.update({"released_at": now}, synchronize_session=False) db.commit() return {"message": f"Released {count} lock(s) for session '{session_id}'", "count": count} except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to release session locks: {e}" )