#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges) Then combine 3 SVGs into a single SVG and a PNG. Views: - top : from +Z towards origin (view direction -Z) - left : from -X towards origin (view direction +X) - front: from +Y towards origin (view direction -Y) Outputs (default in --out_dir): - top.svg, left.svg, front.svg - combined_views.svg - combined_views.png """ import os, sys, math, argparse, tempfile import numpy as np import torch from collections import defaultdict # ---------- PyTorch3D ---------- from pytorch3d.structures import Meshes from pytorch3d.renderer import ( RasterizationSettings, MeshRasterizer, MeshRenderer, BlendParams, SoftSilhouetteShader, look_at_view_transform, OrthographicCameras ) from pytorch3d.renderer.mesh.rasterizer import Fragments # ---------- SVG / Image ---------- import svgwrite from skimage import measure # ---------- pythonocc (OCC) ---------- from OCC.Core.STEPControl import STEPControl_Reader from OCC.Core.TopAbs import TopAbs_FACE from OCC.Core.TopExp import TopExp_Explorer from OCC.Core.BRep import BRep_Tool from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh from OCC.Core.TopLoc import TopLoc_Location from OCC.Core.ShapeFix import ShapeFix_Shape # ---------- Combine SVG ---------- import svgutils.transform as st from svgutils.compose import Unit import cairosvg # ---------------- Utils ---------------- def autodevice(): return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def read_step_solid(path: str): reader = STEPControl_Reader() stat = reader.ReadFile(path) if stat != 1: raise RuntimeError(f"STEP 读取失败: {path}") reader.TransferRoots() shape = reader.OneShape() fixer = ShapeFix_Shape(shape) fixer.Perform() return fixer.Shape() def triangulate(shape, deflection=0.5, angle=0.5): """OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))""" BRepMesh_IncrementalMesh(shape, deflection, False, angle, True) verts, faces, v_off = [], [], 0 exp = TopExp_Explorer(shape, TopAbs_FACE) while exp.More(): face = exp.Current() loc = TopLoc_Location() tri = BRep_Tool.Triangulation(face, loc) if tri is not None: nb_nodes = tri.NbNodes() has_nodes_arr = hasattr(tri, "Nodes") for i in range(1, nb_nodes + 1): p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i) p = p.Transformed(loc.Transformation()) verts.append([p.X(), p.Y(), p.Z()]) nb_tris = tri.NbTriangles() has_tris_arr = hasattr(tri, "Triangles") for i in range(1, nb_tris + 1): t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i) a, b, c = t.Get() faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1]) v_off += nb_nodes exp.Next() if not verts or not faces: raise RuntimeError("三角化为空,尝试减小 --defl") return np.asarray(verts, np.float32), np.asarray(faces, np.int64) def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0): c = verts.mean(axis=0, keepdims=True) v0 = verts - c s = np.max(np.abs(v0)) s = max(s, 1e-12) return v0 / s * unit_scale def build_mesh(V_np, F_np, device): V = torch.from_numpy(V_np).to(device) F = torch.from_numpy(F_np).to(device) return Meshes(verts=[V], faces=[F]) def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor): v0 = verts[faces[:, 0]]; v1 = verts[faces[:, 1]]; v2 = verts[faces[:, 2]] n = torch.cross(v1 - v0, v2 - v0, dim=1) return torch.nn.functional.normalize(n, dim=1, eps=1e-12) def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float): """二面角>=angle_deg 视为特征棱;边界边必取。""" thr = math.cos(math.radians(max(0.0, angle_deg))) e2f = defaultdict(list) F = faces.shape[0] for f in range(F): i, j, k = faces[f].tolist() for a, b in ((i, j), (j, k), (k, i)): e = (a, b) if a < b else (b, a) e2f[e].append(f) feat = [] for e, fl in e2f.items(): if len(fl) == 1: feat.append(e) elif len(fl) == 2: n0, n1 = face_normals[fl[0]], face_normals[fl[1]] cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item() if cosv <= thr: feat.append(e) return feat def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1): rs = RasterizationSettings( image_size=image_size, blur_radius=0.0, faces_per_pixel=faces_per_pixel, cull_backfaces=True ) rast = MeshRasterizer(cameras=cameras, raster_settings=rs) with torch.no_grad(): frags: Fragments = rast(mesh, cameras=cameras) return frags, frags.zbuf[0, ..., 0] def render_silhouette(mesh: Meshes, cameras, image_size=1600): rs = RasterizationSettings( image_size=image_size, blur_radius=1e-6, faces_per_pixel=50, cull_backfaces=True ) renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs), shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4)) ) with torch.no_grad(): img = renderer(mesh, cameras=cameras) # (1,H,W,4) return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0) def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int): with torch.no_grad(): scr = cameras.transform_points_screen( points_world[None, ...], image_size=((image_size, image_size),) ) return scr[0] # (N,3) def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4): H = W = image_size vis, hid = [], [] scr = project_points_to_screen(cameras, verts_world, image_size) zmin = zbuf.detach().cpu().numpy() for i, j in edges_idx: p0, p1 = scr[i], scr[j] m = 0.5 * (p0 + p1) x = int(np.clip(m[0].item(), 0, W - 1)) y = int(np.clip(m[1].item(), 0, H - 1)) z_proj = m[2].item() z_ref = zmin[y, x] seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item()) if np.isfinite(z_ref) and (z_proj <= z_ref + eps): vis.append(seg) else: hid.append(seg) return vis, hid def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1): contours = measure.find_contours(alpha, level=threshold) polys = [] for cnt in contours: if len(cnt) < 8: continue cnt = cnt[::max(1, step)] polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y) return polys def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"): dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin)) dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none")) # silhouette fill (optional) for poly in silhouettes: if len(poly) < 3: continue path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"] for i in range(1, len(poly)): path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}") path.append("Z") dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none")) # visible edges for (x0, y0, x1, y1) in edges_vis: dwg.add(dwg.line( (x0 + margin, y0 + margin), (x1 + margin, y1 + margin), stroke="black", stroke_width=stroke )) # hidden edges for (x0, y0, x1, y1) in edges_hid: dwg.add(dwg.line( (x0 + margin, y0 + margin), (x1 + margin, y1 + margin), stroke="black", stroke_width=max(1.0, 0.75 * stroke), stroke_dasharray=[6, 6], opacity=0.85 )) dwg.save() # ---------------- Cameras ---------------- def make_camera(view: str, device, dist=2.7): """ look_at_view_transform spherical params for engineering views. """ if view == "top": # from +Z to origin (view -Z) R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device) elif view == "left": # from -X to origin (view +X) R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device) elif view == "front": # from +Y to origin (view -Y) R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device) else: raise ValueError("view must be one of ['top','left','front']") return OrthographicCameras(R=R, T=T, device=device) # ---------------- Combine 3 SVGs ---------------- def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image): """ 横向拼接 3 张 SVG,并额外导出 PNG。 """ out_dir = os.path.dirname(output_svg) if out_dir: os.makedirs(out_dir, exist_ok=True) view1 = st.fromfile(view1_path) view2 = st.fromfile(view2_path) view3 = st.fromfile(view3_path) r1 = view1.getroot() r2 = view2.getroot() r3 = view3.getroot() w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()] w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()] w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()] max_w = max(w1, w2, w3) max_h = max(h1, h2, h3) combined_width = max_w * 3 + 40 combined_height = max_h + 20 combined = st.SVGFigure(Unit(combined_width), Unit(combined_height)) x_offset = 10 y_offset = 10 r1.moveto(x_offset, y_offset) x_offset += max_w + 20 r2.moveto(x_offset, y_offset) x_offset += max_w + 20 r3.moveto(x_offset, y_offset) combined.append([r1, r2, r3]) combined.save(output_svg) # SVG -> PNG cairosvg.svg2png(url=output_svg, write_to=output_image) # ---------------- Main ---------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step") ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output") ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)") ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差") ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)") ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)") ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长") ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG,不合成") ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg)") ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png)") args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) device = autodevice() print(f"[Device] {device}") # STEP -> mesh shape = read_step_solid(args.step_path) V_np, F_np = triangulate(shape, deflection=args.defl) V_np = normalize_mesh_np(V_np, unit_scale=args.scale) mesh = build_mesh(V_np, F_np, device) verts = mesh.verts_packed() faces = mesh.faces_packed() fn = compute_face_normals(verts, faces) feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang) view_paths = {} for name in ["top", "left", "front"]: print(f"[View] {name}") cam = make_camera(name, device) alpha = render_silhouette(mesh, cam, image_size=args.res) _, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1) e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4) polys = trace_silhouettes(alpha, threshold=0.5, step=1) out_svg = os.path.join(args.out_dir, f"{name}.svg") svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid, stroke=args.stroke, margin=10, fill="none") view_paths[name] = out_svg print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})") if args.no_combine: print("[Done] 3 views exported (no combine).") return combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg") combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png") # 注意:这里组合顺序按你第二段示例:front / top / left combine_svgs( view1_path=view_paths["front"], view2_path=view_paths["top"], view3_path=view_paths["left"], output_svg=combined_svg, output_image=combined_png ) print(f"[Combined] SVG: {combined_svg}") print(f"[Combined] PNG: {combined_png}") print("[Done] All outputs exported.") if __name__ == "__main__": main()