|
|
""" |
|
|
Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration. |
|
|
|
|
|
Features: |
|
|
- Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored). |
|
|
- Upload .ply: Standard PLY format point clouds. |
|
|
- Upload .obj: OBJ format with vertices and faces (triangles). |
|
|
- Interactive 3D view: orbit, pan, zoom with mouse. |
|
|
- Optional: downsample for speed, normalize to unit cube, toggle axes, set point size. |
|
|
- Dual view: Input point cloud and model predictions side-by-side. |
|
|
- PI3DETR model integration for curve detection. |
|
|
- Immediate point cloud rendering on upload. |
|
|
""" |
|
|
|
|
|
import io |
|
|
import os |
|
|
from typing import List, Dict, Optional |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
from plyfile import PlyData |
|
|
import pandas |
|
|
import torch |
|
|
from torch_geometric.data import Data |
|
|
import fpsample |
|
|
import trimesh |
|
|
|
|
|
|
|
|
from pi3detr import ( |
|
|
build_model, |
|
|
build_model_config, |
|
|
load_args, |
|
|
load_weights, |
|
|
) |
|
|
from pi3detr.dataset import normalize_and_scale |
|
|
|
|
|
|
|
|
PI3DETR_MODEL = None |
|
|
MODEL_STATUS = {"loaded": False, "message": "Model not loaded"} |
|
|
|
|
|
HOVER_FONT_SIZE = 16 |
|
|
FIG_TEMPLATE = "plotly_white" |
|
|
PLOT_HEIGHT = 800 |
|
|
|
|
|
|
|
|
DEMO_POINTCLOUDS = { |
|
|
"Demo 1": "demo_inputs/demo1.xyz", |
|
|
"Demo 2": "demo_inputs/demo2.xyz", |
|
|
"Demo 3": "demo_inputs/demo3.xyz", |
|
|
"Demo 4": "demo_inputs/demo4.xyz", |
|
|
"Demo 5": "demo_inputs/demo5.xyz", |
|
|
} |
|
|
|
|
|
|
|
|
def initialize_model(checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"): |
|
|
"""Initialize the model at startup and store it in the global cache.""" |
|
|
global PI3DETR_MODEL, MODEL_STATUS |
|
|
try: |
|
|
args = load_args(config_path) if config_path else {} |
|
|
model_config = build_model_config(args) |
|
|
model = build_model(model_config) |
|
|
load_weights(model, checkpoint_path) |
|
|
model.eval() |
|
|
|
|
|
PI3DETR_MODEL = model |
|
|
MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"} |
|
|
print("PI3DETR model initialized successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"} |
|
|
print(f"Error initializing PI3DETR model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def read_xyz(file_obj: io.BytesIO) -> np.ndarray: |
|
|
""" |
|
|
Parse a .xyz text file from bytes and return Nx3 float32 array. |
|
|
Lines with fewer than 3 numeric values are skipped. |
|
|
Only the first three numeric columns are used. |
|
|
""" |
|
|
if file_obj is None: |
|
|
return np.zeros((0, 3), dtype=np.float32) |
|
|
|
|
|
|
|
|
raw = file_obj.read() |
|
|
try: |
|
|
text = raw.decode("utf-8", errors="ignore") |
|
|
except Exception: |
|
|
text = raw.decode("latin-1", errors="ignore") |
|
|
|
|
|
pts = [] |
|
|
for line in text.splitlines(): |
|
|
line = line.strip() |
|
|
if not line or line.startswith("#"): |
|
|
continue |
|
|
parts = line.replace(",", " ").split() |
|
|
nums = [] |
|
|
for p in parts: |
|
|
try: |
|
|
nums.append(float(p)) |
|
|
except ValueError: |
|
|
|
|
|
pass |
|
|
if len(nums) == 3: |
|
|
break |
|
|
if len(nums) >= 3: |
|
|
pts.append(nums[:3]) |
|
|
|
|
|
if not pts: |
|
|
return np.zeros((0, 3), dtype=np.float32) |
|
|
|
|
|
return np.asarray(pts, dtype=np.float32) |
|
|
|
|
|
|
|
|
def read_ply(file_obj: io.BytesIO) -> np.ndarray: |
|
|
""" |
|
|
Parse a .ply file from bytes and return Nx3 float32 array of points. |
|
|
""" |
|
|
if file_obj is None: |
|
|
return np.zeros((0, 3), dtype=np.float32) |
|
|
|
|
|
try: |
|
|
ply_data = PlyData.read(file_obj) |
|
|
vertex = ply_data["vertex"] |
|
|
|
|
|
x = np.asarray(vertex["x"]) |
|
|
y = np.asarray(vertex["y"]) |
|
|
z = np.asarray(vertex["z"]) |
|
|
|
|
|
points = np.column_stack([x, y, z]).astype(np.float32) |
|
|
return points |
|
|
except Exception as e: |
|
|
print(f"Error reading PLY file: {e}") |
|
|
return np.zeros((0, 3), dtype=np.float32) |
|
|
|
|
|
|
|
|
def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int): |
|
|
"""Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface.""" |
|
|
raw = file_obj.read() |
|
|
|
|
|
try: |
|
|
mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh") |
|
|
except Exception as e: |
|
|
print(f"trimesh load error: {e}") |
|
|
return ( |
|
|
np.zeros((0, 3), dtype=np.float32), |
|
|
np.zeros((0, 3), dtype=np.float32), |
|
|
"OBJ load failure", |
|
|
) |
|
|
|
|
|
if isinstance(mesh, trimesh.Scene): |
|
|
mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values())) |
|
|
if mesh.is_empty or mesh.vertices.shape[0] == 0: |
|
|
return ( |
|
|
np.zeros((0, 3), dtype=np.float32), |
|
|
np.zeros((0, 3), dtype=np.float32), |
|
|
"OBJ: empty mesh", |
|
|
) |
|
|
sample_n = min(display_max_points, max(1, display_max_points)) |
|
|
try: |
|
|
sampled = mesh.sample(sample_n) |
|
|
except Exception as e: |
|
|
print(f"Sampling error: {e}") |
|
|
sampled = mesh.vertices |
|
|
if sampled.shape[0] > sample_n: |
|
|
sampled = sampled[:sample_n] |
|
|
sampled = np.asarray(sampled, dtype=np.float32) |
|
|
info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts" |
|
|
model_pts = sampled.copy() |
|
|
return model_pts, sampled, info |
|
|
|
|
|
|
|
|
def downsample(pts: np.ndarray, max_points: int) -> np.ndarray: |
|
|
if pts.shape[0] <= max_points: |
|
|
return pts |
|
|
rng = np.random.default_rng(42) |
|
|
idx = rng.choice(pts.shape[0], size=max_points, replace=False) |
|
|
return pts[idx] |
|
|
|
|
|
|
|
|
def make_figure( |
|
|
pts: np.ndarray, |
|
|
point_size: int = 2, |
|
|
show_axes: bool = True, |
|
|
title: str = "", |
|
|
polylines: Optional[List[Dict]] = None, |
|
|
) -> go.Figure: |
|
|
""" |
|
|
Build a Plotly 3D scatter figure with equal aspect ratio. |
|
|
Optionally includes polylines from model predictions. |
|
|
""" |
|
|
if pts.size == 0 and (polylines is None or len(polylines) == 0): |
|
|
fig = go.Figure() |
|
|
fig.update_layout( |
|
|
title="No data to display", |
|
|
template=FIG_TEMPLATE, |
|
|
scene=dict( |
|
|
xaxis_visible=False, |
|
|
yaxis_visible=False, |
|
|
zaxis_visible=False, |
|
|
), |
|
|
margin=dict(l=0, r=0, t=40, b=0), |
|
|
) |
|
|
return fig |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
if pts.size > 0: |
|
|
x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] |
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=x, |
|
|
y=y, |
|
|
z=z, |
|
|
mode="markers", |
|
|
marker=dict( |
|
|
size=max(1, int(point_size)), color="darkgray", opacity=0.2 |
|
|
), |
|
|
hoverinfo="skip", |
|
|
name="Curves", |
|
|
showlegend=False, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
curve_colors = { |
|
|
"Line": "blue", |
|
|
"Circle": "green", |
|
|
"Arc": "red", |
|
|
"BSpline": "purple", |
|
|
} |
|
|
|
|
|
|
|
|
if polylines: |
|
|
for curve in polylines: |
|
|
points = np.array(curve["points"]) |
|
|
if len(points) < 2: |
|
|
continue |
|
|
|
|
|
curve_type = curve["type"] |
|
|
curve_id = curve["id"] |
|
|
score = curve["score"] |
|
|
|
|
|
|
|
|
color = curve.get("display_color") or curve_colors.get(curve_type, "orange") |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=points[:, 0], |
|
|
y=points[:, 1], |
|
|
z=points[:, 2], |
|
|
mode="lines", |
|
|
line=dict(color=color, width=8), |
|
|
name=f"{curve_type} #{curve_id} ({score:.2f})", |
|
|
visible=curve.get("visible_state", True), |
|
|
hoverinfo="text", |
|
|
text=f"{curve_type} #{curve_id} ({score:.4f})", |
|
|
showlegend=False, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if pts.size > 0: |
|
|
mins = pts.min(axis=0) |
|
|
maxs = pts.max(axis=0) |
|
|
elif polylines and len(polylines) > 0: |
|
|
|
|
|
all_points = np.vstack([np.array(curve["points"]) for curve in polylines]) |
|
|
mins = all_points.min(axis=0) |
|
|
maxs = all_points.max(axis=0) |
|
|
else: |
|
|
mins = np.array([-1, -1, -1]) |
|
|
maxs = np.array([1, 1, 1]) |
|
|
|
|
|
centers = (mins + maxs) / 2.0 |
|
|
span = (maxs - mins).max() |
|
|
if span <= 0: |
|
|
span = 1.0 |
|
|
half = span / 2.0 |
|
|
xrange = [centers[0] - half, centers[0] + half] |
|
|
yrange = [centers[1] - half, centers[1] + half] |
|
|
zrange = [centers[2] - half, centers[2] + half] |
|
|
|
|
|
scene_axes = dict( |
|
|
xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""), |
|
|
yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""), |
|
|
zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""), |
|
|
aspectmode="cube", |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title=title, |
|
|
template=FIG_TEMPLATE, |
|
|
showlegend=False, |
|
|
scene=scene_axes, |
|
|
margin=dict(l=0, r=0, t=40, b=0), |
|
|
hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)), |
|
|
height=PLOT_HEIGHT, |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
def process_model_predictions(data: Data) -> list: |
|
|
""" |
|
|
Process model outputs into a format suitable for visualization. |
|
|
""" |
|
|
class_names = ["None", "BSpline", "Line", "Circle", "Arc"] |
|
|
polylines = data.polylines.cpu().numpy() |
|
|
curves = [] |
|
|
|
|
|
|
|
|
for i, polyline in enumerate(polylines): |
|
|
cls = data.polyline_class[i].item() |
|
|
score = data.polyline_score[i].item() |
|
|
cls_name = class_names[cls] |
|
|
|
|
|
|
|
|
if cls == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
curve_data = { |
|
|
"type": cls_name, |
|
|
"id": i + 1, |
|
|
"index": i, |
|
|
"score": score, |
|
|
"points": polyline, |
|
|
} |
|
|
curves.append(curve_data) |
|
|
|
|
|
return curves |
|
|
|
|
|
|
|
|
def process_data_for_model( |
|
|
points: np.ndarray, |
|
|
sample: int = 32768, |
|
|
sample_mode: str = "fps", |
|
|
) -> Data: |
|
|
""" |
|
|
Process and subsample point cloud data using the same approach as predict_pi3detr.py. |
|
|
|
|
|
Args: |
|
|
points: Input point cloud as numpy array |
|
|
sample: Number of points to sample |
|
|
sample_mode: Sampling method ("fps", "random", "uniform", "all") |
|
|
|
|
|
Returns: |
|
|
Data object ready for model inference |
|
|
""" |
|
|
|
|
|
pos = torch.tensor(points, dtype=torch.float32) |
|
|
|
|
|
|
|
|
if sample_mode == "random": |
|
|
if pos.size(0) > sample: |
|
|
indices = torch.randperm(pos.size(0))[:sample] |
|
|
pos = pos[indices] |
|
|
|
|
|
elif sample_mode == "fps": |
|
|
if pos.size(0) > sample: |
|
|
indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6) |
|
|
pos = pos[indices] |
|
|
|
|
|
elif sample_mode == "uniform": |
|
|
if pos.size(0) > sample: |
|
|
step = max(1, pos.size(0) // sample) |
|
|
pos = pos[::step][:sample] |
|
|
|
|
|
elif sample_mode == "all": |
|
|
pass |
|
|
|
|
|
|
|
|
data = Data(pos=pos) |
|
|
|
|
|
|
|
|
data.batch = torch.zeros(data.pos.size(0), dtype=torch.long) |
|
|
data.batch_size = 1 |
|
|
|
|
|
|
|
|
data = normalize_and_scale(data) |
|
|
|
|
|
|
|
|
if hasattr(data, "scale") and data.scale.dim() == 0: |
|
|
data.scale = data.scale.unsqueeze(0) |
|
|
if hasattr(data, "center") and data.center.dim() == 1: |
|
|
data.center = data.center.unsqueeze(0) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def run_model_inference( |
|
|
model, |
|
|
points: np.ndarray, |
|
|
max_points: int = 32768, |
|
|
sample_mode: str = "fps", |
|
|
num_queries: int = 256, |
|
|
) -> list: |
|
|
"""Run model inference on the given point cloud.""" |
|
|
global PI3DETR_MODEL |
|
|
if model is None: |
|
|
model = PI3DETR_MODEL |
|
|
if model is None: |
|
|
return [] |
|
|
try: |
|
|
data = process_data_for_model( |
|
|
points, sample=max_points, sample_mode=sample_mode |
|
|
) |
|
|
device = next(model.parameters()).device |
|
|
data = data.to(device) |
|
|
|
|
|
if model.num_preds != num_queries: |
|
|
model.set_num_preds(num_queries) |
|
|
|
|
|
output = model.predict_step( |
|
|
data, |
|
|
reverse_norm=True, |
|
|
thresholds=None, |
|
|
) |
|
|
result = output[0] |
|
|
curves = process_model_predictions(result) |
|
|
return curves |
|
|
except Exception as e: |
|
|
print(f"Error in model inference: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def load_and_process_pointcloud( |
|
|
file: gr.File, |
|
|
max_points: int, |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
): |
|
|
""" |
|
|
Load and process a point cloud from .xyz or .ply file |
|
|
""" |
|
|
if file is None: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, None, None, os.path.basename(file.name) if file else "" |
|
|
|
|
|
|
|
|
file_ext = os.path.splitext(file.name)[1].lower() |
|
|
|
|
|
|
|
|
with open(file.name, "rb") as f: |
|
|
if file_ext == ".xyz": |
|
|
pts = read_xyz(f) |
|
|
mode = "XYZ" |
|
|
elif file_ext == ".ply": |
|
|
pts = read_ply(f) |
|
|
mode = "PLY" |
|
|
elif file_ext == ".obj": |
|
|
model_pts, display_pts, _ = read_obj_and_sample(f, max_points) |
|
|
fig = make_figure( |
|
|
display_pts, |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=f"{os.path.basename(file.name)}", |
|
|
) |
|
|
return fig, model_pts, display_pts, os.path.basename(file.name) |
|
|
else: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return ( |
|
|
empty_fig, |
|
|
None, |
|
|
None, |
|
|
"Unsupported file type. Please use .xyz, .ply or .obj.", |
|
|
"", |
|
|
) |
|
|
|
|
|
original_n = pts.shape[0] |
|
|
|
|
|
|
|
|
model_pts = pts.copy() |
|
|
|
|
|
pts = downsample(pts, max_points=max_points) |
|
|
displayed_n = pts.shape[0] |
|
|
|
|
|
fig = make_figure( |
|
|
pts, |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=f"{os.path.basename(file.name)}", |
|
|
) |
|
|
|
|
|
info = f"Loaded ({mode}): {original_n} points" |
|
|
|
|
|
|
|
|
return fig, model_pts, pts, os.path.basename(file.name) |
|
|
|
|
|
|
|
|
def run_model_prediction( |
|
|
model_pts: np.ndarray, |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
model_max_points: int, |
|
|
sample_mode: str, |
|
|
th_bspline: float, |
|
|
th_line: float, |
|
|
th_circle: float, |
|
|
th_arc: float, |
|
|
num_queries: int = 256, |
|
|
): |
|
|
|
|
|
|
|
|
return run_model_prediction_unified( |
|
|
model_pts, |
|
|
None, |
|
|
point_size, |
|
|
show_axes, |
|
|
model_max_points, |
|
|
sample_mode, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
"", |
|
|
num_queries, |
|
|
) |
|
|
|
|
|
|
|
|
def run_model_prediction_unified( |
|
|
model_pts: np.ndarray, |
|
|
display_pts: Optional[np.ndarray], |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
model_max_points: int, |
|
|
sample_mode: str, |
|
|
th_bspline: float, |
|
|
th_line: float, |
|
|
th_circle: float, |
|
|
th_arc: float, |
|
|
file_name: str = "", |
|
|
num_queries: int = 256, |
|
|
): |
|
|
""" |
|
|
Run model inference and apply initial threshold-based coloring. |
|
|
""" |
|
|
global PI3DETR_MODEL, MODEL_STATUS |
|
|
if model_pts is None: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, [] |
|
|
|
|
|
|
|
|
curves = [] |
|
|
try: |
|
|
if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]: |
|
|
|
|
|
initialize_model() |
|
|
|
|
|
if PI3DETR_MODEL is not None: |
|
|
|
|
|
curves = run_model_inference( |
|
|
PI3DETR_MODEL, |
|
|
model_pts, |
|
|
max_points=model_max_points, |
|
|
sample_mode=sample_mode, |
|
|
num_queries=num_queries, |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
thresholds = { |
|
|
"BSpline": th_bspline, |
|
|
"Line": th_line, |
|
|
"Circle": th_circle, |
|
|
"Arc": th_arc, |
|
|
} |
|
|
colored_curves = [] |
|
|
for c in curves: |
|
|
c_disp = dict(c) |
|
|
if c["score"] < thresholds.get(c["type"], 0.7): |
|
|
c_disp["visible_state"] = "legendonly" |
|
|
colored_curves.append(c_disp) |
|
|
|
|
|
|
|
|
if display_pts is None: |
|
|
display_pts = downsample(model_pts, max_points=100000) |
|
|
title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)" |
|
|
fig = make_figure( |
|
|
display_pts, |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=title, |
|
|
polylines=colored_curves, |
|
|
) |
|
|
return fig, curves |
|
|
|
|
|
|
|
|
def apply_pointcloud_display_settings( |
|
|
model_pts: np.ndarray, |
|
|
curves: List[Dict], |
|
|
max_points: int, |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
th_bspline: float, |
|
|
th_line: float, |
|
|
th_circle: float, |
|
|
th_arc: float, |
|
|
file_name: str, |
|
|
): |
|
|
""" |
|
|
Apply point cloud display settings without re-running inference. |
|
|
Keeps existing detections and re-applies thresholds. |
|
|
""" |
|
|
if model_pts is None: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, None |
|
|
display_pts = downsample(model_pts, max_points=max_points) |
|
|
if not curves: |
|
|
fig = make_figure( |
|
|
display_pts, |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=file_name or "Point Cloud", |
|
|
) |
|
|
return fig, display_pts |
|
|
thresholds = { |
|
|
"BSpline": th_bspline, |
|
|
"Line": th_line, |
|
|
"Circle": th_circle, |
|
|
"Arc": th_arc, |
|
|
} |
|
|
colored_curves = [] |
|
|
for c in curves: |
|
|
c_disp = dict(c) |
|
|
if c["score"] < thresholds.get(c["type"], 0.7): |
|
|
c_disp["visible_state"] = "legendonly" |
|
|
colored_curves.append(c_disp) |
|
|
fig = make_figure( |
|
|
display_pts, |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=(file_name or "Point Cloud") + " (curves)", |
|
|
polylines=colored_curves, |
|
|
) |
|
|
return fig, display_pts |
|
|
|
|
|
|
|
|
def clear_curves( |
|
|
curves: List[Dict], |
|
|
display_pts: Optional[np.ndarray], |
|
|
model_pts: Optional[np.ndarray], |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
file_name: str, |
|
|
): |
|
|
""" |
|
|
Recolor already inferred curves based on updated thresholds (no re-inference). |
|
|
""" |
|
|
if curves is None or model_pts is None or len(curves) == 0: |
|
|
empty_fig = make_figure( |
|
|
display_pts if display_pts is not None else np.zeros((0, 3)) |
|
|
) |
|
|
return empty_fig, None |
|
|
|
|
|
fig = make_figure( |
|
|
display_pts if display_pts is not None else np.zeros((0, 3)), |
|
|
point_size=point_size, |
|
|
show_axes=show_axes, |
|
|
title=file_name or "Point Cloud", |
|
|
polylines=None, |
|
|
) |
|
|
return fig, None |
|
|
|
|
|
|
|
|
def load_demo_pointcloud( |
|
|
label: str, |
|
|
max_points: int, |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
): |
|
|
""" |
|
|
Load one of the predefined demo point clouds. |
|
|
Clears existing detected curves (curves_state -> None). |
|
|
Also returns a value for the file upload component so the filename shows up. |
|
|
""" |
|
|
path = DEMO_POINTCLOUDS.get(label, "") |
|
|
if not path or not os.path.isfile(path): |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, None, None, "", None, None |
|
|
ext = os.path.splitext(path)[1].lower() |
|
|
try: |
|
|
with open(path, "rb") as f: |
|
|
if ext == ".xyz": |
|
|
pts = read_xyz(f) |
|
|
elif ext == ".ply": |
|
|
pts = read_ply(f) |
|
|
elif ext == ".obj": |
|
|
model_pts, display_pts, _ = read_obj_and_sample( |
|
|
f, min(20000, max_points) |
|
|
) |
|
|
fig = make_figure( |
|
|
display_pts, |
|
|
point_size=1, |
|
|
show_axes=show_axes, |
|
|
title=f"{os.path.basename(path)} (demo)", |
|
|
) |
|
|
return fig, model_pts, display_pts, os.path.basename(path), None, path |
|
|
else: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, None, None, "", None, None |
|
|
except Exception: |
|
|
empty_fig = make_figure(np.zeros((0, 3))) |
|
|
return empty_fig, None, None, "", None, None |
|
|
model_pts = pts.copy() |
|
|
pts = downsample(pts, max_points=max_points) |
|
|
fig = make_figure( |
|
|
pts, |
|
|
point_size=1, |
|
|
show_axes=show_axes, |
|
|
title=f"{os.path.basename(path)} (demo)", |
|
|
) |
|
|
return fig, model_pts, pts, os.path.basename(path), None, path |
|
|
|
|
|
|
|
|
|
|
|
def load_demo1(max_points, point_size, show_axes): |
|
|
return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes) |
|
|
|
|
|
|
|
|
def load_demo2(max_points, point_size, show_axes): |
|
|
return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes) |
|
|
|
|
|
|
|
|
def load_demo3(max_points, point_size, show_axes): |
|
|
return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes) |
|
|
|
|
|
|
|
|
def load_demo4(max_points, point_size, show_axes): |
|
|
return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes) |
|
|
|
|
|
|
|
|
def load_demo5(max_points, point_size, show_axes): |
|
|
return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes) |
|
|
|
|
|
|
|
|
def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure: |
|
|
"""Create a small preview figure for a demo point cloud (no curves).""" |
|
|
path = DEMO_POINTCLOUDS.get(label, "") |
|
|
if not path or not os.path.isfile(path): |
|
|
return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)") |
|
|
try: |
|
|
ext = os.path.splitext(path)[1].lower() |
|
|
with open(path, "rb") as f: |
|
|
if ext == ".xyz": |
|
|
pts = read_xyz(f) |
|
|
elif ext == ".ply": |
|
|
pts = read_ply(f) |
|
|
elif ext == ".obj": |
|
|
_, pts, _ = read_obj_and_sample(f, max_pts) |
|
|
else: |
|
|
return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)") |
|
|
pts = downsample(pts, max_pts) |
|
|
return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview") |
|
|
except Exception as e: |
|
|
return make_figure(np.zeros((0, 3)), title=f"{label}: error") |
|
|
|
|
|
|
|
|
def run_model_with_display( |
|
|
model_pts: np.ndarray, |
|
|
max_points: int, |
|
|
point_size: int, |
|
|
show_axes: bool, |
|
|
model_max_points: int, |
|
|
sample_mode: str, |
|
|
th_bspline: float, |
|
|
th_line: float, |
|
|
th_circle: float, |
|
|
th_arc: float, |
|
|
file_name: str = "", |
|
|
num_queries: int = 256, |
|
|
): |
|
|
""" |
|
|
Run inference (if model_pts present) then immediately apply current display |
|
|
(max_points/point_size/show_axes) and thresholds. Returns: |
|
|
figure, info_text, curves(list), display_pts |
|
|
""" |
|
|
if model_pts is None: |
|
|
empty = make_figure(np.zeros((0, 3))) |
|
|
return empty, None, None |
|
|
|
|
|
|
|
|
fig_infer, curves = run_model_prediction_unified( |
|
|
model_pts, |
|
|
None, |
|
|
point_size, |
|
|
show_axes, |
|
|
model_max_points, |
|
|
sample_mode, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name, |
|
|
num_queries, |
|
|
) |
|
|
|
|
|
|
|
|
fig_final, display_pts = apply_pointcloud_display_settings( |
|
|
model_pts, |
|
|
curves, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name, |
|
|
) |
|
|
return fig_final, curves, display_pts |
|
|
|
|
|
|
|
|
with gr.Blocks(title="PI3DETR") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🥧 PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW] |
|
|
|
|
|
A novel end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**. |
|
|
|
|
|
<div style="margin-top: 10px;"> |
|
|
<a href="https://arxiv.org/pdf/2509.03262" target="_blank" style=" |
|
|
display: inline-block; |
|
|
background-color: #4CAF50; |
|
|
color: white; |
|
|
padding: 8px 16px; |
|
|
text-decoration: none; |
|
|
border-radius: 5px; |
|
|
margin-right: 8px; |
|
|
font-weight: bold; |
|
|
">📄 Paper</a> |
|
|
<a href="https://fafraob.github.io/pi3detr/" target="_blank" style=" |
|
|
display: inline-block; |
|
|
background-color: #2196F3; |
|
|
color: white; |
|
|
padding: 8px 16px; |
|
|
text-decoration: none; |
|
|
border-radius: 5px; |
|
|
margin-right: 8px; |
|
|
font-weight: bold; |
|
|
">🌐 Website</a> |
|
|
<a href="https://github.com/fafraob/pi3detr" target="_blank" style=" |
|
|
display: inline-block; |
|
|
background-color: #333; |
|
|
color: white; |
|
|
padding: 8px 16px; |
|
|
text-decoration: none; |
|
|
border-radius: 5px; |
|
|
font-weight: bold; |
|
|
">🐙 GitHub</a> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### 🧩 Supported Inputs\n" |
|
|
"- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n" |
|
|
"- `Mesh` is surface-sampled using **Max Points (display)** slider." |
|
|
) |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### ⚙️ Point Cloud Settings\n" |
|
|
"- Adjust **Max Points**, **point size**, and **axes visibility**.\n" |
|
|
"- Controls visualization of point cloud." |
|
|
) |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### 🎯 Confidence Thresholds\n" |
|
|
"- Hover to inspect scores.\n" |
|
|
"- Filter curves by **class confidence** interactively" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### 🧠 Model Settings\n" |
|
|
"- **Sampling Mode:** Choose downsampling strategy.\n" |
|
|
"- **Model Input Size:** Number of model input points.\n" |
|
|
"- **Queries:** Transformer decoder queries (max. output curves)." |
|
|
) |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### ⚡ Performance Notes\n" |
|
|
"- Trained on **human-made objects**.\n" |
|
|
"- Optimized for **GPU**; this demo runs on **CPU**.\n" |
|
|
"- For **full qualitative performance**: \n" |
|
|
"[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)" |
|
|
) |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
"### ▶️ Run Inference\n" |
|
|
"- Click on demo point clouds (from test set) below.\n" |
|
|
"- Press **Run PI3DETR** to execute inference and visualize results." |
|
|
) |
|
|
|
|
|
model_pts_state = gr.State(None) |
|
|
display_pts_state = gr.State(None) |
|
|
curves_state = gr.State(None) |
|
|
file_name_state = gr.State("demo_inputs/demo2.xyz") |
|
|
with gr.Row(): |
|
|
file_in = gr.File( |
|
|
label="Upload Point Cloud (auto-renders)", |
|
|
file_types=[".xyz", ".ply", ".obj"], |
|
|
type="filepath", |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Point Cloud Settings") |
|
|
max_points = gr.Slider( |
|
|
0, |
|
|
500_000, |
|
|
value=200_000, |
|
|
step=1_000, |
|
|
label="Max points (display)", |
|
|
) |
|
|
point_size = gr.Slider(1, 8, value=1, step=1, label="Point size") |
|
|
show_axes = gr.Checkbox(value=False, label="Show axes") |
|
|
|
|
|
gr.Markdown("### Model Settings") |
|
|
sample_mode = gr.Radio( |
|
|
["fps", "random", "all"], |
|
|
value="fps", |
|
|
label="Main Sampling Method", |
|
|
) |
|
|
model_max_points = gr.Slider( |
|
|
1_000, |
|
|
100_000, |
|
|
value=32768, |
|
|
step=500, |
|
|
label="Downsample to Model Input Size", |
|
|
) |
|
|
num_queries = gr.Slider( |
|
|
32, |
|
|
512, |
|
|
value=128, |
|
|
step=1, |
|
|
label="Number of Queries", |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("#### Confidence Thresholds (per class)") |
|
|
th_bspline = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="BSpline ≥") |
|
|
th_line = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Line ≥") |
|
|
th_circle = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Circle ≥") |
|
|
th_arc = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Arc ≥") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
main_plot = gr.Plot( |
|
|
label="Point Cloud & Curves" |
|
|
) |
|
|
|
|
|
run_model_button = gr.Button("Run PI3DETR", variant="primary") |
|
|
clear_curves_button = gr.Button("Clear Curves", variant="secondary") |
|
|
|
|
|
|
|
|
file_in.change( |
|
|
load_and_process_pointcloud, |
|
|
inputs=[file_in, max_points, point_size, show_axes], |
|
|
outputs=[ |
|
|
main_plot, |
|
|
model_pts_state, |
|
|
display_pts_state, |
|
|
file_name_state, |
|
|
], |
|
|
) |
|
|
|
|
|
run_model_button.click( |
|
|
run_model_with_display, |
|
|
inputs=[ |
|
|
model_pts_state, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
model_max_points, |
|
|
sample_mode, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name_state, |
|
|
num_queries, |
|
|
], |
|
|
outputs=[main_plot, curves_state, display_pts_state], |
|
|
) |
|
|
|
|
|
|
|
|
def _apply_display_wrapper( |
|
|
model_pts, |
|
|
curves, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name, |
|
|
display_pts_state_value, |
|
|
): |
|
|
fig, display_pts = apply_pointcloud_display_settings( |
|
|
model_pts, |
|
|
curves, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name, |
|
|
) |
|
|
return fig, display_pts |
|
|
|
|
|
|
|
|
for slider in [max_points, point_size]: |
|
|
slider.release( |
|
|
_apply_display_wrapper, |
|
|
inputs=[ |
|
|
model_pts_state, |
|
|
curves_state, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name_state, |
|
|
display_pts_state, |
|
|
], |
|
|
outputs=[main_plot, display_pts_state], |
|
|
) |
|
|
|
|
|
show_axes.change( |
|
|
_apply_display_wrapper, |
|
|
inputs=[ |
|
|
model_pts_state, |
|
|
curves_state, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name_state, |
|
|
display_pts_state, |
|
|
], |
|
|
outputs=[main_plot, display_pts_state], |
|
|
) |
|
|
|
|
|
|
|
|
for th in [th_bspline, th_line, th_circle, th_arc]: |
|
|
th.release( |
|
|
_apply_display_wrapper, |
|
|
inputs=[ |
|
|
model_pts_state, |
|
|
curves_state, |
|
|
max_points, |
|
|
point_size, |
|
|
show_axes, |
|
|
th_bspline, |
|
|
th_line, |
|
|
th_circle, |
|
|
th_arc, |
|
|
file_name_state, |
|
|
display_pts_state, |
|
|
], |
|
|
outputs=[main_plot, display_pts_state], |
|
|
) |
|
|
|
|
|
clear_curves_button.click( |
|
|
clear_curves, |
|
|
inputs=[ |
|
|
curves_state, |
|
|
display_pts_state, |
|
|
model_pts_state, |
|
|
point_size, |
|
|
show_axes, |
|
|
file_name_state, |
|
|
], |
|
|
outputs=[main_plot, curves_state], |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### Demo Point Clouds (click an image to load)") |
|
|
with gr.Row(): |
|
|
|
|
|
demo_image_components = {} |
|
|
for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: |
|
|
png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png" |
|
|
demo_image_components[label] = gr.Image( |
|
|
value=png_path if os.path.isfile(png_path) else None, |
|
|
label=label, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
|
|
|
_demo_loaders = { |
|
|
"Demo 1": load_demo1, |
|
|
"Demo 2": load_demo2, |
|
|
"Demo 3": load_demo3, |
|
|
"Demo 4": load_demo4, |
|
|
"Demo 5": load_demo5, |
|
|
} |
|
|
for label, comp in demo_image_components.items(): |
|
|
comp.select( |
|
|
_demo_loaders[label], |
|
|
inputs=[max_points, point_size, show_axes], |
|
|
outputs=[ |
|
|
main_plot, |
|
|
model_pts_state, |
|
|
display_pts_state, |
|
|
file_name_state, |
|
|
curves_state, |
|
|
file_in, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
load_demo2, |
|
|
inputs=[max_points, point_size, show_axes], |
|
|
outputs=[ |
|
|
main_plot, |
|
|
model_pts_state, |
|
|
display_pts_state, |
|
|
file_name_state, |
|
|
curves_state, |
|
|
file_in, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
initialize_model() |
|
|
demo.launch() |
|
|
|