387 lines
13 KiB
Python
387 lines
13 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
|
|||
|
|
Then combine 3 SVGs into a single SVG and a PNG.
|
|||
|
|
|
|||
|
|
Views:
|
|||
|
|
- top : from +Z towards origin (view direction -Z)
|
|||
|
|
- left : from -X towards origin (view direction +X)
|
|||
|
|
- front: from +Y towards origin (view direction -Y)
|
|||
|
|
|
|||
|
|
Outputs (default in --out_dir):
|
|||
|
|
- top.svg, left.svg, front.svg
|
|||
|
|
- combined_views.svg
|
|||
|
|
- combined_views.png
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os, sys, math, argparse, tempfile
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from collections import defaultdict
|
|||
|
|
|
|||
|
|
# ---------- PyTorch3D ----------
|
|||
|
|
from pytorch3d.structures import Meshes
|
|||
|
|
from pytorch3d.renderer import (
|
|||
|
|
RasterizationSettings, MeshRasterizer,
|
|||
|
|
MeshRenderer, BlendParams, SoftSilhouetteShader,
|
|||
|
|
look_at_view_transform, OrthographicCameras
|
|||
|
|
)
|
|||
|
|
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
|||
|
|
|
|||
|
|
# ---------- SVG / Image ----------
|
|||
|
|
import svgwrite
|
|||
|
|
from skimage import measure
|
|||
|
|
|
|||
|
|
# ---------- pythonocc (OCC) ----------
|
|||
|
|
from OCC.Core.STEPControl import STEPControl_Reader
|
|||
|
|
from OCC.Core.TopAbs import TopAbs_FACE
|
|||
|
|
from OCC.Core.TopExp import TopExp_Explorer
|
|||
|
|
from OCC.Core.BRep import BRep_Tool
|
|||
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
|||
|
|
from OCC.Core.TopLoc import TopLoc_Location
|
|||
|
|
from OCC.Core.ShapeFix import ShapeFix_Shape
|
|||
|
|
|
|||
|
|
# ---------- Combine SVG ----------
|
|||
|
|
import svgutils.transform as st
|
|||
|
|
from svgutils.compose import Unit
|
|||
|
|
import cairosvg
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------- Utils ----------------
|
|||
|
|
|
|||
|
|
def autodevice():
|
|||
|
|
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_step_solid(path: str):
|
|||
|
|
reader = STEPControl_Reader()
|
|||
|
|
stat = reader.ReadFile(path)
|
|||
|
|
if stat != 1:
|
|||
|
|
raise RuntimeError(f"STEP 读取失败: {path}")
|
|||
|
|
reader.TransferRoots()
|
|||
|
|
shape = reader.OneShape()
|
|||
|
|
fixer = ShapeFix_Shape(shape)
|
|||
|
|
fixer.Perform()
|
|||
|
|
return fixer.Shape()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def triangulate(shape, deflection=0.5, angle=0.5):
|
|||
|
|
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
|
|||
|
|
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
|
|||
|
|
verts, faces, v_off = [], [], 0
|
|||
|
|
exp = TopExp_Explorer(shape, TopAbs_FACE)
|
|||
|
|
while exp.More():
|
|||
|
|
face = exp.Current()
|
|||
|
|
loc = TopLoc_Location()
|
|||
|
|
tri = BRep_Tool.Triangulation(face, loc)
|
|||
|
|
if tri is not None:
|
|||
|
|
nb_nodes = tri.NbNodes()
|
|||
|
|
has_nodes_arr = hasattr(tri, "Nodes")
|
|||
|
|
for i in range(1, nb_nodes + 1):
|
|||
|
|
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
|
|||
|
|
p = p.Transformed(loc.Transformation())
|
|||
|
|
verts.append([p.X(), p.Y(), p.Z()])
|
|||
|
|
nb_tris = tri.NbTriangles()
|
|||
|
|
has_tris_arr = hasattr(tri, "Triangles")
|
|||
|
|
for i in range(1, nb_tris + 1):
|
|||
|
|
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
|
|||
|
|
a, b, c = t.Get()
|
|||
|
|
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
|
|||
|
|
v_off += nb_nodes
|
|||
|
|
exp.Next()
|
|||
|
|
if not verts or not faces:
|
|||
|
|
raise RuntimeError("三角化为空,尝试减小 --defl")
|
|||
|
|
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
|
|||
|
|
c = verts.mean(axis=0, keepdims=True)
|
|||
|
|
v0 = verts - c
|
|||
|
|
s = np.max(np.abs(v0))
|
|||
|
|
s = max(s, 1e-12)
|
|||
|
|
return v0 / s * unit_scale
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_mesh(V_np, F_np, device):
|
|||
|
|
V = torch.from_numpy(V_np).to(device)
|
|||
|
|
F = torch.from_numpy(F_np).to(device)
|
|||
|
|
return Meshes(verts=[V], faces=[F])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
|
|||
|
|
v0 = verts[faces[:, 0]];
|
|||
|
|
v1 = verts[faces[:, 1]];
|
|||
|
|
v2 = verts[faces[:, 2]]
|
|||
|
|
n = torch.cross(v1 - v0, v2 - v0, dim=1)
|
|||
|
|
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
|
|||
|
|
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
|
|||
|
|
thr = math.cos(math.radians(max(0.0, angle_deg)))
|
|||
|
|
e2f = defaultdict(list)
|
|||
|
|
F = faces.shape[0]
|
|||
|
|
for f in range(F):
|
|||
|
|
i, j, k = faces[f].tolist()
|
|||
|
|
for a, b in ((i, j), (j, k), (k, i)):
|
|||
|
|
e = (a, b) if a < b else (b, a)
|
|||
|
|
e2f[e].append(f)
|
|||
|
|
feat = []
|
|||
|
|
for e, fl in e2f.items():
|
|||
|
|
if len(fl) == 1:
|
|||
|
|
feat.append(e)
|
|||
|
|
elif len(fl) == 2:
|
|||
|
|
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
|
|||
|
|
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
|
|||
|
|
if cosv <= thr:
|
|||
|
|
feat.append(e)
|
|||
|
|
return feat
|
|||
|
|
|
|||
|
|
|
|||
|
|
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
|
|||
|
|
rs = RasterizationSettings(
|
|||
|
|
image_size=image_size,
|
|||
|
|
blur_radius=0.0,
|
|||
|
|
faces_per_pixel=faces_per_pixel,
|
|||
|
|
cull_backfaces=True
|
|||
|
|
)
|
|||
|
|
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
frags: Fragments = rast(mesh, cameras=cameras)
|
|||
|
|
return frags, frags.zbuf[0, ..., 0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
|
|||
|
|
rs = RasterizationSettings(
|
|||
|
|
image_size=image_size,
|
|||
|
|
blur_radius=1e-6,
|
|||
|
|
faces_per_pixel=50,
|
|||
|
|
cull_backfaces=True
|
|||
|
|
)
|
|||
|
|
renderer = MeshRenderer(
|
|||
|
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
|
|||
|
|
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
|
|||
|
|
)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
|
|||
|
|
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
|
|||
|
|
with torch.no_grad():
|
|||
|
|
scr = cameras.transform_points_screen(
|
|||
|
|
points_world[None, ...],
|
|||
|
|
image_size=((image_size, image_size),)
|
|||
|
|
)
|
|||
|
|
return scr[0] # (N,3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
|
|||
|
|
H = W = image_size
|
|||
|
|
vis, hid = [], []
|
|||
|
|
scr = project_points_to_screen(cameras, verts_world, image_size)
|
|||
|
|
zmin = zbuf.detach().cpu().numpy()
|
|||
|
|
for i, j in edges_idx:
|
|||
|
|
p0, p1 = scr[i], scr[j]
|
|||
|
|
m = 0.5 * (p0 + p1)
|
|||
|
|
x = int(np.clip(m[0].item(), 0, W - 1))
|
|||
|
|
y = int(np.clip(m[1].item(), 0, H - 1))
|
|||
|
|
z_proj = m[2].item()
|
|||
|
|
z_ref = zmin[y, x]
|
|||
|
|
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
|
|||
|
|
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
|
|||
|
|
vis.append(seg)
|
|||
|
|
else:
|
|||
|
|
hid.append(seg)
|
|||
|
|
return vis, hid
|
|||
|
|
|
|||
|
|
|
|||
|
|
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
|
|||
|
|
contours = measure.find_contours(alpha, level=threshold)
|
|||
|
|
polys = []
|
|||
|
|
for cnt in contours:
|
|||
|
|
if len(cnt) < 8:
|
|||
|
|
continue
|
|||
|
|
cnt = cnt[::max(1, step)]
|
|||
|
|
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
|
|||
|
|
return polys
|
|||
|
|
|
|||
|
|
|
|||
|
|
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
|
|||
|
|
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
|
|||
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
|
|||
|
|
|
|||
|
|
# silhouette fill (optional)
|
|||
|
|
for poly in silhouettes:
|
|||
|
|
if len(poly) < 3:
|
|||
|
|
continue
|
|||
|
|
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
|
|||
|
|
for i in range(1, len(poly)):
|
|||
|
|
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
|
|||
|
|
path.append("Z")
|
|||
|
|
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
|
|||
|
|
|
|||
|
|
# visible edges
|
|||
|
|
for (x0, y0, x1, y1) in edges_vis:
|
|||
|
|
dwg.add(dwg.line(
|
|||
|
|
(x0 + margin, y0 + margin),
|
|||
|
|
(x1 + margin, y1 + margin),
|
|||
|
|
stroke="black",
|
|||
|
|
stroke_width=stroke
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# hidden edges
|
|||
|
|
for (x0, y0, x1, y1) in edges_hid:
|
|||
|
|
dwg.add(dwg.line(
|
|||
|
|
(x0 + margin, y0 + margin),
|
|||
|
|
(x1 + margin, y1 + margin),
|
|||
|
|
stroke="black",
|
|||
|
|
stroke_width=max(1.0, 0.75 * stroke),
|
|||
|
|
stroke_dasharray=[6, 6],
|
|||
|
|
opacity=0.85
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
dwg.save()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------- Cameras ----------------
|
|||
|
|
|
|||
|
|
def make_camera(view: str, device, dist=2.7):
|
|||
|
|
"""
|
|||
|
|
look_at_view_transform spherical params for engineering views.
|
|||
|
|
"""
|
|||
|
|
if view == "top":
|
|||
|
|
# from +Z to origin (view -Z)
|
|||
|
|
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
|
|||
|
|
elif view == "left":
|
|||
|
|
# from -X to origin (view +X)
|
|||
|
|
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
|
|||
|
|
elif view == "front":
|
|||
|
|
# from +Y to origin (view -Y)
|
|||
|
|
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
|
|||
|
|
else:
|
|||
|
|
raise ValueError("view must be one of ['top','left','front']")
|
|||
|
|
return OrthographicCameras(R=R, T=T, device=device)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------- Combine 3 SVGs ----------------
|
|||
|
|
|
|||
|
|
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
|
|||
|
|
"""
|
|||
|
|
横向拼接 3 张 SVG,并额外导出 PNG。
|
|||
|
|
"""
|
|||
|
|
out_dir = os.path.dirname(output_svg)
|
|||
|
|
if out_dir:
|
|||
|
|
os.makedirs(out_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
view1 = st.fromfile(view1_path)
|
|||
|
|
view2 = st.fromfile(view2_path)
|
|||
|
|
view3 = st.fromfile(view3_path)
|
|||
|
|
|
|||
|
|
r1 = view1.getroot()
|
|||
|
|
r2 = view2.getroot()
|
|||
|
|
r3 = view3.getroot()
|
|||
|
|
|
|||
|
|
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
|
|||
|
|
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
|
|||
|
|
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
|
|||
|
|
|
|||
|
|
max_w = max(w1, w2, w3)
|
|||
|
|
max_h = max(h1, h2, h3)
|
|||
|
|
|
|||
|
|
combined_width = max_w * 3 + 40
|
|||
|
|
combined_height = max_h + 20
|
|||
|
|
|
|||
|
|
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
|
|||
|
|
|
|||
|
|
x_offset = 10
|
|||
|
|
y_offset = 10
|
|||
|
|
|
|||
|
|
r1.moveto(x_offset, y_offset)
|
|||
|
|
x_offset += max_w + 20
|
|||
|
|
r2.moveto(x_offset, y_offset)
|
|||
|
|
x_offset += max_w + 20
|
|||
|
|
r3.moveto(x_offset, y_offset)
|
|||
|
|
|
|||
|
|
combined.append([r1, r2, r3])
|
|||
|
|
combined.save(output_svg)
|
|||
|
|
|
|||
|
|
# SVG -> PNG
|
|||
|
|
cairosvg.svg2png(url=output_svg, write_to=output_image)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------- Main ----------------
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
ap = argparse.ArgumentParser()
|
|||
|
|
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
|
|||
|
|
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
|
|||
|
|
|
|||
|
|
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
|
|||
|
|
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
|
|||
|
|
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
|
|||
|
|
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
|
|||
|
|
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
|
|||
|
|
|
|||
|
|
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG,不合成")
|
|||
|
|
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg)")
|
|||
|
|
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png)")
|
|||
|
|
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
|
|||
|
|
os.makedirs(args.out_dir, exist_ok=True)
|
|||
|
|
device = autodevice()
|
|||
|
|
print(f"[Device] {device}")
|
|||
|
|
|
|||
|
|
# STEP -> mesh
|
|||
|
|
shape = read_step_solid(args.step_path)
|
|||
|
|
V_np, F_np = triangulate(shape, deflection=args.defl)
|
|||
|
|
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
|
|||
|
|
|
|||
|
|
mesh = build_mesh(V_np, F_np, device)
|
|||
|
|
verts = mesh.verts_packed()
|
|||
|
|
faces = mesh.faces_packed()
|
|||
|
|
fn = compute_face_normals(verts, faces)
|
|||
|
|
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
|
|||
|
|
|
|||
|
|
view_paths = {}
|
|||
|
|
for name in ["top", "left", "front"]:
|
|||
|
|
print(f"[View] {name}")
|
|||
|
|
cam = make_camera(name, device)
|
|||
|
|
alpha = render_silhouette(mesh, cam, image_size=args.res)
|
|||
|
|
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
|
|||
|
|
|
|||
|
|
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
|
|||
|
|
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
|
|||
|
|
|
|||
|
|
out_svg = os.path.join(args.out_dir, f"{name}.svg")
|
|||
|
|
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
|
|||
|
|
stroke=args.stroke, margin=10, fill="none")
|
|||
|
|
view_paths[name] = out_svg
|
|||
|
|
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
|
|||
|
|
|
|||
|
|
if args.no_combine:
|
|||
|
|
print("[Done] 3 views exported (no combine).")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
|
|||
|
|
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
|
|||
|
|
|
|||
|
|
# 注意:这里组合顺序按你第二段示例:front / top / left
|
|||
|
|
combine_svgs(
|
|||
|
|
view1_path=view_paths["front"],
|
|||
|
|
view2_path=view_paths["top"],
|
|||
|
|
view3_path=view_paths["left"],
|
|||
|
|
output_svg=combined_svg,
|
|||
|
|
output_image=combined_png
|
|||
|
|
)
|
|||
|
|
print(f"[Combined] SVG: {combined_svg}")
|
|||
|
|
print(f"[Combined] PNG: {combined_png}")
|
|||
|
|
print("[Done] All outputs exported.")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|