import numpy as np from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, SAM2Engine class _FakePredictor: def __init__(self, masks, scores): self.masks = masks self.scores = scores self.calls = [] def set_image(self, _image): pass def predict(self, **kwargs): self.calls.append(kwargs) return self.masks, self.scores, None def _mask(offset=0): mask = np.zeros((32, 32), dtype=np.uint8) mask[4 + offset:20 + offset, 5 + offset:22 + offset] = 1 return mask def _ready_engine(monkeypatch, predictor): monkeypatch.setattr("services.sam2_engine.SAM2_AVAILABLE", True) engine = SAM2Engine() engine._model_loaded[DEFAULT_SAM2_MODEL_ID] = True engine._predictors[DEFAULT_SAM2_MODEL_ID] = predictor return engine def test_sam2_point_prediction_requests_single_best_mask(monkeypatch): predictor = _FakePredictor( np.array([_mask()], dtype=np.uint8), np.array([0.92], dtype=np.float32), ) engine = _ready_engine(monkeypatch, predictor) polygons, scores = engine.predict_points( DEFAULT_SAM2_MODEL_ID, np.zeros((32, 32, 3), dtype=np.uint8), [[0.5, 0.5]], [1], ) assert predictor.calls[0]["multimask_output"] is False assert len(polygons) == 1 assert scores == [0.9200000166893005] def test_sam2_auto_prediction_keeps_single_best_mask(monkeypatch): predictor = _FakePredictor( np.array([_mask(0), _mask(2), _mask(4)], dtype=np.uint8), np.array([0.8, 0.7, 0.6], dtype=np.float32), ) engine = _ready_engine(monkeypatch, predictor) polygons, scores = engine.predict_auto(DEFAULT_SAM2_MODEL_ID, np.zeros((32, 32, 3), dtype=np.uint8)) assert predictor.calls[0]["multimask_output"] is False assert len(polygons) == 1 assert scores == [0.800000011920929] def test_sam2_status_exposes_selectable_variants(monkeypatch, tmp_path): checkpoint = tmp_path / "sam2.1_hiera_small.pt" checkpoint.write_bytes(b"model") monkeypatch.setattr("services.sam2_engine.settings.sam_model_path", str(tmp_path / "sam2.1_hiera_tiny.pt")) engine = SAM2Engine() status = engine.status("sam2.1_hiera_small") assert engine.normalize_model_id("sam2") == DEFAULT_SAM2_MODEL_ID assert "sam2.1_hiera_small" in engine.variant_ids() assert status["id"] == "sam2.1_hiera_small" assert status["label"] == "SAM 2.1 Small" assert status["checkpoint_exists"] is True assert status["checkpoint_path"].endswith("sam2.1_hiera_small.pt")