import logging from datetime import datetime, timezone, timedelta from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, status from api.middleware.auth import get_current_user from api.schemas.gravityzone import ( GZCompanyItem, GZEndpointDetail, GZEndpointItem, GZStatusResponse, GZSweepResult, ) from api.services.gravityzone_service import ( ACG_COMPANIES_CONTAINER_ID, get_gravityzone_service, ) logger = logging.getLogger(__name__) router = APIRouter() def _raise_on_failure(result, detail_prefix: str = "GravityZone error") -> dict: if not result.success: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"{detail_prefix}: {result.error or 'unknown error'}", ) return result.data or {} # -------------------------------------------------------------------------- # Status # -------------------------------------------------------------------------- @router.get( "/status", response_model=GZStatusResponse, summary="GravityZone API key status and license info", status_code=status.HTTP_200_OK, ) async def get_status(current_user: dict = Depends(get_current_user)): service = get_gravityzone_service() result = await service.get_api_status() data = _raise_on_failure(result, "GravityZone status") return GZStatusResponse( enabled_apis=data.get("enabledApis", []), key_created_at=data.get("createdAt"), used_slots=data.get("usedSlots"), total_slots=data.get("totalSlots"), expiry_date=data.get("expiryDate"), ) # -------------------------------------------------------------------------- # Companies # -------------------------------------------------------------------------- @router.get( "/companies", summary="List GravityZone client companies", status_code=status.HTTP_200_OK, ) async def list_companies( page: int = Query(1, ge=1), per_page: int = Query(100, ge=1, le=500), current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() result = await service.list_client_companies(page=page, per_page=per_page) data = _raise_on_failure(result, "GravityZone companies") companies = [ GZCompanyItem( id=item.get("id", ""), name=item.get("name", ""), type=item.get("type", 1), ) for item in data.get("items", []) ] return {"total": data.get("total", len(companies)), "companies": companies} # -------------------------------------------------------------------------- # Endpoints # -------------------------------------------------------------------------- @router.get( "/companies/{company_id}/endpoints", summary="List endpoints for a GravityZone company", status_code=status.HTTP_200_OK, ) async def list_endpoints( company_id: str, page: int = Query(1, ge=1), per_page: int = Query(50, ge=1, le=200), current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() result = await service.list_endpoints(company_id, page=page, per_page=per_page) data = _raise_on_failure(result, "GravityZone endpoints") endpoints = [ GZEndpointItem( id=item.get("id", ""), name=item.get("name", ""), fqdn=item.get("fqdn"), ip=item.get("ip"), os_version=item.get("operatingSystemVersion"), is_managed=bool(item.get("isManaged", False)), policy_name=(item.get("policy") or {}).get("name"), ) for item in data.get("items", []) ] return {"total": data.get("total", len(endpoints)), "endpoints": endpoints} @router.get( "/endpoints/{endpoint_id}", response_model=GZEndpointDetail, summary="Get detailed info for a single GravityZone endpoint", status_code=status.HTTP_200_OK, ) async def get_endpoint( endpoint_id: str, current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() result = await service.get_endpoint_details(endpoint_id) data = _raise_on_failure(result, "GravityZone endpoint detail") malware = data.get("malwareStatus", {}) agent = data.get("agent", {}) return GZEndpointDetail( id=data.get("id", endpoint_id), name=data.get("name", ""), company_id=data.get("companyId"), infected=bool(malware.get("infected", False)), detection_active=bool(malware.get("detection", False)), signature_outdated=bool(agent.get("signatureOutdated", False)), product_outdated=bool(agent.get("productOutdated", False)), agent_version=agent.get("productVersion"), engine_version=agent.get("engineVersion"), last_seen=data.get("lastSeen"), last_update=agent.get("lastUpdate"), state=data.get("state", 0), modules=data.get("modules"), ) # -------------------------------------------------------------------------- # Quarantine # -------------------------------------------------------------------------- @router.get( "/companies/{company_id}/quarantine", summary="List quarantine items for a GravityZone company", status_code=status.HTTP_200_OK, ) async def list_quarantine( company_id: str, page: int = Query(1, ge=1), per_page: int = Query(50, ge=1, le=200), current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() result = await service.list_quarantine_items(company_id, page=page, per_page=per_page) data = _raise_on_failure(result, "GravityZone quarantine") return {"total": data.get("total", 0), "items": data.get("items", [])} # -------------------------------------------------------------------------- # Security sweep # -------------------------------------------------------------------------- def _build_sweep_result(summaries) -> GZSweepResult: stale_cutoff = datetime.now(timezone.utc) - timedelta(days=7) not_seen_recently = 0 for s in summaries: if s.last_seen: try: last_seen_dt = datetime.fromisoformat( s.last_seen.replace("Z", "+00:00") ) if last_seen_dt.tzinfo is None: last_seen_dt = last_seen_dt.replace(tzinfo=timezone.utc) if last_seen_dt < stale_cutoff: not_seen_recently += 1 except (ValueError, AttributeError): pass return GZSweepResult( total=len(summaries), infected=sum(1 for s in summaries if s.infected), signature_outdated=sum(1 for s in summaries if s.signature_outdated), product_outdated=sum(1 for s in summaries if s.product_outdated), not_seen_recently=not_seen_recently, endpoints=[ { "endpoint_id": s.endpoint_id, "name": s.name, "company_id": s.company_id, "infected": s.infected, "detection_active": s.detection_active, "signature_outdated": s.signature_outdated, "product_outdated": s.product_outdated, "last_seen": s.last_seen, "agent_version": s.agent_version, "state": s.state, } for s in summaries ], ) @router.get( "/sweep/{parent_id}", response_model=GZSweepResult, summary="Security sweep for all endpoints under a parent ID", status_code=status.HTTP_200_OK, ) async def sweep_parent( parent_id: str, current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() summaries = await service.security_sweep(parent_id) return _build_sweep_result(summaries) @router.get( "/sweep", response_model=GZSweepResult, summary="Security sweep across all ACG client companies", status_code=status.HTTP_200_OK, ) async def sweep_all_clients( current_user: dict = Depends(get_current_user), ): service = get_gravityzone_service() summaries = await service.security_sweep_all_clients() return _build_sweep_result(summaries)