|
|
import gradio as gr |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import tempfile |
|
|
import sys |
|
|
import os |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
print("="*60) |
|
|
print("Setting up RF-DETR SoccerNet Model...") |
|
|
print("="*60) |
|
|
|
|
|
repo_id = "julianzu9612/RFDETR-Soccernet" |
|
|
|
|
|
try: |
|
|
|
|
|
print("\nDownloading inference.py...") |
|
|
inference_path = hf_hub_download(repo_id=repo_id, filename="inference.py") |
|
|
|
|
|
|
|
|
with open(inference_path, 'r') as f: |
|
|
inference_code = f.read() |
|
|
|
|
|
print("\nπ§ Patching inference.py...") |
|
|
print(" Changing: RFDETRBase() β RFDETRLarge()") |
|
|
|
|
|
|
|
|
inference_code = inference_code.replace( |
|
|
'from rfdetr import RFDETRBase', |
|
|
'from rfdetr import RFDETRLarge' |
|
|
) |
|
|
inference_code = inference_code.replace( |
|
|
'self.model = RFDETRBase()', |
|
|
'self.model = RFDETRLarge()' |
|
|
) |
|
|
|
|
|
|
|
|
with open(inference_path, 'w') as f: |
|
|
f.write(inference_code) |
|
|
print("β Patched inference.py successfully!") |
|
|
|
|
|
|
|
|
print("\nDownloading model weights...") |
|
|
weights_path = hf_hub_download(repo_id=repo_id, filename="weights/checkpoint_best_regular.pth") |
|
|
print(f"β Downloaded weights") |
|
|
|
|
|
|
|
|
cache_dir = os.path.dirname(inference_path) |
|
|
|
|
|
if cache_dir not in sys.path: |
|
|
sys.path.insert(0, cache_dir) |
|
|
|
|
|
original_dir = os.getcwd() |
|
|
os.chdir(cache_dir) |
|
|
|
|
|
|
|
|
weights_dir = os.path.join(cache_dir, "weights") |
|
|
os.makedirs(weights_dir, exist_ok=True) |
|
|
|
|
|
expected_weights = os.path.join(weights_dir, "checkpoint_best_regular.pth") |
|
|
if not os.path.exists(expected_weights): |
|
|
import shutil |
|
|
shutil.copy(weights_path, expected_weights) |
|
|
print(f"β Weights copied to: {expected_weights}") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Initializing RF-DETR SoccerNet Model...") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
from inference import RFDETRSoccerNet |
|
|
|
|
|
detector = RFDETRSoccerNet() |
|
|
print("\nβ
Model loaded successfully!") |
|
|
|
|
|
os.chdir(original_dir) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
def draw_detections_on_image(image, df): |
|
|
"""Draw bounding boxes on PIL image""" |
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
colors = { |
|
|
'ball': (255, 0, 0), |
|
|
'player': (0, 255, 0), |
|
|
'referee': (255, 255, 0), |
|
|
'goalkeeper': (0, 0, 255) |
|
|
} |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
x1, y1, x2, y2 = row['x1'], row['y1'], row['x2'], row['y2'] |
|
|
class_name = row['class_name'] |
|
|
conf = row['confidence'] |
|
|
color = colors.get(class_name, (255, 255, 255)) |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
|
|
|
|
|
text = f"{class_name}: {conf:.2f}" |
|
|
bbox = draw.textbbox((x1, y1-20), text, font=font) |
|
|
draw.rectangle([bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2], fill=color) |
|
|
draw.text((x1, y1-20), text, fill=(0, 0, 0), font=font) |
|
|
|
|
|
return image |
|
|
|
|
|
def process_image_interface(image, confidence_threshold): |
|
|
"""Process image with the model""" |
|
|
if image is None: |
|
|
return None, pd.DataFrame() |
|
|
|
|
|
try: |
|
|
|
|
|
temp_path = tempfile.mktemp(suffix='.jpg') |
|
|
Image.fromarray(image if isinstance(image, np.ndarray) else np.array(image)).save(temp_path) |
|
|
|
|
|
|
|
|
df = detector.process_image(temp_path, confidence_threshold=confidence_threshold) |
|
|
|
|
|
|
|
|
img = Image.open(temp_path) |
|
|
annotated_img = draw_detections_on_image(img, df) |
|
|
|
|
|
|
|
|
os.remove(temp_path) |
|
|
|
|
|
return annotated_img, df |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, pd.DataFrame() |
|
|
|
|
|
def process_video_interface(video, confidence_threshold, frame_skip, max_frames): |
|
|
"""Process video with the model""" |
|
|
if video is None: |
|
|
return None, pd.DataFrame() |
|
|
|
|
|
try: |
|
|
max_frames_val = int(max_frames) if max_frames > 0 else None |
|
|
|
|
|
|
|
|
print(f"Processing video with confidence={confidence_threshold}, frame_skip={frame_skip}, max_frames={max_frames_val}") |
|
|
df = detector.process_video( |
|
|
video, |
|
|
confidence_threshold=confidence_threshold, |
|
|
frame_skip=int(frame_skip), |
|
|
max_frames=max_frames_val |
|
|
) |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video) |
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
output_path = tempfile.mktemp(suffix='.mp4') |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
frame_num = 0 |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
frame_detections = df[df['frame'] == frame_num] |
|
|
|
|
|
if not frame_detections.empty: |
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_img = Image.fromarray(rgb_frame) |
|
|
annotated_pil = draw_detections_on_image(pil_img, frame_detections) |
|
|
frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
out.write(frame) |
|
|
frame_num += 1 |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
return output_path, df |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing video: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, pd.DataFrame() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="β½ Soccer Object Detection", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# β½ Soccer Object Detection with RF-DETR |
|
|
|
|
|
Professional-grade object detection for soccer videos using RF-DETR-Large model. |
|
|
|
|
|
### Model: [julianzu9612/RFDETR-Soccernet](https://huggingface.co/julianzu9612/RFDETR-Soccernet) |
|
|
- **Architecture**: RF-DETR-Large (128M parameters) |
|
|
- **Performance**: 85.7% mAP@50, 49.8% mAP |
|
|
- **Dataset**: SoccerNet-Tracking 2023 (42,750 images) |
|
|
- **Classes**: Ball, Player, Referee, Goalkeeper |
|
|
""") |
|
|
|
|
|
with gr.Tab("πΈ Image Detection"): |
|
|
gr.Markdown("### Upload a soccer image to detect objects") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(label="Upload Soccer Image", type="numpy") |
|
|
image_confidence = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.05, |
|
|
label="Confidence Threshold", |
|
|
info="Lower values detect more objects but may include false positives" |
|
|
) |
|
|
image_button = gr.Button("π Detect Objects", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
image_output = gr.Image(label="Detected Objects") |
|
|
|
|
|
image_detections = gr.Dataframe( |
|
|
label="Detection Results", |
|
|
wrap=True, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
image_button.click( |
|
|
fn=process_image_interface, |
|
|
inputs=[image_input, image_confidence], |
|
|
outputs=[image_output, image_detections] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[], |
|
|
inputs=image_input, |
|
|
label="Example Images (Upload your own!)" |
|
|
) |
|
|
|
|
|
with gr.Tab("π₯ Video Detection"): |
|
|
gr.Markdown("### Upload a soccer video to track objects frame by frame") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video(label="Upload Soccer Video") |
|
|
video_confidence = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.05, |
|
|
label="Confidence Threshold" |
|
|
) |
|
|
video_frame_skip = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
label="Frame Skip", |
|
|
info="Process every Nth frame (higher = faster but less detections)" |
|
|
) |
|
|
video_max_frames = gr.Number( |
|
|
value=300, |
|
|
label="Max Frames to Process", |
|
|
info="Set to 0 to process entire video (300 frames β 10 seconds at 30 FPS)" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
#### β‘ Performance Tips: |
|
|
- **CPU**: 2-3 FPS (slow) - Use frame_skip=5 and limit frames |
|
|
- **GPU**: 12-30 FPS (fast) - Can process full videos |
|
|
- **Quick test**: Use 300 frames with frame_skip=5 |
|
|
""") |
|
|
|
|
|
video_button = gr.Button("π¬ Process Video", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
video_output = gr.Video(label="Annotated Video") |
|
|
|
|
|
video_detections = gr.Dataframe( |
|
|
label="Detection Results", |
|
|
wrap=True, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
video_button.click( |
|
|
fn=process_video_interface, |
|
|
inputs=[video_input, video_confidence, video_frame_skip, video_max_frames], |
|
|
outputs=[video_output, video_detections] |
|
|
) |
|
|
|
|
|
with gr.Tab("βΉοΈ About"): |
|
|
gr.Markdown(""" |
|
|
## About This Model |
|
|
|
|
|
### π― Detected Classes |
|
|
|
|
|
| Class | Color | Precision | Description | |
|
|
|-------|-------|-----------|-------------| |
|
|
| π΄ Ball | Red | 78.5% | Soccer ball detection | |
|
|
| π’ Player | Green | 91.3% | Field players from both teams | |
|
|
| π‘ Referee | Yellow | 85.2% | Match officials | |
|
|
| π΅ Goalkeeper | Blue | 88.9% | Specialized goalkeeper detection | |
|
|
|
|
|
### π Model Performance |
|
|
|
|
|
- **mAP@50**: 85.7% |
|
|
- **mAP**: 49.8% |
|
|
- **mAP@75**: 52.0% |
|
|
- **Parameters**: 128M |
|
|
- **Training Time**: ~14 hours on NVIDIA A100 40GB |
|
|
|
|
|
### π Training Details |
|
|
|
|
|
- **Dataset**: SoccerNet-Tracking 2023 |
|
|
- **Images**: 42,750 annotated images |
|
|
- **Source**: Professional soccer broadcasts |
|
|
- **Input Resolution**: 1280x1280 pixels |
|
|
- **Optimizer**: AdamW (lr=1e-4) |
|
|
|
|
|
### π‘ Best Practices |
|
|
|
|
|
1. **Confidence Threshold**: |
|
|
- Use 0.5 for general detection |
|
|
- Use 0.7+ for high-precision applications |
|
|
|
|
|
2. **Video Quality**: |
|
|
- Works best on 720p+ broadcast footage |
|
|
- Standard broadcast camera angles preferred |
|
|
|
|
|
3. **Frame Processing**: |
|
|
- frame_skip=1: Every frame (best accuracy, slow) |
|
|
- frame_skip=5: Every 5th frame (good balance) |
|
|
- frame_skip=10: Every 10th frame (fast, lower accuracy) |
|
|
|
|
|
### π¨ Limitations |
|
|
|
|
|
- Optimized for professional broadcast footage |
|
|
- May have reduced accuracy in poor lighting |
|
|
- Small balls may be missed when heavily occluded |
|
|
- Camera angle dependency |
|
|
|
|
|
### π Use Cases |
|
|
|
|
|
- **Sports Analytics**: Player tracking, formation analysis |
|
|
- **Broadcast Enhancement**: Automatic highlighting, statistics overlay |
|
|
- **Research**: Tactical analysis, computer vision benchmarking |
|
|
- **Video Analytics**: Automated video processing pipelines |
|
|
|
|
|
### π Links |
|
|
|
|
|
- [Model on Hugging Face](https://huggingface.co/julianzu9612/RFDETR-Soccernet) |
|
|
- [SoccerNet Dataset](https://www.soccer-net.org/) |
|
|
- [RF-DETR Paper](https://arxiv.org/abs/2304.08069) |
|
|
|
|
|
### π Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{rfdetr-soccernet-2025, |
|
|
title={RF-DETR SoccerNet: High-Performance Soccer Object Detection}, |
|
|
author={Computer Vision Research Team}, |
|
|
year={2025}, |
|
|
publisher={Hugging Face}, |
|
|
url={https://huggingface.co/julianzu9612/rf-detr-soccernet} |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
**License**: Apache 2.0 |
|
|
""") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Launching Gradio Interface...") |
|
|
print("="*60) |
|
|
|
|
|
demo.launch() |