Files
FiDA-3D-Trellis/server.py
2026-04-13 14:48:52 +08:00

443 lines
14 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import mimetypes
import os
import secrets
import subprocess
import tempfile
import uuid
from datetime import datetime
from io import BytesIO
import imageio
import numpy as np
import trimesh
import litserve as ls
from minio import Minio
from glb2svg import glb_to_obj, obj_to_step, step_to_svg
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
from utils.new_oss_client import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, minio_get_image, upload_local_file, download_from_minio
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def generate_unique_name(original_name: str) -> str:
stem, ext = os.path.splitext(original_name)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_part = secrets.token_hex(4) # 8位随机十六进制
return f"{stem}_{timestamp}_{random_part}{ext}"
def load_mesh(file_path):
"""
加载.obj或.glb文件返回顶点数据
Args:
file_path: 模型文件路径(支持.obj和.glb格式
Returns:
numpy.ndarray: 顶点数组shape为(N, 3)
"""
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.obj':
# 使用trimesh加载obj文件
mesh = trimesh.load(file_path, file_type='obj')
elif file_ext == '.glb' or file_ext == '.gltf':
# 使用trimesh加载glb/gltf文件
mesh = trimesh.load(file_path, file_type='glb')
else:
raise ValueError(f"不支持的文件格式: {file_ext},仅支持.obj和.glb/.gltf")
# 获取顶点数据
if isinstance(mesh, trimesh.Scene):
# 如果是场景,合并所有几何体的顶点
vertices = []
for geom in mesh.geometry.values():
vertices.append(geom.vertices)
vertices = np.vstack(vertices)
else:
# 如果是单个几何体
vertices = mesh.vertices
if len(vertices) == 0:
raise ValueError("文件中未找到顶点数据")
return vertices
def analyze_mesh(file_path):
"""
分析3D模型文件计算质心、边界框、尺寸等信息
Args:
file_path: 模型文件路径(支持.obj和.glb格式
Returns:
dict: 包含模型分析信息的字典
"""
# 加载模型并获取顶点
vertices = load_mesh(file_path)
# 边界框(每个轴的最小/最大值)
min_coords = vertices.min(axis=0)
max_coords = vertices.max(axis=0)
# 质心
centroid = vertices.mean(axis=0)
# 尺寸 = 边界框维度
size = max_coords - min_coords
# 计算尺寸比例(每个轴占总尺寸的比例)
total_size = np.sum(size)
size_ratio = size / total_size if total_size != 0 else [0, 0, 0]
info = {
# "file_path": file_path,
"file_format": os.path.splitext(file_path)[1].lower(),
"vertex_count": len(vertices),
"centroid": centroid.tolist(),
"bounding_box_min": min_coords.tolist(),
"bounding_box_max": max_coords.tolist(),
"size": size.tolist(),
"size_ratio": size_ratio.tolist(),
"size_ratio_percentage": (size_ratio * 100).tolist()
}
return info
def render_glb_preview(glb_path, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cmd = [
"blender",
"--background",
"--python",
"render_model.py",
"--",
glb_path,
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(
f"Blender render failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
return output_path
class TrellisAPI(ls.LitAPI):
def setup(self, device):
os.environ.setdefault("SPCONV_ALGO", "native")
self.pipeline = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
self.pipeline.to(device)
def decode_request(self, request):
image_paths = request["image_paths"]
bucket_name = request["bucket_name"]
user_id = request["user_id"]
images = []
for path in image_paths:
bucket, object_name = path.split('/', 1)
image = minio_get_image(minio_client, bucket, object_name)
images.append(image)
params = {
"file_name": uuid.uuid4().hex,
"model": request.get("model", "single"),
"seed": request.get("seed", 1),
"steps_sparse": request.get("steps_sparse", 12),
"cfg_sparse": request.get("cfg_sparse", 7.5),
"steps_slat": request.get("steps_slat", 12),
"cfg_slat": request.get("cfg_slat", 3.0),
"simplify": request.get("simplify", 0.95),
"texture_size": request.get("texture_size", 1024),
"fps": request.get("fps", 30),
"bucket_name": bucket_name,
"user_id": user_id
}
return images, params
def predict(self, inputs):
images, params = inputs
if params["model"] == "single":
outputs = self.pipeline.run(
images[0],
seed=params["seed"],
sparse_structure_sampler_params={
"steps": params["steps_sparse"],
"cfg_strength": params["cfg_sparse"],
},
slat_sampler_params={
"steps": params["steps_slat"],
"cfg_strength": params["cfg_slat"],
},
)
else:
outputs = self.pipeline.run_multi_image(
images,
seed=params['seed'],
sparse_structure_sampler_params={
"steps": params['steps_sparse'],
"cfg_strength": params['cfg_sparse'],
},
slat_sampler_params={
"steps": params['steps_slat'],
"cfg_strength": params['cfg_slat'],
},
)
# video_path = self.upload_video(outputs, params)
minio_glb_path, local_glb_path = self.upload_glb(outputs, params)
glb_info = analyze_mesh(local_glb_path)
static_model_image = self.get_static_model_image(model_path=local_glb_path, params=params)
return {
"glb_path": minio_glb_path,
"glb_static_img_path": static_model_image,
"glb_info": glb_info,
}
def encode_response(self, output):
return output
# def upload_video(self, outputs, params):
# gaussian_name = f"3d_result/video/{params['file_name']}-gaussian.mp4"
# radiance_field_name = f"3d_result/video/{params['file_name']}-radiance_field.mp4"
# mesh_name = f"3d_result/video/{params['file_name']}-mesh.mp4"
#
# # gaussian video
# video = render_utils.render_video(outputs["gaussian"][0])["color"]
# buffer = BytesIO()
# imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
# gaussian_video_path = upload_bytes(
# buffer.getvalue(),
# gaussian_name,
# "video/mp4",
# )
#
# # radiance field video
# video = render_utils.render_video(outputs["radiance_field"][0])["color"]
# buffer = BytesIO()
# imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
# radiance_field_video_path = upload_bytes(
# buffer.getvalue(),
# radiance_field_name,
# "video/mp4",
# )
#
# # mesh video
# video = render_utils.render_video(outputs["mesh"][0])["normal"]
# buffer = BytesIO()
# imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
# mesh_path = upload_bytes(
# buffer.getvalue(),
# mesh_name,
# "video/mp4",
# )
#
# return {
# "gaussian": gaussian_video_path,
# "radiance_field": radiance_field_video_path,
# "mesh": mesh_path
# }
def upload_glb(self, outputs, params):
minio_path = f"{params['bucket_name']}/{params['user_id']}/3d_result/{params['file_name']}.glb"
local_glb_path = os.path.join("glb_output", generate_unique_name("sample.glb"))
out_dir = os.path.dirname(local_glb_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=params['simplify'],
texture_size=params['texture_size'],
)
glb.export(
file_obj=local_glb_path,
file_type="glb"
)
glb_path = upload_local_file(
file_path=local_glb_path,
minio_path=minio_path,
content_type="application/octet-stream"
)
return glb_path, local_glb_path
# def upload_ply(self, outputs, params):
# file_name = f"3d_result/ply/{params['file_name']}.ply"
#
# with tempfile.NamedTemporaryFile(suffix=".ply") as tmp:
# outputs["gaussian"][0].save_ply(tmp.name)
# tmp.seek(0)
#
# ply_path = upload_bytes(
# tmp.read(),
# file_name,
# "application/octet-stream",
# )
# return {"ply": ply_path}
def get_static_model_image(self, model_path, params):
local_static_model_image_path = os.path.join("glb_output", generate_unique_name("static_model_image.png"))
print(f"model_path : {model_path}")
print(f"local_static_model_image_path :{local_static_model_image_path}")
output_path = render_glb_preview(model_path, local_static_model_image_path)
static_model_image = self.upload_local_file(
output_path,
params['bucket_name'],
params['user_id'],
)
print(f"Saved to {static_model_image}")
return static_model_image
def upload_local_file(self, local_path, bucket_name, user_id):
"""
通用上传函数:支持 SVG, PNG, OBJ 等
"""
object_name = f"{user_id}/3d_result/{uuid.uuid4().hex}.png"
if not os.path.exists(local_path):
print(f"错误: 文件 {local_path} 不存在")
return None
# 自动根据后缀名识别 Content-Type
# 例如: .svg -> image/svg+xml, .png -> image/png
content_type, _ = mimetypes.guess_type(local_path)
if content_type is None:
content_type = "application/octet-stream"
try:
minio_client.fput_object(
bucket_name=bucket_name,
object_name=object_name,
file_path=local_path,
content_type=content_type
)
print(f"成功上传 [{content_type}]: {object_name}")
return f"{bucket_name}/{object_name}"
except Exception as e:
print(f"上传失败: {e}")
return None
class ModelToThreeViews(ls.LitAPI):
def setup(self, device):
pass
def upload_local_file(self, local_path, type, bucket_name, user_id):
"""
通用上传函数:支持 SVG, PNG, OBJ 等
"""
object_name = f"{user_id}/3d_result/{type}/{uuid.uuid4().hex}.{type}"
if not os.path.exists(local_path):
print(f"错误: 文件 {local_path} 不存在")
return None
# 自动根据后缀名识别 Content-Type
# 例如: .svg -> image/svg+xml, .png -> image/png
content_type, _ = mimetypes.guess_type(local_path)
if content_type is None:
content_type = "application/octet-stream"
try:
minio_client.fput_object(
bucket_name=bucket_name,
object_name=object_name,
file_path=local_path,
content_type=content_type
)
print(f"成功上传 [{content_type}]: {object_name}")
return f"{bucket_name}/{object_name}"
except Exception as e:
print(f"上传失败: {e}")
return None
def predict(self, request):
minio_glb_path = request['minio_glb_path']
bucket_name = request["bucket_name"]
user_id = request["user_id"]
work_dir = f"glb_to_obj"
os.makedirs(work_dir, exist_ok=True)
glb_path = os.path.join(work_dir, f"model{uuid.uuid4().hex}.glb")
step_dir = os.path.join(work_dir, "step")
svg_dir = os.path.join(work_dir, "svg")
os.makedirs(step_dir, exist_ok=True)
os.makedirs(svg_dir, exist_ok=True)
print(f"""
入参阶段:
input glb-obj minio-path:{minio_glb_path},\n
work_dir : {work_dir},glb_path : {glb_path},step_dir : {step_dir},svg_dir : {svg_dir}\n
""")
print("=" * 10)
print(f" 第一阶段 下载glb文件: ")
# 1 下载
glb_result = download_from_minio(object_path=minio_glb_path, local_path=glb_path)
print(f" 下载结果 : {glb_result} \n")
print("=" * 10)
print(f" 第二阶段 glb -> obj: ")
# 2 glb -> obj
obj_result = glb_to_obj(glb_result)
print(f" glb -> obj 结果 : {obj_result} \n")
print("=" * 10)
print(f" 第三阶段 obj -> step: ")
# 3 obj -> step
step_result = obj_to_step(
input_obj=obj_result,
output_dir=step_dir,
script_path="1_obj_to_step.py"
)
print(f" obj -> step 结果 : {step_result} \n")
print("=" * 10)
print(f" 第四阶段 step -> svg: ")
# 4 step -> svg
combined_svg, combined_png = step_to_svg(
step_path=step_result,
out_dir=svg_dir
)
print(f" step -> svg 结果 : {combined_svg} \n")
print("=" * 10)
# 5 上传
minio_svg_path = self.upload_local_file(combined_png, "svg", bucket_name, user_id)
return {"minio_svg_path": minio_svg_path}
if __name__ == "__main__":
trellis_api = TrellisAPI(api_path="/canvas/img_to_3D")
model_to_three_api = ModelToThreeViews(api_path="/canvas/3d_to_3views")
server = ls.LitServer([
trellis_api,
model_to_three_api])
server.run(port=8122)