This commit is contained in:
zcr
2026-03-17 11:28:52 +08:00
commit 59570f8812
45 changed files with 5308 additions and 0 deletions

437
.gitignore vendored Normal file
View File

@@ -0,0 +1,437 @@
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
# User-specific files
*.rsuser
*.suo
*.user
*.userosscache
*.sln.docstates
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Mono auto generated files
mono_crash.*
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
[Ww][Ii][Nn]32/
[Aa][Rr][Mm]/
[Aa][Rr][Mm]64/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
[Ll]ogs/
# Visual Studio 2015/2017 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# Visual Studio 2017 auto generated files
Generated\ Files/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUnit
*.VisualState.xml
TestResult.xml
nunit-*.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# Benchmark Results
BenchmarkDotNet.Artifacts/
# .NET Core
project.lock.json
project.fragment.lock.json
artifacts/
# ASP.NET Scaffolding
ScaffoldingReadMe.txt
# StyleCop
StyleCopReport.xml
# Files built by Visual Studio
*_i.c
*_p.c
*_h.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
*.pdb
*.ipdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*_wpftmp.csproj
*.log
*.tlog
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# Visual Studio Trace Files
*.e2e
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# AxoCover is a Code Coverage Tool
.axoCover/*
!.axoCover/settings.json
# Coverlet is a free, cross platform Code Coverage Tool
coverage*.json
coverage*.xml
coverage*.info
# Visual Studio code coverage results
*.coverage
*.coveragexml
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# Note: Comment the next line if you want to checkin your web deploy settings,
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# NuGet Symbol Packages
*.snupkg
# The packages folder can be ignored because of Package Restore
**/[Pp]ackages/*
# except build/, which is used as an MSBuild target.
!**/[Pp]ackages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/[Pp]ackages/repositories.config
# NuGet v3's project.json files produces more ignorable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
*.appx
*.appxbundle
*.appxupload
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!?*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
orleans.codegen.cs
# Including strong name files can present a security risk
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
#*.snk
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
ServiceFabricBackup/
*.rptproj.bak
# SQL Server files
*.mdf
*.ldf
*.ndf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
*.rptproj.rsuser
*- [Bb]ackup.rdl
*- [Bb]ackup ([0-9]).rdl
*- [Bb]ackup ([0-9][0-9]).rdl
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw
# Visual Studio 6 auto-generated project file (contains which files were open etc.)
*.vbp
# Visual Studio 6 workspace and project file (working project files containing files to include in project)
*.dsw
*.dsp
# Visual Studio 6 technical files
*.ncb
*.aps
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# CodeRush personal settings
.cr/personal
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
# Cake - Uncomment if you are using it
# tools/**
# !tools/packages.config
# Tabs Studio
*.tss
# Telerik's JustMock configuration file
*.jmconfig
# BizTalk build output
*.btp.cs
*.btm.cs
*.odx.cs
*.xsd.cs
# OpenCover UI analysis results
OpenCover/
# Azure Stream Analytics local run output
ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
# MFractors (Xamarin productivity tool) working folder
.mfractor/
# Local History for Visual Studio
.localhistory/
# Visual Studio History (VSHistory) files
.vshistory/
# BeatPulse healthcheck temp database
healthchecksdb
# Backup folder for Package Reference Convert tool in Visual Studio 2017
MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
# Fody - auto-generated XML schema
FodyWeavers.xsd
# VS Code files for those working on multiple tools
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
# Local History for Visual Studio Code
.history/
# Windows Installer files from build outputs
*.cab
*.msi
*.msix
*.msm
*.msp
# JetBrains Rider
*.sln.iml
*.glb
*.ply
*.mtl
*.step
*.mp4
*.svg
*.png
*.ipynb
*.zip
# 模型文件(重点忽略)
*.pth
*.ckpt
*.bin
*.pt
*.h5
# 压缩包
*.zip
*.rar
*.7z
*.tar.gz
# 图片/视频非必要则忽略必要则后续用LFS
*.png
*.svg
*.jpg
*.jpeg
*.mp4
*.avi
# 日志/缓存/临时文件
*.log
*.tmp
__pycache__/
*.pyc
.DS_Store

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "trellis/representations/mesh/flexicubes"]
path = trellis/representations/mesh/flexicubes
url = https://github.com/MaxtirError/FlexiCubes.git

10
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,10 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 已忽略包含查询文件的默认文件夹
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/

68
0_glb_to_obj.py Normal file
View File

@@ -0,0 +1,68 @@
import bpy
import sys
import os
import argparse
glb_input = "trellis_out/sample.ply"
obj_output = "obj_output/sample.obj"
def parse_args():
argv = sys.argv
argv = argv[argv.index("--") + 1:] if "--" in argv else []
p = argparse.ArgumentParser()
p.add_argument("-i", "--input", default=True)
p.add_argument("-o", "--output", required=True)
return p.parse_args(argv)
def clean():
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.delete(use_global=False)
def main():
args = parse_args()
clean()
bpy.ops.import_scene.gltf(filepath=args.input) # 支持 .glb/.gltf
# 只保留 mesh并合并成一个符合你“单资产”假设
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
if not meshes:
raise RuntimeError("No mesh objects found in GLB/GLTF.")
bpy.ops.object.select_all(action='DESELECT')
for o in meshes:
o.select_set(True)
bpy.context.view_layer.objects.active = meshes[0]
if len(meshes) > 1:
bpy.ops.object.join()
obj = bpy.context.view_layer.objects.active
obj.select_set(True)
out_dir = os.path.dirname(args.output)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
# bpy.ops.wm.obj_export(
# filepath=args.output,
# export_selected_objects=True,
# export_normals=True,
# export_uv=True,
# )
bpy.ops.export_scene.obj(
filepath=args.output,
use_selection=True, # 只导出选中的对象
use_normals=True,
use_uvs=True,
use_materials=True, # 根据需要改(你原来没导出材质)
axis_forward='-Z', # glTF 通常是 -Z forwardY up可根据需要调整
axis_up='Y',
# keep_vertex_order=True, # 可选:保持顶点顺序
)
if __name__ == "__main__":
main()

260
0_remesh.py Normal file
View File

@@ -0,0 +1,260 @@
import bpy
import sys
import os
import argparse
import tempfile
def clean_scene():
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.delete(use_global=False)
def _parse_args_blender(argv):
"""Blender 参数在 `--` 之后。"""
if "--" in argv:
return argv[argv.index("--") + 1:]
return []
def _is_glb_gltf(path: str) -> bool:
ext = os.path.splitext(path)[1].lower()
return ext in [".glb", ".gltf"]
def _is_obj(path: str) -> bool:
return os.path.splitext(path)[1].lower() == ".obj"
def import_mesh_any(input_path: str):
"""
导入 OBJ 或 GLB/GLTF并将场景中所有 MESH 合并为一个 active mesh 对象返回。
"""
clean_scene()
if _is_glb_gltf(input_path):
print(f"[import] GLB/GLTF: {input_path}")
bpy.ops.import_scene.gltf(filepath=input_path)
elif _is_obj(input_path):
print(f"[import] OBJ: {input_path}")
bpy.ops.wm.obj_import(filepath=input_path)
else:
raise RuntimeError(f"Unsupported input format: {input_path}")
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
if not meshes:
raise RuntimeError("No mesh objects found after import.")
bpy.ops.object.select_all(action='DESELECT')
for o in meshes:
o.select_set(True)
bpy.context.view_layer.objects.active = meshes[0]
if len(meshes) > 1:
bpy.ops.object.join()
obj = bpy.context.view_layer.objects.active
obj.select_set(True)
return obj
def export_obj_selected(output_obj: str):
out_dir = os.path.dirname(output_obj)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
bpy.ops.wm.obj_export(
filepath=output_obj,
export_selected_objects=True,
export_normals=True,
export_uv=True,
)
def fix_mesh(obj):
"""修复网格使其更符合 QuadriFlow 的输入要求"""
bpy.context.view_layer.objects.active = obj
obj.select_set(True)
bpy.ops.object.mode_set(mode='EDIT')
bpy.ops.mesh.select_all(action='SELECT')
bpy.ops.mesh.remove_doubles(threshold=0.001)
bpy.ops.mesh.delete_loose()
bpy.ops.mesh.select_all(action='DESELECT')
bpy.ops.mesh.select_non_manifold()
selected_count = sum(1 for v in obj.data.vertices if v.select)
if selected_count > 0:
print(f"[fix_mesh] found {selected_count} non-manifold verts, deleting...")
bpy.ops.mesh.delete(type='VERT')
bpy.ops.mesh.select_all(action='SELECT')
bpy.ops.mesh.normals_make_consistent(inside=False)
bpy.ops.mesh.select_all(action='SELECT')
bpy.ops.mesh.fill_holes(sides=0)
bpy.ops.mesh.remove_doubles(threshold=0.001)
bpy.ops.mesh.dissolve_degenerate(threshold=0.0001)
bpy.ops.mesh.select_all(action='SELECT')
bpy.ops.mesh.normals_make_consistent(inside=False)
bpy.ops.object.mode_set(mode='OBJECT')
def quadriflow_remesh_obj(
input_obj: str,
output_obj: str,
face_count: int = 8000,
use_mesh_symmetry: bool = False,
use_preserve_sharp: bool = False,
use_preserve_boundary: bool = True,
use_voxel_preprocess: bool = True,
voxel_size: float = 0.008,
):
"""只处理 OBJ 输入(你的原流程)"""
obj = import_mesh_any(input_obj) # 这里会走 OBJ import
print(f"[import] object={obj.name}, verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
print("[mesh] fixing...")
fix_mesh(obj)
if use_voxel_preprocess:
print(f"[voxel] preprocess voxel_size={voxel_size} ...")
voxel_mod = obj.modifiers.new(name="Voxel", type='REMESH')
voxel_mod.mode = 'VOXEL'
voxel_mod.voxel_size = voxel_size
voxel_mod.use_smooth_shade = True
voxel_mod.use_remove_disconnected = False
bpy.ops.object.modifier_apply(modifier=voxel_mod.name)
print(f"[voxel] after: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
print("[mesh] fixing after voxel...")
fix_mesh(obj)
print(f"[mesh] after fix2: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
print(f"[quadriflow] target_faces={face_count} ...")
bpy.context.view_layer.objects.active = obj
obj.select_set(True)
try:
bpy.ops.object.quadriflow_remesh(
use_mesh_symmetry=use_mesh_symmetry,
use_preserve_sharp=use_preserve_sharp,
use_preserve_boundary=use_preserve_boundary,
smooth_normals=False,
mode='FACES',
target_faces=face_count,
seed=0
)
print("[quadriflow] done.")
except Exception as e:
print(f"[quadriflow] error: {e} (continue to export)")
print(f"[export] {output_obj}")
export_obj_selected(output_obj)
print(f"[done] verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
def glb_to_tmp_obj(input_glb: str) -> str:
"""
在 Blender 内把 GLB/GLTF 导入后导出为临时 OBJ按你的要求“先转OBJ再走后续”
"""
obj = import_mesh_any(input_glb) # 会走 gltf import
tmp_obj = tempfile.mktemp(suffix=".obj")
print(f"[glb2obj] export tmp obj -> {tmp_obj}")
export_obj_selected(tmp_obj)
return tmp_obj
def run_pipeline(
input_path: str,
output_obj: str,
face_count: int,
mesh_symmetry: bool,
preserve_sharp: bool,
preserve_boundary: bool,
no_voxel_preprocess: bool,
voxel_size: float,
):
"""
统一入口:
- 输入 OBJ直接 remesh
- 输入 GLB/GLTF先转临时 OBJ再 remesh
"""
tmp_obj = None
try:
if _is_glb_gltf(input_path):
tmp_obj = glb_to_tmp_obj(input_path)
quadriflow_remesh_obj(
input_obj=tmp_obj,
output_obj=output_obj,
face_count=face_count,
use_mesh_symmetry=mesh_symmetry,
use_preserve_sharp=preserve_sharp,
use_preserve_boundary=preserve_boundary,
use_voxel_preprocess=(not no_voxel_preprocess),
voxel_size=voxel_size,
)
elif _is_obj(input_path):
quadriflow_remesh_obj(
input_obj=input_path,
output_obj=output_obj,
face_count=face_count,
use_mesh_symmetry=mesh_symmetry,
use_preserve_sharp=preserve_sharp,
use_preserve_boundary=preserve_boundary,
use_voxel_preprocess=(not no_voxel_preprocess),
voxel_size=voxel_size,
)
else:
raise RuntimeError(f"Unsupported input format: {input_path}")
finally:
if tmp_obj and os.path.exists(tmp_obj):
try:
os.remove(tmp_obj)
print(f"[cleanup] removed tmp obj: {tmp_obj}")
except Exception:
pass
def main():
parser = argparse.ArgumentParser(description="GLB/GLTF/OBJ -> (optional tmp OBJ) -> QuadriFlow remesh -> OBJ")
parser.add_argument("-i", "--input", default="trellis_out/sample.glb", help="Input file path (.obj/.glb/.gltf).")
parser.add_argument("-o", "--output", default="output/test.obj", help="Output OBJ file path.")
parser.add_argument("--face_count", type=int, default=8000, help="Target quad face count.")
parser.add_argument("--mesh_symmetry", action="store_true", help="Enable mesh symmetry.")
parser.add_argument("--preserve_sharp", action="store_true", help="Preserve sharp edges.")
parser.add_argument("--preserve_boundary", action="store_true", help="Preserve boundary edges.")
parser.add_argument("--no_voxel_preprocess", action="store_true", help="Disable voxel preprocess.")
parser.add_argument("--voxel_size", type=float, default=0.008, help="Voxel size for preprocess (smaller=denser).")
args = parser.parse_args(_parse_args_blender(sys.argv))
if not os.path.exists(args.input):
raise FileNotFoundError(f"Input not found: {args.input}")
run_pipeline(
input_path=args.input,
output_obj=args.output,
face_count=args.face_count,
mesh_symmetry=args.mesh_symmetry,
preserve_sharp=args.preserve_sharp,
preserve_boundary=args.preserve_boundary,
no_voxel_preprocess=args.no_voxel_preprocess,
voxel_size=args.voxel_size,
)
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"[fatal] {e}")
import traceback
traceback.print_exc()
raise

423
1_obj_to_step.py Normal file
View File

@@ -0,0 +1,423 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fast OBJ/GLB/GLTF -> STEP converter for FreeCAD (headless-friendly)
参数说明:
tol : 网格→形状拟合公差,越大越快(默认 0.5
sew_tol : Sewing 缝合公差(环境缺少 Sewing 会自动跳过),默认 0.05
solid : 1 尝试固化为 Solid耗时大0 仅导出壳(默认 0
split : 1 使用 trimesh 分块清理(若环境有 trimesh0 关闭(默认 1
max_comp : 分块后最多保留的组件数(默认 64
注意:
- 若 FreeCAD 发行版缺少 MeshPart.meshToShape 或 Part.Sewing脚本会自动回退/跳过。
- 若导出 STEP 仍失败,会兜底导出 .brep方便后续用 pythonocc/OCP 再转 STEP。
"""
import secrets
import sys, os, time, tempfile
from datetime import datetime
import FreeCAD as App
import Part, Mesh
# 尝试加载 MeshPart某些构建可能没有
try:
import MeshPart
HAS_MESHPART = True
except Exception:
MeshPart = None
HAS_MESHPART = False
# 可选trimesh 用于更稳的分块/清理与 GLB/GLTF 转临时 OBJ
try:
import trimesh
HAS_TRIMESH = True
except Exception:
HAS_TRIMESH = False
def log(msg):
try:
App.Console.PrintMessage(str(msg) + "\n")
except Exception:
print(msg, flush=True)
# ---------------------- 网格清理/转换核心 ----------------------
def _freecad_mesh_quick_clean(fc_mesh):
"""FreeCAD 轻量清理:不大改拓扑,提升稳健性。"""
for fn in (
"removeDuplicatedFacets",
"removeDuplicatedPoints",
"removeDegeneratedFacets",
"removeNonManifoldEdges",
):
try:
getattr(fc_mesh, fn)()
except Exception:
pass
return fc_mesh
def _mesh_to_shape_fast(fc_mesh, tol=0.5, sew_tol=0.05, to_solid=False):
"""
Mesh -> B-Rep更稳健版本
- 若有 MeshPart.meshToShape 优先使用(共面聚合,通常更快)
- 否则退回 Part.Shape.makeShapeFromMesh
- Sewing 若不可用则跳过(部分构建缺少 Part.Sewing
- 仅在 to_solid=True 且闭壳时尝试 Solidify耗时较大
"""
# 1) mesh -> shape
shape = None
if HAS_MESHPART and hasattr(MeshPart, "meshToShape"):
try:
shape = MeshPart.meshToShape(fc_mesh, tol)
if isinstance(shape, tuple): # 某些版本返回 (shape, mapping)
shape = shape[0]
except Exception as e:
log(f"meshToShape 失败,退回 makeShapeFromMesh: {e}")
if shape is None:
shape = Part.Shape()
shape.makeShapeFromMesh(fc_mesh.Topology, tol)
# 2) sewing若 Part.Sewing 不存在则跳过)
try:
SewingCls = getattr(Part, "Sewing", None)
if SewingCls is not None:
sew = SewingCls()
ok = False
for setter in ("tolerance", "SetTolerance"):
try:
if setter == "tolerance":
sew.tolerance = sew_tol
else:
sew.SetTolerance(sew_tol)
ok = True
break
except Exception:
continue
sew.add(shape)
sew.perform()
new_shape = None
if hasattr(sew, "SewedShape"):
new_shape = sew.SewedShape
try:
if hasattr(new_shape, "isNull") and new_shape.isNull():
new_shape = None
except Exception:
pass
if new_shape is None and hasattr(sew, "sewedShape"):
try:
new_shape = sew.sewedShape()
except Exception:
pass
if new_shape is None and hasattr(sew, "sewShape"):
try:
new_shape = sew.sewShape()
except Exception:
pass
if new_shape is not None:
shape = new_shape
else:
log("跳过 Sewing当前 FreeCAD Part 模块缺少 Sewing 绑定")
except Exception as e:
log(f"Sewing 失败,继续原 shape: {e}")
# 3) 可选固化
if to_solid:
try:
if not getattr(shape, "Solids", []):
shell = Part.Shell(shape.Faces)
if hasattr(shell, "isClosed") and shell.isClosed():
shape = Part.makeSolid(shell)
except Exception as e:
log(f"Solidify 失败,保留壳: {e}")
# 4) 轻量拓扑精简
try:
shape = shape.removeSplitter()
except Exception:
pass
try:
if hasattr(shape, "fixTolerance"):
shape.fixTolerance(sew_tol)
except Exception:
pass
return shape
def _export_shapes_to_step(shapes, step_file_path):
"""
稳健导出:
- 确保输出目录存在
- 过滤无效 shape
- 尝试合并 Compound有时更稳定
- Part.export 失败则 Import.export再不行兜底写 .brep
"""
out_dir = os.path.dirname(step_file_path) or "."
os.makedirs(out_dir, exist_ok=True)
valid = []
for s in shapes:
try:
if hasattr(s, "isNull") and s.isNull():
continue
if hasattr(s, "Faces") and len(s.Faces) == 0:
continue
valid.append(s)
except Exception:
continue
if not valid:
raise RuntimeError("没有可导出的有效 shape可能网格太脏或全部构造失败")
try:
compound = Part.makeCompound(valid)
except Exception:
compound = valid[0]
doc = App.newDocument("Obj2StepFast")
try:
if hasattr(doc, "suppressRecompute"):
doc.suppressRecompute()
obj = doc.addObject("Part::Feature", "Compound")
obj.Shape = compound
if hasattr(doc, "recompute"):
doc.recompute()
try:
Part.export([obj], step_file_path)
return
except Exception as e1:
log(f"Part.export 失败:{e1}")
try:
import Import
Import.export([obj], step_file_path)
return
except Exception as e2:
log(f"Import.export 也失败:{e2}")
brep_path = os.path.splitext(step_file_path)[0] + ".brep"
try:
obj.Shape.exportBrep(brep_path)
raise RuntimeError(
f"STEP 导出失败,已兜底写入 BREP{brep_path}\n"
f"可用 pythonocc/OCP 将 BREP 转为 STEP。"
)
except Exception as e3:
raise RuntimeError(f"STEP 与 BREP 均导出失败:{e3}")
finally:
App.closeDocument(doc.Name)
# ---------------------- 顶层转换逻辑 ----------------------
def _glb_gltf_to_tmp_obj_if_needed(in_path):
ext = os.path.splitext(in_path)[1].lower()
if ext in (".glb", ".gltf"):
if not HAS_TRIMESH:
raise RuntimeError("输入为 GLB/GLTF但 FreeCAD Python 环境未安装 trimesh。请安装后再试。")
mesh = trimesh.load(in_path, force="mesh")
if mesh.is_empty:
raise RuntimeError(f"GLB/GLTF 读取失败或为空:{in_path}")
tmp = tempfile.mktemp(suffix=".obj")
mesh.export(tmp)
return tmp, True
return in_path, False
# 新增:生成带时间戳+随机串的唯一文件名(避免覆盖)
def generate_unique_step_filename(base_name: str, output_dir: str) -> str:
"""
生成格式示例:
model_20260316_131245_8f4a2c1d.step
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_part = secrets.token_hex(4) # 8 位随机十六进制
filename = f"{base_name}_{timestamp}_{random_part}.step"
return os.path.join(output_dir, filename)
def obj_to_step_fast(input_path,
step_file_path=None,
tol=0.5,
sew_tol=0.05,
to_solid=False,
split_components=True,
max_components=64):
t0 = time.perf_counter()
if not os.path.exists(input_path):
raise FileNotFoundError(input_path)
base = os.path.splitext(os.path.basename(input_path))[0]
# ── 决定最终输出路径 ──
if step_file_path is None:
# 没有指定路径 → 用 output_dir 或 input 同目录 + 唯一文件名
out_dir = os.path.dirname(input_path)
step_file_path = generate_unique_step_filename(base, out_dir)
elif os.path.isdir(step_file_path):
# 用户传的是目录 → 在该目录生成唯一文件名
step_file_path = generate_unique_step_filename(base, step_file_path)
else:
# 用户明确指定了 .step / .stp 文件路径 → 直接使用(不加随机后缀)
if not step_file_path.lower().endswith((".step", ".stp")):
step_file_path += ".step"
# 这里不做额外唯一性处理,尊重用户意图(可能想覆盖)
log(f"目标 STEP 文件:{step_file_path}")
# 下面部分保持原样 ........................................
shapes = []
tmp_created_paths = []
src_for_freecad, is_tmp = _glb_gltf_to_tmp_obj_if_needed(input_path)
if is_tmp:
tmp_created_paths.append(src_for_freecad)
if HAS_TRIMESH and split_components:
m = None
try:
m = trimesh.load(src_for_freecad, force="mesh")
except Exception as e:
log(f"trimesh 读取失败(转纯 FreeCAD 路径):{e}")
m = None
if m is not None and not m.is_empty:
try:
m.remove_duplicate_faces()
m.remove_degenerate_faces()
m.remove_unreferenced_vertices()
m.fix_normals()
except Exception:
pass
parts = m.split(only_watertight=False)
if len(parts) == 0:
parts = [m]
if len(parts) > max_components:
parts = sorted(parts, key=lambda x: x.faces.shape[0], reverse=True)[:max_components]
log(f"Trimesh分块{len(parts)} 个组件")
for i, pm in enumerate(parts):
tmp_obj = tempfile.mktemp(suffix=f"_{i}.obj")
pm.export(tmp_obj)
tmp_created_paths.append(tmp_obj)
fc_mesh = Mesh.Mesh(tmp_obj)
_freecad_mesh_quick_clean(fc_mesh)
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
shapes.append(shp)
else:
fc_mesh = Mesh.Mesh(src_for_freecad)
log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}")
_freecad_mesh_quick_clean(fc_mesh)
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
shapes.append(shp)
else:
fc_mesh = Mesh.Mesh(src_for_freecad)
log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}")
_freecad_mesh_quick_clean(fc_mesh)
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
shapes.append(shp)
_export_shapes_to_step(shapes, step_file_path)
for p in tmp_created_paths:
try:
os.remove(p)
except Exception:
pass
t1 = time.perf_counter()
log(f"STEP导出: {step_file_path} (耗时 {t1 - t0:.2f}s")
return step_file_path
def main(obj_file_path,
output_dir=None,
tol=0.5,
sew_tol=0.05,
to_solid=False,
split_components=True,
max_components=64):
if output_dir is None:
output_dir = os.path.dirname(obj_file_path)
os.makedirs(output_dir, exist_ok=True)
base = os.path.splitext(os.path.basename(obj_file_path))[0]
# 使用唯一文件名生成 step_path
step_path = generate_unique_step_filename(base, output_dir)
return obj_to_step_fast(
obj_file_path,
step_file_path=step_path, # 这里传唯一路径
tol=tol,
sew_tol=sew_tol,
to_solid=to_solid,
split_components=split_components,
max_components=max_components,
)
# ── CLI 部分也做相应调整 ──
def _parse_cli(argv):
in_path = os.path.abspath(argv[1])
out_arg = os.path.abspath(argv[2]) if len(argv) > 2 else None
tol = float(argv[3]) if len(argv) > 3 else 0.5
sew_tol = float(argv[4]) if len(argv) > 4 else 0.05
solid = bool(int(argv[5])) if len(argv) > 5 else False
split = bool(int(argv[6])) if len(argv) > 6 else True
max_comp = int(argv[7]) if len(argv) > 7 else 64
return in_path, out_arg, tol, sew_tol, solid, split, max_comp
# ---------------------- CLI entry ----------------------
if __name__ == "__main__":
in_path, out_arg, tol, sew_tol, solid, split, max_comp = _parse_cli(sys.argv)
# 判断第二个参数是目录还是 step 文件
out_is_step = out_arg.lower().endswith((".step", ".stp"))
if (not out_is_step) or os.path.isdir(out_arg):
out_dir = out_arg
os.makedirs(out_dir, exist_ok=True)
result = main(
in_path,
output_dir=out_dir,
tol=tol,
sew_tol=sew_tol,
to_solid=solid,
split_components=split,
max_components=max_comp,
)
else:
os.makedirs(os.path.dirname(out_arg), exist_ok=True)
result = obj_to_step_fast(
in_path,
step_file_path=out_arg,
tol=tol,
sew_tol=sew_tol,
to_solid=solid,
split_components=split,
max_components=max_comp,
)
log("完成!")
log(f"STEP: {result}")

386
2_step_to_svg.py Normal file
View File

@@ -0,0 +1,386 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
Then combine 3 SVGs into a single SVG and a PNG.
Views:
- top : from +Z towards origin (view direction -Z)
- left : from -X towards origin (view direction +X)
- front: from +Y towards origin (view direction -Y)
Outputs (default in --out_dir):
- top.svg, left.svg, front.svg
- combined_views.svg
- combined_views.png
"""
import os, sys, math, argparse, tempfile
import numpy as np
import torch
from collections import defaultdict
# ---------- PyTorch3D ----------
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
RasterizationSettings, MeshRasterizer,
MeshRenderer, BlendParams, SoftSilhouetteShader,
look_at_view_transform, OrthographicCameras
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
# ---------- SVG / Image ----------
import svgwrite
from skimage import measure
# ---------- pythonocc (OCC) ----------
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.BRep import BRep_Tool
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.ShapeFix import ShapeFix_Shape
# ---------- Combine SVG ----------
import svgutils.transform as st
from svgutils.compose import Unit
import cairosvg
# ---------------- Utils ----------------
def autodevice():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_step_solid(path: str):
reader = STEPControl_Reader()
stat = reader.ReadFile(path)
if stat != 1:
raise RuntimeError(f"STEP 读取失败: {path}")
reader.TransferRoots()
shape = reader.OneShape()
fixer = ShapeFix_Shape(shape)
fixer.Perform()
return fixer.Shape()
def triangulate(shape, deflection=0.5, angle=0.5):
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
verts, faces, v_off = [], [], 0
exp = TopExp_Explorer(shape, TopAbs_FACE)
while exp.More():
face = exp.Current()
loc = TopLoc_Location()
tri = BRep_Tool.Triangulation(face, loc)
if tri is not None:
nb_nodes = tri.NbNodes()
has_nodes_arr = hasattr(tri, "Nodes")
for i in range(1, nb_nodes + 1):
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
p = p.Transformed(loc.Transformation())
verts.append([p.X(), p.Y(), p.Z()])
nb_tris = tri.NbTriangles()
has_tris_arr = hasattr(tri, "Triangles")
for i in range(1, nb_tris + 1):
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
a, b, c = t.Get()
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
v_off += nb_nodes
exp.Next()
if not verts or not faces:
raise RuntimeError("三角化为空,尝试减小 --defl")
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
c = verts.mean(axis=0, keepdims=True)
v0 = verts - c
s = np.max(np.abs(v0))
s = max(s, 1e-12)
return v0 / s * unit_scale
def build_mesh(V_np, F_np, device):
V = torch.from_numpy(V_np).to(device)
F = torch.from_numpy(F_np).to(device)
return Meshes(verts=[V], faces=[F])
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
v0 = verts[faces[:, 0]];
v1 = verts[faces[:, 1]];
v2 = verts[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=1)
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
thr = math.cos(math.radians(max(0.0, angle_deg)))
e2f = defaultdict(list)
F = faces.shape[0]
for f in range(F):
i, j, k = faces[f].tolist()
for a, b in ((i, j), (j, k), (k, i)):
e = (a, b) if a < b else (b, a)
e2f[e].append(f)
feat = []
for e, fl in e2f.items():
if len(fl) == 1:
feat.append(e)
elif len(fl) == 2:
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
if cosv <= thr:
feat.append(e)
return feat
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
cull_backfaces=True
)
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
with torch.no_grad():
frags: Fragments = rast(mesh, cameras=cameras)
return frags, frags.zbuf[0, ..., 0]
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=1e-6,
faces_per_pixel=50,
cull_backfaces=True
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
)
with torch.no_grad():
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
with torch.no_grad():
scr = cameras.transform_points_screen(
points_world[None, ...],
image_size=((image_size, image_size),)
)
return scr[0] # (N,3)
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
H = W = image_size
vis, hid = [], []
scr = project_points_to_screen(cameras, verts_world, image_size)
zmin = zbuf.detach().cpu().numpy()
for i, j in edges_idx:
p0, p1 = scr[i], scr[j]
m = 0.5 * (p0 + p1)
x = int(np.clip(m[0].item(), 0, W - 1))
y = int(np.clip(m[1].item(), 0, H - 1))
z_proj = m[2].item()
z_ref = zmin[y, x]
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
vis.append(seg)
else:
hid.append(seg)
return vis, hid
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
contours = measure.find_contours(alpha, level=threshold)
polys = []
for cnt in contours:
if len(cnt) < 8:
continue
cnt = cnt[::max(1, step)]
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
return polys
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
# silhouette fill (optional)
for poly in silhouettes:
if len(poly) < 3:
continue
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
for i in range(1, len(poly)):
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
path.append("Z")
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
# visible edges
for (x0, y0, x1, y1) in edges_vis:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=stroke
))
# hidden edges
for (x0, y0, x1, y1) in edges_hid:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=max(1.0, 0.75 * stroke),
stroke_dasharray=[6, 6],
opacity=0.85
))
dwg.save()
# ---------------- Cameras ----------------
def make_camera(view: str, device, dist=2.7):
"""
look_at_view_transform spherical params for engineering views.
"""
if view == "top":
# from +Z to origin (view -Z)
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
elif view == "left":
# from -X to origin (view +X)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
elif view == "front":
# from +Y to origin (view -Y)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
else:
raise ValueError("view must be one of ['top','left','front']")
return OrthographicCameras(R=R, T=T, device=device)
# ---------------- Combine 3 SVGs ----------------
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
"""
横向拼接 3 张 SVG并额外导出 PNG。
"""
out_dir = os.path.dirname(output_svg)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
view1 = st.fromfile(view1_path)
view2 = st.fromfile(view2_path)
view3 = st.fromfile(view3_path)
r1 = view1.getroot()
r2 = view2.getroot()
r3 = view3.getroot()
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
max_w = max(w1, w2, w3)
max_h = max(h1, h2, h3)
combined_width = max_w * 3 + 40
combined_height = max_h + 20
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
x_offset = 10
y_offset = 10
r1.moveto(x_offset, y_offset)
x_offset += max_w + 20
r2.moveto(x_offset, y_offset)
x_offset += max_w + 20
r3.moveto(x_offset, y_offset)
combined.append([r1, r2, r3])
combined.save(output_svg)
# SVG -> PNG
cairosvg.svg2png(url=output_svg, write_to=output_image)
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG不合成")
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg")
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
device = autodevice()
print(f"[Device] {device}")
# STEP -> mesh
shape = read_step_solid(args.step_path)
V_np, F_np = triangulate(shape, deflection=args.defl)
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
mesh = build_mesh(V_np, F_np, device)
verts = mesh.verts_packed()
faces = mesh.faces_packed()
fn = compute_face_normals(verts, faces)
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
view_paths = {}
for name in ["top", "left", "front"]:
print(f"[View] {name}")
cam = make_camera(name, device)
alpha = render_silhouette(mesh, cam, image_size=args.res)
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
out_svg = os.path.join(args.out_dir, f"{name}.svg")
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
stroke=args.stroke, margin=10, fill="none")
view_paths[name] = out_svg
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
if args.no_combine:
print("[Done] 3 views exported (no combine).")
return
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
# 注意这里组合顺序按你第二段示例front / top / left
combine_svgs(
view1_path=view_paths["front"],
view2_path=view_paths["top"],
view3_path=view_paths["left"],
output_svg=combined_svg,
output_image=combined_png
)
print(f"[Combined] SVG: {combined_svg}")
print(f"[Combined] PNG: {combined_png}")
print("[Done] All outputs exported.")
if __name__ == "__main__":
main()

403
app.py Normal file
View File

@@ -0,0 +1,403 @@
import gradio as gr
from gradio_litmodel3d import LitModel3D
import os
import shutil
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Image.Image:
"""
Preprocess the input image.
Args:
image (Image.Image): The input image.
Returns:
Image.Image: The preprocessed image.
"""
processed_image = pipeline.preprocess_image(image)
return processed_image
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
"""
Preprocess a list of input images.
Args:
images (List[Tuple[Image.Image, str]]): The input images.
Returns:
List[Image.Image]: The preprocessed images.
"""
images = [image[0] for image in images]
processed_images = [pipeline.preprocess_image(image) for image in images]
return processed_images
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def image_to_3d(
image: Image.Image,
multiimages: List[Tuple[Image.Image, str]],
is_multiimage: bool,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[dict, str]:
"""
Convert an image to a 3D model.
Args:
image (Image.Image): The input image.
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
is_multiimage (bool): Whether is in multi-image mode.
seed (int): The random seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
if not is_multiimage:
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
else:
outputs = pipeline.run_multi_image(
[image[0] for image in multiimages],
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
mode=multiimage_algo,
)
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
def prepare_multi_example() -> List[Image.Image]:
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
images = []
for case in multi_case:
_images = []
for i in range(1, 4):
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
W, H = img.size
img = img.resize((int(W / H * 512), 512))
_images.append(np.array(img))
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
return images
def split_image(image: Image.Image) -> List[Image.Image]:
"""
Split an image into multiple views.
"""
image = np.array(image)
alpha = image[..., 3]
alpha = np.any(alpha>0, axis=0)
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
images = []
for s, e in zip(start_pos, end_pos):
images.append(Image.fromarray(image[:, s:e+1]))
return [preprocess_image(image) for image in images]
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
""")
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
gr.Markdown("""
Input different views of the object in separate images.
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
""")
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
is_multiimage = gr.State(False)
output_buf = gr.State()
# Example images at the bottom of the page
with gr.Row() as single_image_example:
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=64,
)
with gr.Row(visible=False) as multiimage_example:
examples_multi = gr.Examples(
examples=prepare_multi_example(),
inputs=[image_prompt],
fn=split_image,
outputs=[multiimage_prompt],
run_on_click=True,
examples_per_page=8,
)
# Handlers
demo.load(start_session)
demo.unload(end_session)
single_image_input_tab.select(
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
multiimage_input_tab.select(
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
outputs=[is_multiimage, single_image_example, multiimage_example]
)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
multiimage_prompt.upload(
preprocess_images,
inputs=[multiimage_prompt],
outputs=[multiimage_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
outputs=[output_buf, video_output],
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
demo.launch()

266
app_text.py Normal file
View File

@@ -0,0 +1,266 @@
import gradio as gr
from gradio_litmodel3d import LitModel3D
import os
import shutil
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from trellis.pipelines import TrellisTextTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def text_to_3d(
prompt: str,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
) -> Tuple[dict, str]:
"""
Convert an text prompt to a 3D model.
Args:
prompt (str): The text prompt.
seed (int): The random seed.
ss_guidance_strength (float): The guidance strength for sparse structure generation.
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
slat_guidance_strength (float): The guidance strength for structured latent generation.
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
Returns:
dict: The information of the generated 3D model.
str: The path to the video of the 3D model.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
outputs = pipeline.run(
prompt,
seed=seed,
formats=["gaussian", "mesh"],
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
mesh_simplify (float): The mesh simplification factor.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
"""
Extract a Gaussian file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
Returns:
str: The path to the extracted Gaussian file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, _ = unpack_state(state)
gaussian_path = os.path.join(user_dir, 'sample.ply')
gs.save_ply(gaussian_path)
torch.cuda.empty_cache()
return gaussian_path, gaussian_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Type a text prompt and click "Generate" to create a 3D asset.
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
""")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
gr.Markdown("""
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
text_to_3d,
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[output_buf, video_output],
).then(
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
video_output.clear(
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
outputs=[extract_glb_btn, extract_gs_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
extract_gs_btn.click(
extract_gaussian,
inputs=[output_buf],
outputs=[model_output, download_gs],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_gs],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge")
pipeline.cuda()
demo.launch()

48
blender_glb_to_obj.py Normal file
View File

@@ -0,0 +1,48 @@
import bpy
import sys
import os
argv = sys.argv
argv = argv[argv.index("--") + 1:]
glb_input = argv[0]
obj_output = argv[1]
def clean():
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.delete(use_global=False)
clean()
bpy.ops.import_scene.gltf(filepath=glb_input)
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
if not meshes:
raise RuntimeError("No mesh objects found in GLB/GLTF.")
bpy.ops.object.select_all(action='DESELECT')
for o in meshes:
o.select_set(True)
bpy.context.view_layer.objects.active = meshes[0]
if len(meshes) > 1:
bpy.ops.object.join()
obj = bpy.context.view_layer.objects.active
obj.select_set(True)
os.makedirs(os.path.dirname(obj_output), exist_ok=True)
bpy.ops.wm.obj_export(
filepath=obj_output,
export_selected_objects=True,
export_normals=True,
export_uv=True,
)
print("OBJ exported:", obj_output)

7
client.py Normal file
View File

@@ -0,0 +1,7 @@
# This file is auto-generated by LitServe.
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
import requests
response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")

View File

@@ -0,0 +1,285 @@
import os
import shutil
import sys
import time
import importlib
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
import utils3d
def get_first_directory(path):
with os.scandir(path) as it:
for entry in it:
if entry.is_dir():
return entry.name
return None
def need_process(key):
return key in opt.field or opt.field == ['all']
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--field', type=str, default='all',
help='Fields to process, separated by commas')
parser.add_argument('--from_file', action='store_true',
help='Build metadata from file instead of from records of processings.' +
'Useful when some processing fail to generate records but file already exists.')
dataset_utils.add_args(parser)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
os.makedirs(opt.output_dir, exist_ok=True)
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
opt.field = opt.field.split(',')
timestamp = str(int(time.time()))
# get file list
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
print('Loading previous metadata...')
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
else:
metadata = dataset_utils.get_metadata(**opt)
metadata.set_index('sha256', inplace=True)
# merge downloaded
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
if 'local_path' in metadata.columns:
metadata.update(df, overwrite=True)
else:
metadata = metadata.join(df, on='sha256', how='left')
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# detect models
image_models = []
if os.path.exists(os.path.join(opt.output_dir, 'features')):
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
latent_models = []
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
ss_latent_models = []
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
print(f'Image models: {image_models}')
print(f'Latent models: {latent_models}')
print(f'Sparse Structure latent models: {ss_latent_models}')
if 'rendered' not in metadata.columns:
metadata['rendered'] = [False] * len(metadata)
if 'voxelized' not in metadata.columns:
metadata['voxelized'] = [False] * len(metadata)
if 'num_voxels' not in metadata.columns:
metadata['num_voxels'] = [0] * len(metadata)
if 'cond_rendered' not in metadata.columns:
metadata['cond_rendered'] = [False] * len(metadata)
for model in image_models:
if f'feature_{model}' not in metadata.columns:
metadata[f'feature_{model}'] = [False] * len(metadata)
for model in latent_models:
if f'latent_{model}' not in metadata.columns:
metadata[f'latent_{model}'] = [False] * len(metadata)
for model in ss_latent_models:
if f'ss_latent_{model}' not in metadata.columns:
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
# merge rendered
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge aesthetic scores
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('aesthetic_scores_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge voxelized
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge cond_rendered
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge features
for model in image_models:
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge latents
for model in latent_models:
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# merge sparse structure latents
for model in ss_latent_models:
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
except:
pass
if len(df_parts) > 0:
df = pd.concat(df_parts)
df.set_index('sha256', inplace=True)
metadata.update(df, overwrite=True)
for f in df_files:
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
# build metadata from files
if opt.from_file:
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Building metadata") as pbar:
def worker(sha256):
try:
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
metadata.loc[sha256, 'rendered'] = True
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
try:
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
metadata.loc[sha256, 'voxelized'] = True
metadata.loc[sha256, 'num_voxels'] = len(pts)
except Exception as e:
pass
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
metadata.loc[sha256, 'cond_rendered'] = True
for model in image_models:
if need_process(f'feature_{model}') and \
metadata.loc[sha256, f'feature_{model}'] == False and \
metadata.loc[sha256, 'rendered'] == True and \
metadata.loc[sha256, 'voxelized'] == True and \
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
metadata.loc[sha256, f'feature_{model}'] = True
for model in latent_models:
if need_process(f'latent_{model}') and \
metadata.loc[sha256, f'latent_{model}'] == False and \
metadata.loc[sha256, 'rendered'] == True and \
metadata.loc[sha256, 'voxelized'] == True and \
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
metadata.loc[sha256, f'latent_{model}'] = True
for model in ss_latent_models:
if need_process(f'ss_latent_{model}') and \
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
metadata.loc[sha256, 'voxelized'] == True and \
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
metadata.loc[sha256, f'ss_latent_{model}'] = True
pbar.update()
except Exception as e:
print(f'Error processing {sha256}: {e}')
pbar.update()
executor.map(worker, metadata.index)
executor.shutdown(wait=True)
# statistics
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
f.write('Statistics:\n')
f.write(f' - Number of assets: {len(metadata)}\n')
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
if len(image_models) != 0:
f.write(f' - Number of assets with image features extracted:\n')
for model in image_models:
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
if len(latent_models) != 0:
f.write(f' - Number of assets with latents extracted:\n')
for model in latent_models:
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
if len(ss_latent_models) != 0:
f.write(f' - Number of assets with sparse structure latents extracted:\n')
for model in ss_latent_models:
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
print(f.read())

View File

@@ -0,0 +1,102 @@
import os
import argparse
import torch
import torch.nn as nn
from PIL import Image
import open_clip
from os.path import expanduser
from urllib.request import urlretrieve
import pandas as pd
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
def get_aesthetic_model(clip_model="vit_l_14"):
"""load the aethetic model"""
home = expanduser("~")
cache_folder = home + "/.cache/emb_reader"
path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
if not os.path.exists(path_to_model):
os.makedirs(cache_folder, exist_ok=True)
url_model = (
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
)
urlretrieve(url_model, path_to_model)
if clip_model == "vit_l_14":
m = nn.Linear(768, 1)
elif clip_model == "vit_b_32":
m = nn.Linear(512, 1)
else:
raise ValueError()
s = torch.load(path_to_model)
m.load_state_dict(s)
m.eval()
return m
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--clip_model", type=str, default="vit_l_14")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
opt = parser.parse_args()
amodel = get_aesthetic_model(clip_model="vit_l_14")
amodel.eval()
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
model = model.cuda()
amodel = amodel.cuda()
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
metadata = metadata[metadata['snapshotted'] == 1]
sha256s = metadata['sha256'].values
# filter out objects that are already calculated
if os.path.exists(os.path.join(opt.output_dir, 'aesthetic_scores.csv')):
with open(os.path.join(opt.output_dir, 'aesthetic_scores.csv'), 'r') as f:
old_metadata = pd.read_csv(f)
sha256s = list(set(sha256s) - set(old_metadata['sha256'].values))
sha256s = sorted(sha256s)
sha256s = sha256s[len(sha256s) * opt.rank // opt.world_size: len(sha256s) * (opt.rank + 1) // opt.world_size]
rows = []
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
finished = Queue(maxsize=128)
def load_image(sha256):
try:
files = os.listdir(os.path.join(opt.output_dir, 'snapshots', sha256))
files = [f for f in files if f.endswith('.png')]
processed = []
for file in files:
image = Image.open(os.path.join(opt.output_dir, 'snapshots', sha256, file))
processed.append(preprocess(image))
processed = torch.stack(processed, dim=0)
except Exception as e:
print(e)
processed = None
finished.put((sha256, processed))
executor.map(load_image, sha256s)
for _ in tqdm(range(len(sha256s)), desc='Calculating aesthetic scores'):
sha256, processed = finished.get()
if processed is not None:
with torch.no_grad():
image_features = model.encode_image(processed.cuda())
image_features /= image_features.norm(dim=-1, keepdim=True)
aesthetic_score = amodel(image_features).cpu()
rows.append(pd.DataFrame({
'sha256': [sha256],
'mean': [aesthetic_score.mean().item()],
'std': [aesthetic_score.std().item()],
'min': [aesthetic_score.min().item()],
'max': [aesthetic_score.max().item()],
'median': [aesthetic_score.median().item()]
}))
with open(os.path.join(opt.output_dir, f'aesthetic_scores_{opt.rank}.csv'), 'w') as f:
pd.concat(rows).to_csv(f, index=False)

View File

@@ -0,0 +1,97 @@
import os
import re
import argparse
import zipfile
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import pandas as pd
from utils import get_file_hash
def add_args(parser: argparse.ArgumentParser):
pass
def get_metadata(**kwargs):
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
return metadata
def download(metadata, output_dir, **kwargs):
os.makedirs(output_dir, exist_ok=True)
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
print("\033[93m")
print("3D-FUTURE have to be downloaded manually")
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
print("\033[0m")
raise FileNotFoundError("3D-FUTURE-model.zip not found")
downloaded = {}
metadata = metadata.set_index("file_identifier")
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
all_names = zip_ref.namelist()
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
instances = list(filter(lambda x: x in metadata.index, instances))
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(instances), desc="Extracting") as pbar:
def worker(instance: str) -> str:
try:
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
pbar.update()
return sha256
except Exception as e:
pbar.update()
print(f"Error extracting for {instance}: {e}")
return None
sha256s = executor.map(worker, instances)
executor.shutdown(wait=True)
for k, sha256 in zip(instances, sha256s):
if sha256 is not None:
if sha256 == metadata.loc[k, "sha256"]:
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
else:
print(f"Error downloading {k}: sha256s do not match")
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
# load metadata
metadata = metadata.to_dict('records')
# processing objects
records = []
max_workers = max_workers or os.cpu_count()
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
tqdm(total=len(metadata), desc=desc) as pbar:
def worker(metadatum):
try:
local_path = metadatum['local_path']
sha256 = metadatum['sha256']
file = os.path.join(output_dir, local_path)
record = func(file, sha256)
if record is not None:
records.append(record)
pbar.update()
except Exception as e:
print(f"Error processing object {sha256}: {e}")
pbar.update()
executor.map(worker, metadata)
executor.shutdown(wait=True)
except:
print("Error happened during processing.")
return pd.DataFrame.from_records(records)

View File

@@ -0,0 +1,96 @@
import os
import re
import argparse
import tarfile
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import pandas as pd
from utils import get_file_hash
def add_args(parser: argparse.ArgumentParser):
pass
def get_metadata(**kwargs):
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
return metadata
def download(metadata, output_dir, **kwargs):
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
try:
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
except:
print("\033[93m")
print("Error downloading ABO dataset. Please check your internet connection and try again.")
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
print("\033[0m")
raise FileNotFoundError("Error downloading ABO dataset")
downloaded = {}
metadata = metadata.set_index("file_identifier")
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
with ThreadPoolExecutor(max_workers=1) as executor, \
tqdm(total=len(metadata), desc="Extracting") as pbar:
def worker(instance: str) -> str:
try:
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
pbar.update()
return sha256
except Exception as e:
pbar.update()
print(f"Error extracting for {instance}: {e}")
return None
sha256s = executor.map(worker, metadata.index)
executor.shutdown(wait=True)
for k, sha256 in zip(metadata.index, sha256s):
if sha256 is not None:
if sha256 == metadata.loc[k, "sha256"]:
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
else:
print(f"Error downloading {k}: sha256s do not match")
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
# load metadata
metadata = metadata.to_dict('records')
# processing objects
records = []
max_workers = max_workers or os.cpu_count()
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
tqdm(total=len(metadata), desc=desc) as pbar:
def worker(metadatum):
try:
local_path = metadatum['local_path']
sha256 = metadatum['sha256']
file = os.path.join(output_dir, local_path)
record = func(file, sha256)
if record is not None:
records.append(record)
pbar.update()
except Exception as e:
print(f"Error processing object {sha256}: {e}")
pbar.update()
executor.map(worker, metadata)
executor.shutdown(wait=True)
except:
print("Error happened during processing.")
return pd.DataFrame.from_records(records)

6
trellis/__init__.py Executable file
View File

@@ -0,0 +1,6 @@
from . import models
from . import modules
from . import pipelines
from . import renderers
from . import representations
from . import utils

View File

@@ -0,0 +1,58 @@
import importlib
__attributes = {
'SparseStructure': 'sparse_structure',
'SparseFeat2Render': 'sparse_feat2render',
'SLat2Render':'structured_latent2render',
'Slat2RenderGeo':'structured_latent2render',
'SparseStructureLatent': 'sparse_structure_latent',
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
'SLat': 'structured_latent',
'TextConditionedSLat': 'structured_latent',
'ImageConditionedSLat': 'structured_latent',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .sparse_structure import SparseStructure
from .sparse_feat2render import SparseFeat2Render
from .structured_latent2render import (
SLat2Render,
Slat2RenderGeo,
)
from .sparse_structure_latent import (
SparseStructureLatent,
TextConditionedSparseStructureLatent,
ImageConditionedSparseStructureLatent,
)
from .structured_latent import (
SLat,
TextConditionedSLat,
ImageConditionedSLat,
)

View File

@@ -0,0 +1,96 @@
import importlib
__attributes = {
'SparseStructureEncoder': 'sparse_structure_vae',
'SparseStructureDecoder': 'sparse_structure_vae',
'SparseStructureFlowModel': 'sparse_structure_flow',
'SLatEncoder': 'structured_latent_vae',
'SLatGaussianDecoder': 'structured_latent_vae',
'SLatRadianceFieldDecoder': 'structured_latent_vae',
'SLatMeshDecoder': 'structured_latent_vae',
'ElasticSLatEncoder': 'structured_latent_vae',
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
'ElasticSLatMeshDecoder': 'structured_latent_vae',
'SLatFlowModel': 'structured_latent_flow',
'ElasticSLatFlowModel': 'structured_latent_flow',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
def from_pretrained(path: str, **kwargs):
"""
Load a model from a pretrained checkpoint.
Args:
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
**kwargs: Additional arguments for the model constructor.
"""
import os
import json
from safetensors.torch import load_file
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
if is_local:
config_file = f"{path}.json"
model_file = f"{path}.safetensors"
else:
from huggingface_hub import hf_hub_download
path_parts = path.split('/')
repo_id = f'{path_parts[0]}/{path_parts[1]}'
model_name = '/'.join(path_parts[2:])
config_file = hf_hub_download(repo_id, f"{model_name}.json")
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
with open(config_file, 'r') as f:
config = json.load(f)
model = __getattr__(config['name'])(**config['args'], **kwargs)
model.load_state_dict(load_file(model_file))
return model
# For Pylance
if __name__ == '__main__':
from .sparse_structure_vae import (
SparseStructureEncoder,
SparseStructureDecoder,
)
from .sparse_structure_flow import SparseStructureFlowModel
from .structured_latent_vae import (
SLatEncoder,
SLatGaussianDecoder,
SLatRadianceFieldDecoder,
SLatMeshDecoder,
ElasticSLatEncoder,
ElasticSLatGaussianDecoder,
ElasticSLatRadianceFieldDecoder,
ElasticSLatMeshDecoder,
)
from .structured_latent_flow import (
SLatFlowModel,
ElasticSLatFlowModel,
)

View File

@@ -0,0 +1,4 @@
from .encoder import SLatEncoder, ElasticSLatEncoder
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder

View File

@@ -0,0 +1,117 @@
from typing import *
import torch
import torch.nn as nn
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
from ...modules import sparse as sp
from ...modules.transformer import AbsolutePositionEmbedder
from ...modules.sparse.transformer import SparseTransformerBlock
def block_attn_config(self):
"""
Return the attention configuration of the model.
"""
for i in range(self.num_blocks):
if self.attn_mode == "shift_window":
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_sequence":
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_order":
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
elif self.attn_mode == "full":
yield "full", None, None, None, None
elif self.attn_mode == "swin":
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
class SparseTransformerBase(nn.Module):
"""
Sparse Transformer without output layers.
Serve as the base class for encoder and decoder.
"""
def __init__(
self,
in_channels: int,
model_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.window_size = window_size
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.attn_mode = attn_mode
self.pe_mode = pe_mode
self.use_fp16 = use_fp16
self.use_checkpoint = use_checkpoint
self.qk_rms_norm = qk_rms_norm
self.dtype = torch.float16 if use_fp16 else torch.float32
if pe_mode == "ape":
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
self.input_layer = sp.SparseLinear(in_channels, model_channels)
self.blocks = nn.ModuleList([
SparseTransformerBlock(
model_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
qk_rms_norm=self.qk_rms_norm,
)
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
])
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.input_layer(x)
if self.pe_mode == "ape":
h = h + self.pos_embedder(x.coords[:, 1:])
h = h.type(self.dtype)
for block in self.blocks:
h = block(h)
return h

View File

@@ -0,0 +1,36 @@
from typing import *
BACKEND = 'flash_attn'
DEBUG = False
def __from_env():
import os
global BACKEND
global DEBUG
env_attn_backend = os.environ.get('ATTN_BACKEND')
env_sttn_debug = os.environ.get('ATTN_DEBUG')
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
BACKEND = env_attn_backend
if env_sttn_debug is not None:
DEBUG = env_sttn_debug == '1'
print(f"[ATTENTION] Using backend: {BACKEND}")
__from_env()
def set_backend(backend: Literal['xformers', 'flash_attn']):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
from .full_attn import *
from .modules import *

View File

@@ -0,0 +1,102 @@
from typing import *
BACKEND = 'spconv'
DEBUG = False
ATTN = 'flash_attn'
def __from_env():
import os
global BACKEND
global DEBUG
global ATTN
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
if env_sparse_attn is None:
env_sparse_attn = os.environ.get('ATTN_BACKEND')
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
BACKEND = env_sparse_backend
if env_sparse_debug is not None:
DEBUG = env_sparse_debug == '1'
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
ATTN = env_sparse_attn
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
__from_env()
def set_backend(backend: Literal['spconv', 'torchsparse']):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
def set_attn(attn: Literal['xformers', 'flash_attn']):
global ATTN
ATTN = attn
import importlib
__attributes = {
'SparseTensor': 'basic',
'sparse_batch_broadcast': 'basic',
'sparse_batch_op': 'basic',
'sparse_cat': 'basic',
'sparse_unbind': 'basic',
'SparseGroupNorm': 'norm',
'SparseLayerNorm': 'norm',
'SparseGroupNorm32': 'norm',
'SparseLayerNorm32': 'norm',
'SparseReLU': 'nonlinearity',
'SparseSiLU': 'nonlinearity',
'SparseGELU': 'nonlinearity',
'SparseActivation': 'nonlinearity',
'SparseLinear': 'linear',
'sparse_scaled_dot_product_attention': 'attention',
'SerializeMode': 'attention',
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
'SparseMultiHeadAttention': 'attention',
'SparseConv3d': 'conv',
'SparseInverseConv3d': 'conv',
'SparseDownsample': 'spatial',
'SparseUpsample': 'spatial',
'SparseSubdivide' : 'spatial'
}
__submodules = ['transformer']
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import *
from .norm import *
from .nonlinearity import *
from .linear import *
from .attention import *
from .conv import *
from .spatial import *
import transformer

View File

@@ -0,0 +1,4 @@
from .full_attn import *
from .serialized_attn import *
from .windowed_attn import *
from .modules import *

459
trellis/modules/sparse/basic.py Executable file
View File

@@ -0,0 +1,459 @@
from typing import *
import torch
import torch.nn as nn
from . import BACKEND, DEBUG
SparseTensorData = None # Lazy import
__all__ = [
'SparseTensor',
'sparse_batch_broadcast',
'sparse_batch_op',
'sparse_cat',
'sparse_unbind',
]
class SparseTensor:
"""
Sparse tensor with support for both torchsparse and spconv backends.
Parameters:
- feats (torch.Tensor): Features of the sparse tensor.
- coords (torch.Tensor): Coordinates of the sparse tensor.
- shape (torch.Size): Shape of the sparse tensor.
- layout (List[slice]): Layout of the sparse tensor for each batch
- data (SparseTensorData): Sparse tensor data used for convolusion
NOTE:
- Data corresponding to a same batch should be contiguous.
- Coords should be in [0, 1023]
"""
@overload
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
@overload
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
def __init__(self, *args, **kwargs):
# Lazy import of sparse tensor backend
global SparseTensorData
if SparseTensorData is None:
import importlib
if BACKEND == 'torchsparse':
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
elif BACKEND == 'spconv':
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
method_id = 0
if len(args) != 0:
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
else:
method_id = 1 if 'data' in kwargs else 0
if method_id == 0:
feats, coords, shape, layout = args + (None,) * (4 - len(args))
if 'feats' in kwargs:
feats = kwargs['feats']
del kwargs['feats']
if 'coords' in kwargs:
coords = kwargs['coords']
del kwargs['coords']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
if 'layout' in kwargs:
layout = kwargs['layout']
del kwargs['layout']
if shape is None:
shape = self.__cal_shape(feats, coords)
if layout is None:
layout = self.__cal_layout(coords, shape[0])
if BACKEND == 'torchsparse':
self.data = SparseTensorData(feats, coords, **kwargs)
elif BACKEND == 'spconv':
spatial_shape = list(coords.max(0)[0] + 1)[1:]
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
self.data._features = feats
elif method_id == 1:
data, shape, layout = args + (None,) * (3 - len(args))
if 'data' in kwargs:
data = kwargs['data']
del kwargs['data']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
if 'layout' in kwargs:
layout = kwargs['layout']
del kwargs['layout']
self.data = data
if shape is None:
shape = self.__cal_shape(self.feats, self.coords)
if layout is None:
layout = self.__cal_layout(self.coords, shape[0])
self._shape = shape
self._layout = layout
self._scale = kwargs.get('scale', (1, 1, 1))
self._spatial_cache = kwargs.get('spatial_cache', {})
if DEBUG:
try:
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
for i in range(self.shape[0]):
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
except Exception as e:
print('Debugging information:')
print(f"- Shape: {self.shape}")
print(f"- Layout: {self.layout}")
print(f"- Scale: {self._scale}")
print(f"- Coords: {self.coords}")
raise e
def __cal_shape(self, feats, coords):
shape = []
shape.append(coords[:, 0].max().item() + 1)
shape.extend([*feats.shape[1:]])
return torch.Size(shape)
def __cal_layout(self, coords, batch_size):
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
offset = torch.cumsum(seq_len, dim=0)
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
return layout
@property
def shape(self) -> torch.Size:
return self._shape
def dim(self) -> int:
return len(self.shape)
@property
def layout(self) -> List[slice]:
return self._layout
@property
def feats(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.F
elif BACKEND == 'spconv':
return self.data.features
@feats.setter
def feats(self, value: torch.Tensor):
if BACKEND == 'torchsparse':
self.data.F = value
elif BACKEND == 'spconv':
self.data.features = value
@property
def coords(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.C
elif BACKEND == 'spconv':
return self.data.indices
@coords.setter
def coords(self, value: torch.Tensor):
if BACKEND == 'torchsparse':
self.data.C = value
elif BACKEND == 'spconv':
self.data.indices = value
@property
def dtype(self):
return self.feats.dtype
@property
def device(self):
return self.feats.device
@overload
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
@overload
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
def to(self, *args, **kwargs) -> 'SparseTensor':
device = None
dtype = None
if len(args) == 2:
device, dtype = args
elif len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype = args[0]
else:
device = args[0]
if 'dtype' in kwargs:
assert dtype is None, "to() received multiple values for argument 'dtype'"
dtype = kwargs['dtype']
if 'device' in kwargs:
assert device is None, "to() received multiple values for argument 'device'"
device = kwargs['device']
new_feats = self.feats.to(device=device, dtype=dtype)
new_coords = self.coords.to(device=device)
return self.replace(new_feats, new_coords)
def type(self, dtype):
new_feats = self.feats.type(dtype)
return self.replace(new_feats)
def cpu(self) -> 'SparseTensor':
new_feats = self.feats.cpu()
new_coords = self.coords.cpu()
return self.replace(new_feats, new_coords)
def cuda(self) -> 'SparseTensor':
new_feats = self.feats.cuda()
new_coords = self.coords.cuda()
return self.replace(new_feats, new_coords)
def half(self) -> 'SparseTensor':
new_feats = self.feats.half()
return self.replace(new_feats)
def float(self) -> 'SparseTensor':
new_feats = self.feats.float()
return self.replace(new_feats)
def detach(self) -> 'SparseTensor':
new_coords = self.coords.detach()
new_feats = self.feats.detach()
return self.replace(new_feats, new_coords)
def dense(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.dense()
elif BACKEND == 'spconv':
return self.data.dense()
def reshape(self, *shape) -> 'SparseTensor':
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
return self.replace(new_feats)
def unbind(self, dim: int) -> List['SparseTensor']:
return sparse_unbind(self, dim)
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
new_shape = [self.shape[0]]
new_shape.extend(feats.shape[1:])
if BACKEND == 'torchsparse':
new_data = SparseTensorData(
feats=feats,
coords=self.data.coords if coords is None else coords,
stride=self.data.stride,
spatial_range=self.data.spatial_range,
)
new_data._caches = self.data._caches
elif BACKEND == 'spconv':
new_data = SparseTensorData(
self.data.features.reshape(self.data.features.shape[0], -1),
self.data.indices,
self.data.spatial_shape,
self.data.batch_size,
self.data.grid,
self.data.voxel_num,
self.data.indice_dict
)
new_data._features = feats
new_data.benchmark = self.data.benchmark
new_data.benchmark_record = self.data.benchmark_record
new_data.thrust_allocator = self.data.thrust_allocator
new_data._timer = self.data._timer
new_data.force_algo = self.data.force_algo
new_data.int8_scale = self.data.int8_scale
if coords is not None:
new_data.indices = coords
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
return new_tensor
@staticmethod
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
N, C = dim
x = torch.arange(aabb[0], aabb[3] + 1)
y = torch.arange(aabb[1], aabb[4] + 1)
z = torch.arange(aabb[2], aabb[5] + 1)
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
coords = torch.cat([
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
coords.repeat(N, 1),
], dim=1).to(dtype=torch.int32, device=device)
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
return SparseTensor(feats=feats, coords=coords)
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
new_cache = {}
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
if k in self._spatial_cache:
new_cache[k] = self._spatial_cache[k]
if k in other._spatial_cache:
if k not in new_cache:
new_cache[k] = other._spatial_cache[k]
else:
new_cache[k].update(other._spatial_cache[k])
return new_cache
def __neg__(self) -> 'SparseTensor':
return self.replace(-self.feats)
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
if isinstance(other, torch.Tensor):
try:
other = torch.broadcast_to(other, self.shape)
other = sparse_batch_broadcast(self, other)
except:
pass
if isinstance(other, SparseTensor):
other = other.feats
new_feats = op(self.feats, other)
new_tensor = self.replace(new_feats)
if isinstance(other, SparseTensor):
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
return new_tensor
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.add)
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.add)
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.sub)
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.mul)
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.mul)
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.div)
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
def __getitem__(self, idx):
if isinstance(idx, int):
idx = [idx]
elif isinstance(idx, slice):
idx = range(*idx.indices(self.shape[0]))
elif isinstance(idx, torch.Tensor):
if idx.dtype == torch.bool:
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
idx = idx.nonzero().squeeze(1)
elif idx.dtype in [torch.int32, torch.int64]:
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
else:
raise ValueError(f"Unknown index type: {idx.dtype}")
else:
raise ValueError(f"Unknown index type: {type(idx)}")
coords = []
feats = []
for new_idx, old_idx in enumerate(idx):
coords.append(self.coords[self.layout[old_idx]].clone())
coords[-1][:, 0] = new_idx
feats.append(self.feats[self.layout[old_idx]])
coords = torch.cat(coords, dim=0).contiguous()
feats = torch.cat(feats, dim=0).contiguous()
return SparseTensor(feats=feats, coords=coords)
def register_spatial_cache(self, key, value) -> None:
"""
Register a spatial cache.
The spatial cache can be any thing you want to cache.
The registery and retrieval of the cache is based on current scale.
"""
scale_key = str(self._scale)
if scale_key not in self._spatial_cache:
self._spatial_cache[scale_key] = {}
self._spatial_cache[scale_key][key] = value
def get_spatial_cache(self, key=None):
"""
Get a spatial cache.
"""
scale_key = str(self._scale)
cur_scale_cache = self._spatial_cache.get(scale_key, {})
if key is None:
return cur_scale_cache
return cur_scale_cache.get(key, None)
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
"""
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
Args:
input (torch.Tensor): 1D tensor to broadcast.
target (SparseTensor): Sparse tensor to broadcast to.
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
"""
coords, feats = input.coords, input.feats
broadcasted = torch.zeros_like(feats)
for k in range(input.shape[0]):
broadcasted[input.layout[k]] = other[k]
return broadcasted
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
"""
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
Args:
input (torch.Tensor): 1D tensor to broadcast.
target (SparseTensor): Sparse tensor to broadcast to.
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
"""
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
"""
Concatenate a list of sparse tensors.
Args:
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
"""
if dim == 0:
start = 0
coords = []
for input in inputs:
coords.append(input.coords.clone())
coords[-1][:, 0] += start
start += input.shape[0]
coords = torch.cat(coords, dim=0)
feats = torch.cat([input.feats for input in inputs], dim=0)
output = SparseTensor(
coords=coords,
feats=feats,
)
else:
feats = torch.cat([input.feats for input in inputs], dim=dim)
output = inputs[0].replace(feats)
return output
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
"""
Unbind a sparse tensor along a dimension.
Args:
input (SparseTensor): Sparse tensor to unbind.
dim (int): Dimension to unbind.
"""
if dim == 0:
return [input[i] for i in range(input.shape[0])]
else:
feats = input.feats.unbind(dim)
return [input.replace(f) for f in feats]

View File

@@ -0,0 +1,21 @@
from .. import BACKEND
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
def __from_env():
import os
global SPCONV_ALGO
env_spconv_algo = os.environ.get('SPCONV_ALGO')
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
SPCONV_ALGO = env_spconv_algo
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
__from_env()
if BACKEND == 'torchsparse':
from .conv_torchsparse import *
elif BACKEND == 'spconv':
from .conv_spconv import *

View File

@@ -0,0 +1,2 @@
from .blocks import *
from .modulated import *

View File

@@ -0,0 +1,151 @@
from typing import *
import torch
import torch.nn as nn
from ..basic import SparseTensor
from ..linear import SparseLinear
from ..nonlinearity import SparseGELU
from ..attention import SparseMultiHeadAttention, SerializeMode
from ...norm import LayerNorm32
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio)),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels),
)
def forward(self, x: SparseTensor) -> SparseTensor:
return self.mlp(x)
class SparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
shift_sequence: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
serialize_mode: Optional[SerializeMode] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor) -> SparseTensor:
h = x.replace(self.norm1(x.feats))
h = self.attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
shift_sequence: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
serialize_mode: Optional[SerializeMode] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
h = x.replace(self.norm1(x.feats))
h = self.self_attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor, context: torch.Tensor):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
else:
return self._forward(x, context)

View File

@@ -0,0 +1,2 @@
from .blocks import *
from .modulated import *

View File

@@ -0,0 +1,182 @@
from typing import *
import torch
import torch.nn as nn
from ..attention import MultiHeadAttention
from ..norm import LayerNorm32
class AbsolutePositionEmbedder(nn.Module):
"""
Embeds spatial positions into vector representations.
"""
def __init__(self, channels: int, in_channels: int = 3):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.freq_dim = channels // in_channels // 2
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = 1.0 / (10000 ** self.freqs)
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
"""
Create sinusoidal position embeddings.
Args:
x: a 1-D Tensor of N indices
Returns:
an (N, D) Tensor of positional embeddings.
"""
self.freqs = self.freqs.to(x.device)
out = torch.outer(x, self.freqs)
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): (N, D) tensor of spatial positions
"""
N, D = x.shape
assert D == self.in_channels, "Input dimension must match number of input channels"
embed = self._sin_cos_embedding(x.reshape(-1))
embed = embed.reshape(N, -1)
if embed.shape[1] < self.channels:
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
return embed
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(channels * mlp_ratio), channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class TransformerBlock(nn.Module):
"""
Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[int] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = MultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = self.attn(h)
x = x + h
h = self.norm2(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class TransformerCrossBlock(nn.Module):
"""
Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor, context: torch.Tensor):
h = self.norm1(x)
h = self.self_attn(h)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor, context: torch.Tensor):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
else:
return self._forward(x, context)

View File

@@ -0,0 +1,25 @@
from . import samplers
from .trellis_image_to_3d import TrellisImageTo3DPipeline
from .trellis_text_to_3d import TrellisTextTo3DPipeline
def from_pretrained(path: str):
"""
Load a pipeline from a model folder or a Hugging Face model hub.
Args:
path: The path to the model. Can be either local path or a Hugging Face model name.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
if is_local:
config_file = f"{path}/pipeline.json"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
with open(config_file, 'r') as f:
config = json.load(f)
return globals()[config['name']].from_pretrained(path)

68
trellis/pipelines/base.py Normal file
View File

@@ -0,0 +1,68 @@
from typing import *
import torch
import torch.nn as nn
from .. import models
class Pipeline:
"""
A base class for pipelines.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
):
if models is None:
return
self.models = models
for model in self.models.values():
model.eval()
@staticmethod
def from_pretrained(path: str) -> "Pipeline":
"""
Load a pretrained model.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
if is_local:
config_file = f"{path}/pipeline.json"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
with open(config_file, 'r') as f:
args = json.load(f)['args']
_models = {}
for k, v in args['models'].items():
try:
_models[k] = models.from_pretrained(f"{path}/{v}")
except:
_models[k] = models.from_pretrained(v)
new_pipeline = Pipeline(_models)
new_pipeline._pretrained_args = args
return new_pipeline
@property
def device(self) -> torch.device:
for model in self.models.values():
if hasattr(model, 'device'):
return model.device
for model in self.models.values():
if hasattr(model, 'parameters'):
return next(model.parameters()).device
raise RuntimeError("No device found.")
def to(self, device: torch.device) -> None:
for model in self.models.values():
model.to(device)
def cuda(self) -> None:
self.to(torch.device("cuda"))
def cpu(self) -> None:
self.to(torch.device("cpu"))

View File

@@ -0,0 +1,2 @@
from .base import Sampler
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler

View File

@@ -0,0 +1,20 @@
from typing import *
from abc import ABC, abstractmethod
class Sampler(ABC):
"""
A base class for samplers.
"""
@abstractmethod
def sample(
self,
model,
**kwargs
):
"""
Sample from a model.
"""
pass

View File

@@ -0,0 +1,12 @@
from typing import *
class ClassifierFreeGuidanceSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred

31
trellis/renderers/__init__.py Executable file
View File

@@ -0,0 +1,31 @@
import importlib
__attributes = {
'OctreeRenderer': 'octree_renderer',
'GaussianRenderer': 'gaussian_render',
'MeshRenderer': 'mesh_renderer',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .octree_renderer import OctreeRenderer
from .gaussian_render import GaussianRenderer
from .mesh_renderer import MeshRenderer

View File

@@ -0,0 +1,4 @@
from .radiance_field import Strivec
from .octree import DfsOctree as Octree
from .gaussian import Gaussian
from .mesh import MeshExtractResult

View File

@@ -0,0 +1 @@
from .gaussian_model import Gaussian

View File

@@ -0,0 +1 @@
from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult

View File

@@ -0,0 +1 @@
from .octree_dfs import DfsOctree

View File

@@ -0,0 +1 @@
from .strivec import Strivec

View File

@@ -0,0 +1,63 @@
import importlib
__attributes = {
'BasicTrainer': 'basic',
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
'FlowMatchingTrainer': 'flow_matching.flow_matching',
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import BasicTrainer
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
from .flow_matching.flow_matching import (
FlowMatchingTrainer,
FlowMatchingCFGTrainer,
TextConditionedFlowMatchingCFGTrainer,
ImageConditionedFlowMatchingCFGTrainer,
)
from .flow_matching.sparse_flow_matching import (
SparseFlowMatchingTrainer,
SparseFlowMatchingCFGTrainer,
TextConditionedSparseFlowMatchingCFGTrainer,
ImageConditionedSparseFlowMatchingCFGTrainer,
)

451
trellis/trainers/base.py Normal file
View File

@@ -0,0 +1,451 @@
from abc import abstractmethod
import os
import time
import json
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
import numpy as np
from torchvision import utils
from torch.utils.tensorboard import SummaryWriter
from .utils import *
from ..utils.general_utils import *
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
class Trainer:
"""
Base class for training.
"""
def __init__(self,
models,
dataset,
*,
output_dir,
load_dir,
step,
max_steps,
batch_size=None,
batch_size_per_gpu=None,
batch_split=None,
optimizer={},
lr_scheduler=None,
elastic=None,
grad_clip=None,
ema_rate=0.9999,
fp16_mode='inflat_all',
fp16_scale_growth=1e-3,
finetune_ckpt=None,
log_param_stats=False,
prefetch_data=True,
i_print=1000,
i_log=500,
i_sample=10000,
i_save=10000,
i_ddpcheck=10000,
**kwargs
):
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
self.models = models
self.dataset = dataset
self.batch_split = batch_split if batch_split is not None else 1
self.max_steps = max_steps
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
self.elastic_controller_config = elastic
self.grad_clip = grad_clip
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
self.fp16_mode = fp16_mode
self.fp16_scale_growth = fp16_scale_growth
self.log_param_stats = log_param_stats
self.prefetch_data = prefetch_data
if self.prefetch_data:
self._data_prefetched = None
self.output_dir = output_dir
self.i_print = i_print
self.i_log = i_log
self.i_sample = i_sample
self.i_save = i_save
self.i_ddpcheck = i_ddpcheck
if dist.is_initialized():
# Multi-GPU params
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.local_rank = dist.get_rank() % torch.cuda.device_count()
self.is_master = self.rank == 0
else:
# Single-GPU params
self.world_size = 1
self.rank = 0
self.local_rank = 0
self.is_master = True
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
self.init_models_and_more(**kwargs)
self.prepare_dataloader(**kwargs)
# Load checkpoint
self.step = 0
if load_dir is not None and step is not None:
self.load(load_dir, step)
elif finetune_ckpt is not None:
self.finetune_from(finetune_ckpt)
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
if self.world_size > 1:
self.check_ddp()
if self.is_master:
print('\n\nTrainer initialized.')
print(self)
@property
def device(self):
for _, model in self.models.items():
if hasattr(model, 'device'):
return model.device
return next(list(self.models.values())[0].parameters()).device
@abstractmethod
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
pass
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = ResumableSampler(
self.dataset,
shuffle=True,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
@abstractmethod
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def save(self):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
pass
@abstractmethod
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
"""
Run a snapshot of the model.
"""
pass
@torch.no_grad()
def visualize_sample(self, sample):
"""
Convert a sample to an image.
"""
if hasattr(self.dataset, 'visualize_sample'):
return self.dataset.visualize_sample(sample)
else:
return sample
@torch.no_grad()
def snapshot_dataset(self, num_samples=100):
"""
Sample images from the dataset.
"""
dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=num_samples,
num_workers=0,
shuffle=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
data = next(iter(dataloader))
data = recursive_to_device(data, self.device)
vis = self.visualize_sample(data)
if isinstance(vis, dict):
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
else:
save_cfg = [('dataset', vis)]
for name, image in save_cfg:
utils.save_image(
image,
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
@torch.no_grad()
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
"""
Sample images from the model.
NOTE: This function should be called by all processes.
"""
if self.is_master:
print(f'\nSampling {num_samples} images...', end='')
if suffix is None:
suffix = f'step{self.step:07d}'
# Assign tasks
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
# Preprocess images
for key in list(samples.keys()):
if samples[key]['type'] == 'sample':
vis = self.visualize_sample(samples[key]['value'])
if isinstance(vis, dict):
for k, v in vis.items():
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
del samples[key]
else:
samples[key] = {'value': vis, 'type': 'image'}
# Gather results
if self.world_size > 1:
for key in samples.keys():
samples[key]['value'] = samples[key]['value'].contiguous()
if self.is_master:
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
else:
all_images = []
dist.gather(samples[key]['value'], all_images, dst=0)
if self.is_master:
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
# Save images
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
for key in samples.keys():
if samples[key]['type'] == 'image':
utils.save_image(
samples[key]['value'],
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
elif samples[key]['type'] == 'number':
min = samples[key]['value'].min()
max = samples[key]['value'].max()
images = (samples[key]['value'] - min) / (max - min)
images = utils.make_grid(
images,
nrow=int(np.sqrt(num_samples)),
normalize=False,
)
save_image_with_notes(
images,
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
notes=f'{key} min: {min}, max: {max}',
)
if self.is_master:
print(' Done.')
@abstractmethod
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
pass
@abstractmethod
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
pass
@abstractmethod
def training_losses(**mb_data):
"""
Compute training losses.
"""
pass
def load_data(self):
"""
Load data.
"""
if self.prefetch_data:
if self._data_prefetched is None:
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
data = self._data_prefetched
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
else:
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
if isinstance(data, dict):
if self.batch_split == 1:
data_list = [data]
else:
batch_size = list(data.values())[0].shape[0]
data_list = [
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
for i in range(self.batch_split)
]
elif isinstance(data, list):
data_list = data
else:
raise ValueError('Data must be a dict or a list of dicts.')
return data_list
@abstractmethod
def run_step(self, data_list):
"""
Run a training step.
"""
pass
def run(self):
"""
Run training.
"""
if self.is_master:
print('\nStarting training...')
self.snapshot_dataset()
if self.step == 0:
self.snapshot(suffix='init')
else: # resume
self.snapshot(suffix=f'resume_step{self.step:07d}')
log = []
time_last_print = 0.0
time_elapsed = 0.0
while self.step < self.max_steps:
time_start = time.time()
data_list = self.load_data()
step_log = self.run_step(data_list)
time_end = time.time()
time_elapsed += time_end - time_start
self.step += 1
# Print progress
if self.is_master and self.step % self.i_print == 0:
speed = self.i_print / (time_elapsed - time_last_print) * 3600
columns = [
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
f'Elapsed: {time_elapsed / 3600:.2f} h',
f'Speed: {speed:.2f} steps/h',
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
]
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
time_last_print = time_elapsed
# Check ddp
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
self.check_ddp()
# Sample images
if self.step % self.i_sample == 0:
self.snapshot()
if self.is_master:
log.append((self.step, {}))
# Log time
log[-1][1]['time'] = {
'step': time_end - time_start,
'elapsed': time_elapsed,
}
# Log losses
if step_log is not None:
log[-1][1].update(step_log)
# Log scale
if self.fp16_mode == 'amp':
log[-1][1]['scale'] = self.scaler.get_scale()
elif self.fp16_mode == 'inflat_all':
log[-1][1]['log_scale'] = self.log_scale
# Save log
if self.step % self.i_log == 0:
## save to log file
log_str = '\n'.join([
f'{step}: {json.dumps(log)}' for step, log in log
])
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
log_file.write(log_str + '\n')
# show with mlflow
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
log_show = dict_reduce(log_show, lambda x: np.mean(x))
log_show = dict_flatten(log_show, sep='/')
for key, value in log_show.items():
self.writer.add_scalar(key, value, self.step)
log = []
# Save checkpoint
if self.step % self.i_save == 0:
self.save()
if self.is_master:
self.snapshot(suffix='final')
self.writer.close()
print('Training finished.')
def profile(self, wait=2, warmup=3, active=5):
"""
Profile the training loop.
"""
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
profile_memory=True,
with_stack=True,
) as prof:
for _ in range(wait + warmup + active):
self.run_step()
prof.step()

438
trellis/trainers/basic.py Normal file
View File

@@ -0,0 +1,438 @@
import os
import copy
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from .utils import *
from .base import Trainer
from ..utils.general_utils import *
from ..utils.dist_utils import *
from ..utils import grad_clip_utils, elastic_utils
class BasicTrainer(Trainer):
"""
Trainer for basic training loop.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
"""
def __str__(self):
lines = []
lines.append(self.__class__.__name__)
lines.append(f' - Models:')
for name, model in self.models.items():
lines.append(f' - {name}: {model.__class__.__name__}')
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
lines.append(f' - Dataloader:')
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
lines.append(f' - Num workers: {self.dataloader.num_workers}')
lines.append(f' - Number of steps: {self.max_steps}')
lines.append(f' - Number of GPUs: {self.world_size}')
lines.append(f' - Batch size: {self.batch_size}')
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
lines.append(f' - Batch split: {self.batch_split}')
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
if self.lr_scheduler_config is not None:
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
if self.elastic_controller_config is not None:
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
if self.grad_clip is not None:
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
lines.append(f' - EMA rate: {self.ema_rate}')
lines.append(f' - FP16 mode: {self.fp16_mode}')
return '\n'.join(lines)
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
if self.world_size > 1:
# Prepare distributed data parallel
self.training_models = {
name: DDP(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
bucket_cap_mb=128,
find_unused_parameters=False
)
for name, model in self.models.items()
}
else:
self.training_models = self.models
# Build master params
self.model_params = sum(
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
, [])
if self.fp16_mode == 'amp':
self.master_params = self.model_params
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
elif self.fp16_mode == 'inflat_all':
self.master_params = make_master_params(self.model_params)
self.fp16_scale_growth = self.fp16_scale_growth
self.log_scale = 20.0
elif self.fp16_mode is None:
self.master_params = self.model_params
else:
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
# Build EMA params
if self.is_master:
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
# Initialize optimizer
if hasattr(torch.optim, self.optimizer_config['name']):
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
else:
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
# Initalize learning rate scheduler
if self.lr_scheduler_config is not None:
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
else:
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
# Initialize elastic memory controller
if self.elastic_controller_config is not None:
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
for model in self.models.values():
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
model.register_memory_controller(self.elastic_controller)
# Initialize gradient clipper
if self.grad_clip is not None:
if isinstance(self.grad_clip, (float, int)):
self.grad_clip = float(self.grad_clip)
else:
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
def _master_params_to_state_dicts(self, master_params):
"""
Convert master params to dict of state_dicts.
"""
if self.fp16_mode == 'inflat_all':
master_params = unflatten_master_params(self.model_params, master_params)
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
for i, (model_name, param_name) in enumerate(master_params_names):
state_dicts[model_name][param_name] = master_params[i]
return state_dicts
def _state_dicts_to_master_params(self, master_params, state_dicts):
"""
Convert a state_dict to master params.
"""
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
if self.fp16_mode == 'inflat_all':
model_params_to_master_params(params, master_params)
else:
for i, param in enumerate(params):
master_params[i].data.copy_(param.data)
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print(f'\nLoading checkpoint from step {step}...', end='')
model_ckpts = {}
for name, model in self.models.items():
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
if self.fp16_mode == 'inflat_all':
model.convert_to_fp16()
self._state_dicts_to_master_params(self.master_params, model_ckpts)
del model_ckpts
if self.is_master:
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = {}
for name, model in self.models.items():
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
ema_ckpts[name] = ema_ckpt
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
del ema_ckpts
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
self.step = misc_ckpt['step']
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
if self.fp16_mode == 'amp':
self.scaler.load_state_dict(misc_ckpt['scaler'])
elif self.fp16_mode == 'inflat_all':
self.log_scale = misc_ckpt['log_scale']
if self.lr_scheduler_config is not None:
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
if self.elastic_controller_config is not None:
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
del misc_ckpt
if self.world_size > 1:
dist.barrier()
if self.is_master:
print(' Done.')
if self.world_size > 1:
self.check_ddp()
def save(self):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
assert self.is_master, 'save() should be called only by the rank 0 process.'
print(f'\nSaving checkpoint at step {self.step}...', end='')
model_ckpts = self._master_params_to_state_dicts(self.master_params)
for name, model_ckpt in model_ckpts.items():
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
for name, ema_ckpt in ema_ckpts.items():
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
misc_ckpt = {
'optimizer': self.optimizer.state_dict(),
'step': self.step,
'data_sampler': self.data_sampler.state_dict(),
}
if self.fp16_mode == 'amp':
misc_ckpt['scaler'] = self.scaler.state_dict()
elif self.fp16_mode == 'inflat_all':
misc_ckpt['log_scale'] = self.log_scale
if self.lr_scheduler_config is not None:
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
if self.elastic_controller_config is not None:
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
print(' Done.')
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print('\nFinetuning from:')
for name, path in finetune_ckpt.items():
print(f' - {name}: {path}')
model_ckpts = {}
for name, model in self.models.items():
model_state_dict = model.state_dict()
if name in finetune_ckpt:
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
for k, v in model_ckpt.items():
if model_ckpt[k].shape != model_state_dict[k].shape:
if self.is_master:
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
model_ckpt[k] = model_state_dict[k]
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
if self.fp16_mode == 'inflat_all':
model.convert_to_fp16()
else:
if self.is_master:
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
model_ckpts[name] = model_state_dict
self._state_dicts_to_master_params(self.master_params, model_ckpts)
del model_ckpts
if self.world_size > 1:
dist.barrier()
if self.is_master:
print('Done.')
if self.world_size > 1:
self.check_ddp()
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
for i, ema_rate in enumerate(self.ema_rate):
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
if self.is_master:
print('\nPerforming DDP check...')
if self.is_master:
print('Checking if parameters are consistent across processes...')
dist.barrier()
try:
for p in self.master_params:
# split to avoid OOM
for i in range(0, p.numel(), 10000000):
sub_size = min(10000000, p.numel() - i)
sub_p = p.detach().view(-1)[i:i+sub_size]
# gather from all processes
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
dist.all_gather(sub_p_gather, sub_p)
# check if equal
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
except AssertionError as e:
if self.is_master:
print(f'\n\033[91mError: {e}\033[0m')
print('DDP check failed.')
raise e
dist.barrier()
if self.is_master:
print('Done.')
def run_step(self, data_list):
"""
Run a training step.
"""
step_log = {'loss': {}, 'status': {}}
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
# Train
losses = []
statuses = []
elastic_controller_logs = []
zero_grad(self.model_params)
for i, mb_data in enumerate(data_list):
## sync at the end of each batch split
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
with nested_contexts(*sync_contexts), elastic_controller_context():
with amp_context():
loss, status = self.training_losses(**mb_data)
l = loss['loss'] / len(data_list)
## backward
if self.fp16_mode == 'amp':
self.scaler.scale(l).backward()
elif self.fp16_mode == 'inflat_all':
scaled_l = l * (2 ** self.log_scale)
scaled_l.backward()
else:
l.backward()
## log
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
if self.elastic_controller_config is not None:
elastic_controller_logs.append(self.elastic_controller.log())
## gradient clip
if self.grad_clip is not None:
if self.fp16_mode == 'amp':
self.scaler.unscale_(self.optimizer)
elif self.fp16_mode == 'inflat_all':
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
if isinstance(self.grad_clip, float):
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
else:
grad_norm = self.grad_clip(self.master_params)
if torch.isfinite(grad_norm):
statuses[-1]['grad_norm'] = grad_norm.item()
## step
if self.fp16_mode == 'amp':
prev_scale = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
elif self.fp16_mode == 'inflat_all':
prev_scale = 2 ** self.log_scale
if not any(not p.grad.isfinite().all() for p in self.model_params):
if self.grad_clip is None:
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
self.optimizer.step()
master_params_to_model_params(self.model_params, self.master_params)
self.log_scale += self.fp16_scale_growth
else:
self.log_scale -= 1
else:
prev_scale = 1.0
if not any(not p.grad.isfinite().all() for p in self.model_params):
self.optimizer.step()
else:
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
## adjust learning rate
if self.lr_scheduler_config is not None:
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
# Logs
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
if self.elastic_controller_config is not None:
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
if self.grad_clip is not None:
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
# Check grad and norm of each param
if self.log_param_stats:
param_norms = {}
param_grads = {}
for name, param in self.backbone.named_parameters():
if param.requires_grad:
param_norms[name] = param.norm().item()
if param.grad is not None and torch.isfinite(param.grad).all():
param_grads[name] = param.grad.norm().item() / prev_scale
step_log['param_norms'] = param_norms
step_log['param_grads'] = param_grads
# Update exponential moving average
if self.is_master:
self.update_ema()
return step_log

View File

@@ -0,0 +1,59 @@
import torch
import numpy as np
from ....utils.general_utils import dict_foreach
from ....pipelines import samplers
class ClassifierFreeGuidanceMixin:
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
super().__init__(*args, **kwargs)
self.p_uncond = p_uncond
def get_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
if self.p_uncond > 0:
# randomly drop the class label
def get_batch_size(cond):
if isinstance(cond, torch.Tensor):
return cond.shape[0]
elif isinstance(cond, list):
return len(cond)
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
B = get_batch_size(ref_cond)
def select(cond, neg_cond, mask):
if isinstance(cond, torch.Tensor):
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
return torch.where(mask, neg_cond, cond)
elif isinstance(cond, list):
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
mask = list(np.random.rand(B) < self.p_uncond)
if not isinstance(cond, dict):
cond = select(cond, neg_cond, mask)
else:
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
return cond
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data for inference.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerCfgSampler(self.sigma_min)

0
trellis/utils/__init__.py Executable file
View File