Files
FiDA-3D-Trellis/glb2svg.py
2026-04-13 11:21:23 +08:00

556 lines
18 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import secrets
from collections import defaultdict
from datetime import datetime
import os
import subprocess
from pathlib import Path
import logging
from typing import Optional
import imageio
import numpy as np
import torch
from PIL import Image
# ---------- PyTorch3D ----------
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
RasterizationSettings, MeshRasterizer,
MeshRenderer, BlendParams, SoftSilhouetteShader,
look_at_view_transform, OrthographicCameras
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
# ---------- SVG / Image ----------
import svgwrite
from skimage import measure
# ---------- pythonocc (OCC) ----------
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.BRep import BRep_Tool
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.ShapeFix import ShapeFix_Shape
# ---------- Combine SVG ----------
import svgutils.transform as st
from svgutils.compose import Unit
import cairosvg
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
logger = logging.getLogger(__name__)
"""single image to 3D"""
def generate_unique_name(original_name: str) -> str:
stem, ext = os.path.splitext(original_name)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_part = secrets.token_hex(4) # 8位随机十六进制
return f"{stem}_{timestamp}_{random_part}{ext}"
def image_to_3d(image: list[str], out_dir: str = "trellis_out", single_image: bool = False, seed: int = 1,
steps_sparse: int = 12, cfg_sparse: float = 7.5, steps_slat: int = 12,
cfg_slat: float = 3.0, simplify: float = 0.95, texture_size: int = 1024,
save_video: bool = False, fps: int = 30, video_gs_name: str = "sample_gs.mp4", video_rf_name: str = "sample_rf.mp4", video_mesh_name: str = "sample_mesh.mp4"):
os.makedirs(out_dir, exist_ok=True)
# Optional env
os.environ.setdefault("SPCONV_ALGO", "native")
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
if single_image:
image = Image.open(image[0])
outputs = pipeline.run(
image,
seed=seed,
sparse_structure_sampler_params={
"steps": steps_sparse,
"cfg_strength": cfg_sparse,
},
slat_sampler_params={
"steps": steps_slat,
"cfg_strength": cfg_slat,
},
)
else:
images = [Image.open(p) for p in image]
outputs = pipeline.run_multi_image(
images,
seed=seed,
sparse_structure_sampler_params={
"steps": steps_sparse,
"cfg_strength": cfg_sparse,
},
slat_sampler_params={
"steps": steps_slat,
"cfg_strength": cfg_slat,
},
)
if save_video:
video = render_utils.render_video(outputs["gaussian"][0])["color"]
imageio.mimsave(os.path.join(out_dir, video_gs_name), video, fps=fps)
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
imageio.mimsave(os.path.join(out_dir, video_rf_name), video, fps=fps)
video = render_utils.render_video(outputs["mesh"][0])["normal"]
imageio.mimsave(os.path.join(out_dir, video_mesh_name), video, fps=fps)
glb_path = os.path.join(out_dir, generate_unique_name("sample.glb"))
ply_path = os.path.join(out_dir, generate_unique_name("sample.ply"))
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=simplify,
texture_size=texture_size,
)
# 先写到临时路径
glb.export(glb_path)
outputs["gaussian"][0].save_ply(ply_path)
return glb_path, ply_path
"""
glb_to_obj
"""
def glb_to_obj(glb_input, obj_output_dir="obj_output"):
obj_path = os.path.join(obj_output_dir, generate_unique_name("sample.obj"))
os.makedirs(obj_output_dir, exist_ok=True)
cmd = [
"blender",
"-b",
"-P",
"0_glb_to_obj.py",
"--",
"--input",
glb_input,
"--output",
obj_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result)
print("\n")
if result.returncode != 0:
raise RuntimeError(
f"Blender failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
return obj_path
"""
obj_to_step
"""
def obj_to_step(
input_obj: str | Path,
output_dir: str | Path,
script_path: str | Path = "1_obj_to_step.py",
timeout: int = 120, # 秒,建议根据实际转换耗时调整
freecadcmd_path: Optional[str] = None
) -> bool:
"""
把 OBJ 转 STEP 的 FreeCAD 命令封装成函数
返回 True 表示成功False 表示失败
"""
input_obj = Path(input_obj).resolve()
output_dir = Path(output_dir).resolve()
script_path = Path(script_path).resolve()
print(f"input_obj : {input_obj}")
print(f"output_dir : {output_dir}")
print(f"script_path : {script_path}")
if not input_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_obj}")
if not script_path.exists():
raise FileNotFoundError(f"转换脚本不存在: {script_path}")
# 构造 -c 参数(完全复刻你原来的命令)
python_code = f'''import sys
sys.argv = ["{script_path.name}", "{input_obj}", "{output_dir}"]
exec(open("{script_path}", "r", encoding="utf-8").read())'''
cmd = [
freecadcmd_path or "freecadcmd", # 如果不在 PATH 中,可传入绝对路径
"-c",
python_code
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True, # 自动抛出 CalledProcessError
timeout=timeout,
cwd=script_path.parent # 让脚本能找到相对路径文件
)
logger.info(f"FreeCAD 转换成功: {input_obj.name}{output_dir}")
print(f"FreeCAD 转换成功: {input_obj.name}{output_dir}")
if result.stdout:
logger.debug(result.stdout)
return result.stdout.split()[-1]
except subprocess.TimeoutExpired:
logger.error(f"FreeCAD 转换超时(>{timeout}s: {input_obj}")
return False
except subprocess.CalledProcessError as e:
logger.error(f"FreeCAD 执行失败: {e.stderr}")
return False
except Exception as e:
logger.exception("未知错误")
return False
"""
step to svg
"""
def autodevice():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_step_solid(path: str):
reader = STEPControl_Reader()
stat = reader.ReadFile(path)
if stat != 1:
raise RuntimeError(f"STEP 读取失败: {path}")
reader.TransferRoots()
shape = reader.OneShape()
fixer = ShapeFix_Shape(shape)
fixer.Perform()
return fixer.Shape()
def triangulate(shape, deflection=0.5, angle=0.5):
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
verts, faces, v_off = [], [], 0
exp = TopExp_Explorer(shape, TopAbs_FACE)
while exp.More():
face = exp.Current()
loc = TopLoc_Location()
tri = BRep_Tool.Triangulation(face, loc)
if tri is not None:
nb_nodes = tri.NbNodes()
has_nodes_arr = hasattr(tri, "Nodes")
for i in range(1, nb_nodes + 1):
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
p = p.Transformed(loc.Transformation())
verts.append([p.X(), p.Y(), p.Z()])
nb_tris = tri.NbTriangles()
has_tris_arr = hasattr(tri, "Triangles")
for i in range(1, nb_tris + 1):
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
a, b, c = t.Get()
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
v_off += nb_nodes
exp.Next()
if not verts or not faces:
raise RuntimeError("三角化为空,尝试减小 --defl")
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
c = verts.mean(axis=0, keepdims=True)
v0 = verts - c
s = np.max(np.abs(v0))
s = max(s, 1e-12)
return v0 / s * unit_scale
def build_mesh(V_np, F_np, device):
V = torch.from_numpy(V_np).to(device)
F = torch.from_numpy(F_np).to(device)
return Meshes(verts=[V], faces=[F])
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=1)
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
thr = math.cos(math.radians(max(0.0, angle_deg)))
e2f = defaultdict(list)
F = faces.shape[0]
for f in range(F):
i, j, k = faces[f].tolist()
for a, b in ((i, j), (j, k), (k, i)):
e = (a, b) if a < b else (b, a)
e2f[e].append(f)
feat = []
for e, fl in e2f.items():
if len(fl) == 1:
feat.append(e)
elif len(fl) == 2:
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
if cosv <= thr:
feat.append(e)
return feat
def make_camera(view: str, device, dist=2.7):
"""
look_at_view_transform spherical params for engineering views.
"""
if view == "top":
# from +Z to origin (view -Z)
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
elif view == "left":
# from -X to origin (view +X)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
elif view == "front":
# from +Y to origin (view -Y)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
else:
raise ValueError("view must be one of ['top','left','front']")
return OrthographicCameras(R=R, T=T, device=device)
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=1e-6,
faces_per_pixel=50,
cull_backfaces=True
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
)
with torch.no_grad():
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
cull_backfaces=True
)
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
with torch.no_grad():
frags: Fragments = rast(mesh, cameras=cameras)
return frags, frags.zbuf[0, ..., 0]
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
with torch.no_grad():
scr = cameras.transform_points_screen(
points_world[None, ...],
image_size=((image_size, image_size),)
)
return scr[0]
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
H = W = image_size
vis, hid = [], []
scr = project_points_to_screen(cameras, verts_world, image_size)
zmin = zbuf.detach().cpu().numpy()
for i, j in edges_idx:
p0, p1 = scr[i], scr[j]
m = 0.5 * (p0 + p1)
x = int(np.clip(m[0].item(), 0, W - 1))
y = int(np.clip(m[1].item(), 0, H - 1))
z_proj = m[2].item()
z_ref = zmin[y, x]
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
vis.append(seg)
else:
hid.append(seg)
return vis, hid
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
contours = measure.find_contours(alpha, level=threshold)
polys = []
for cnt in contours:
if len(cnt) < 8:
continue
cnt = cnt[::max(1, step)]
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
return polys
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
# silhouette fill (optional)
for poly in silhouettes:
if len(poly) < 3:
continue
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
for i in range(1, len(poly)):
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
path.append("Z")
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
# visible edges
for (x0, y0, x1, y1) in edges_vis:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=stroke
))
# hidden edges
for (x0, y0, x1, y1) in edges_hid:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=max(1.0, 0.75 * stroke),
stroke_dasharray=[6, 6],
opacity=0.85
))
dwg.save()
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
"""
横向拼接 3 张 SVG并额外导出 PNG。
"""
out_dir = os.path.dirname(output_svg)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
view1 = st.fromfile(view1_path)
view2 = st.fromfile(view2_path)
view3 = st.fromfile(view3_path)
r1 = view1.getroot()
r2 = view2.getroot()
r3 = view3.getroot()
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
max_w = max(w1, w2, w3)
max_h = max(h1, h2, h3)
combined_width = max_w * 3 + 40
combined_height = max_h + 20
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
x_offset = 10
y_offset = 10
r1.moveto(x_offset, y_offset)
x_offset += max_w + 20
r2.moveto(x_offset, y_offset)
x_offset += max_w + 20
r3.moveto(x_offset, y_offset)
combined.append([r1, r2, r3])
combined.save(output_svg)
# SVG -> PNG
cairosvg.svg2png(url=output_svg, write_to=output_image)
def step_to_svg(step_path, out_dir,
res: int = 512, defl: float = 1.5,
ang: float = 65.0, stroke: float = 1.3,
scale: float = 1.0, no_combine: bool = False,
combined_svg: str = None, combined_png: str = None):
os.makedirs(out_dir, exist_ok=True)
device = autodevice()
print(f"[Device] {device}")
# STEP -> mesh
shape = read_step_solid(step_path)
V_np, F_np = triangulate(shape, deflection=defl)
V_np = normalize_mesh_np(V_np, unit_scale=scale)
mesh = build_mesh(V_np, F_np, device)
verts = mesh.verts_packed()
faces = mesh.faces_packed()
fn = compute_face_normals(verts, faces)
feat_edges = extract_feature_edges(faces, fn, angle_deg=ang)
view_paths = {}
for name in ["top", "left", "front"]:
print(f"[View] {name}")
cam = make_camera(name, device)
alpha = render_silhouette(mesh, cam, image_size=res)
_, zbuf = raster_fragments(mesh, cam, image_size=res, faces_per_pixel=1)
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, res, eps=1e-4)
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
out_svg = os.path.join(out_dir, f"{name}.svg")
svg_from_view(out_svg, res, res, polys, e_vis, e_hid,
stroke=stroke, margin=10, fill="none")
view_paths[name] = out_svg
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
if no_combine:
print("[Done] 3 views exported (no combine).")
return None
combined_svg = combined_svg or os.path.join(out_dir, "combined_views.svg")
combined_png = combined_png or os.path.join(out_dir, "combined_views.png")
# 注意这里组合顺序按你第二段示例front / top / left
combine_svgs(
view1_path=view_paths["front"],
view2_path=view_paths["top"],
view3_path=view_paths["left"],
output_svg=combined_svg,
output_image=combined_png
)
return combined_svg, combined_png
if __name__ == "__main__":
# step -1
glb_result, ply_result = image_to_3d(image=["assets/whiteboard_exported_image.png"])
print(f"glb_result : {glb_result}", f"ply_result : {ply_result}")
# step 0
obj_result = glb_to_obj(glb_result)
print(f"obj_result : {obj_result}")
# step 1
step_result = obj_to_step(input_obj=obj_result, output_dir="step_output", script_path="1_obj_to_step.py")
# step 2
out_dir = "svg_output"
combined_svg, combined_png = step_to_svg(step_path=step_result, out_dir=out_dir)
print(f"combined_svg : {combined_svg}", f"combined_png : {combined_png}")