556 lines
18 KiB
Python
Executable File
556 lines
18 KiB
Python
Executable File
import math
|
||
import secrets
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
|
||
import os
|
||
import subprocess
|
||
from pathlib import Path
|
||
import logging
|
||
from typing import Optional
|
||
|
||
import imageio
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image
|
||
|
||
# ---------- PyTorch3D ----------
|
||
from pytorch3d.structures import Meshes
|
||
from pytorch3d.renderer import (
|
||
RasterizationSettings, MeshRasterizer,
|
||
MeshRenderer, BlendParams, SoftSilhouetteShader,
|
||
look_at_view_transform, OrthographicCameras
|
||
)
|
||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||
|
||
# ---------- SVG / Image ----------
|
||
import svgwrite
|
||
from skimage import measure
|
||
|
||
# ---------- pythonocc (OCC) ----------
|
||
from OCC.Core.STEPControl import STEPControl_Reader
|
||
from OCC.Core.TopAbs import TopAbs_FACE
|
||
from OCC.Core.TopExp import TopExp_Explorer
|
||
from OCC.Core.BRep import BRep_Tool
|
||
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
||
from OCC.Core.TopLoc import TopLoc_Location
|
||
from OCC.Core.ShapeFix import ShapeFix_Shape
|
||
|
||
# ---------- Combine SVG ----------
|
||
import svgutils.transform as st
|
||
from svgutils.compose import Unit
|
||
import cairosvg
|
||
|
||
from trellis.pipelines import TrellisImageTo3DPipeline
|
||
from trellis.utils import render_utils, postprocessing_utils
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
"""single image to 3D"""
|
||
|
||
|
||
def generate_unique_name(original_name: str) -> str:
|
||
stem, ext = os.path.splitext(original_name)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
random_part = secrets.token_hex(4) # 8位随机十六进制
|
||
return f"{stem}_{timestamp}_{random_part}{ext}"
|
||
|
||
|
||
def image_to_3d(image: list[str], out_dir: str = "trellis_out", single_image: bool = False, seed: int = 1,
|
||
steps_sparse: int = 12, cfg_sparse: float = 7.5, steps_slat: int = 12,
|
||
cfg_slat: float = 3.0, simplify: float = 0.95, texture_size: int = 1024,
|
||
save_video: bool = False, fps: int = 30, video_gs_name: str = "sample_gs.mp4", video_rf_name: str = "sample_rf.mp4", video_mesh_name: str = "sample_mesh.mp4"):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
# Optional env
|
||
os.environ.setdefault("SPCONV_ALGO", "native")
|
||
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
|
||
pipeline.cuda()
|
||
|
||
if single_image:
|
||
image = Image.open(image[0])
|
||
outputs = pipeline.run(
|
||
image,
|
||
seed=seed,
|
||
sparse_structure_sampler_params={
|
||
"steps": steps_sparse,
|
||
"cfg_strength": cfg_sparse,
|
||
},
|
||
slat_sampler_params={
|
||
"steps": steps_slat,
|
||
"cfg_strength": cfg_slat,
|
||
},
|
||
)
|
||
else:
|
||
images = [Image.open(p) for p in image]
|
||
outputs = pipeline.run_multi_image(
|
||
images,
|
||
seed=seed,
|
||
sparse_structure_sampler_params={
|
||
"steps": steps_sparse,
|
||
"cfg_strength": cfg_sparse,
|
||
},
|
||
slat_sampler_params={
|
||
"steps": steps_slat,
|
||
"cfg_strength": cfg_slat,
|
||
},
|
||
)
|
||
|
||
if save_video:
|
||
video = render_utils.render_video(outputs["gaussian"][0])["color"]
|
||
imageio.mimsave(os.path.join(out_dir, video_gs_name), video, fps=fps)
|
||
|
||
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
|
||
imageio.mimsave(os.path.join(out_dir, video_rf_name), video, fps=fps)
|
||
|
||
video = render_utils.render_video(outputs["mesh"][0])["normal"]
|
||
imageio.mimsave(os.path.join(out_dir, video_mesh_name), video, fps=fps)
|
||
|
||
glb_path = os.path.join(out_dir, generate_unique_name("sample.glb"))
|
||
ply_path = os.path.join(out_dir, generate_unique_name("sample.ply"))
|
||
|
||
glb = postprocessing_utils.to_glb(
|
||
outputs["gaussian"][0],
|
||
outputs["mesh"][0],
|
||
simplify=simplify,
|
||
texture_size=texture_size,
|
||
)
|
||
|
||
# 先写到临时路径
|
||
glb.export(glb_path)
|
||
outputs["gaussian"][0].save_ply(ply_path)
|
||
|
||
return glb_path, ply_path
|
||
|
||
|
||
"""
|
||
glb_to_obj
|
||
"""
|
||
|
||
|
||
def glb_to_obj(glb_input, obj_output_dir="obj_output"):
|
||
obj_path = os.path.join(obj_output_dir, generate_unique_name("sample.obj"))
|
||
os.makedirs(obj_output_dir, exist_ok=True)
|
||
|
||
cmd = [
|
||
"blender",
|
||
"-b",
|
||
"-P",
|
||
"0_glb_to_obj.py",
|
||
"--",
|
||
"--input",
|
||
glb_input,
|
||
"--output",
|
||
obj_path
|
||
]
|
||
|
||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||
print(result)
|
||
print("\n")
|
||
if result.returncode != 0:
|
||
raise RuntimeError(
|
||
f"Blender failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
||
)
|
||
|
||
return obj_path
|
||
|
||
|
||
"""
|
||
obj_to_step
|
||
"""
|
||
|
||
|
||
def obj_to_step(
|
||
input_obj: str | Path,
|
||
output_dir: str | Path,
|
||
script_path: str | Path = "1_obj_to_step.py",
|
||
timeout: int = 120, # 秒,建议根据实际转换耗时调整
|
||
freecadcmd_path: Optional[str] = None
|
||
) -> bool:
|
||
"""
|
||
把 OBJ 转 STEP 的 FreeCAD 命令封装成函数
|
||
返回 True 表示成功,False 表示失败
|
||
"""
|
||
input_obj = Path(input_obj).resolve()
|
||
output_dir = Path(output_dir).resolve()
|
||
script_path = Path(script_path).resolve()
|
||
|
||
print(f"input_obj : {input_obj}")
|
||
print(f"output_dir : {output_dir}")
|
||
print(f"script_path : {script_path}")
|
||
|
||
if not input_obj.exists():
|
||
raise FileNotFoundError(f"输入文件不存在: {input_obj}")
|
||
if not script_path.exists():
|
||
raise FileNotFoundError(f"转换脚本不存在: {script_path}")
|
||
|
||
# 构造 -c 参数(完全复刻你原来的命令)
|
||
python_code = f'''import sys
|
||
sys.argv = ["{script_path.name}", "{input_obj}", "{output_dir}"]
|
||
exec(open("{script_path}", "r", encoding="utf-8").read())'''
|
||
|
||
cmd = [
|
||
freecadcmd_path or "freecadcmd", # 如果不在 PATH 中,可传入绝对路径
|
||
"-c",
|
||
python_code
|
||
]
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
check=True, # 自动抛出 CalledProcessError
|
||
timeout=timeout,
|
||
cwd=script_path.parent # 让脚本能找到相对路径文件
|
||
)
|
||
|
||
logger.info(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}")
|
||
print(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}")
|
||
if result.stdout:
|
||
logger.debug(result.stdout)
|
||
return result.stdout.split()[-1]
|
||
|
||
except subprocess.TimeoutExpired:
|
||
logger.error(f"FreeCAD 转换超时(>{timeout}s): {input_obj}")
|
||
return False
|
||
except subprocess.CalledProcessError as e:
|
||
logger.error(f"FreeCAD 执行失败: {e.stderr}")
|
||
return False
|
||
except Exception as e:
|
||
logger.exception("未知错误")
|
||
return False
|
||
|
||
|
||
"""
|
||
step to svg
|
||
"""
|
||
|
||
|
||
def autodevice():
|
||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
||
|
||
def read_step_solid(path: str):
|
||
reader = STEPControl_Reader()
|
||
stat = reader.ReadFile(path)
|
||
if stat != 1:
|
||
raise RuntimeError(f"STEP 读取失败: {path}")
|
||
reader.TransferRoots()
|
||
shape = reader.OneShape()
|
||
fixer = ShapeFix_Shape(shape)
|
||
fixer.Perform()
|
||
return fixer.Shape()
|
||
|
||
|
||
def triangulate(shape, deflection=0.5, angle=0.5):
|
||
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
|
||
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
|
||
verts, faces, v_off = [], [], 0
|
||
exp = TopExp_Explorer(shape, TopAbs_FACE)
|
||
while exp.More():
|
||
face = exp.Current()
|
||
loc = TopLoc_Location()
|
||
tri = BRep_Tool.Triangulation(face, loc)
|
||
if tri is not None:
|
||
nb_nodes = tri.NbNodes()
|
||
has_nodes_arr = hasattr(tri, "Nodes")
|
||
for i in range(1, nb_nodes + 1):
|
||
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
|
||
p = p.Transformed(loc.Transformation())
|
||
verts.append([p.X(), p.Y(), p.Z()])
|
||
nb_tris = tri.NbTriangles()
|
||
has_tris_arr = hasattr(tri, "Triangles")
|
||
for i in range(1, nb_tris + 1):
|
||
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
|
||
a, b, c = t.Get()
|
||
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
|
||
v_off += nb_nodes
|
||
exp.Next()
|
||
if not verts or not faces:
|
||
raise RuntimeError("三角化为空,尝试减小 --defl")
|
||
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
|
||
|
||
|
||
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
|
||
c = verts.mean(axis=0, keepdims=True)
|
||
v0 = verts - c
|
||
s = np.max(np.abs(v0))
|
||
s = max(s, 1e-12)
|
||
return v0 / s * unit_scale
|
||
|
||
|
||
def build_mesh(V_np, F_np, device):
|
||
V = torch.from_numpy(V_np).to(device)
|
||
F = torch.from_numpy(F_np).to(device)
|
||
return Meshes(verts=[V], faces=[F])
|
||
|
||
|
||
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
|
||
v0 = verts[faces[:, 0]]
|
||
v1 = verts[faces[:, 1]]
|
||
v2 = verts[faces[:, 2]]
|
||
n = torch.cross(v1 - v0, v2 - v0, dim=1)
|
||
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
|
||
|
||
|
||
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
|
||
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
|
||
thr = math.cos(math.radians(max(0.0, angle_deg)))
|
||
e2f = defaultdict(list)
|
||
F = faces.shape[0]
|
||
for f in range(F):
|
||
i, j, k = faces[f].tolist()
|
||
for a, b in ((i, j), (j, k), (k, i)):
|
||
e = (a, b) if a < b else (b, a)
|
||
e2f[e].append(f)
|
||
feat = []
|
||
for e, fl in e2f.items():
|
||
if len(fl) == 1:
|
||
feat.append(e)
|
||
elif len(fl) == 2:
|
||
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
|
||
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
|
||
if cosv <= thr:
|
||
feat.append(e)
|
||
return feat
|
||
|
||
|
||
def make_camera(view: str, device, dist=2.7):
|
||
"""
|
||
look_at_view_transform spherical params for engineering views.
|
||
"""
|
||
if view == "top":
|
||
# from +Z to origin (view -Z)
|
||
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
|
||
elif view == "left":
|
||
# from -X to origin (view +X)
|
||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
|
||
elif view == "front":
|
||
# from +Y to origin (view -Y)
|
||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
|
||
else:
|
||
raise ValueError("view must be one of ['top','left','front']")
|
||
return OrthographicCameras(R=R, T=T, device=device)
|
||
|
||
|
||
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
|
||
rs = RasterizationSettings(
|
||
image_size=image_size,
|
||
blur_radius=1e-6,
|
||
faces_per_pixel=50,
|
||
cull_backfaces=True
|
||
)
|
||
renderer = MeshRenderer(
|
||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
|
||
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
|
||
)
|
||
with torch.no_grad():
|
||
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
|
||
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
|
||
|
||
|
||
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
|
||
rs = RasterizationSettings(
|
||
image_size=image_size,
|
||
blur_radius=0.0,
|
||
faces_per_pixel=faces_per_pixel,
|
||
cull_backfaces=True
|
||
)
|
||
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
|
||
with torch.no_grad():
|
||
frags: Fragments = rast(mesh, cameras=cameras)
|
||
return frags, frags.zbuf[0, ..., 0]
|
||
|
||
|
||
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
|
||
with torch.no_grad():
|
||
scr = cameras.transform_points_screen(
|
||
points_world[None, ...],
|
||
image_size=((image_size, image_size),)
|
||
)
|
||
return scr[0]
|
||
|
||
|
||
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
|
||
H = W = image_size
|
||
vis, hid = [], []
|
||
scr = project_points_to_screen(cameras, verts_world, image_size)
|
||
zmin = zbuf.detach().cpu().numpy()
|
||
for i, j in edges_idx:
|
||
p0, p1 = scr[i], scr[j]
|
||
m = 0.5 * (p0 + p1)
|
||
x = int(np.clip(m[0].item(), 0, W - 1))
|
||
y = int(np.clip(m[1].item(), 0, H - 1))
|
||
z_proj = m[2].item()
|
||
z_ref = zmin[y, x]
|
||
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
|
||
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
|
||
vis.append(seg)
|
||
else:
|
||
hid.append(seg)
|
||
return vis, hid
|
||
|
||
|
||
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
|
||
contours = measure.find_contours(alpha, level=threshold)
|
||
polys = []
|
||
for cnt in contours:
|
||
if len(cnt) < 8:
|
||
continue
|
||
cnt = cnt[::max(1, step)]
|
||
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
|
||
return polys
|
||
|
||
|
||
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
|
||
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
|
||
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
|
||
|
||
# silhouette fill (optional)
|
||
for poly in silhouettes:
|
||
if len(poly) < 3:
|
||
continue
|
||
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
|
||
for i in range(1, len(poly)):
|
||
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
|
||
path.append("Z")
|
||
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
|
||
|
||
# visible edges
|
||
for (x0, y0, x1, y1) in edges_vis:
|
||
dwg.add(dwg.line(
|
||
(x0 + margin, y0 + margin),
|
||
(x1 + margin, y1 + margin),
|
||
stroke="black",
|
||
stroke_width=stroke
|
||
))
|
||
|
||
# hidden edges
|
||
for (x0, y0, x1, y1) in edges_hid:
|
||
dwg.add(dwg.line(
|
||
(x0 + margin, y0 + margin),
|
||
(x1 + margin, y1 + margin),
|
||
stroke="black",
|
||
stroke_width=max(1.0, 0.75 * stroke),
|
||
stroke_dasharray=[6, 6],
|
||
opacity=0.85
|
||
))
|
||
|
||
dwg.save()
|
||
|
||
|
||
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
|
||
"""
|
||
横向拼接 3 张 SVG,并额外导出 PNG。
|
||
"""
|
||
out_dir = os.path.dirname(output_svg)
|
||
if out_dir:
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
view1 = st.fromfile(view1_path)
|
||
view2 = st.fromfile(view2_path)
|
||
view3 = st.fromfile(view3_path)
|
||
|
||
r1 = view1.getroot()
|
||
r2 = view2.getroot()
|
||
r3 = view3.getroot()
|
||
|
||
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
|
||
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
|
||
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
|
||
|
||
max_w = max(w1, w2, w3)
|
||
max_h = max(h1, h2, h3)
|
||
|
||
combined_width = max_w * 3 + 40
|
||
combined_height = max_h + 20
|
||
|
||
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
|
||
|
||
x_offset = 10
|
||
y_offset = 10
|
||
|
||
r1.moveto(x_offset, y_offset)
|
||
x_offset += max_w + 20
|
||
r2.moveto(x_offset, y_offset)
|
||
x_offset += max_w + 20
|
||
r3.moveto(x_offset, y_offset)
|
||
|
||
combined.append([r1, r2, r3])
|
||
combined.save(output_svg)
|
||
|
||
# SVG -> PNG
|
||
cairosvg.svg2png(url=output_svg, write_to=output_image)
|
||
|
||
|
||
def step_to_svg(step_path, out_dir,
|
||
res: int = 512, defl: float = 1.5,
|
||
ang: float = 65.0, stroke: float = 1.3,
|
||
scale: float = 1.0, no_combine: bool = False,
|
||
combined_svg: str = None, combined_png: str = None):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
device = autodevice()
|
||
print(f"[Device] {device}")
|
||
|
||
# STEP -> mesh
|
||
shape = read_step_solid(step_path)
|
||
V_np, F_np = triangulate(shape, deflection=defl)
|
||
V_np = normalize_mesh_np(V_np, unit_scale=scale)
|
||
|
||
mesh = build_mesh(V_np, F_np, device)
|
||
verts = mesh.verts_packed()
|
||
faces = mesh.faces_packed()
|
||
fn = compute_face_normals(verts, faces)
|
||
feat_edges = extract_feature_edges(faces, fn, angle_deg=ang)
|
||
|
||
view_paths = {}
|
||
for name in ["top", "left", "front"]:
|
||
print(f"[View] {name}")
|
||
cam = make_camera(name, device)
|
||
alpha = render_silhouette(mesh, cam, image_size=res)
|
||
_, zbuf = raster_fragments(mesh, cam, image_size=res, faces_per_pixel=1)
|
||
|
||
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, res, eps=1e-4)
|
||
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
|
||
|
||
out_svg = os.path.join(out_dir, f"{name}.svg")
|
||
svg_from_view(out_svg, res, res, polys, e_vis, e_hid,
|
||
stroke=stroke, margin=10, fill="none")
|
||
view_paths[name] = out_svg
|
||
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
|
||
|
||
if no_combine:
|
||
print("[Done] 3 views exported (no combine).")
|
||
return None
|
||
|
||
combined_svg = combined_svg or os.path.join(out_dir, "combined_views.svg")
|
||
combined_png = combined_png or os.path.join(out_dir, "combined_views.png")
|
||
|
||
# 注意:这里组合顺序按你第二段示例:front / top / left
|
||
combine_svgs(
|
||
view1_path=view_paths["front"],
|
||
view2_path=view_paths["top"],
|
||
view3_path=view_paths["left"],
|
||
output_svg=combined_svg,
|
||
output_image=combined_png
|
||
)
|
||
return combined_svg, combined_png
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# step -1
|
||
glb_result, ply_result = image_to_3d(image=["assets/whiteboard_exported_image.png"])
|
||
print(f"glb_result : {glb_result}", f"ply_result : {ply_result}")
|
||
|
||
# step 0
|
||
obj_result = glb_to_obj(glb_result)
|
||
print(f"obj_result : {obj_result}")
|
||
|
||
# step 1
|
||
step_result = obj_to_step(input_obj=obj_result, output_dir="step_output", script_path="1_obj_to_step.py")
|
||
|
||
# step 2
|
||
out_dir = "svg_output"
|
||
combined_svg, combined_png = step_to_svg(step_path=step_result, out_dir=out_dir)
|
||
print(f"combined_svg : {combined_svg}", f"combined_png : {combined_png}")
|