"""Authentication endpoints and dependencies.""" from datetime import datetime, timedelta, timezone from typing import Any from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import BaseModel from sqlalchemy.orm import Session from config import settings from database import get_db from models import AuditLog, User from schemas import LoginResponse, UserOut router = APIRouter(prefix="/api/auth", tags=["Auth"]) password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") bearer_scheme = HTTPBearer(auto_error=False) SUPPORTED_ROLES = {"admin", "annotator"} class LoginRequest(BaseModel): username: str password: str def hash_password(password: str) -> str: """Hash a plain password for storage.""" return password_context.hash(password) def verify_password(password: str, password_hash: str) -> bool: """Verify a plain password against a stored hash.""" return password_context.verify(password, password_hash) def create_access_token(user: User, expires_delta: timedelta | None = None) -> str: """Create a signed JWT access token for a user.""" expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(minutes=settings.access_token_expire_minutes) ) payload: dict[str, Any] = { "sub": str(user.id), "username": user.username, "role": user.role, "exp": expire, } return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm) def ensure_default_admin(db: Session) -> User: """Create and enforce the single default administrator account.""" existing = db.query(User).filter(User.username == settings.default_admin_username).first() if existing: changed = False if existing.role != "admin": existing.role = "admin" changed = True if not existing.is_active: existing.is_active = 1 changed = True extra_admins = db.query(User).filter( User.role == "admin", User.id != existing.id, ).all() for user in extra_admins: user.role = "annotator" changed = True if changed: db.commit() db.refresh(existing) return existing user = User( username=settings.default_admin_username, password_hash=hash_password(settings.default_admin_password), role="admin", is_active=1, ) db.add(user) db.commit() db.refresh(user) extra_admins = db.query(User).filter( User.role == "admin", User.id != user.id, ).all() if extra_admins: for extra_user in extra_admins: extra_user.role = "annotator" db.commit() db.refresh(user) return user def normalize_user_role(db: Session, user: User) -> User: """Keep legacy accounts within the current two-role policy.""" desired_role = "admin" if user.username == settings.default_admin_username else "annotator" changed = False if user.role != desired_role: user.role = desired_role changed = True if user.username == settings.default_admin_username and not user.is_active: user.is_active = 1 changed = True if changed: db.commit() db.refresh(user) return user def get_current_user( credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), db: Session = Depends(get_db), ) -> User: """Resolve and validate the current user from the Bearer token.""" if credentials is None or credentials.scheme.lower() != "bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode( credentials.credentials, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm], ) user_id = int(payload.get("sub")) except (JWTError, TypeError, ValueError) as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", headers={"WWW-Authenticate": "Bearer"}, ) from exc user = db.query(User).filter(User.id == user_id).first() if user: user = normalize_user_role(db, user) if not user or not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive or missing user", headers={"WWW-Authenticate": "Bearer"}, ) return user def require_admin(current_user: User = Depends(get_current_user)) -> User: """Require the current user to have the administrator role.""" if current_user.role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required") return current_user def require_editor(current_user: User = Depends(get_current_user)) -> User: """Require a user role that can modify segmentation data.""" if current_user.role not in SUPPORTED_ROLES: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required") return current_user def write_audit_log( db: Session, *, actor: User | None, action: str, target_type: str | None = None, target_id: str | int | None = None, detail: dict[str, Any] | None = None, ) -> AuditLog: """Persist a compact audit event.""" log = AuditLog( actor_user_id=actor.id if actor else None, action=action, target_type=target_type, target_id=str(target_id) if target_id is not None else None, detail=detail or {}, ) db.add(log) db.commit() db.refresh(log) return log @router.post("/login", response_model=LoginResponse) def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict: """Authenticate a user and return a signed JWT.""" ensure_default_admin(db) user = db.query(User).filter(User.username == payload.username).first() if user: user = normalize_user_role(db, user) if not user or not user.is_active or not verify_password(payload.password, user.password_hash): write_audit_log( db, actor=None, action="auth.login_failed", target_type="user", target_id=payload.username, detail={"username": payload.username}, ) raise HTTPException(status_code=401, detail="Invalid credentials") write_audit_log( db, actor=user, action="auth.login_success", target_type="user", target_id=user.id, detail={"username": user.username}, ) return { "token": create_access_token(user), "token_type": "bearer", "username": user.username, "user": user, } @router.get("/me", response_model=UserOut) def read_current_user(current_user: User = Depends(get_current_user)) -> User: """Return the authenticated user profile.""" return current_user