pi3detr / app.py
fafraob's picture
add links & remove post-processing selection
925407b
"""
Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration.
Features:
- Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored).
- Upload .ply: Standard PLY format point clouds.
- Upload .obj: OBJ format with vertices and faces (triangles).
- Interactive 3D view: orbit, pan, zoom with mouse.
- Optional: downsample for speed, normalize to unit cube, toggle axes, set point size.
- Dual view: Input point cloud and model predictions side-by-side.
- PI3DETR model integration for curve detection.
- Immediate point cloud rendering on upload.
"""
import io
import os
from typing import List, Dict, Optional
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from plyfile import PlyData
import pandas
import torch
from torch_geometric.data import Data
import fpsample
import trimesh # NEW: for robust mesh loading & surface sampling
# Import PI3DETR modules
from pi3detr import (
build_model,
build_model_config,
load_args,
load_weights,
)
from pi3detr.dataset import normalize_and_scale
# Global model cache
PI3DETR_MODEL = None
MODEL_STATUS = {"loaded": False, "message": "Model not loaded"}
HOVER_FONT_SIZE = 16 # enlarged hover text size
FIG_TEMPLATE = "plotly_white" # global figure template
PLOT_HEIGHT = 800 # NEW: desired plot height (adjust as needed)
# NEW: demo point cloud file paths (fill these with real .xyz/.ply paths)
DEMO_POINTCLOUDS = {
"Demo 1": "demo_inputs/demo1.xyz",
"Demo 2": "demo_inputs/demo2.xyz",
"Demo 3": "demo_inputs/demo3.xyz",
"Demo 4": "demo_inputs/demo4.xyz",
"Demo 5": "demo_inputs/demo5.xyz",
}
def initialize_model(checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"):
"""Initialize the model at startup and store it in the global cache."""
global PI3DETR_MODEL, MODEL_STATUS
try:
args = load_args(config_path) if config_path else {}
model_config = build_model_config(args)
model = build_model(model_config)
load_weights(model, checkpoint_path)
model.eval()
PI3DETR_MODEL = model
MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"}
print("PI3DETR model initialized successfully")
return True
except Exception as e:
MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"}
print(f"Error initializing PI3DETR model: {e}")
return False
def read_xyz(file_obj: io.BytesIO) -> np.ndarray:
"""
Parse a .xyz text file from bytes and return Nx3 float32 array.
Lines with fewer than 3 numeric values are skipped.
Only the first three numeric columns are used.
"""
if file_obj is None:
return np.zeros((0, 3), dtype=np.float32)
# Read bytes → text
raw = file_obj.read()
try:
text = raw.decode("utf-8", errors="ignore")
except Exception:
text = raw.decode("latin-1", errors="ignore")
pts = []
for line in text.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.replace(",", " ").split()
nums = []
for p in parts:
try:
nums.append(float(p))
except ValueError:
# skip non-numeric tokens
pass
if len(nums) == 3:
break
if len(nums) >= 3:
pts.append(nums[:3])
if not pts:
return np.zeros((0, 3), dtype=np.float32)
return np.asarray(pts, dtype=np.float32)
def read_ply(file_obj: io.BytesIO) -> np.ndarray:
"""
Parse a .ply file from bytes and return Nx3 float32 array of points.
"""
if file_obj is None:
return np.zeros((0, 3), dtype=np.float32)
try:
ply_data = PlyData.read(file_obj)
vertex = ply_data["vertex"]
x = np.asarray(vertex["x"])
y = np.asarray(vertex["y"])
z = np.asarray(vertex["z"])
points = np.column_stack([x, y, z]).astype(np.float32)
return points
except Exception as e:
print(f"Error reading PLY file: {e}")
return np.zeros((0, 3), dtype=np.float32)
def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int):
"""Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface."""
raw = file_obj.read()
# Rewind not strictly needed after read since we don't reuse file_obj
try:
mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh")
except Exception as e:
print(f"trimesh load error: {e}")
return (
np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.float32),
"OBJ load failure",
)
# Handle scenes by merging
if isinstance(mesh, trimesh.Scene):
mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values()))
if mesh.is_empty or mesh.vertices.shape[0] == 0:
return (
np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.float32),
"OBJ: empty mesh",
)
sample_n = min(display_max_points, max(1, display_max_points))
try:
sampled = mesh.sample(sample_n)
except Exception as e:
print(f"Sampling error: {e}")
sampled = mesh.vertices
if sampled.shape[0] > sample_n:
sampled = sampled[:sample_n]
sampled = np.asarray(sampled, dtype=np.float32)
info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts"
model_pts = sampled.copy()
return model_pts, sampled, info
def downsample(pts: np.ndarray, max_points: int) -> np.ndarray:
if pts.shape[0] <= max_points:
return pts
rng = np.random.default_rng(42)
idx = rng.choice(pts.shape[0], size=max_points, replace=False)
return pts[idx]
def make_figure(
pts: np.ndarray,
point_size: int = 2,
show_axes: bool = True,
title: str = "",
polylines: Optional[List[Dict]] = None,
) -> go.Figure:
"""
Build a Plotly 3D scatter figure with equal aspect ratio.
Optionally includes polylines from model predictions.
"""
if pts.size == 0 and (polylines is None or len(polylines) == 0):
fig = go.Figure()
fig.update_layout(
title="No data to display",
template=FIG_TEMPLATE,
scene=dict(
xaxis_visible=False,
yaxis_visible=False,
zaxis_visible=False,
),
margin=dict(l=0, r=0, t=40, b=0),
)
return fig
fig = go.Figure()
# Add point cloud if available
if pts.size > 0:
x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]
fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=max(1, int(point_size)), color="darkgray", opacity=0.2
),
hoverinfo="skip",
name="Curves",
showlegend=False, # legend hidden
)
)
# Define colors for each curve type
curve_colors = {
"Line": "blue",
"Circle": "green",
"Arc": "red",
"BSpline": "purple",
}
# Add polylines from model predictions if available
if polylines:
for curve in polylines:
points = np.array(curve["points"])
if len(points) < 2:
continue
curve_type = curve["type"]
curve_id = curve["id"]
score = curve["score"]
# NEW: allow override color if provided (e.g., threshold filtered)
color = curve.get("display_color") or curve_colors.get(curve_type, "orange")
# NEW: support hidden-by-default via legendonly
fig.add_trace(
go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode="lines",
line=dict(color=color, width=8), # CHANGED: increased from 5 to 8
name=f"{curve_type} #{curve_id} ({score:.2f})",
visible=curve.get("visible_state", True),
hoverinfo="text",
text=f"{curve_type} #{curve_id} ({score:.4f})",
showlegend=False, # hide individual curve legend entries
)
)
# Equal aspect ratio using data ranges
if pts.size > 0:
mins = pts.min(axis=0)
maxs = pts.max(axis=0)
elif polylines and len(polylines) > 0:
# If we only have polylines, calculate range from them
all_points = np.vstack([np.array(curve["points"]) for curve in polylines])
mins = all_points.min(axis=0)
maxs = all_points.max(axis=0)
else:
mins = np.array([-1, -1, -1])
maxs = np.array([1, 1, 1])
centers = (mins + maxs) / 2.0
span = (maxs - mins).max()
if span <= 0:
span = 1.0
half = span / 2.0
xrange = [centers[0] - half, centers[0] + half]
yrange = [centers[1] - half, centers[1] + half]
zrange = [centers[2] - half, centers[2] + half]
scene_axes = dict(
xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""),
yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""),
zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""),
aspectmode="cube",
)
fig.update_layout(
title=title,
template=FIG_TEMPLATE,
showlegend=False,
scene=scene_axes,
margin=dict(l=0, r=0, t=40, b=0),
hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)),
height=PLOT_HEIGHT, # NEW
)
return fig
def process_model_predictions(data: Data) -> list:
"""
Process model outputs into a format suitable for visualization.
"""
class_names = ["None", "BSpline", "Line", "Circle", "Arc"]
polylines = data.polylines.cpu().numpy()
curves = []
# Process detected polylines
for i, polyline in enumerate(polylines):
cls = data.polyline_class[i].item()
score = data.polyline_score[i].item()
cls_name = class_names[cls]
# Skip low-confidence or "None" class predictions
if cls == 0:
continue
# Add curve data to results with unique ID
curve_data = {
"type": cls_name,
"id": i + 1, # 1-based ID for better user experience
"index": i,
"score": score,
"points": polyline,
}
curves.append(curve_data)
return curves
def process_data_for_model(
points: np.ndarray,
sample: int = 32768,
sample_mode: str = "fps",
) -> Data: # CHANGED: removed reduction param
"""
Process and subsample point cloud data using the same approach as predict_pi3detr.py.
Args:
points: Input point cloud as numpy array
sample: Number of points to sample
sample_mode: Sampling method ("fps", "random", "uniform", "all")
Returns:
Data object ready for model inference
"""
# Convert to torch tensor
pos = torch.tensor(points, dtype=torch.float32)
# Apply sampling strategy
if sample_mode == "random":
if pos.size(0) > sample:
indices = torch.randperm(pos.size(0))[:sample]
pos = pos[indices]
elif sample_mode == "fps":
if pos.size(0) > sample:
indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6)
pos = pos[indices]
elif sample_mode == "uniform":
if pos.size(0) > sample:
step = max(1, pos.size(0) // sample)
pos = pos[::step][:sample]
elif sample_mode == "all":
pass # Keep all points
# Create Data object
data = Data(pos=pos)
# Add batch information for single point cloud BEFORE normalization
data.batch = torch.zeros(data.pos.size(0), dtype=torch.long)
data.batch_size = 1
# Normalize and scale using PI3DETR's method
data = normalize_and_scale(data)
# Ensure scale and center are proper batch tensors
if hasattr(data, "scale") and data.scale.dim() == 0:
data.scale = data.scale.unsqueeze(0)
if hasattr(data, "center") and data.center.dim() == 1:
data.center = data.center.unsqueeze(0)
return data
@torch.no_grad()
def run_model_inference(
model,
points: np.ndarray,
max_points: int = 32768,
sample_mode: str = "fps",
num_queries: int = 256,
) -> list:
"""Run model inference on the given point cloud."""
global PI3DETR_MODEL
if model is None:
model = PI3DETR_MODEL
if model is None:
return []
try:
data = process_data_for_model(
points, sample=max_points, sample_mode=sample_mode
)
device = next(model.parameters()).device
data = data.to(device)
if model.num_preds != num_queries:
model.set_num_preds(num_queries)
output = model.predict_step(
data,
reverse_norm=True,
thresholds=None,
)
result = output[0]
curves = process_model_predictions(result)
return curves
except Exception as e:
print(f"Error in model inference: {e}")
return []
def load_and_process_pointcloud(
file: gr.File,
max_points: int,
point_size: int,
show_axes: bool,
):
"""
Load and process a point cloud from .xyz or .ply file
"""
if file is None:
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, None, None, os.path.basename(file.name) if file else ""
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
# Read file based on extension
with open(file.name, "rb") as f:
if file_ext == ".xyz":
pts = read_xyz(f)
mode = "XYZ"
elif file_ext == ".ply":
pts = read_ply(f)
mode = "PLY"
elif file_ext == ".obj":
model_pts, display_pts, _ = read_obj_and_sample(f, max_points)
fig = make_figure(
display_pts,
point_size=point_size,
show_axes=show_axes,
title=f"{os.path.basename(file.name)}",
)
return fig, model_pts, display_pts, os.path.basename(file.name)
else:
empty_fig = make_figure(np.zeros((0, 3)))
return (
empty_fig,
None,
None,
"Unsupported file type. Please use .xyz, .ply or .obj.",
"",
)
original_n = pts.shape[0]
# Keep original points for model if normalizing for display
model_pts = pts.copy()
pts = downsample(pts, max_points=max_points)
displayed_n = pts.shape[0]
fig = make_figure(
pts,
point_size=point_size,
show_axes=show_axes,
title=f"{os.path.basename(file.name)}",
)
info = f"Loaded ({mode}): {original_n} points" # | Displayed: {displayed_n} points"
# RETURN single figure + model/full points + displayed subset
return fig, model_pts, pts, os.path.basename(file.name) # ADDED filename
def run_model_prediction(
model_pts: np.ndarray,
point_size: int,
show_axes: bool,
model_max_points: int,
sample_mode: str,
th_bspline: float,
th_line: float,
th_circle: float,
th_arc: float,
num_queries: int = 256,
):
# NOTE: display points now handled outside; keep signature (called before adding display pts state)
# (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version)
return run_model_prediction_unified( # type: ignore
model_pts,
None,
point_size,
show_axes,
model_max_points,
sample_mode,
th_bspline,
th_line,
th_circle,
th_arc,
"",
num_queries,
)
def run_model_prediction_unified(
model_pts: np.ndarray,
display_pts: Optional[np.ndarray],
point_size: int,
show_axes: bool,
model_max_points: int,
sample_mode: str,
th_bspline: float,
th_line: float,
th_circle: float,
th_arc: float,
file_name: str = "",
num_queries: int = 256,
):
"""
Run model inference and apply initial threshold-based coloring.
"""
global PI3DETR_MODEL, MODEL_STATUS
if model_pts is None:
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, []
# Run model inference using cached model
curves = []
try:
if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]:
# Try to initialize model if not already loaded
initialize_model()
if PI3DETR_MODEL is not None:
# Run inference with the same settings as predict_pi3detr.py
curves = run_model_inference(
PI3DETR_MODEL,
model_pts,
max_points=model_max_points,
sample_mode=sample_mode,
num_queries=num_queries,
)
except Exception:
pass
# NEW: apply thresholds for display (store raw curves separately)
thresholds = {
"BSpline": th_bspline,
"Line": th_line,
"Circle": th_circle,
"Arc": th_arc,
}
colored_curves = []
for c in curves:
c_disp = dict(c)
if c["score"] < thresholds.get(c["type"], 0.7):
c_disp["visible_state"] = "legendonly"
colored_curves.append(c_disp)
# Use existing displayed subset if provided; else derive lightweight subset
if display_pts is None:
display_pts = downsample(model_pts, max_points=100000)
title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)"
fig = make_figure(
display_pts,
point_size=point_size,
show_axes=show_axes,
title=title,
polylines=colored_curves,
)
return fig, curves
def apply_pointcloud_display_settings(
model_pts: np.ndarray,
curves: List[Dict],
max_points: int,
point_size: int,
show_axes: bool,
th_bspline: float,
th_line: float,
th_circle: float,
th_arc: float,
file_name: str,
):
"""
Apply point cloud display settings without re-running inference.
Keeps existing detections and re-applies thresholds.
"""
if model_pts is None:
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, None
display_pts = downsample(model_pts, max_points=max_points)
if not curves:
fig = make_figure(
display_pts,
point_size=point_size,
show_axes=show_axes,
title=file_name or "Point Cloud",
)
return fig, display_pts
thresholds = {
"BSpline": th_bspline,
"Line": th_line,
"Circle": th_circle,
"Arc": th_arc,
}
colored_curves = []
for c in curves:
c_disp = dict(c)
if c["score"] < thresholds.get(c["type"], 0.7):
c_disp["visible_state"] = "legendonly"
colored_curves.append(c_disp)
fig = make_figure(
display_pts,
point_size=point_size,
show_axes=show_axes,
title=(file_name or "Point Cloud") + " (curves)",
polylines=colored_curves,
)
return fig, display_pts
def clear_curves(
curves: List[Dict],
display_pts: Optional[np.ndarray],
model_pts: Optional[np.ndarray],
point_size: int,
show_axes: bool,
file_name: str,
):
"""
Recolor already inferred curves based on updated thresholds (no re-inference).
"""
if curves is None or model_pts is None or len(curves) == 0:
empty_fig = make_figure(
display_pts if display_pts is not None else np.zeros((0, 3))
)
return empty_fig, None
fig = make_figure(
display_pts if display_pts is not None else np.zeros((0, 3)),
point_size=point_size,
show_axes=show_axes,
title=file_name or "Point Cloud",
polylines=None,
)
return fig, None
def load_demo_pointcloud(
label: str,
max_points: int,
point_size: int,
show_axes: bool,
):
"""
Load one of the predefined demo point clouds.
Clears existing detected curves (curves_state -> None).
Also returns a value for the file upload component so the filename shows up.
"""
path = DEMO_POINTCLOUDS.get(label, "")
if not path or not os.path.isfile(path):
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, None, None, "", None, None
ext = os.path.splitext(path)[1].lower()
try:
with open(path, "rb") as f:
if ext == ".xyz":
pts = read_xyz(f)
elif ext == ".ply":
pts = read_ply(f)
elif ext == ".obj":
model_pts, display_pts, _ = read_obj_and_sample(
f, min(20000, max_points)
)
fig = make_figure(
display_pts,
point_size=1,
show_axes=show_axes,
title=f"{os.path.basename(path)} (demo)",
)
return fig, model_pts, display_pts, os.path.basename(path), None, path
else:
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, None, None, "", None, None
except Exception:
empty_fig = make_figure(np.zeros((0, 3)))
return empty_fig, None, None, "", None, None
model_pts = pts.copy()
pts = downsample(pts, max_points=max_points)
fig = make_figure(
pts,
point_size=1,
show_axes=show_axes,
title=f"{os.path.basename(path)} (demo)",
)
return fig, model_pts, pts, os.path.basename(path), None, path
# Convenience wrappers for each demo (avoid lambdas for clarity)
def load_demo1(max_points, point_size, show_axes):
return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes)
def load_demo2(max_points, point_size, show_axes):
return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes)
def load_demo3(max_points, point_size, show_axes):
return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes)
def load_demo4(max_points, point_size, show_axes): # NEW
return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes)
def load_demo5(max_points, point_size, show_axes): # NEW
return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes)
def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure:
"""Create a small preview figure for a demo point cloud (no curves)."""
path = DEMO_POINTCLOUDS.get(label, "")
if not path or not os.path.isfile(path):
return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)")
try:
ext = os.path.splitext(path)[1].lower()
with open(path, "rb") as f:
if ext == ".xyz":
pts = read_xyz(f)
elif ext == ".ply":
pts = read_ply(f)
elif ext == ".obj": # UPDATED
_, pts, _ = read_obj_and_sample(f, max_pts)
else:
return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)")
pts = downsample(pts, max_pts)
return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview")
except Exception as e:
return make_figure(np.zeros((0, 3)), title=f"{label}: error")
def run_model_with_display(
model_pts: np.ndarray,
max_points: int,
point_size: int,
show_axes: bool,
model_max_points: int,
sample_mode: str,
th_bspline: float,
th_line: float,
th_circle: float,
th_arc: float,
file_name: str = "",
num_queries: int = 256,
):
"""
Run inference (if model_pts present) then immediately apply current display
(max_points/point_size/show_axes) and thresholds. Returns:
figure, info_text, curves(list), display_pts
"""
if model_pts is None:
empty = make_figure(np.zeros((0, 3)))
return empty, None, None
# Inference first (no display subset passed so it builds from model_pts)
fig_infer, curves = run_model_prediction_unified(
model_pts,
None,
point_size,
show_axes,
model_max_points,
sample_mode,
th_bspline,
th_line,
th_circle,
th_arc,
file_name,
num_queries,
)
# Now apply current display settings & thresholds without re-inference
fig_final, display_pts = apply_pointcloud_display_settings(
model_pts,
curves,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name,
)
return fig_final, curves, display_pts
with gr.Blocks(title="PI3DETR") as demo:
gr.Markdown(
"""
# 🥧 PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW]
A novel end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.
<div style="margin-top: 10px;">
<a href="https://arxiv.org/pdf/2509.03262" target="_blank" style="
display: inline-block;
background-color: #4CAF50;
color: white;
padding: 8px 16px;
text-decoration: none;
border-radius: 5px;
margin-right: 8px;
font-weight: bold;
">📄 Paper</a>
<a href="https://fafraob.github.io/pi3detr/" target="_blank" style="
display: inline-block;
background-color: #2196F3;
color: white;
padding: 8px 16px;
text-decoration: none;
border-radius: 5px;
margin-right: 8px;
font-weight: bold;
">🌐 Website</a>
<a href="https://github.com/fafraob/pi3detr" target="_blank" style="
display: inline-block;
background-color: #333;
color: white;
padding: 8px 16px;
text-decoration: none;
border-radius: 5px;
font-weight: bold;
">🐙 GitHub</a>
</div>
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"### 🧩 Supported Inputs\n"
"- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n"
"- `Mesh` is surface-sampled using **Max Points (display)** slider."
)
with gr.Column():
gr.Markdown(
"### ⚙️ Point Cloud Settings\n"
"- Adjust **Max Points**, **point size**, and **axes visibility**.\n"
"- Controls visualization of point cloud."
)
with gr.Column():
gr.Markdown(
"### 🎯 Confidence Thresholds\n"
"- Hover to inspect scores.\n"
"- Filter curves by **class confidence** interactively"
)
with gr.Row():
with gr.Column():
gr.Markdown(
"### 🧠 Model Settings\n"
"- **Sampling Mode:** Choose downsampling strategy.\n"
"- **Model Input Size:** Number of model input points.\n"
"- **Queries:** Transformer decoder queries (max. output curves)."
)
with gr.Column():
gr.Markdown(
"### ⚡ Performance Notes\n"
"- Trained on **human-made objects**.\n"
"- Optimized for **GPU**; this demo runs on **CPU**.\n"
"- For **full qualitative performance**: \n"
"[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)"
)
with gr.Column():
gr.Markdown(
"### ▶️ Run Inference\n"
"- Click on demo point clouds (from test set) below.\n"
"- Press **Run PI3DETR** to execute inference and visualize results."
)
model_pts_state = gr.State(None)
display_pts_state = gr.State(None)
curves_state = gr.State(None)
file_name_state = gr.State("demo_inputs/demo2.xyz")
with gr.Row():
file_in = gr.File(
label="Upload Point Cloud (auto-renders)",
file_types=[".xyz", ".ply", ".obj"],
type="filepath",
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Point Cloud Settings")
max_points = gr.Slider(
0,
500_000,
value=200_000,
step=1_000,
label="Max points (display)",
)
point_size = gr.Slider(1, 8, value=1, step=1, label="Point size")
show_axes = gr.Checkbox(value=False, label="Show axes")
gr.Markdown("### Model Settings")
sample_mode = gr.Radio(
["fps", "random", "all"],
value="fps",
label="Main Sampling Method",
)
model_max_points = gr.Slider(
1_000,
100_000,
value=32768,
step=500,
label="Downsample to Model Input Size",
)
num_queries = gr.Slider( # NEW
32,
512,
value=128,
step=1,
label="Number of Queries",
)
# Threshold sliders (no auto-change triggers)
gr.Markdown("#### Confidence Thresholds (per class)")
th_bspline = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="BSpline ≥")
th_line = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Line ≥")
th_circle = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Circle ≥")
th_arc = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Arc ≥")
with gr.Column(scale=1):
main_plot = gr.Plot(
label="Point Cloud & Curves"
) # height from fig.update_layout(PLOT_HEIGHT)
run_model_button = gr.Button("Run PI3DETR", variant="primary")
clear_curves_button = gr.Button("Clear Curves", variant="secondary")
# Auto-render point cloud when file is uploaded
file_in.change(
load_and_process_pointcloud,
inputs=[file_in, max_points, point_size, show_axes],
outputs=[
main_plot,
model_pts_state,
display_pts_state,
file_name_state,
],
)
run_model_button.click(
run_model_with_display,
inputs=[
model_pts_state,
max_points,
point_size,
show_axes,
model_max_points,
sample_mode,
th_bspline,
th_line,
th_circle,
th_arc,
file_name_state,
num_queries,
],
outputs=[main_plot, curves_state, display_pts_state],
)
# NEW: auto-apply display & thresholds on interaction (no inference)
def _apply_display_wrapper(
model_pts,
curves,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name,
display_pts_state_value,
):
fig, display_pts = apply_pointcloud_display_settings(
model_pts,
curves,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name,
)
return fig, display_pts
# Point cloud sliders (release) & checkbox (change)
for slider in [max_points, point_size]:
slider.release(
_apply_display_wrapper,
inputs=[
model_pts_state,
curves_state,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name_state,
display_pts_state,
],
outputs=[main_plot, display_pts_state],
)
show_axes.change(
_apply_display_wrapper,
inputs=[
model_pts_state,
curves_state,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name_state,
display_pts_state,
],
outputs=[main_plot, display_pts_state],
)
# Threshold sliders (apply on release)
for th in [th_bspline, th_line, th_circle, th_arc]:
th.release(
_apply_display_wrapper,
inputs=[
model_pts_state,
curves_state,
max_points,
point_size,
show_axes,
th_bspline,
th_line,
th_circle,
th_arc,
file_name_state,
display_pts_state,
],
outputs=[main_plot, display_pts_state],
)
clear_curves_button.click(
clear_curves,
inputs=[
curves_state,
display_pts_state,
model_pts_state,
point_size,
show_axes,
file_name_state,
],
outputs=[main_plot, curves_state],
)
# REPLACED demo preview plots + buttons WITH clickable images
with gr.Row():
gr.Markdown("### Demo Point Clouds (click an image to load)")
with gr.Row():
# CLEANUP: generate images dynamically for all demos
demo_image_components = {}
for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: # UPDATED
png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png"
demo_image_components[label] = gr.Image(
value=png_path if os.path.isfile(png_path) else None,
label=label,
interactive=False,
)
# CLEANUP: map labels to loader functions & attach select handlers
_demo_loaders = {
"Demo 1": load_demo1,
"Demo 2": load_demo2,
"Demo 3": load_demo3,
"Demo 4": load_demo4,
"Demo 5": load_demo5, # NEW
}
for label, comp in demo_image_components.items():
comp.select(
_demo_loaders[label],
inputs=[max_points, point_size, show_axes],
outputs=[
main_plot,
model_pts_state,
display_pts_state,
file_name_state,
curves_state,
file_in,
],
)
# NEW: auto-load Demo 2 on app start
demo.load(
load_demo2,
inputs=[max_points, point_size, show_axes],
outputs=[
main_plot,
model_pts_state,
display_pts_state,
file_name_state,
curves_state,
file_in,
],
)
if __name__ == "__main__":
# Initialize model at startup
initialize_model()
demo.launch()