Files
FiDA-3D-Trellis/2_step_to_svg.py
2026-04-13 11:20:56 +08:00

387 lines
13 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
Then combine 3 SVGs into a single SVG and a PNG.
Views:
- top : from +Z towards origin (view direction -Z)
- left : from -X towards origin (view direction +X)
- front: from +Y towards origin (view direction -Y)
Outputs (default in --out_dir):
- top.svg, left.svg, front.svg
- combined_views.svg
- combined_views.png
"""
import os, sys, math, argparse, tempfile
import numpy as np
import torch
from collections import defaultdict
# ---------- PyTorch3D ----------
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
RasterizationSettings, MeshRasterizer,
MeshRenderer, BlendParams, SoftSilhouetteShader,
look_at_view_transform, OrthographicCameras
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
# ---------- SVG / Image ----------
import svgwrite
from skimage import measure
# ---------- pythonocc (OCC) ----------
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.BRep import BRep_Tool
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.ShapeFix import ShapeFix_Shape
# ---------- Combine SVG ----------
import svgutils.transform as st
from svgutils.compose import Unit
import cairosvg
# ---------------- Utils ----------------
def autodevice():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_step_solid(path: str):
reader = STEPControl_Reader()
stat = reader.ReadFile(path)
if stat != 1:
raise RuntimeError(f"STEP 读取失败: {path}")
reader.TransferRoots()
shape = reader.OneShape()
fixer = ShapeFix_Shape(shape)
fixer.Perform()
return fixer.Shape()
def triangulate(shape, deflection=0.5, angle=0.5):
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
verts, faces, v_off = [], [], 0
exp = TopExp_Explorer(shape, TopAbs_FACE)
while exp.More():
face = exp.Current()
loc = TopLoc_Location()
tri = BRep_Tool.Triangulation(face, loc)
if tri is not None:
nb_nodes = tri.NbNodes()
has_nodes_arr = hasattr(tri, "Nodes")
for i in range(1, nb_nodes + 1):
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
p = p.Transformed(loc.Transformation())
verts.append([p.X(), p.Y(), p.Z()])
nb_tris = tri.NbTriangles()
has_tris_arr = hasattr(tri, "Triangles")
for i in range(1, nb_tris + 1):
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
a, b, c = t.Get()
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
v_off += nb_nodes
exp.Next()
if not verts or not faces:
raise RuntimeError("三角化为空,尝试减小 --defl")
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
c = verts.mean(axis=0, keepdims=True)
v0 = verts - c
s = np.max(np.abs(v0))
s = max(s, 1e-12)
return v0 / s * unit_scale
def build_mesh(V_np, F_np, device):
V = torch.from_numpy(V_np).to(device)
F = torch.from_numpy(F_np).to(device)
return Meshes(verts=[V], faces=[F])
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
v0 = verts[faces[:, 0]];
v1 = verts[faces[:, 1]];
v2 = verts[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=1)
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
thr = math.cos(math.radians(max(0.0, angle_deg)))
e2f = defaultdict(list)
F = faces.shape[0]
for f in range(F):
i, j, k = faces[f].tolist()
for a, b in ((i, j), (j, k), (k, i)):
e = (a, b) if a < b else (b, a)
e2f[e].append(f)
feat = []
for e, fl in e2f.items():
if len(fl) == 1:
feat.append(e)
elif len(fl) == 2:
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
if cosv <= thr:
feat.append(e)
return feat
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
cull_backfaces=True
)
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
with torch.no_grad():
frags: Fragments = rast(mesh, cameras=cameras)
return frags, frags.zbuf[0, ..., 0]
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=1e-6,
faces_per_pixel=50,
cull_backfaces=True
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
)
with torch.no_grad():
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
with torch.no_grad():
scr = cameras.transform_points_screen(
points_world[None, ...],
image_size=((image_size, image_size),)
)
return scr[0] # (N,3)
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
H = W = image_size
vis, hid = [], []
scr = project_points_to_screen(cameras, verts_world, image_size)
zmin = zbuf.detach().cpu().numpy()
for i, j in edges_idx:
p0, p1 = scr[i], scr[j]
m = 0.5 * (p0 + p1)
x = int(np.clip(m[0].item(), 0, W - 1))
y = int(np.clip(m[1].item(), 0, H - 1))
z_proj = m[2].item()
z_ref = zmin[y, x]
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
vis.append(seg)
else:
hid.append(seg)
return vis, hid
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
contours = measure.find_contours(alpha, level=threshold)
polys = []
for cnt in contours:
if len(cnt) < 8:
continue
cnt = cnt[::max(1, step)]
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
return polys
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
# silhouette fill (optional)
for poly in silhouettes:
if len(poly) < 3:
continue
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
for i in range(1, len(poly)):
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
path.append("Z")
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
# visible edges
for (x0, y0, x1, y1) in edges_vis:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=stroke
))
# hidden edges
for (x0, y0, x1, y1) in edges_hid:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=max(1.0, 0.75 * stroke),
stroke_dasharray=[6, 6],
opacity=0.85
))
dwg.save()
# ---------------- Cameras ----------------
def make_camera(view: str, device, dist=2.7):
"""
look_at_view_transform spherical params for engineering views.
"""
if view == "top":
# from +Z to origin (view -Z)
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
elif view == "left":
# from -X to origin (view +X)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
elif view == "front":
# from +Y to origin (view -Y)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
else:
raise ValueError("view must be one of ['top','left','front']")
return OrthographicCameras(R=R, T=T, device=device)
# ---------------- Combine 3 SVGs ----------------
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
"""
横向拼接 3 张 SVG并额外导出 PNG。
"""
out_dir = os.path.dirname(output_svg)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
view1 = st.fromfile(view1_path)
view2 = st.fromfile(view2_path)
view3 = st.fromfile(view3_path)
r1 = view1.getroot()
r2 = view2.getroot()
r3 = view3.getroot()
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
max_w = max(w1, w2, w3)
max_h = max(h1, h2, h3)
combined_width = max_w * 3 + 40
combined_height = max_h + 20
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
x_offset = 10
y_offset = 10
r1.moveto(x_offset, y_offset)
x_offset += max_w + 20
r2.moveto(x_offset, y_offset)
x_offset += max_w + 20
r3.moveto(x_offset, y_offset)
combined.append([r1, r2, r3])
combined.save(output_svg)
# SVG -> PNG
cairosvg.svg2png(url=output_svg, write_to=output_image)
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG不合成")
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg")
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
device = autodevice()
print(f"[Device] {device}")
# STEP -> mesh
shape = read_step_solid(args.step_path)
V_np, F_np = triangulate(shape, deflection=args.defl)
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
mesh = build_mesh(V_np, F_np, device)
verts = mesh.verts_packed()
faces = mesh.faces_packed()
fn = compute_face_normals(verts, faces)
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
view_paths = {}
for name in ["top", "left", "front"]:
print(f"[View] {name}")
cam = make_camera(name, device)
alpha = render_silhouette(mesh, cam, image_size=args.res)
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
out_svg = os.path.join(args.out_dir, f"{name}.svg")
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
stroke=args.stroke, margin=10, fill="none")
view_paths[name] = out_svg
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
if args.no_combine:
print("[Done] 3 views exported (no combine).")
return
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
# 注意这里组合顺序按你第二段示例front / top / left
combine_svgs(
view1_path=view_paths["front"],
view2_path=view_paths["top"],
view3_path=view_paths["left"],
output_svg=combined_svg,
output_image=combined_png
)
print(f"[Combined] SVG: {combined_svg}")
print(f"[Combined] PNG: {combined_png}")
print("[Done] All outputs exported.")
if __name__ == "__main__":
main()