lucid-hf's picture
CI: deploy Docker/PDM Space
a65508e verified
from abc import ABC, abstractmethod
from pathlib import Path
import cv2
from PIL import Image
from ultralytics import YOLO
class BaseModel(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
pass
@abstractmethod
def predict_image(self, image):
pass
@abstractmethod
def predict_video(self, video):
pass
class YOLOModel(BaseModel):
def __init__(self, model_path=None):
if model_path is None:
repo_root = Path(__file__).resolve().parent
weights_path = repo_root / "models" / "yolov8n.pt"
else:
weights_path = Path(model_path)
self.model = YOLO(str(weights_path), task="detect")
def predict_image(self, image, min_confidence, classes=None):
results = self.model.predict(
image, save=False, imgsz=800, conf=min_confidence, classes=classes
)
annotated_image_filename = "annotated_image.png"
last_im = None
for result in results:
im_array = result.plot()
last_im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image
last_im.save(annotated_image_filename)
# Return PIL Image for robust display in Streamlit
return last_im if last_im is not None else Image.open(annotated_image_filename)
def predict_video(
self, video, min_confidence, target_dir_name="annotated_video", classes=None
):
self.model.predict(
video,
save=True,
project=".",
name=target_dir_name,
exist_ok=True,
imgsz=800,
conf=min_confidence,
classes=classes,
)
@staticmethod
def draw_yolo_dets(frame_bgr, result, show_score=True):
"""Draw YOLO detection results on a frame.
Args:
frame_bgr: Input frame in BGR format
result: YOLO detection result object
show_score: Whether to show confidence scores
Returns:
Annotated frame with bounding boxes and labels
"""
out = frame_bgr.copy()
boxes = getattr(result, "boxes", None)
if boxes is None:
return out
names = result.names
cls_ids = boxes.cls.numpy().astype(int)
confs = boxes.conf.numpy()
xyxy = boxes.xyxy.numpy()
for (x1_coord, y1_coord, x2_coord, y2_coord), cls, score in zip(
xyxy, cls_ids, confs
):
x1, y1, x2, y2 = map(int, (x1_coord, y1_coord, x2_coord, y2_coord))
label = names.get(int(cls), str(int(cls)))
if show_score:
label = f"{label} {score:.2f}"
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(out, (x1, y1 - th - 8), (x1 + tw + 6, y1), (0, 255, 0), -1)
cv2.putText(
out,
label,
(x1 + 3, max(0, y1 - 6)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 0),
2,
cv2.LINE_AA,
)
return out
def predict_and_visualize(
self, frame, min_confidence, classes=None, show_score=True, imgsz=640
):
"""Predict objects in a frame and return the annotated frame.
Args:
frame: Input frame (BGR format)
min_confidence: Minimum confidence threshold
classes: List of class IDs to detect (None for all)
show_score: Whether to show confidence scores
imgsz: Image size for inference
Returns:
tuple: (results, annotated_frame)
- results: YOLO detection results
- annotated_frame: Frame with bounding boxes drawn
"""
try:
results = self.model.predict(
frame,
conf=min_confidence,
iou=0.5,
verbose=False,
classes=classes,
imgsz=imgsz,
)
if results and len(results) > 0:
annotated_frame = YOLOModel.draw_yolo_dets(
frame, results[0], show_score=show_score
)
return results, annotated_frame
else:
# Return original frame if no detections
return [], frame.copy()
except Exception as e:
print(f"Error in YOLO prediction: {e}")
# Return original frame on error
return [], frame.copy()