收敛用户角色并共享项目库

- 后端限制系统只保留默认 admin 管理员,新建用户固定为标注员,并拒绝观察员或额外管理员角色。

- 将项目、帧、媒体解析、AI 标注、任务、Dashboard 和导出接口改为共享项目库访问,标注员具备同等项目管理和标注能力。

- 前端用户管理移除角色选择和观察员入口,只展示唯一管理员与标注员状态。

- 更新后端/前端测试,覆盖唯一 admin、旧 viewer 归一为标注员、用户删除和共享项目库访问。

- 同步更新 AGENTS 与 doc 文档中的角色权限、共享项目库和测试计划说明。
This commit is contained in:
2026-05-04 05:20:28 +08:00
parent 02635abab1
commit 523beeb446
21 changed files with 214 additions and 172 deletions

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
from routers.auth import SUPPORTED_ROLES, ensure_default_admin, hash_password, normalize_user_role, require_admin, write_audit_log
from schemas import (
AdminUserCreate,
AdminUserUpdate,
@@ -30,18 +30,22 @@ from services.default_templates import restore_default_templates
router = APIRouter(prefix="/api/admin", tags=["Admin"])
VALID_ROLES = {"admin", "annotator", "viewer"}
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
def _normalize_role(role: str | None) -> str:
normalized = (role or "annotator").strip().lower()
if normalized not in VALID_ROLES:
if normalized not in SUPPORTED_ROLES:
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
return normalized
def _assert_non_admin_role(role: str) -> None:
if role == "admin":
raise HTTPException(status_code=400, detail="Only the default admin account can have admin role")
@router.get("/users", response_model=List[UserOut], summary="List users")
def list_users(
db: Session = Depends(get_db),
@@ -49,7 +53,8 @@ def list_users(
) -> List[User]:
"""Return all users for the administrator console."""
_ = admin_user
return db.query(User).order_by(User.id).all()
users = db.query(User).order_by(User.id).all()
return [normalize_user_role(db, user) for user in users]
@router.post(
@@ -69,10 +74,12 @@ def create_user(
raise HTTPException(status_code=400, detail="Username is required")
if len(payload.password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
role = _normalize_role(payload.role)
_assert_non_admin_role(role)
user = User(
username=username,
password_hash=hash_password(payload.password),
role=_normalize_role(payload.role),
role=role,
is_active=1 if payload.is_active else 0,
)
db.add(user)
@@ -104,6 +111,7 @@ def update_user(
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user = normalize_user_role(db, user)
updates = payload.model_dump(exclude_unset=True)
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
@@ -111,6 +119,8 @@ def update_user(
username = (updates["username"] or "").strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if user.role == "admin" and username != settings.default_admin_username:
raise HTTPException(status_code=400, detail="Default admin username cannot be changed")
user.username = username
if "password" in updates:
password = updates["password"] or ""
@@ -119,8 +129,11 @@ def update_user(
user.password_hash = hash_password(password)
if "role" in updates:
next_role = _normalize_role(updates["role"])
if user.id == admin_user.id and next_role != "admin":
raise HTTPException(status_code=400, detail="Cannot remove your own admin role")
if user.username == settings.default_admin_username:
if next_role != "admin":
raise HTTPException(status_code=400, detail="Cannot remove the default admin role")
else:
_assert_non_admin_role(next_role)
user.role = next_role
if "is_active" in updates:
if user.id == admin_user.id and not updates["is_active"]:
@@ -158,9 +171,9 @@ def delete_user(
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
owned_project_count = db.query(Project).filter(Project.owner_user_id == user_id).count()
if owned_project_count:
raise HTTPException(status_code=409, detail="User owns projects; deactivate or migrate projects first")
user = normalize_user_role(db, user)
if user.role == "admin":
raise HTTPException(status_code=400, detail="Cannot delete the default admin account")
username = user.username
db.delete(user)
db.commit()

View File

@@ -45,21 +45,20 @@ GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
def _shared_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
_ = current_user
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
def _shared_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
_ = current_user
query = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
.filter(Frame.id == frame_id)
)
if project_id is not None:
query = query.filter(Frame.project_id == project_id)
@@ -480,7 +479,7 @@ def predict(
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
"""
frame = _owned_frame_or_404(payload.image_id, db, current_user)
frame = _shared_frame_or_404(payload.image_id, db, current_user)
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
@@ -649,7 +648,7 @@ def analyze_mask(
) -> dict:
"""Return backend-computed mask properties for the frontend inspector."""
if payload.frame_id is not None:
_owned_frame_or_404(payload.frame_id, db, current_user)
_shared_frame_or_404(payload.frame_id, db, current_user)
mask_data = payload.mask_data or {}
polygons = mask_data.get("polygons") or []
@@ -705,7 +704,7 @@ def smooth_mask(
to the current mask, then save through the normal annotation endpoint.
"""
if payload.frame_id is not None:
_owned_frame_or_404(payload.frame_id, db, current_user)
_shared_frame_or_404(payload.frame_id, db, current_user)
polygons = payload.mask_data.get("polygons") or []
valid_polygons = _normalize_polygons(polygons)
@@ -751,8 +750,8 @@ def propagate(
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
_shared_project_or_404(payload.project_id, db, current_user)
source_frame = _shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
@@ -881,8 +880,8 @@ def queue_propagate_task(
current_user: User = Depends(require_editor),
) -> ProcessingTaskOut:
"""Queue multiple seed/direction propagation steps as one background task."""
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
_shared_project_or_404(payload.project_id, db, current_user)
source_frame = _shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if not payload.steps:
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
@@ -936,7 +935,7 @@ def auto_segment(
current_user: User = Depends(require_editor),
) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = _owned_frame_or_404(image_id, db, current_user)
frame = _shared_frame_or_404(image_id, db, current_user)
image = _load_frame_image(frame)
try:
@@ -959,10 +958,10 @@ def save_annotation(
current_user: User = Depends(require_editor),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
_owned_project_or_404(payload.project_id, db, current_user)
_shared_project_or_404(payload.project_id, db, current_user)
if payload.frame_id:
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
_shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if payload.template_id:
_visible_template_or_404(payload.template_id, db, current_user)
@@ -998,8 +997,8 @@ async def import_gt_mask(
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
_owned_project_or_404(project_id, db, current_user)
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
_shared_project_or_404(project_id, db, current_user)
frame = _shared_frame_or_404(frame_id, db, current_user, project_id)
if unknown_color_policy not in {"discard", "undefined"}:
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
@@ -1143,11 +1142,11 @@ def list_annotations(
current_user: User = Depends(get_current_user),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
_owned_project_or_404(project_id, db, current_user)
_shared_project_or_404(project_id, db, current_user)
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
_owned_frame_or_404(frame_id, db, current_user, project_id)
_shared_frame_or_404(frame_id, db, current_user, project_id)
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@@ -1167,7 +1166,7 @@ def update_annotation(
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.filter(Annotation.id == annotation_id)
.first()
)
if not annotation:
@@ -1200,7 +1199,7 @@ def delete_annotation(
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.filter(Annotation.id == annotation_id)
.first()
)
if not annotation:

View File

@@ -18,6 +18,7 @@ from schemas import LoginResponse, UserOut
router = APIRouter(prefix="/api/auth", tags=["Auth"])
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
bearer_scheme = HTTPBearer(auto_error=False)
SUPPORTED_ROLES = {"admin", "annotator"}
class LoginRequest(BaseModel):
@@ -50,9 +51,26 @@ def create_access_token(user: User, expires_delta: timedelta | None = None) -> s
def ensure_default_admin(db: Session) -> User:
"""Create the default development admin if the user table is empty."""
"""Create and enforce the single default administrator account."""
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
if existing:
changed = False
if existing.role != "admin":
existing.role = "admin"
changed = True
if not existing.is_active:
existing.is_active = 1
changed = True
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != existing.id,
).all()
for user in extra_admins:
user.role = "annotator"
changed = True
if changed:
db.commit()
db.refresh(existing)
return existing
user = User(
username=settings.default_admin_username,
@@ -63,6 +81,31 @@ def ensure_default_admin(db: Session) -> User:
db.add(user)
db.commit()
db.refresh(user)
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != user.id,
).all()
if extra_admins:
for extra_user in extra_admins:
extra_user.role = "annotator"
db.commit()
db.refresh(user)
return user
def normalize_user_role(db: Session, user: User) -> User:
"""Keep legacy accounts within the current two-role policy."""
desired_role = "admin" if user.username == settings.default_admin_username else "annotator"
changed = False
if user.role != desired_role:
user.role = desired_role
changed = True
if user.username == settings.default_admin_username and not user.is_active:
user.is_active = 1
changed = True
if changed:
db.commit()
db.refresh(user)
return user
@@ -92,6 +135,8 @@ def get_current_user(
) from exc
user = db.query(User).filter(User.id == user_id).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -110,7 +155,7 @@ def require_admin(current_user: User = Depends(get_current_user)) -> User:
def require_editor(current_user: User = Depends(get_current_user)) -> User:
"""Require a user role that can modify segmentation data."""
if current_user.role not in {"admin", "annotator"}:
if current_user.role not in SUPPORTED_ROLES:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
return current_user
@@ -143,6 +188,8 @@ def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
"""Authenticate a user and return a signed JWT."""
ensure_default_admin(db)
user = db.query(User).filter(User.username == payload.username).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
write_audit_log(
db,

View File

@@ -58,12 +58,12 @@ def get_dashboard_overview(
current_user: User = Depends(get_current_user),
) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
owned_project_ids_query = db.query(Project.id).filter(Project.owner_user_id == current_user.id)
project_count = db.query(func.count(Project.id)).filter(Project.owner_user_id == current_user.id).scalar() or 0
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(owned_project_ids_query)).scalar() or 0
shared_project_ids_query = db.query(Project.id)
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(shared_project_ids_query)).scalar() or 0
annotation_count = (
db.query(func.count(Annotation.id))
.filter(Annotation.project_id.in_(owned_project_ids_query))
.filter(Annotation.project_id.in_(shared_project_ids_query))
.scalar()
or 0
)
@@ -76,7 +76,6 @@ def get_dashboard_overview(
active_task_count = (
db.query(func.count(ProcessingTask.id))
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
@@ -84,14 +83,12 @@ def get_dashboard_overview(
projects = (
db.query(Project)
.filter(Project.owner_user_id == current_user.id)
.order_by(Project.updated_at.desc())
.all()
)
recent_tasks = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
@@ -120,7 +117,7 @@ def get_dashboard_overview(
recent_annotations = (
db.query(Annotation)
.filter(Annotation.project_id.in_(owned_project_ids_query))
.filter(Annotation.project_id.in_(shared_project_ids_query))
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()

View File

@@ -206,10 +206,8 @@ def _frame_image_extension(frame: Frame) -> str:
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
_ = current_user
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project

View File

@@ -72,10 +72,7 @@ async def upload_media(
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.video_path = object_name
@@ -141,10 +138,7 @@ async def upload_dicom_batch(
uploaded = []
if project_id:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
else:
@@ -202,10 +196,7 @@ def parse_media(
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")

View File

@@ -16,12 +16,12 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
def _next_project_copy_name(db: Session, owner_user_id: int, source_name: str) -> str:
def _next_project_copy_name(db: Session, source_name: str) -> str:
base_name = f"{source_name} 副本"
existing_names = {
row[0]
for row in db.query(Project.name)
.filter(Project.owner_user_id == owner_user_id, Project.name.like(f"{base_name}%"))
.filter(Project.name.like(f"{base_name}%"))
.all()
}
if base_name not in existing_names:
@@ -76,7 +76,6 @@ def list_projects(
"""Retrieve a paginated list of projects."""
projects = (
db.query(Project)
.filter(Project.owner_user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
@@ -97,10 +96,7 @@ def get_project(
current_user: User = Depends(get_current_user),
) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return _prepare_project_response(project)
@@ -119,16 +115,13 @@ def copy_project(
current_user: User = Depends(require_editor),
) -> Project:
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
source = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
source = db.query(Project).filter(Project.id == project_id).first()
if not source:
raise HTTPException(status_code=404, detail="Project not found")
next_name = (payload.name or "").strip() if payload.name is not None else ""
if not next_name:
next_name = _next_project_copy_name(db, current_user.id, source.name)
next_name = _next_project_copy_name(db, source.name)
copied = Project(
name=next_name,
@@ -196,10 +189,7 @@ def update_project(
current_user: User = Depends(require_editor),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -227,10 +217,7 @@ def delete_project(
current_user: User = Depends(require_editor),
) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -255,10 +242,7 @@ def create_frame(
current_user: User = Depends(require_editor),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -282,10 +266,7 @@ def list_frames(
current_user: User = Depends(get_current_user),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -320,7 +301,6 @@ def get_frame(
.filter(
Frame.project_id == project_id,
Frame.id == frame_id,
Project.owner_user_id == current_user.id,
)
.first()
)

View File

@@ -33,13 +33,11 @@ def _now() -> datetime:
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
_ = current_user
task = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(
ProcessingTask.id == task_id,
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
)
.filter(ProcessingTask.id == task_id)
.first()
)
if not task:
@@ -60,9 +58,8 @@ def list_tasks(
current_user: User = Depends(get_current_user),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
)
_ = current_user
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
@@ -130,10 +127,7 @@ def retry_task(
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(
Project.id == previous.project_id,
Project.owner_user_id == current_user.id,
).first()
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
is_propagation_task = previous.task_type == "propagate_masks"

View File

@@ -12,14 +12,14 @@ def test_admin_user_management_and_audit_logs(client, db_session):
})
assert created.status_code == 201
user_id = created.json()["id"]
assert created.json()["role"] == "annotator"
updated = client.patch(f"/api/admin/users/{user_id}", json={
"role": "viewer",
"password": "newsecret",
"is_active": False,
})
assert updated.status_code == 200
assert updated.json()["role"] == "viewer"
assert updated.json()["role"] == "annotator"
assert updated.json()["is_active"] == 0
users = client.get("/api/admin/users")
@@ -37,8 +37,41 @@ def test_admin_user_management_and_audit_logs(client, db_session):
assert "admin.user_deleted" in actions
def test_only_default_admin_role_is_supported(client, db_session):
extra_admin = client.post("/api/admin/users", json={
"username": "chief",
"password": "secret123",
"role": "admin",
"is_active": True,
})
assert extra_admin.status_code == 400
viewer = client.post("/api/admin/users", json={
"username": "observer",
"password": "secret123",
"role": "viewer",
"is_active": True,
})
assert viewer.status_code == 400
created = client.post("/api/admin/users", json={
"username": "doctor",
"password": "secret123",
"is_active": True,
})
assert created.status_code == 201
user_id = created.json()["id"]
assert created.json()["role"] == "annotator"
assert client.patch(f"/api/admin/users/{user_id}", json={"role": "admin"}).status_code == 400
assert client.patch(f"/api/admin/users/{user_id}", json={"role": "viewer"}).status_code == 400
admin_id = client.get("/api/auth/me").json()["id"]
assert client.patch(f"/api/admin/users/{admin_id}", json={"role": "annotator"}).status_code == 400
assert client.patch(f"/api/admin/users/{admin_id}", json={"username": "chief"}).status_code == 400
def test_admin_routes_require_admin_role(client, db_session):
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
@@ -51,7 +84,7 @@ def test_admin_routes_require_admin_role(client, db_session):
client.headers.update({"Authorization": original_auth})
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
def test_legacy_viewer_role_is_promoted_to_annotator(client, db_session):
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
db_session.add(user)
@@ -61,14 +94,14 @@ def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
try:
assert client.get("/api/projects").status_code == 200
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
assert client.post("/api/projects", json={"name": "Annotator Project"}).status_code == 201
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Shared Edit"}).status_code == 200
assert client.get("/api/auth/me").json()["role"] == "annotator"
finally:
client.headers.update({"Authorization": original_auth})
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
def test_admin_cannot_delete_self_but_can_delete_project_author(client, db_session):
me = client.get("/api/auth/me").json()
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
@@ -80,7 +113,8 @@ def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
db_session.commit()
response = client.delete(f"/api/admin/users/{user.id}")
assert response.status_code == 409
assert response.status_code == 204
assert db_session.query(Project).filter(Project.name == "Owned").count() == 1
def test_demo_factory_reset_leaves_admin_and_parsed_demo_dicom(client, db_session, monkeypatch, tmp_path):

View File

@@ -186,7 +186,7 @@ def test_project_and_frame_404s(client):
assert client.get("/api/projects/999/frames/1").status_code == 404
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
def test_projects_are_shared_between_authenticated_users(client, db_session):
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
other_user = User(
username="other",
@@ -203,14 +203,19 @@ def test_projects_are_scoped_to_authenticated_owner(client, db_session):
db_session.refresh(other_project)
listing = client.get("/api/projects")
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
assert {project["id"] for project in listing.json()} == {owner_project["id"], other_project.id}
assert client.get(f"/api/projects/{other_project.id}").status_code == 200
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
try:
other_listing = client.get("/api/projects")
assert [project["id"] for project in other_listing.json()] == [other_project.id]
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
assert {project["id"] for project in other_listing.json()} == {owner_project["id"], other_project.id}
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 200
renamed = client.patch(f"/api/projects/{owner_project['id']}", json={"name": "Edited By Other"})
assert renamed.status_code == 200
assert renamed.json()["name"] == "Edited By Other"
finally:
client.headers.update({"Authorization": original_auth})
assert client.get(f"/api/projects/{owner_project['id']}").json()["name"] == "Edited By Other"