Files
ISISeg/backend/segmentation.py

233 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
import cv2
import numpy as np
from skimage.filters import frangi, threshold_otsu
from skimage.morphology import remove_small_objects, skeletonize
METHOD_DESCRIPTIONS = {
"hessian_ridge": {
"label": "Hessian / Frangi 细线增强",
"description": "多尺度 Hessian 管状结构响应,适合低对比细导丝候选提取。",
"uses_temporal": False,
},
"edge_morphology": {
"label": "边缘 + 形态学",
"description": "CLAHE、黑帽增强、Canny 边缘与线性形态学连接。",
"uses_temporal": False,
},
"temporal_difference": {
"label": "视频时序差分",
"description": "利用相邻帧运动候选抑制静态骨骼和背景结构。",
"uses_temporal": True,
},
"fusion": {
"label": "融合模式",
"description": "融合 Hessian、边缘形态学和时序差分作为默认稳健输出。",
"uses_temporal": True,
},
"compare": {
"label": "多方法对比",
"description": "对同一帧同时运行多种方法,便于调参和方案比较。",
"uses_temporal": True,
},
}
@dataclass(frozen=True)
class SegmentationOutput:
method: str
mask: np.ndarray
overlay: np.ndarray
metrics: dict[str, float | int]
def normalize01(image: np.ndarray) -> np.ndarray:
image = image.astype(np.float32)
low = float(np.percentile(image, 1))
high = float(np.percentile(image, 99))
if high <= low:
return np.zeros_like(image, dtype=np.float32)
return np.clip((image - low) / (high - low), 0.0, 1.0)
def to_gray(frame: np.ndarray) -> np.ndarray:
if frame.ndim == 2:
return frame
return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
def clahe_gray(frame: np.ndarray) -> np.ndarray:
gray = to_gray(frame)
clahe = cv2.createCLAHE(clipLimit=2.2, tileGridSize=(8, 8))
return clahe.apply(gray)
def _adaptive_cutoff(response: np.ndarray, sensitivity: float) -> float:
response = normalize01(response)
nonzero = response[response > 0]
if nonzero.size < 16:
return 1.0
sensitivity = float(np.clip(sensitivity, 0.05, 0.95))
percentile = 99.2 - sensitivity * 22.0
percentile_cut = float(np.percentile(nonzero, percentile))
try:
otsu_cut = float(threshold_otsu(nonzero))
except ValueError:
otsu_cut = percentile_cut
return max(min(percentile_cut, 0.98), otsu_cut * 0.72)
def clean_mask(mask: np.ndarray, min_area: int = 12) -> np.ndarray:
binary = mask.astype(bool)
binary = remove_small_objects(binary, max_size=max(1, int(min_area)))
cleaned = binary.astype(np.uint8) * 255
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel, iterations=1)
return cleaned
def hessian_ridge_mask(frame: np.ndarray, sensitivity: float = 0.56) -> np.ndarray:
enhanced = clahe_gray(frame)
inverted = 255 - enhanced
normalized = normalize01(inverted)
response = frangi(
normalized,
sigmas=(0.7, 1.1, 1.7, 2.3),
alpha=0.55,
beta=0.55,
gamma=12,
black_ridges=False,
)
response = normalize01(response)
cutoff = _adaptive_cutoff(response, sensitivity)
mask = response >= cutoff
return clean_mask(mask, min_area=10)
def edge_morphology_mask(frame: np.ndarray, sensitivity: float = 0.56) -> np.ndarray:
enhanced = clahe_gray(frame)
blur = cv2.GaussianBlur(enhanced, (3, 3), 0)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
blackhat = cv2.morphologyEx(blur, cv2.MORPH_BLACKHAT, kernel)
dark_line = normalize01(blackhat)
cutoff = _adaptive_cutoff(dark_line, min(0.95, sensitivity + 0.1))
candidate = (dark_line >= cutoff).astype(np.uint8) * 255
low = int(20 + (1.0 - sensitivity) * 65)
high = int(70 + (1.0 - sensitivity) * 120)
edges = cv2.Canny(blur, low, high)
candidate = cv2.dilate(candidate, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
edges = cv2.bitwise_and(edges, candidate)
line_h = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 1))
line_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
connected = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, line_h, iterations=1)
connected = cv2.morphologyEx(connected, cv2.MORPH_CLOSE, line_v, iterations=1)
connected = cv2.bitwise_or(connected, cv2.erode(candidate, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))))
return clean_mask(connected, min_area=8)
def temporal_difference_mask(
frame: np.ndarray,
previous_frame: np.ndarray | None,
sensitivity: float = 0.56,
) -> np.ndarray:
ridge = hessian_ridge_mask(frame, sensitivity=sensitivity)
if previous_frame is None:
return ridge
current = cv2.GaussianBlur(clahe_gray(frame), (5, 5), 0)
previous = cv2.GaussianBlur(clahe_gray(previous_frame), (5, 5), 0)
diff = cv2.absdiff(current, previous)
diff = normalize01(diff)
cutoff = _adaptive_cutoff(diff, min(0.92, sensitivity + 0.12))
moving = (diff >= cutoff).astype(np.uint8) * 255
moving = cv2.dilate(moving, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
blended = cv2.bitwise_or(cv2.bitwise_and(ridge, moving), cv2.bitwise_and(ridge, cv2.dilate(moving, None)))
if int(np.count_nonzero(blended)) < 8:
blended = cv2.bitwise_or(ridge, moving)
return clean_mask(blended, min_area=8)
def fusion_mask(
frame: np.ndarray,
previous_frame: np.ndarray | None = None,
sensitivity: float = 0.56,
) -> np.ndarray:
ridge = hessian_ridge_mask(frame, sensitivity=sensitivity)
edge = edge_morphology_mask(frame, sensitivity=sensitivity)
temporal = temporal_difference_mask(frame, previous_frame, sensitivity=sensitivity)
votes = (
(ridge > 0).astype(np.uint8)
+ (edge > 0).astype(np.uint8)
+ (temporal > 0).astype(np.uint8)
)
fused = votes >= 2
if int(np.count_nonzero(fused)) < 8:
fused = votes >= 1
return clean_mask(fused, min_area=10)
def overlay_mask(frame: np.ndarray, mask: np.ndarray, color: tuple[int, int, int] = (0, 220, 255)) -> np.ndarray:
if frame.ndim == 2:
base = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
else:
base = frame.copy()
color_layer = np.zeros_like(base)
color_layer[mask > 0] = color
return cv2.addWeighted(base, 0.78, color_layer, 0.72, 0)
def mask_metrics(mask: np.ndarray) -> dict[str, float | int]:
binary = mask > 0
coverage = float(np.count_nonzero(binary) / binary.size) if binary.size else 0.0
skeleton = skeletonize(binary)
component_count, _ = cv2.connectedComponents(binary.astype(np.uint8))
return {
"coverage": round(coverage, 6),
"mask_pixels": int(np.count_nonzero(binary)),
"skeleton_length": int(np.count_nonzero(skeleton)),
"components": max(0, int(component_count) - 1),
}
def segment_frame(
frame: np.ndarray,
method: str = "fusion",
previous_frame: np.ndarray | None = None,
sensitivity: float = 0.56,
) -> SegmentationOutput:
method_map: dict[str, Callable[..., np.ndarray]] = {
"hessian_ridge": hessian_ridge_mask,
"edge_morphology": edge_morphology_mask,
"temporal_difference": temporal_difference_mask,
"fusion": fusion_mask,
}
if method not in method_map:
raise ValueError(f"Unknown segmentation method: {method}")
if method in {"temporal_difference", "fusion"}:
mask = method_map[method](frame, previous_frame, sensitivity)
else:
mask = method_map[method](frame, sensitivity)
return SegmentationOutput(
method=method,
mask=mask,
overlay=overlay_mask(frame, mask),
metrics=mask_metrics(mask),
)
def compare_frame(
frame: np.ndarray,
previous_frame: np.ndarray | None = None,
sensitivity: float = 0.56,
) -> list[SegmentationOutput]:
return [
segment_frame(frame, "hessian_ridge", previous_frame, sensitivity),
segment_frame(frame, "edge_morphology", previous_frame, sensitivity),
segment_frame(frame, "temporal_difference", previous_frame, sensitivity),
segment_frame(frame, "fusion", previous_frame, sensitivity),
]