Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| import cv2 | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| class BaseModel(ABC): | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def predict_image(self, image): | |
| pass | |
| def predict_video(self, video): | |
| pass | |
| class YOLOModel(BaseModel): | |
| def __init__(self, model_path=None): | |
| if model_path is None: | |
| repo_root = Path(__file__).resolve().parent | |
| weights_path = repo_root / "models" / "yolov8n.pt" | |
| else: | |
| weights_path = Path(model_path) | |
| self.model = YOLO(str(weights_path), task="detect") | |
| def predict_image(self, image, min_confidence, classes=None): | |
| results = self.model.predict( | |
| image, save=False, imgsz=800, conf=min_confidence, classes=classes | |
| ) | |
| annotated_image_filename = "annotated_image.png" | |
| last_im = None | |
| for result in results: | |
| im_array = result.plot() | |
| last_im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image | |
| last_im.save(annotated_image_filename) | |
| # Return PIL Image for robust display in Streamlit | |
| return last_im if last_im is not None else Image.open(annotated_image_filename) | |
| def predict_video( | |
| self, video, min_confidence, target_dir_name="annotated_video", classes=None | |
| ): | |
| self.model.predict( | |
| video, | |
| save=True, | |
| project=".", | |
| name=target_dir_name, | |
| exist_ok=True, | |
| imgsz=800, | |
| conf=min_confidence, | |
| classes=classes, | |
| ) | |
| def draw_yolo_dets(frame_bgr, result, show_score=True): | |
| """Draw YOLO detection results on a frame. | |
| Args: | |
| frame_bgr: Input frame in BGR format | |
| result: YOLO detection result object | |
| show_score: Whether to show confidence scores | |
| Returns: | |
| Annotated frame with bounding boxes and labels | |
| """ | |
| out = frame_bgr.copy() | |
| boxes = getattr(result, "boxes", None) | |
| if boxes is None: | |
| return out | |
| names = result.names | |
| cls_ids = boxes.cls.numpy().astype(int) | |
| confs = boxes.conf.numpy() | |
| xyxy = boxes.xyxy.numpy() | |
| for (x1_coord, y1_coord, x2_coord, y2_coord), cls, score in zip( | |
| xyxy, cls_ids, confs | |
| ): | |
| x1, y1, x2, y2 = map(int, (x1_coord, y1_coord, x2_coord, y2_coord)) | |
| label = names.get(int(cls), str(int(cls))) | |
| if show_score: | |
| label = f"{label} {score:.2f}" | |
| cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| cv2.rectangle(out, (x1, y1 - th - 8), (x1 + tw + 6, y1), (0, 255, 0), -1) | |
| cv2.putText( | |
| out, | |
| label, | |
| (x1 + 3, max(0, y1 - 6)), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 0), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| return out | |
| def predict_and_visualize( | |
| self, frame, min_confidence, classes=None, show_score=True, imgsz=640 | |
| ): | |
| """Predict objects in a frame and return the annotated frame. | |
| Args: | |
| frame: Input frame (BGR format) | |
| min_confidence: Minimum confidence threshold | |
| classes: List of class IDs to detect (None for all) | |
| show_score: Whether to show confidence scores | |
| imgsz: Image size for inference | |
| Returns: | |
| tuple: (results, annotated_frame) | |
| - results: YOLO detection results | |
| - annotated_frame: Frame with bounding boxes drawn | |
| """ | |
| try: | |
| results = self.model.predict( | |
| frame, | |
| conf=min_confidence, | |
| iou=0.5, | |
| verbose=False, | |
| classes=classes, | |
| imgsz=imgsz, | |
| ) | |
| if results and len(results) > 0: | |
| annotated_frame = YOLOModel.draw_yolo_dets( | |
| frame, results[0], show_score=show_score | |
| ) | |
| return results, annotated_frame | |
| else: | |
| # Return original frame if no detections | |
| return [], frame.copy() | |
| except Exception as e: | |
| print(f"Error in YOLO prediction: {e}") | |
| # Return original frame on error | |
| return [], frame.copy() | |