Files
FiDA-3D-Trellis/2_step_to_svg.py

387 lines
13 KiB
Python
Raw Normal View History

2026-03-17 11:28:52 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
Then combine 3 SVGs into a single SVG and a PNG.
Views:
- top : from +Z towards origin (view direction -Z)
- left : from -X towards origin (view direction +X)
- front: from +Y towards origin (view direction -Y)
Outputs (default in --out_dir):
- top.svg, left.svg, front.svg
- combined_views.svg
- combined_views.png
"""
import os, sys, math, argparse, tempfile
import numpy as np
import torch
from collections import defaultdict
# ---------- PyTorch3D ----------
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
RasterizationSettings, MeshRasterizer,
MeshRenderer, BlendParams, SoftSilhouetteShader,
look_at_view_transform, OrthographicCameras
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
# ---------- SVG / Image ----------
import svgwrite
from skimage import measure
# ---------- pythonocc (OCC) ----------
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.BRep import BRep_Tool
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.ShapeFix import ShapeFix_Shape
# ---------- Combine SVG ----------
import svgutils.transform as st
from svgutils.compose import Unit
import cairosvg
# ---------------- Utils ----------------
def autodevice():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_step_solid(path: str):
reader = STEPControl_Reader()
stat = reader.ReadFile(path)
if stat != 1:
raise RuntimeError(f"STEP 读取失败: {path}")
reader.TransferRoots()
shape = reader.OneShape()
fixer = ShapeFix_Shape(shape)
fixer.Perform()
return fixer.Shape()
def triangulate(shape, deflection=0.5, angle=0.5):
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
verts, faces, v_off = [], [], 0
exp = TopExp_Explorer(shape, TopAbs_FACE)
while exp.More():
face = exp.Current()
loc = TopLoc_Location()
tri = BRep_Tool.Triangulation(face, loc)
if tri is not None:
nb_nodes = tri.NbNodes()
has_nodes_arr = hasattr(tri, "Nodes")
for i in range(1, nb_nodes + 1):
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
p = p.Transformed(loc.Transformation())
verts.append([p.X(), p.Y(), p.Z()])
nb_tris = tri.NbTriangles()
has_tris_arr = hasattr(tri, "Triangles")
for i in range(1, nb_tris + 1):
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
a, b, c = t.Get()
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
v_off += nb_nodes
exp.Next()
if not verts or not faces:
raise RuntimeError("三角化为空,尝试减小 --defl")
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
c = verts.mean(axis=0, keepdims=True)
v0 = verts - c
s = np.max(np.abs(v0))
s = max(s, 1e-12)
return v0 / s * unit_scale
def build_mesh(V_np, F_np, device):
V = torch.from_numpy(V_np).to(device)
F = torch.from_numpy(F_np).to(device)
return Meshes(verts=[V], faces=[F])
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
v0 = verts[faces[:, 0]];
v1 = verts[faces[:, 1]];
v2 = verts[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=1)
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
thr = math.cos(math.radians(max(0.0, angle_deg)))
e2f = defaultdict(list)
F = faces.shape[0]
for f in range(F):
i, j, k = faces[f].tolist()
for a, b in ((i, j), (j, k), (k, i)):
e = (a, b) if a < b else (b, a)
e2f[e].append(f)
feat = []
for e, fl in e2f.items():
if len(fl) == 1:
feat.append(e)
elif len(fl) == 2:
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
if cosv <= thr:
feat.append(e)
return feat
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
cull_backfaces=True
)
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
with torch.no_grad():
frags: Fragments = rast(mesh, cameras=cameras)
return frags, frags.zbuf[0, ..., 0]
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=1e-6,
faces_per_pixel=50,
cull_backfaces=True
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
)
with torch.no_grad():
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
with torch.no_grad():
scr = cameras.transform_points_screen(
points_world[None, ...],
image_size=((image_size, image_size),)
)
return scr[0] # (N,3)
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
H = W = image_size
vis, hid = [], []
scr = project_points_to_screen(cameras, verts_world, image_size)
zmin = zbuf.detach().cpu().numpy()
for i, j in edges_idx:
p0, p1 = scr[i], scr[j]
m = 0.5 * (p0 + p1)
x = int(np.clip(m[0].item(), 0, W - 1))
y = int(np.clip(m[1].item(), 0, H - 1))
z_proj = m[2].item()
z_ref = zmin[y, x]
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
vis.append(seg)
else:
hid.append(seg)
return vis, hid
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
contours = measure.find_contours(alpha, level=threshold)
polys = []
for cnt in contours:
if len(cnt) < 8:
continue
cnt = cnt[::max(1, step)]
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
return polys
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
# silhouette fill (optional)
for poly in silhouettes:
if len(poly) < 3:
continue
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
for i in range(1, len(poly)):
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
path.append("Z")
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
# visible edges
for (x0, y0, x1, y1) in edges_vis:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=stroke
))
# hidden edges
for (x0, y0, x1, y1) in edges_hid:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=max(1.0, 0.75 * stroke),
stroke_dasharray=[6, 6],
opacity=0.85
))
dwg.save()
# ---------------- Cameras ----------------
def make_camera(view: str, device, dist=2.7):
"""
look_at_view_transform spherical params for engineering views.
"""
if view == "top":
# from +Z to origin (view -Z)
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
elif view == "left":
# from -X to origin (view +X)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
elif view == "front":
# from +Y to origin (view -Y)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
else:
raise ValueError("view must be one of ['top','left','front']")
return OrthographicCameras(R=R, T=T, device=device)
# ---------------- Combine 3 SVGs ----------------
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
"""
横向拼接 3 SVG并额外导出 PNG
"""
out_dir = os.path.dirname(output_svg)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
view1 = st.fromfile(view1_path)
view2 = st.fromfile(view2_path)
view3 = st.fromfile(view3_path)
r1 = view1.getroot()
r2 = view2.getroot()
r3 = view3.getroot()
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
max_w = max(w1, w2, w3)
max_h = max(h1, h2, h3)
combined_width = max_w * 3 + 40
combined_height = max_h + 20
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
x_offset = 10
y_offset = 10
r1.moveto(x_offset, y_offset)
x_offset += max_w + 20
r2.moveto(x_offset, y_offset)
x_offset += max_w + 20
r3.moveto(x_offset, y_offset)
combined.append([r1, r2, r3])
combined.save(output_svg)
# SVG -> PNG
cairosvg.svg2png(url=output_svg, write_to=output_image)
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG不合成")
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg")
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
device = autodevice()
print(f"[Device] {device}")
# STEP -> mesh
shape = read_step_solid(args.step_path)
V_np, F_np = triangulate(shape, deflection=args.defl)
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
mesh = build_mesh(V_np, F_np, device)
verts = mesh.verts_packed()
faces = mesh.faces_packed()
fn = compute_face_normals(verts, faces)
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
view_paths = {}
for name in ["top", "left", "front"]:
print(f"[View] {name}")
cam = make_camera(name, device)
alpha = render_silhouette(mesh, cam, image_size=args.res)
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
out_svg = os.path.join(args.out_dir, f"{name}.svg")
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
stroke=args.stroke, margin=10, fill="none")
view_paths[name] = out_svg
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
if args.no_combine:
print("[Done] 3 views exported (no combine).")
return
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
# 注意这里组合顺序按你第二段示例front / top / left
combine_svgs(
view1_path=view_paths["front"],
view2_path=view_paths["top"],
view3_path=view_paths["left"],
output_svg=combined_svg,
output_image=combined_png
)
print(f"[Combined] SVG: {combined_svg}")
print(f"[Combined] PNG: {combined_png}")
print("[Done] All outputs exported.")
if __name__ == "__main__":
main()