387 lines
13 KiB
Python
Executable File
387 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
|
||
Then combine 3 SVGs into a single SVG and a PNG.
|
||
|
||
Views:
|
||
- top : from +Z towards origin (view direction -Z)
|
||
- left : from -X towards origin (view direction +X)
|
||
- front: from +Y towards origin (view direction -Y)
|
||
|
||
Outputs (default in --out_dir):
|
||
- top.svg, left.svg, front.svg
|
||
- combined_views.svg
|
||
- combined_views.png
|
||
"""
|
||
|
||
import os, sys, math, argparse, tempfile
|
||
import numpy as np
|
||
import torch
|
||
from collections import defaultdict
|
||
|
||
# ---------- PyTorch3D ----------
|
||
from pytorch3d.structures import Meshes
|
||
from pytorch3d.renderer import (
|
||
RasterizationSettings, MeshRasterizer,
|
||
MeshRenderer, BlendParams, SoftSilhouetteShader,
|
||
look_at_view_transform, OrthographicCameras
|
||
)
|
||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||
|
||
# ---------- SVG / Image ----------
|
||
import svgwrite
|
||
from skimage import measure
|
||
|
||
# ---------- pythonocc (OCC) ----------
|
||
from OCC.Core.STEPControl import STEPControl_Reader
|
||
from OCC.Core.TopAbs import TopAbs_FACE
|
||
from OCC.Core.TopExp import TopExp_Explorer
|
||
from OCC.Core.BRep import BRep_Tool
|
||
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
||
from OCC.Core.TopLoc import TopLoc_Location
|
||
from OCC.Core.ShapeFix import ShapeFix_Shape
|
||
|
||
# ---------- Combine SVG ----------
|
||
import svgutils.transform as st
|
||
from svgutils.compose import Unit
|
||
import cairosvg
|
||
|
||
|
||
# ---------------- Utils ----------------
|
||
|
||
def autodevice():
|
||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
||
|
||
def read_step_solid(path: str):
|
||
reader = STEPControl_Reader()
|
||
stat = reader.ReadFile(path)
|
||
if stat != 1:
|
||
raise RuntimeError(f"STEP 读取失败: {path}")
|
||
reader.TransferRoots()
|
||
shape = reader.OneShape()
|
||
fixer = ShapeFix_Shape(shape)
|
||
fixer.Perform()
|
||
return fixer.Shape()
|
||
|
||
|
||
def triangulate(shape, deflection=0.5, angle=0.5):
|
||
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
|
||
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
|
||
verts, faces, v_off = [], [], 0
|
||
exp = TopExp_Explorer(shape, TopAbs_FACE)
|
||
while exp.More():
|
||
face = exp.Current()
|
||
loc = TopLoc_Location()
|
||
tri = BRep_Tool.Triangulation(face, loc)
|
||
if tri is not None:
|
||
nb_nodes = tri.NbNodes()
|
||
has_nodes_arr = hasattr(tri, "Nodes")
|
||
for i in range(1, nb_nodes + 1):
|
||
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
|
||
p = p.Transformed(loc.Transformation())
|
||
verts.append([p.X(), p.Y(), p.Z()])
|
||
nb_tris = tri.NbTriangles()
|
||
has_tris_arr = hasattr(tri, "Triangles")
|
||
for i in range(1, nb_tris + 1):
|
||
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
|
||
a, b, c = t.Get()
|
||
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
|
||
v_off += nb_nodes
|
||
exp.Next()
|
||
if not verts or not faces:
|
||
raise RuntimeError("三角化为空,尝试减小 --defl")
|
||
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
|
||
|
||
|
||
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
|
||
c = verts.mean(axis=0, keepdims=True)
|
||
v0 = verts - c
|
||
s = np.max(np.abs(v0))
|
||
s = max(s, 1e-12)
|
||
return v0 / s * unit_scale
|
||
|
||
|
||
def build_mesh(V_np, F_np, device):
|
||
V = torch.from_numpy(V_np).to(device)
|
||
F = torch.from_numpy(F_np).to(device)
|
||
return Meshes(verts=[V], faces=[F])
|
||
|
||
|
||
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
|
||
v0 = verts[faces[:, 0]];
|
||
v1 = verts[faces[:, 1]];
|
||
v2 = verts[faces[:, 2]]
|
||
n = torch.cross(v1 - v0, v2 - v0, dim=1)
|
||
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
|
||
|
||
|
||
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
|
||
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
|
||
thr = math.cos(math.radians(max(0.0, angle_deg)))
|
||
e2f = defaultdict(list)
|
||
F = faces.shape[0]
|
||
for f in range(F):
|
||
i, j, k = faces[f].tolist()
|
||
for a, b in ((i, j), (j, k), (k, i)):
|
||
e = (a, b) if a < b else (b, a)
|
||
e2f[e].append(f)
|
||
feat = []
|
||
for e, fl in e2f.items():
|
||
if len(fl) == 1:
|
||
feat.append(e)
|
||
elif len(fl) == 2:
|
||
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
|
||
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
|
||
if cosv <= thr:
|
||
feat.append(e)
|
||
return feat
|
||
|
||
|
||
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
|
||
rs = RasterizationSettings(
|
||
image_size=image_size,
|
||
blur_radius=0.0,
|
||
faces_per_pixel=faces_per_pixel,
|
||
cull_backfaces=True
|
||
)
|
||
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
|
||
with torch.no_grad():
|
||
frags: Fragments = rast(mesh, cameras=cameras)
|
||
return frags, frags.zbuf[0, ..., 0]
|
||
|
||
|
||
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
|
||
rs = RasterizationSettings(
|
||
image_size=image_size,
|
||
blur_radius=1e-6,
|
||
faces_per_pixel=50,
|
||
cull_backfaces=True
|
||
)
|
||
renderer = MeshRenderer(
|
||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
|
||
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
|
||
)
|
||
with torch.no_grad():
|
||
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
|
||
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
|
||
|
||
|
||
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
|
||
with torch.no_grad():
|
||
scr = cameras.transform_points_screen(
|
||
points_world[None, ...],
|
||
image_size=((image_size, image_size),)
|
||
)
|
||
return scr[0] # (N,3)
|
||
|
||
|
||
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
|
||
H = W = image_size
|
||
vis, hid = [], []
|
||
scr = project_points_to_screen(cameras, verts_world, image_size)
|
||
zmin = zbuf.detach().cpu().numpy()
|
||
for i, j in edges_idx:
|
||
p0, p1 = scr[i], scr[j]
|
||
m = 0.5 * (p0 + p1)
|
||
x = int(np.clip(m[0].item(), 0, W - 1))
|
||
y = int(np.clip(m[1].item(), 0, H - 1))
|
||
z_proj = m[2].item()
|
||
z_ref = zmin[y, x]
|
||
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
|
||
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
|
||
vis.append(seg)
|
||
else:
|
||
hid.append(seg)
|
||
return vis, hid
|
||
|
||
|
||
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
|
||
contours = measure.find_contours(alpha, level=threshold)
|
||
polys = []
|
||
for cnt in contours:
|
||
if len(cnt) < 8:
|
||
continue
|
||
cnt = cnt[::max(1, step)]
|
||
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
|
||
return polys
|
||
|
||
|
||
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
|
||
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
|
||
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
|
||
|
||
# silhouette fill (optional)
|
||
for poly in silhouettes:
|
||
if len(poly) < 3:
|
||
continue
|
||
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
|
||
for i in range(1, len(poly)):
|
||
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
|
||
path.append("Z")
|
||
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
|
||
|
||
# visible edges
|
||
for (x0, y0, x1, y1) in edges_vis:
|
||
dwg.add(dwg.line(
|
||
(x0 + margin, y0 + margin),
|
||
(x1 + margin, y1 + margin),
|
||
stroke="black",
|
||
stroke_width=stroke
|
||
))
|
||
|
||
# hidden edges
|
||
for (x0, y0, x1, y1) in edges_hid:
|
||
dwg.add(dwg.line(
|
||
(x0 + margin, y0 + margin),
|
||
(x1 + margin, y1 + margin),
|
||
stroke="black",
|
||
stroke_width=max(1.0, 0.75 * stroke),
|
||
stroke_dasharray=[6, 6],
|
||
opacity=0.85
|
||
))
|
||
|
||
dwg.save()
|
||
|
||
|
||
# ---------------- Cameras ----------------
|
||
|
||
def make_camera(view: str, device, dist=2.7):
|
||
"""
|
||
look_at_view_transform spherical params for engineering views.
|
||
"""
|
||
if view == "top":
|
||
# from +Z to origin (view -Z)
|
||
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
|
||
elif view == "left":
|
||
# from -X to origin (view +X)
|
||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
|
||
elif view == "front":
|
||
# from +Y to origin (view -Y)
|
||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
|
||
else:
|
||
raise ValueError("view must be one of ['top','left','front']")
|
||
return OrthographicCameras(R=R, T=T, device=device)
|
||
|
||
|
||
# ---------------- Combine 3 SVGs ----------------
|
||
|
||
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
|
||
"""
|
||
横向拼接 3 张 SVG,并额外导出 PNG。
|
||
"""
|
||
out_dir = os.path.dirname(output_svg)
|
||
if out_dir:
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
view1 = st.fromfile(view1_path)
|
||
view2 = st.fromfile(view2_path)
|
||
view3 = st.fromfile(view3_path)
|
||
|
||
r1 = view1.getroot()
|
||
r2 = view2.getroot()
|
||
r3 = view3.getroot()
|
||
|
||
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
|
||
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
|
||
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
|
||
|
||
max_w = max(w1, w2, w3)
|
||
max_h = max(h1, h2, h3)
|
||
|
||
combined_width = max_w * 3 + 40
|
||
combined_height = max_h + 20
|
||
|
||
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
|
||
|
||
x_offset = 10
|
||
y_offset = 10
|
||
|
||
r1.moveto(x_offset, y_offset)
|
||
x_offset += max_w + 20
|
||
r2.moveto(x_offset, y_offset)
|
||
x_offset += max_w + 20
|
||
r3.moveto(x_offset, y_offset)
|
||
|
||
combined.append([r1, r2, r3])
|
||
combined.save(output_svg)
|
||
|
||
# SVG -> PNG
|
||
cairosvg.svg2png(url=output_svg, write_to=output_image)
|
||
|
||
|
||
# ---------------- Main ----------------
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
|
||
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
|
||
|
||
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
|
||
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
|
||
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
|
||
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
|
||
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
|
||
|
||
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG,不合成")
|
||
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg)")
|
||
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png)")
|
||
|
||
args = ap.parse_args()
|
||
|
||
os.makedirs(args.out_dir, exist_ok=True)
|
||
device = autodevice()
|
||
print(f"[Device] {device}")
|
||
|
||
# STEP -> mesh
|
||
shape = read_step_solid(args.step_path)
|
||
V_np, F_np = triangulate(shape, deflection=args.defl)
|
||
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
|
||
|
||
mesh = build_mesh(V_np, F_np, device)
|
||
verts = mesh.verts_packed()
|
||
faces = mesh.faces_packed()
|
||
fn = compute_face_normals(verts, faces)
|
||
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
|
||
|
||
view_paths = {}
|
||
for name in ["top", "left", "front"]:
|
||
print(f"[View] {name}")
|
||
cam = make_camera(name, device)
|
||
alpha = render_silhouette(mesh, cam, image_size=args.res)
|
||
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
|
||
|
||
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
|
||
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
|
||
|
||
out_svg = os.path.join(args.out_dir, f"{name}.svg")
|
||
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
|
||
stroke=args.stroke, margin=10, fill="none")
|
||
view_paths[name] = out_svg
|
||
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
|
||
|
||
if args.no_combine:
|
||
print("[Done] 3 views exported (no combine).")
|
||
return
|
||
|
||
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
|
||
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
|
||
|
||
# 注意:这里组合顺序按你第二段示例:front / top / left
|
||
combine_svgs(
|
||
view1_path=view_paths["front"],
|
||
view2_path=view_paths["top"],
|
||
view3_path=view_paths["left"],
|
||
output_svg=combined_svg,
|
||
output_image=combined_png
|
||
)
|
||
print(f"[Combined] SVG: {combined_svg}")
|
||
print(f"[Combined] PNG: {combined_png}")
|
||
print("[Done] All outputs exported.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|