1
This commit is contained in:
437
.gitignore
vendored
Normal file
437
.gitignore
vendored
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
## Ignore Visual Studio temporary files, build results, and
|
||||||
|
## files generated by popular Visual Studio add-ons.
|
||||||
|
##
|
||||||
|
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
|
||||||
|
|
||||||
|
# User-specific files
|
||||||
|
*.rsuser
|
||||||
|
*.suo
|
||||||
|
*.user
|
||||||
|
*.userosscache
|
||||||
|
*.sln.docstates
|
||||||
|
|
||||||
|
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||||
|
*.userprefs
|
||||||
|
|
||||||
|
# Mono auto generated files
|
||||||
|
mono_crash.*
|
||||||
|
|
||||||
|
# Build results
|
||||||
|
[Dd]ebug/
|
||||||
|
[Dd]ebugPublic/
|
||||||
|
[Rr]elease/
|
||||||
|
[Rr]eleases/
|
||||||
|
x64/
|
||||||
|
x86/
|
||||||
|
[Ww][Ii][Nn]32/
|
||||||
|
[Aa][Rr][Mm]/
|
||||||
|
[Aa][Rr][Mm]64/
|
||||||
|
bld/
|
||||||
|
[Bb]in/
|
||||||
|
[Oo]bj/
|
||||||
|
[Ll]og/
|
||||||
|
[Ll]ogs/
|
||||||
|
|
||||||
|
# Visual Studio 2015/2017 cache/options directory
|
||||||
|
.vs/
|
||||||
|
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||||
|
#wwwroot/
|
||||||
|
|
||||||
|
# Visual Studio 2017 auto generated files
|
||||||
|
Generated\ Files/
|
||||||
|
|
||||||
|
# MSTest test Results
|
||||||
|
[Tt]est[Rr]esult*/
|
||||||
|
[Bb]uild[Ll]og.*
|
||||||
|
|
||||||
|
# NUnit
|
||||||
|
*.VisualState.xml
|
||||||
|
TestResult.xml
|
||||||
|
nunit-*.xml
|
||||||
|
|
||||||
|
# Build Results of an ATL Project
|
||||||
|
[Dd]ebugPS/
|
||||||
|
[Rr]eleasePS/
|
||||||
|
dlldata.c
|
||||||
|
|
||||||
|
# Benchmark Results
|
||||||
|
BenchmarkDotNet.Artifacts/
|
||||||
|
|
||||||
|
# .NET Core
|
||||||
|
project.lock.json
|
||||||
|
project.fragment.lock.json
|
||||||
|
artifacts/
|
||||||
|
|
||||||
|
# ASP.NET Scaffolding
|
||||||
|
ScaffoldingReadMe.txt
|
||||||
|
|
||||||
|
# StyleCop
|
||||||
|
StyleCopReport.xml
|
||||||
|
|
||||||
|
# Files built by Visual Studio
|
||||||
|
*_i.c
|
||||||
|
*_p.c
|
||||||
|
*_h.h
|
||||||
|
*.ilk
|
||||||
|
*.meta
|
||||||
|
*.obj
|
||||||
|
*.iobj
|
||||||
|
*.pch
|
||||||
|
*.pdb
|
||||||
|
*.ipdb
|
||||||
|
*.pgc
|
||||||
|
*.pgd
|
||||||
|
*.rsp
|
||||||
|
*.sbr
|
||||||
|
*.tlb
|
||||||
|
*.tli
|
||||||
|
*.tlh
|
||||||
|
*.tmp
|
||||||
|
*.tmp_proj
|
||||||
|
*_wpftmp.csproj
|
||||||
|
*.log
|
||||||
|
*.tlog
|
||||||
|
*.vspscc
|
||||||
|
*.vssscc
|
||||||
|
.builds
|
||||||
|
*.pidb
|
||||||
|
*.svclog
|
||||||
|
*.scc
|
||||||
|
|
||||||
|
# Chutzpah Test files
|
||||||
|
_Chutzpah*
|
||||||
|
|
||||||
|
# Visual C++ cache files
|
||||||
|
ipch/
|
||||||
|
*.aps
|
||||||
|
*.ncb
|
||||||
|
*.opendb
|
||||||
|
*.opensdf
|
||||||
|
*.sdf
|
||||||
|
*.cachefile
|
||||||
|
*.VC.db
|
||||||
|
*.VC.VC.opendb
|
||||||
|
|
||||||
|
# Visual Studio profiler
|
||||||
|
*.psess
|
||||||
|
*.vsp
|
||||||
|
*.vspx
|
||||||
|
*.sap
|
||||||
|
|
||||||
|
# Visual Studio Trace Files
|
||||||
|
*.e2e
|
||||||
|
|
||||||
|
# TFS 2012 Local Workspace
|
||||||
|
$tf/
|
||||||
|
|
||||||
|
# Guidance Automation Toolkit
|
||||||
|
*.gpState
|
||||||
|
|
||||||
|
# ReSharper is a .NET coding add-in
|
||||||
|
_ReSharper*/
|
||||||
|
*.[Rr]e[Ss]harper
|
||||||
|
*.DotSettings.user
|
||||||
|
|
||||||
|
# TeamCity is a build add-in
|
||||||
|
_TeamCity*
|
||||||
|
|
||||||
|
# DotCover is a Code Coverage Tool
|
||||||
|
*.dotCover
|
||||||
|
|
||||||
|
# AxoCover is a Code Coverage Tool
|
||||||
|
.axoCover/*
|
||||||
|
!.axoCover/settings.json
|
||||||
|
|
||||||
|
# Coverlet is a free, cross platform Code Coverage Tool
|
||||||
|
coverage*.json
|
||||||
|
coverage*.xml
|
||||||
|
coverage*.info
|
||||||
|
|
||||||
|
# Visual Studio code coverage results
|
||||||
|
*.coverage
|
||||||
|
*.coveragexml
|
||||||
|
|
||||||
|
# NCrunch
|
||||||
|
_NCrunch_*
|
||||||
|
.*crunch*.local.xml
|
||||||
|
nCrunchTemp_*
|
||||||
|
|
||||||
|
# MightyMoose
|
||||||
|
*.mm.*
|
||||||
|
AutoTest.Net/
|
||||||
|
|
||||||
|
# Web workbench (sass)
|
||||||
|
.sass-cache/
|
||||||
|
|
||||||
|
# Installshield output folder
|
||||||
|
[Ee]xpress/
|
||||||
|
|
||||||
|
# DocProject is a documentation generator add-in
|
||||||
|
DocProject/buildhelp/
|
||||||
|
DocProject/Help/*.HxT
|
||||||
|
DocProject/Help/*.HxC
|
||||||
|
DocProject/Help/*.hhc
|
||||||
|
DocProject/Help/*.hhk
|
||||||
|
DocProject/Help/*.hhp
|
||||||
|
DocProject/Help/Html2
|
||||||
|
DocProject/Help/html
|
||||||
|
|
||||||
|
# Click-Once directory
|
||||||
|
publish/
|
||||||
|
|
||||||
|
# Publish Web Output
|
||||||
|
*.[Pp]ublish.xml
|
||||||
|
*.azurePubxml
|
||||||
|
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||||
|
# but database connection strings (with potential passwords) will be unencrypted
|
||||||
|
*.pubxml
|
||||||
|
*.publishproj
|
||||||
|
|
||||||
|
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||||
|
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||||
|
# in these scripts will be unencrypted
|
||||||
|
PublishScripts/
|
||||||
|
|
||||||
|
# NuGet Packages
|
||||||
|
*.nupkg
|
||||||
|
# NuGet Symbol Packages
|
||||||
|
*.snupkg
|
||||||
|
# The packages folder can be ignored because of Package Restore
|
||||||
|
**/[Pp]ackages/*
|
||||||
|
# except build/, which is used as an MSBuild target.
|
||||||
|
!**/[Pp]ackages/build/
|
||||||
|
# Uncomment if necessary however generally it will be regenerated when needed
|
||||||
|
#!**/[Pp]ackages/repositories.config
|
||||||
|
# NuGet v3's project.json files produces more ignorable files
|
||||||
|
*.nuget.props
|
||||||
|
*.nuget.targets
|
||||||
|
|
||||||
|
# Microsoft Azure Build Output
|
||||||
|
csx/
|
||||||
|
*.build.csdef
|
||||||
|
|
||||||
|
# Microsoft Azure Emulator
|
||||||
|
ecf/
|
||||||
|
rcf/
|
||||||
|
|
||||||
|
# Windows Store app package directories and files
|
||||||
|
AppPackages/
|
||||||
|
BundleArtifacts/
|
||||||
|
Package.StoreAssociation.xml
|
||||||
|
_pkginfo.txt
|
||||||
|
*.appx
|
||||||
|
*.appxbundle
|
||||||
|
*.appxupload
|
||||||
|
|
||||||
|
# Visual Studio cache files
|
||||||
|
# files ending in .cache can be ignored
|
||||||
|
*.[Cc]ache
|
||||||
|
# but keep track of directories ending in .cache
|
||||||
|
!?*.[Cc]ache/
|
||||||
|
|
||||||
|
# Others
|
||||||
|
ClientBin/
|
||||||
|
~$*
|
||||||
|
*~
|
||||||
|
*.dbmdl
|
||||||
|
*.dbproj.schemaview
|
||||||
|
*.jfm
|
||||||
|
*.pfx
|
||||||
|
*.publishsettings
|
||||||
|
orleans.codegen.cs
|
||||||
|
|
||||||
|
# Including strong name files can present a security risk
|
||||||
|
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||||
|
#*.snk
|
||||||
|
|
||||||
|
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||||
|
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||||
|
#bower_components/
|
||||||
|
|
||||||
|
# RIA/Silverlight projects
|
||||||
|
Generated_Code/
|
||||||
|
|
||||||
|
# Backup & report files from converting an old project file
|
||||||
|
# to a newer Visual Studio version. Backup files are not needed,
|
||||||
|
# because we have git ;-)
|
||||||
|
_UpgradeReport_Files/
|
||||||
|
Backup*/
|
||||||
|
UpgradeLog*.XML
|
||||||
|
UpgradeLog*.htm
|
||||||
|
ServiceFabricBackup/
|
||||||
|
*.rptproj.bak
|
||||||
|
|
||||||
|
# SQL Server files
|
||||||
|
*.mdf
|
||||||
|
*.ldf
|
||||||
|
*.ndf
|
||||||
|
|
||||||
|
# Business Intelligence projects
|
||||||
|
*.rdl.data
|
||||||
|
*.bim.layout
|
||||||
|
*.bim_*.settings
|
||||||
|
*.rptproj.rsuser
|
||||||
|
*- [Bb]ackup.rdl
|
||||||
|
*- [Bb]ackup ([0-9]).rdl
|
||||||
|
*- [Bb]ackup ([0-9][0-9]).rdl
|
||||||
|
|
||||||
|
# Microsoft Fakes
|
||||||
|
FakesAssemblies/
|
||||||
|
|
||||||
|
# GhostDoc plugin setting file
|
||||||
|
*.GhostDoc.xml
|
||||||
|
|
||||||
|
# Node.js Tools for Visual Studio
|
||||||
|
.ntvs_analysis.dat
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
# Visual Studio 6 build log
|
||||||
|
*.plg
|
||||||
|
|
||||||
|
# Visual Studio 6 workspace options file
|
||||||
|
*.opt
|
||||||
|
|
||||||
|
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||||
|
*.vbw
|
||||||
|
|
||||||
|
# Visual Studio 6 auto-generated project file (contains which files were open etc.)
|
||||||
|
*.vbp
|
||||||
|
|
||||||
|
# Visual Studio 6 workspace and project file (working project files containing files to include in project)
|
||||||
|
*.dsw
|
||||||
|
*.dsp
|
||||||
|
|
||||||
|
# Visual Studio 6 technical files
|
||||||
|
*.ncb
|
||||||
|
*.aps
|
||||||
|
|
||||||
|
# Visual Studio LightSwitch build output
|
||||||
|
**/*.HTMLClient/GeneratedArtifacts
|
||||||
|
**/*.DesktopClient/GeneratedArtifacts
|
||||||
|
**/*.DesktopClient/ModelManifest.xml
|
||||||
|
**/*.Server/GeneratedArtifacts
|
||||||
|
**/*.Server/ModelManifest.xml
|
||||||
|
_Pvt_Extensions
|
||||||
|
|
||||||
|
# Paket dependency manager
|
||||||
|
.paket/paket.exe
|
||||||
|
paket-files/
|
||||||
|
|
||||||
|
# FAKE - F# Make
|
||||||
|
.fake/
|
||||||
|
|
||||||
|
# CodeRush personal settings
|
||||||
|
.cr/personal
|
||||||
|
|
||||||
|
# Python Tools for Visual Studio (PTVS)
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
|
||||||
|
# Cake - Uncomment if you are using it
|
||||||
|
# tools/**
|
||||||
|
# !tools/packages.config
|
||||||
|
|
||||||
|
# Tabs Studio
|
||||||
|
*.tss
|
||||||
|
|
||||||
|
# Telerik's JustMock configuration file
|
||||||
|
*.jmconfig
|
||||||
|
|
||||||
|
# BizTalk build output
|
||||||
|
*.btp.cs
|
||||||
|
*.btm.cs
|
||||||
|
*.odx.cs
|
||||||
|
*.xsd.cs
|
||||||
|
|
||||||
|
# OpenCover UI analysis results
|
||||||
|
OpenCover/
|
||||||
|
|
||||||
|
# Azure Stream Analytics local run output
|
||||||
|
ASALocalRun/
|
||||||
|
|
||||||
|
# MSBuild Binary and Structured Log
|
||||||
|
*.binlog
|
||||||
|
|
||||||
|
# NVidia Nsight GPU debugger configuration file
|
||||||
|
*.nvuser
|
||||||
|
|
||||||
|
# MFractors (Xamarin productivity tool) working folder
|
||||||
|
.mfractor/
|
||||||
|
|
||||||
|
# Local History for Visual Studio
|
||||||
|
.localhistory/
|
||||||
|
|
||||||
|
# Visual Studio History (VSHistory) files
|
||||||
|
.vshistory/
|
||||||
|
|
||||||
|
# BeatPulse healthcheck temp database
|
||||||
|
healthchecksdb
|
||||||
|
|
||||||
|
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
||||||
|
MigrationBackup/
|
||||||
|
|
||||||
|
# Ionide (cross platform F# VS Code tools) working folder
|
||||||
|
.ionide/
|
||||||
|
|
||||||
|
# Fody - auto-generated XML schema
|
||||||
|
FodyWeavers.xsd
|
||||||
|
|
||||||
|
# VS Code files for those working on multiple tools
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/settings.json
|
||||||
|
!.vscode/tasks.json
|
||||||
|
!.vscode/launch.json
|
||||||
|
!.vscode/extensions.json
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
# Local History for Visual Studio Code
|
||||||
|
.history/
|
||||||
|
|
||||||
|
# Windows Installer files from build outputs
|
||||||
|
*.cab
|
||||||
|
*.msi
|
||||||
|
*.msix
|
||||||
|
*.msm
|
||||||
|
*.msp
|
||||||
|
|
||||||
|
# JetBrains Rider
|
||||||
|
*.sln.iml
|
||||||
|
|
||||||
|
|
||||||
|
*.glb
|
||||||
|
*.ply
|
||||||
|
*.mtl
|
||||||
|
*.step
|
||||||
|
*.mp4
|
||||||
|
*.svg
|
||||||
|
*.png
|
||||||
|
*.ipynb
|
||||||
|
*.zip
|
||||||
|
|
||||||
|
# 模型文件(重点忽略)
|
||||||
|
*.pth
|
||||||
|
*.ckpt
|
||||||
|
*.bin
|
||||||
|
*.pt
|
||||||
|
*.h5
|
||||||
|
|
||||||
|
# 压缩包
|
||||||
|
*.zip
|
||||||
|
*.rar
|
||||||
|
*.7z
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
|
# 图片/视频(非必要则忽略,必要则后续用LFS)
|
||||||
|
*.png
|
||||||
|
*.svg
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.mp4
|
||||||
|
*.avi
|
||||||
|
|
||||||
|
# 日志/缓存/临时文件
|
||||||
|
*.log
|
||||||
|
*.tmp
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.DS_Store
|
||||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "trellis/representations/mesh/flexicubes"]
|
||||||
|
path = trellis/representations/mesh/flexicubes
|
||||||
|
url = https://github.com/MaxtirError/FlexiCubes.git
|
||||||
10
.idea/.gitignore
generated
vendored
Normal file
10
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 已忽略包含查询文件的默认文件夹
|
||||||
|
/queries/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
68
0_glb_to_obj.py
Normal file
68
0_glb_to_obj.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import bpy
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
glb_input = "trellis_out/sample.ply"
|
||||||
|
obj_output = "obj_output/sample.obj"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
argv = sys.argv
|
||||||
|
argv = argv[argv.index("--") + 1:] if "--" in argv else []
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("-i", "--input", default=True)
|
||||||
|
p.add_argument("-o", "--output", required=True)
|
||||||
|
return p.parse_args(argv)
|
||||||
|
|
||||||
|
|
||||||
|
def clean():
|
||||||
|
bpy.ops.object.select_all(action='SELECT')
|
||||||
|
bpy.ops.object.delete(use_global=False)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
clean()
|
||||||
|
|
||||||
|
bpy.ops.import_scene.gltf(filepath=args.input) # 支持 .glb/.gltf
|
||||||
|
|
||||||
|
# 只保留 mesh,并合并成一个(符合你“单资产”假设)
|
||||||
|
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
|
||||||
|
if not meshes:
|
||||||
|
raise RuntimeError("No mesh objects found in GLB/GLTF.")
|
||||||
|
bpy.ops.object.select_all(action='DESELECT')
|
||||||
|
for o in meshes:
|
||||||
|
o.select_set(True)
|
||||||
|
bpy.context.view_layer.objects.active = meshes[0]
|
||||||
|
if len(meshes) > 1:
|
||||||
|
bpy.ops.object.join()
|
||||||
|
|
||||||
|
obj = bpy.context.view_layer.objects.active
|
||||||
|
obj.select_set(True)
|
||||||
|
|
||||||
|
out_dir = os.path.dirname(args.output)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# bpy.ops.wm.obj_export(
|
||||||
|
# filepath=args.output,
|
||||||
|
# export_selected_objects=True,
|
||||||
|
# export_normals=True,
|
||||||
|
# export_uv=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
bpy.ops.export_scene.obj(
|
||||||
|
filepath=args.output,
|
||||||
|
use_selection=True, # 只导出选中的对象
|
||||||
|
use_normals=True,
|
||||||
|
use_uvs=True,
|
||||||
|
use_materials=True, # 根据需要改(你原来没导出材质)
|
||||||
|
axis_forward='-Z', # glTF 通常是 -Z forward,Y up,可根据需要调整
|
||||||
|
axis_up='Y',
|
||||||
|
# keep_vertex_order=True, # 可选:保持顶点顺序
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
260
0_remesh.py
Normal file
260
0_remesh.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
import bpy
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
def clean_scene():
|
||||||
|
bpy.ops.object.select_all(action='SELECT')
|
||||||
|
bpy.ops.object.delete(use_global=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args_blender(argv):
|
||||||
|
"""Blender 参数在 `--` 之后。"""
|
||||||
|
if "--" in argv:
|
||||||
|
return argv[argv.index("--") + 1:]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _is_glb_gltf(path: str) -> bool:
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
return ext in [".glb", ".gltf"]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_obj(path: str) -> bool:
|
||||||
|
return os.path.splitext(path)[1].lower() == ".obj"
|
||||||
|
|
||||||
|
|
||||||
|
def import_mesh_any(input_path: str):
|
||||||
|
"""
|
||||||
|
导入 OBJ 或 GLB/GLTF,并将场景中所有 MESH 合并为一个 active mesh 对象返回。
|
||||||
|
"""
|
||||||
|
clean_scene()
|
||||||
|
|
||||||
|
if _is_glb_gltf(input_path):
|
||||||
|
print(f"[import] GLB/GLTF: {input_path}")
|
||||||
|
bpy.ops.import_scene.gltf(filepath=input_path)
|
||||||
|
elif _is_obj(input_path):
|
||||||
|
print(f"[import] OBJ: {input_path}")
|
||||||
|
bpy.ops.wm.obj_import(filepath=input_path)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported input format: {input_path}")
|
||||||
|
|
||||||
|
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
|
||||||
|
if not meshes:
|
||||||
|
raise RuntimeError("No mesh objects found after import.")
|
||||||
|
|
||||||
|
bpy.ops.object.select_all(action='DESELECT')
|
||||||
|
for o in meshes:
|
||||||
|
o.select_set(True)
|
||||||
|
bpy.context.view_layer.objects.active = meshes[0]
|
||||||
|
|
||||||
|
if len(meshes) > 1:
|
||||||
|
bpy.ops.object.join()
|
||||||
|
|
||||||
|
obj = bpy.context.view_layer.objects.active
|
||||||
|
obj.select_set(True)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def export_obj_selected(output_obj: str):
|
||||||
|
out_dir = os.path.dirname(output_obj)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
bpy.ops.wm.obj_export(
|
||||||
|
filepath=output_obj,
|
||||||
|
export_selected_objects=True,
|
||||||
|
export_normals=True,
|
||||||
|
export_uv=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_mesh(obj):
|
||||||
|
"""修复网格使其更符合 QuadriFlow 的输入要求"""
|
||||||
|
bpy.context.view_layer.objects.active = obj
|
||||||
|
obj.select_set(True)
|
||||||
|
|
||||||
|
bpy.ops.object.mode_set(mode='EDIT')
|
||||||
|
bpy.ops.mesh.select_all(action='SELECT')
|
||||||
|
|
||||||
|
bpy.ops.mesh.remove_doubles(threshold=0.001)
|
||||||
|
bpy.ops.mesh.delete_loose()
|
||||||
|
|
||||||
|
bpy.ops.mesh.select_all(action='DESELECT')
|
||||||
|
bpy.ops.mesh.select_non_manifold()
|
||||||
|
|
||||||
|
selected_count = sum(1 for v in obj.data.vertices if v.select)
|
||||||
|
if selected_count > 0:
|
||||||
|
print(f"[fix_mesh] found {selected_count} non-manifold verts, deleting...")
|
||||||
|
bpy.ops.mesh.delete(type='VERT')
|
||||||
|
|
||||||
|
bpy.ops.mesh.select_all(action='SELECT')
|
||||||
|
bpy.ops.mesh.normals_make_consistent(inside=False)
|
||||||
|
|
||||||
|
bpy.ops.mesh.select_all(action='SELECT')
|
||||||
|
bpy.ops.mesh.fill_holes(sides=0)
|
||||||
|
|
||||||
|
bpy.ops.mesh.remove_doubles(threshold=0.001)
|
||||||
|
bpy.ops.mesh.dissolve_degenerate(threshold=0.0001)
|
||||||
|
|
||||||
|
bpy.ops.mesh.select_all(action='SELECT')
|
||||||
|
bpy.ops.mesh.normals_make_consistent(inside=False)
|
||||||
|
|
||||||
|
bpy.ops.object.mode_set(mode='OBJECT')
|
||||||
|
|
||||||
|
|
||||||
|
def quadriflow_remesh_obj(
|
||||||
|
input_obj: str,
|
||||||
|
output_obj: str,
|
||||||
|
face_count: int = 8000,
|
||||||
|
use_mesh_symmetry: bool = False,
|
||||||
|
use_preserve_sharp: bool = False,
|
||||||
|
use_preserve_boundary: bool = True,
|
||||||
|
use_voxel_preprocess: bool = True,
|
||||||
|
voxel_size: float = 0.008,
|
||||||
|
):
|
||||||
|
"""只处理 OBJ 输入(你的原流程)"""
|
||||||
|
obj = import_mesh_any(input_obj) # 这里会走 OBJ import
|
||||||
|
|
||||||
|
print(f"[import] object={obj.name}, verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
|
||||||
|
print("[mesh] fixing...")
|
||||||
|
fix_mesh(obj)
|
||||||
|
|
||||||
|
if use_voxel_preprocess:
|
||||||
|
print(f"[voxel] preprocess voxel_size={voxel_size} ...")
|
||||||
|
voxel_mod = obj.modifiers.new(name="Voxel", type='REMESH')
|
||||||
|
voxel_mod.mode = 'VOXEL'
|
||||||
|
voxel_mod.voxel_size = voxel_size
|
||||||
|
voxel_mod.use_smooth_shade = True
|
||||||
|
voxel_mod.use_remove_disconnected = False
|
||||||
|
bpy.ops.object.modifier_apply(modifier=voxel_mod.name)
|
||||||
|
|
||||||
|
print(f"[voxel] after: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
|
||||||
|
print("[mesh] fixing after voxel...")
|
||||||
|
fix_mesh(obj)
|
||||||
|
print(f"[mesh] after fix2: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
|
||||||
|
|
||||||
|
print(f"[quadriflow] target_faces={face_count} ...")
|
||||||
|
bpy.context.view_layer.objects.active = obj
|
||||||
|
obj.select_set(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bpy.ops.object.quadriflow_remesh(
|
||||||
|
use_mesh_symmetry=use_mesh_symmetry,
|
||||||
|
use_preserve_sharp=use_preserve_sharp,
|
||||||
|
use_preserve_boundary=use_preserve_boundary,
|
||||||
|
smooth_normals=False,
|
||||||
|
mode='FACES',
|
||||||
|
target_faces=face_count,
|
||||||
|
seed=0
|
||||||
|
)
|
||||||
|
print("[quadriflow] done.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[quadriflow] error: {e} (continue to export)")
|
||||||
|
|
||||||
|
print(f"[export] {output_obj}")
|
||||||
|
export_obj_selected(output_obj)
|
||||||
|
print(f"[done] verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}")
|
||||||
|
|
||||||
|
|
||||||
|
def glb_to_tmp_obj(input_glb: str) -> str:
|
||||||
|
"""
|
||||||
|
在 Blender 内把 GLB/GLTF 导入后导出为临时 OBJ(按你的要求“先转OBJ再走后续”)。
|
||||||
|
"""
|
||||||
|
obj = import_mesh_any(input_glb) # 会走 gltf import
|
||||||
|
tmp_obj = tempfile.mktemp(suffix=".obj")
|
||||||
|
print(f"[glb2obj] export tmp obj -> {tmp_obj}")
|
||||||
|
export_obj_selected(tmp_obj)
|
||||||
|
return tmp_obj
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline(
|
||||||
|
input_path: str,
|
||||||
|
output_obj: str,
|
||||||
|
face_count: int,
|
||||||
|
mesh_symmetry: bool,
|
||||||
|
preserve_sharp: bool,
|
||||||
|
preserve_boundary: bool,
|
||||||
|
no_voxel_preprocess: bool,
|
||||||
|
voxel_size: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
统一入口:
|
||||||
|
- 输入 OBJ:直接 remesh
|
||||||
|
- 输入 GLB/GLTF:先转临时 OBJ,再 remesh
|
||||||
|
"""
|
||||||
|
tmp_obj = None
|
||||||
|
try:
|
||||||
|
if _is_glb_gltf(input_path):
|
||||||
|
tmp_obj = glb_to_tmp_obj(input_path)
|
||||||
|
quadriflow_remesh_obj(
|
||||||
|
input_obj=tmp_obj,
|
||||||
|
output_obj=output_obj,
|
||||||
|
face_count=face_count,
|
||||||
|
use_mesh_symmetry=mesh_symmetry,
|
||||||
|
use_preserve_sharp=preserve_sharp,
|
||||||
|
use_preserve_boundary=preserve_boundary,
|
||||||
|
use_voxel_preprocess=(not no_voxel_preprocess),
|
||||||
|
voxel_size=voxel_size,
|
||||||
|
)
|
||||||
|
elif _is_obj(input_path):
|
||||||
|
quadriflow_remesh_obj(
|
||||||
|
input_obj=input_path,
|
||||||
|
output_obj=output_obj,
|
||||||
|
face_count=face_count,
|
||||||
|
use_mesh_symmetry=mesh_symmetry,
|
||||||
|
use_preserve_sharp=preserve_sharp,
|
||||||
|
use_preserve_boundary=preserve_boundary,
|
||||||
|
use_voxel_preprocess=(not no_voxel_preprocess),
|
||||||
|
voxel_size=voxel_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported input format: {input_path}")
|
||||||
|
finally:
|
||||||
|
if tmp_obj and os.path.exists(tmp_obj):
|
||||||
|
try:
|
||||||
|
os.remove(tmp_obj)
|
||||||
|
print(f"[cleanup] removed tmp obj: {tmp_obj}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="GLB/GLTF/OBJ -> (optional tmp OBJ) -> QuadriFlow remesh -> OBJ")
|
||||||
|
parser.add_argument("-i", "--input", default="trellis_out/sample.glb", help="Input file path (.obj/.glb/.gltf).")
|
||||||
|
parser.add_argument("-o", "--output", default="output/test.obj", help="Output OBJ file path.")
|
||||||
|
parser.add_argument("--face_count", type=int, default=8000, help="Target quad face count.")
|
||||||
|
parser.add_argument("--mesh_symmetry", action="store_true", help="Enable mesh symmetry.")
|
||||||
|
parser.add_argument("--preserve_sharp", action="store_true", help="Preserve sharp edges.")
|
||||||
|
parser.add_argument("--preserve_boundary", action="store_true", help="Preserve boundary edges.")
|
||||||
|
parser.add_argument("--no_voxel_preprocess", action="store_true", help="Disable voxel preprocess.")
|
||||||
|
parser.add_argument("--voxel_size", type=float, default=0.008, help="Voxel size for preprocess (smaller=denser).")
|
||||||
|
|
||||||
|
args = parser.parse_args(_parse_args_blender(sys.argv))
|
||||||
|
|
||||||
|
if not os.path.exists(args.input):
|
||||||
|
raise FileNotFoundError(f"Input not found: {args.input}")
|
||||||
|
|
||||||
|
run_pipeline(
|
||||||
|
input_path=args.input,
|
||||||
|
output_obj=args.output,
|
||||||
|
face_count=args.face_count,
|
||||||
|
mesh_symmetry=args.mesh_symmetry,
|
||||||
|
preserve_sharp=args.preserve_sharp,
|
||||||
|
preserve_boundary=args.preserve_boundary,
|
||||||
|
no_voxel_preprocess=args.no_voxel_preprocess,
|
||||||
|
voxel_size=args.voxel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[fatal] {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
423
1_obj_to_step.py
Normal file
423
1_obj_to_step.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Fast OBJ/GLB/GLTF -> STEP converter for FreeCAD (headless-friendly)
|
||||||
|
|
||||||
|
参数说明:
|
||||||
|
tol : 网格→形状拟合公差,越大越快(默认 0.5)
|
||||||
|
sew_tol : Sewing 缝合公差(环境缺少 Sewing 会自动跳过),默认 0.05
|
||||||
|
solid : 1 尝试固化为 Solid(耗时大),0 仅导出壳(默认 0)
|
||||||
|
split : 1 使用 trimesh 分块清理(若环境有 trimesh),0 关闭(默认 1)
|
||||||
|
max_comp : 分块后最多保留的组件数(默认 64)
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 若 FreeCAD 发行版缺少 MeshPart.meshToShape 或 Part.Sewing,脚本会自动回退/跳过。
|
||||||
|
- 若导出 STEP 仍失败,会兜底导出 .brep,方便后续用 pythonocc/OCP 再转 STEP。
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
import sys, os, time, tempfile
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import FreeCAD as App
|
||||||
|
import Part, Mesh
|
||||||
|
|
||||||
|
# 尝试加载 MeshPart(某些构建可能没有)
|
||||||
|
try:
|
||||||
|
import MeshPart
|
||||||
|
|
||||||
|
HAS_MESHPART = True
|
||||||
|
except Exception:
|
||||||
|
MeshPart = None
|
||||||
|
HAS_MESHPART = False
|
||||||
|
|
||||||
|
# 可选:trimesh 用于更稳的分块/清理与 GLB/GLTF 转临时 OBJ
|
||||||
|
try:
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
HAS_TRIMESH = True
|
||||||
|
except Exception:
|
||||||
|
HAS_TRIMESH = False
|
||||||
|
|
||||||
|
|
||||||
|
def log(msg):
|
||||||
|
try:
|
||||||
|
App.Console.PrintMessage(str(msg) + "\n")
|
||||||
|
except Exception:
|
||||||
|
print(msg, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- 网格清理/转换核心 ----------------------
|
||||||
|
|
||||||
|
def _freecad_mesh_quick_clean(fc_mesh):
|
||||||
|
"""FreeCAD 轻量清理:不大改拓扑,提升稳健性。"""
|
||||||
|
for fn in (
|
||||||
|
"removeDuplicatedFacets",
|
||||||
|
"removeDuplicatedPoints",
|
||||||
|
"removeDegeneratedFacets",
|
||||||
|
"removeNonManifoldEdges",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
getattr(fc_mesh, fn)()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return fc_mesh
|
||||||
|
|
||||||
|
|
||||||
|
def _mesh_to_shape_fast(fc_mesh, tol=0.5, sew_tol=0.05, to_solid=False):
|
||||||
|
"""
|
||||||
|
Mesh -> B-Rep(更稳健版本):
|
||||||
|
- 若有 MeshPart.meshToShape 优先使用(共面聚合,通常更快)
|
||||||
|
- 否则退回 Part.Shape.makeShapeFromMesh
|
||||||
|
- Sewing 若不可用则跳过(部分构建缺少 Part.Sewing)
|
||||||
|
- 仅在 to_solid=True 且闭壳时尝试 Solidify(耗时较大)
|
||||||
|
"""
|
||||||
|
# 1) mesh -> shape
|
||||||
|
shape = None
|
||||||
|
if HAS_MESHPART and hasattr(MeshPart, "meshToShape"):
|
||||||
|
try:
|
||||||
|
shape = MeshPart.meshToShape(fc_mesh, tol)
|
||||||
|
if isinstance(shape, tuple): # 某些版本返回 (shape, mapping)
|
||||||
|
shape = shape[0]
|
||||||
|
except Exception as e:
|
||||||
|
log(f"meshToShape 失败,退回 makeShapeFromMesh: {e}")
|
||||||
|
|
||||||
|
if shape is None:
|
||||||
|
shape = Part.Shape()
|
||||||
|
shape.makeShapeFromMesh(fc_mesh.Topology, tol)
|
||||||
|
|
||||||
|
# 2) sewing(若 Part.Sewing 不存在则跳过)
|
||||||
|
try:
|
||||||
|
SewingCls = getattr(Part, "Sewing", None)
|
||||||
|
if SewingCls is not None:
|
||||||
|
sew = SewingCls()
|
||||||
|
ok = False
|
||||||
|
for setter in ("tolerance", "SetTolerance"):
|
||||||
|
try:
|
||||||
|
if setter == "tolerance":
|
||||||
|
sew.tolerance = sew_tol
|
||||||
|
else:
|
||||||
|
sew.SetTolerance(sew_tol)
|
||||||
|
ok = True
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sew.add(shape)
|
||||||
|
sew.perform()
|
||||||
|
|
||||||
|
new_shape = None
|
||||||
|
if hasattr(sew, "SewedShape"):
|
||||||
|
new_shape = sew.SewedShape
|
||||||
|
try:
|
||||||
|
if hasattr(new_shape, "isNull") and new_shape.isNull():
|
||||||
|
new_shape = None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if new_shape is None and hasattr(sew, "sewedShape"):
|
||||||
|
try:
|
||||||
|
new_shape = sew.sewedShape()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if new_shape is None and hasattr(sew, "sewShape"):
|
||||||
|
try:
|
||||||
|
new_shape = sew.sewShape()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if new_shape is not None:
|
||||||
|
shape = new_shape
|
||||||
|
else:
|
||||||
|
log("跳过 Sewing:当前 FreeCAD Part 模块缺少 Sewing 绑定")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"Sewing 失败,继续原 shape: {e}")
|
||||||
|
|
||||||
|
# 3) 可选固化
|
||||||
|
if to_solid:
|
||||||
|
try:
|
||||||
|
if not getattr(shape, "Solids", []):
|
||||||
|
shell = Part.Shell(shape.Faces)
|
||||||
|
if hasattr(shell, "isClosed") and shell.isClosed():
|
||||||
|
shape = Part.makeSolid(shell)
|
||||||
|
except Exception as e:
|
||||||
|
log(f"Solidify 失败,保留壳: {e}")
|
||||||
|
|
||||||
|
# 4) 轻量拓扑精简
|
||||||
|
try:
|
||||||
|
shape = shape.removeSplitter()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(shape, "fixTolerance"):
|
||||||
|
shape.fixTolerance(sew_tol)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
def _export_shapes_to_step(shapes, step_file_path):
|
||||||
|
"""
|
||||||
|
稳健导出:
|
||||||
|
- 确保输出目录存在
|
||||||
|
- 过滤无效 shape
|
||||||
|
- 尝试合并 Compound(有时更稳定)
|
||||||
|
- Part.export 失败则 Import.export;再不行兜底写 .brep
|
||||||
|
"""
|
||||||
|
out_dir = os.path.dirname(step_file_path) or "."
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
valid = []
|
||||||
|
for s in shapes:
|
||||||
|
try:
|
||||||
|
if hasattr(s, "isNull") and s.isNull():
|
||||||
|
continue
|
||||||
|
if hasattr(s, "Faces") and len(s.Faces) == 0:
|
||||||
|
continue
|
||||||
|
valid.append(s)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if not valid:
|
||||||
|
raise RuntimeError("没有可导出的有效 shape(可能网格太脏或全部构造失败)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
compound = Part.makeCompound(valid)
|
||||||
|
except Exception:
|
||||||
|
compound = valid[0]
|
||||||
|
|
||||||
|
doc = App.newDocument("Obj2StepFast")
|
||||||
|
try:
|
||||||
|
if hasattr(doc, "suppressRecompute"):
|
||||||
|
doc.suppressRecompute()
|
||||||
|
|
||||||
|
obj = doc.addObject("Part::Feature", "Compound")
|
||||||
|
obj.Shape = compound
|
||||||
|
|
||||||
|
if hasattr(doc, "recompute"):
|
||||||
|
doc.recompute()
|
||||||
|
|
||||||
|
try:
|
||||||
|
Part.export([obj], step_file_path)
|
||||||
|
return
|
||||||
|
except Exception as e1:
|
||||||
|
log(f"Part.export 失败:{e1}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import Import
|
||||||
|
Import.export([obj], step_file_path)
|
||||||
|
return
|
||||||
|
except Exception as e2:
|
||||||
|
log(f"Import.export 也失败:{e2}")
|
||||||
|
|
||||||
|
brep_path = os.path.splitext(step_file_path)[0] + ".brep"
|
||||||
|
try:
|
||||||
|
obj.Shape.exportBrep(brep_path)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"STEP 导出失败,已兜底写入 BREP:{brep_path}\n"
|
||||||
|
f"可用 pythonocc/OCP 将 BREP 转为 STEP。"
|
||||||
|
)
|
||||||
|
except Exception as e3:
|
||||||
|
raise RuntimeError(f"STEP 与 BREP 均导出失败:{e3}")
|
||||||
|
finally:
|
||||||
|
App.closeDocument(doc.Name)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- 顶层转换逻辑 ----------------------
|
||||||
|
|
||||||
|
def _glb_gltf_to_tmp_obj_if_needed(in_path):
|
||||||
|
ext = os.path.splitext(in_path)[1].lower()
|
||||||
|
if ext in (".glb", ".gltf"):
|
||||||
|
if not HAS_TRIMESH:
|
||||||
|
raise RuntimeError("输入为 GLB/GLTF,但 FreeCAD Python 环境未安装 trimesh。请安装后再试。")
|
||||||
|
mesh = trimesh.load(in_path, force="mesh")
|
||||||
|
if mesh.is_empty:
|
||||||
|
raise RuntimeError(f"GLB/GLTF 读取失败或为空:{in_path}")
|
||||||
|
tmp = tempfile.mktemp(suffix=".obj")
|
||||||
|
mesh.export(tmp)
|
||||||
|
return tmp, True
|
||||||
|
return in_path, False
|
||||||
|
|
||||||
|
|
||||||
|
# 新增:生成带时间戳+随机串的唯一文件名(避免覆盖)
|
||||||
|
def generate_unique_step_filename(base_name: str, output_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
生成格式示例:
|
||||||
|
model_20260316_131245_8f4a2c1d.step
|
||||||
|
"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
random_part = secrets.token_hex(4) # 8 位随机十六进制
|
||||||
|
filename = f"{base_name}_{timestamp}_{random_part}.step"
|
||||||
|
return os.path.join(output_dir, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def obj_to_step_fast(input_path,
|
||||||
|
step_file_path=None,
|
||||||
|
tol=0.5,
|
||||||
|
sew_tol=0.05,
|
||||||
|
to_solid=False,
|
||||||
|
split_components=True,
|
||||||
|
max_components=64):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
if not os.path.exists(input_path):
|
||||||
|
raise FileNotFoundError(input_path)
|
||||||
|
|
||||||
|
base = os.path.splitext(os.path.basename(input_path))[0]
|
||||||
|
|
||||||
|
# ── 决定最终输出路径 ──
|
||||||
|
if step_file_path is None:
|
||||||
|
# 没有指定路径 → 用 output_dir 或 input 同目录 + 唯一文件名
|
||||||
|
out_dir = os.path.dirname(input_path)
|
||||||
|
step_file_path = generate_unique_step_filename(base, out_dir)
|
||||||
|
elif os.path.isdir(step_file_path):
|
||||||
|
# 用户传的是目录 → 在该目录生成唯一文件名
|
||||||
|
step_file_path = generate_unique_step_filename(base, step_file_path)
|
||||||
|
else:
|
||||||
|
# 用户明确指定了 .step / .stp 文件路径 → 直接使用(不加随机后缀)
|
||||||
|
if not step_file_path.lower().endswith((".step", ".stp")):
|
||||||
|
step_file_path += ".step"
|
||||||
|
# 这里不做额外唯一性处理,尊重用户意图(可能想覆盖)
|
||||||
|
|
||||||
|
log(f"目标 STEP 文件:{step_file_path}")
|
||||||
|
|
||||||
|
# 下面部分保持原样 ........................................
|
||||||
|
shapes = []
|
||||||
|
tmp_created_paths = []
|
||||||
|
|
||||||
|
src_for_freecad, is_tmp = _glb_gltf_to_tmp_obj_if_needed(input_path)
|
||||||
|
if is_tmp:
|
||||||
|
tmp_created_paths.append(src_for_freecad)
|
||||||
|
|
||||||
|
if HAS_TRIMESH and split_components:
|
||||||
|
m = None
|
||||||
|
try:
|
||||||
|
m = trimesh.load(src_for_freecad, force="mesh")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"trimesh 读取失败(转纯 FreeCAD 路径):{e}")
|
||||||
|
m = None
|
||||||
|
|
||||||
|
if m is not None and not m.is_empty:
|
||||||
|
try:
|
||||||
|
m.remove_duplicate_faces()
|
||||||
|
m.remove_degenerate_faces()
|
||||||
|
m.remove_unreferenced_vertices()
|
||||||
|
m.fix_normals()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
parts = m.split(only_watertight=False)
|
||||||
|
if len(parts) == 0:
|
||||||
|
parts = [m]
|
||||||
|
|
||||||
|
if len(parts) > max_components:
|
||||||
|
parts = sorted(parts, key=lambda x: x.faces.shape[0], reverse=True)[:max_components]
|
||||||
|
|
||||||
|
log(f"Trimesh分块:{len(parts)} 个组件")
|
||||||
|
|
||||||
|
for i, pm in enumerate(parts):
|
||||||
|
tmp_obj = tempfile.mktemp(suffix=f"_{i}.obj")
|
||||||
|
pm.export(tmp_obj)
|
||||||
|
tmp_created_paths.append(tmp_obj)
|
||||||
|
|
||||||
|
fc_mesh = Mesh.Mesh(tmp_obj)
|
||||||
|
_freecad_mesh_quick_clean(fc_mesh)
|
||||||
|
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
|
||||||
|
shapes.append(shp)
|
||||||
|
else:
|
||||||
|
fc_mesh = Mesh.Mesh(src_for_freecad)
|
||||||
|
log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}")
|
||||||
|
_freecad_mesh_quick_clean(fc_mesh)
|
||||||
|
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
|
||||||
|
shapes.append(shp)
|
||||||
|
else:
|
||||||
|
fc_mesh = Mesh.Mesh(src_for_freecad)
|
||||||
|
log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}")
|
||||||
|
_freecad_mesh_quick_clean(fc_mesh)
|
||||||
|
shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid)
|
||||||
|
shapes.append(shp)
|
||||||
|
|
||||||
|
_export_shapes_to_step(shapes, step_file_path)
|
||||||
|
|
||||||
|
for p in tmp_created_paths:
|
||||||
|
try:
|
||||||
|
os.remove(p)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
log(f"STEP导出: {step_file_path} (耗时 {t1 - t0:.2f}s)")
|
||||||
|
return step_file_path
|
||||||
|
|
||||||
|
|
||||||
|
def main(obj_file_path,
|
||||||
|
output_dir=None,
|
||||||
|
tol=0.5,
|
||||||
|
sew_tol=0.05,
|
||||||
|
to_solid=False,
|
||||||
|
split_components=True,
|
||||||
|
max_components=64):
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.path.dirname(obj_file_path)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
base = os.path.splitext(os.path.basename(obj_file_path))[0]
|
||||||
|
|
||||||
|
# 使用唯一文件名生成 step_path
|
||||||
|
step_path = generate_unique_step_filename(base, output_dir)
|
||||||
|
|
||||||
|
return obj_to_step_fast(
|
||||||
|
obj_file_path,
|
||||||
|
step_file_path=step_path, # 这里传唯一路径
|
||||||
|
tol=tol,
|
||||||
|
sew_tol=sew_tol,
|
||||||
|
to_solid=to_solid,
|
||||||
|
split_components=split_components,
|
||||||
|
max_components=max_components,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI 部分也做相应调整 ──
|
||||||
|
def _parse_cli(argv):
|
||||||
|
in_path = os.path.abspath(argv[1])
|
||||||
|
out_arg = os.path.abspath(argv[2]) if len(argv) > 2 else None
|
||||||
|
|
||||||
|
tol = float(argv[3]) if len(argv) > 3 else 0.5
|
||||||
|
sew_tol = float(argv[4]) if len(argv) > 4 else 0.05
|
||||||
|
solid = bool(int(argv[5])) if len(argv) > 5 else False
|
||||||
|
split = bool(int(argv[6])) if len(argv) > 6 else True
|
||||||
|
max_comp = int(argv[7]) if len(argv) > 7 else 64
|
||||||
|
|
||||||
|
return in_path, out_arg, tol, sew_tol, solid, split, max_comp
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- CLI entry ----------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
in_path, out_arg, tol, sew_tol, solid, split, max_comp = _parse_cli(sys.argv)
|
||||||
|
|
||||||
|
# 判断第二个参数是目录还是 step 文件
|
||||||
|
out_is_step = out_arg.lower().endswith((".step", ".stp"))
|
||||||
|
if (not out_is_step) or os.path.isdir(out_arg):
|
||||||
|
out_dir = out_arg
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
result = main(
|
||||||
|
in_path,
|
||||||
|
output_dir=out_dir,
|
||||||
|
tol=tol,
|
||||||
|
sew_tol=sew_tol,
|
||||||
|
to_solid=solid,
|
||||||
|
split_components=split,
|
||||||
|
max_components=max_comp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
os.makedirs(os.path.dirname(out_arg), exist_ok=True)
|
||||||
|
result = obj_to_step_fast(
|
||||||
|
in_path,
|
||||||
|
step_file_path=out_arg,
|
||||||
|
tol=tol,
|
||||||
|
sew_tol=sew_tol,
|
||||||
|
to_solid=solid,
|
||||||
|
split_components=split,
|
||||||
|
max_components=max_comp,
|
||||||
|
)
|
||||||
|
|
||||||
|
log("完成!")
|
||||||
|
log(f"STEP: {result}")
|
||||||
386
2_step_to_svg.py
Normal file
386
2_step_to_svg.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges)
|
||||||
|
Then combine 3 SVGs into a single SVG and a PNG.
|
||||||
|
|
||||||
|
Views:
|
||||||
|
- top : from +Z towards origin (view direction -Z)
|
||||||
|
- left : from -X towards origin (view direction +X)
|
||||||
|
- front: from +Y towards origin (view direction -Y)
|
||||||
|
|
||||||
|
Outputs (default in --out_dir):
|
||||||
|
- top.svg, left.svg, front.svg
|
||||||
|
- combined_views.svg
|
||||||
|
- combined_views.png
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os, sys, math, argparse, tempfile
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# ---------- PyTorch3D ----------
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
RasterizationSettings, MeshRasterizer,
|
||||||
|
MeshRenderer, BlendParams, SoftSilhouetteShader,
|
||||||
|
look_at_view_transform, OrthographicCameras
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||||
|
|
||||||
|
# ---------- SVG / Image ----------
|
||||||
|
import svgwrite
|
||||||
|
from skimage import measure
|
||||||
|
|
||||||
|
# ---------- pythonocc (OCC) ----------
|
||||||
|
from OCC.Core.STEPControl import STEPControl_Reader
|
||||||
|
from OCC.Core.TopAbs import TopAbs_FACE
|
||||||
|
from OCC.Core.TopExp import TopExp_Explorer
|
||||||
|
from OCC.Core.BRep import BRep_Tool
|
||||||
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
||||||
|
from OCC.Core.TopLoc import TopLoc_Location
|
||||||
|
from OCC.Core.ShapeFix import ShapeFix_Shape
|
||||||
|
|
||||||
|
# ---------- Combine SVG ----------
|
||||||
|
import svgutils.transform as st
|
||||||
|
from svgutils.compose import Unit
|
||||||
|
import cairosvg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- Utils ----------------
|
||||||
|
|
||||||
|
def autodevice():
|
||||||
|
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def read_step_solid(path: str):
|
||||||
|
reader = STEPControl_Reader()
|
||||||
|
stat = reader.ReadFile(path)
|
||||||
|
if stat != 1:
|
||||||
|
raise RuntimeError(f"STEP 读取失败: {path}")
|
||||||
|
reader.TransferRoots()
|
||||||
|
shape = reader.OneShape()
|
||||||
|
fixer = ShapeFix_Shape(shape)
|
||||||
|
fixer.Perform()
|
||||||
|
return fixer.Shape()
|
||||||
|
|
||||||
|
|
||||||
|
def triangulate(shape, deflection=0.5, angle=0.5):
|
||||||
|
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
|
||||||
|
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
|
||||||
|
verts, faces, v_off = [], [], 0
|
||||||
|
exp = TopExp_Explorer(shape, TopAbs_FACE)
|
||||||
|
while exp.More():
|
||||||
|
face = exp.Current()
|
||||||
|
loc = TopLoc_Location()
|
||||||
|
tri = BRep_Tool.Triangulation(face, loc)
|
||||||
|
if tri is not None:
|
||||||
|
nb_nodes = tri.NbNodes()
|
||||||
|
has_nodes_arr = hasattr(tri, "Nodes")
|
||||||
|
for i in range(1, nb_nodes + 1):
|
||||||
|
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
|
||||||
|
p = p.Transformed(loc.Transformation())
|
||||||
|
verts.append([p.X(), p.Y(), p.Z()])
|
||||||
|
nb_tris = tri.NbTriangles()
|
||||||
|
has_tris_arr = hasattr(tri, "Triangles")
|
||||||
|
for i in range(1, nb_tris + 1):
|
||||||
|
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
|
||||||
|
a, b, c = t.Get()
|
||||||
|
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
|
||||||
|
v_off += nb_nodes
|
||||||
|
exp.Next()
|
||||||
|
if not verts or not faces:
|
||||||
|
raise RuntimeError("三角化为空,尝试减小 --defl")
|
||||||
|
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
|
||||||
|
c = verts.mean(axis=0, keepdims=True)
|
||||||
|
v0 = verts - c
|
||||||
|
s = np.max(np.abs(v0))
|
||||||
|
s = max(s, 1e-12)
|
||||||
|
return v0 / s * unit_scale
|
||||||
|
|
||||||
|
|
||||||
|
def build_mesh(V_np, F_np, device):
|
||||||
|
V = torch.from_numpy(V_np).to(device)
|
||||||
|
F = torch.from_numpy(F_np).to(device)
|
||||||
|
return Meshes(verts=[V], faces=[F])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
|
||||||
|
v0 = verts[faces[:, 0]];
|
||||||
|
v1 = verts[faces[:, 1]];
|
||||||
|
v2 = verts[faces[:, 2]]
|
||||||
|
n = torch.cross(v1 - v0, v2 - v0, dim=1)
|
||||||
|
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
|
||||||
|
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
|
||||||
|
thr = math.cos(math.radians(max(0.0, angle_deg)))
|
||||||
|
e2f = defaultdict(list)
|
||||||
|
F = faces.shape[0]
|
||||||
|
for f in range(F):
|
||||||
|
i, j, k = faces[f].tolist()
|
||||||
|
for a, b in ((i, j), (j, k), (k, i)):
|
||||||
|
e = (a, b) if a < b else (b, a)
|
||||||
|
e2f[e].append(f)
|
||||||
|
feat = []
|
||||||
|
for e, fl in e2f.items():
|
||||||
|
if len(fl) == 1:
|
||||||
|
feat.append(e)
|
||||||
|
elif len(fl) == 2:
|
||||||
|
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
|
||||||
|
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
|
||||||
|
if cosv <= thr:
|
||||||
|
feat.append(e)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
|
||||||
|
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
|
||||||
|
rs = RasterizationSettings(
|
||||||
|
image_size=image_size,
|
||||||
|
blur_radius=0.0,
|
||||||
|
faces_per_pixel=faces_per_pixel,
|
||||||
|
cull_backfaces=True
|
||||||
|
)
|
||||||
|
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
|
||||||
|
with torch.no_grad():
|
||||||
|
frags: Fragments = rast(mesh, cameras=cameras)
|
||||||
|
return frags, frags.zbuf[0, ..., 0]
|
||||||
|
|
||||||
|
|
||||||
|
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
|
||||||
|
rs = RasterizationSettings(
|
||||||
|
image_size=image_size,
|
||||||
|
blur_radius=1e-6,
|
||||||
|
faces_per_pixel=50,
|
||||||
|
cull_backfaces=True
|
||||||
|
)
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
|
||||||
|
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
|
||||||
|
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
|
||||||
|
with torch.no_grad():
|
||||||
|
scr = cameras.transform_points_screen(
|
||||||
|
points_world[None, ...],
|
||||||
|
image_size=((image_size, image_size),)
|
||||||
|
)
|
||||||
|
return scr[0] # (N,3)
|
||||||
|
|
||||||
|
|
||||||
|
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
|
||||||
|
H = W = image_size
|
||||||
|
vis, hid = [], []
|
||||||
|
scr = project_points_to_screen(cameras, verts_world, image_size)
|
||||||
|
zmin = zbuf.detach().cpu().numpy()
|
||||||
|
for i, j in edges_idx:
|
||||||
|
p0, p1 = scr[i], scr[j]
|
||||||
|
m = 0.5 * (p0 + p1)
|
||||||
|
x = int(np.clip(m[0].item(), 0, W - 1))
|
||||||
|
y = int(np.clip(m[1].item(), 0, H - 1))
|
||||||
|
z_proj = m[2].item()
|
||||||
|
z_ref = zmin[y, x]
|
||||||
|
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
|
||||||
|
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
|
||||||
|
vis.append(seg)
|
||||||
|
else:
|
||||||
|
hid.append(seg)
|
||||||
|
return vis, hid
|
||||||
|
|
||||||
|
|
||||||
|
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
|
||||||
|
contours = measure.find_contours(alpha, level=threshold)
|
||||||
|
polys = []
|
||||||
|
for cnt in contours:
|
||||||
|
if len(cnt) < 8:
|
||||||
|
continue
|
||||||
|
cnt = cnt[::max(1, step)]
|
||||||
|
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
|
||||||
|
return polys
|
||||||
|
|
||||||
|
|
||||||
|
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
|
||||||
|
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
|
||||||
|
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
|
||||||
|
|
||||||
|
# silhouette fill (optional)
|
||||||
|
for poly in silhouettes:
|
||||||
|
if len(poly) < 3:
|
||||||
|
continue
|
||||||
|
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
|
||||||
|
for i in range(1, len(poly)):
|
||||||
|
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
|
||||||
|
path.append("Z")
|
||||||
|
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
|
||||||
|
|
||||||
|
# visible edges
|
||||||
|
for (x0, y0, x1, y1) in edges_vis:
|
||||||
|
dwg.add(dwg.line(
|
||||||
|
(x0 + margin, y0 + margin),
|
||||||
|
(x1 + margin, y1 + margin),
|
||||||
|
stroke="black",
|
||||||
|
stroke_width=stroke
|
||||||
|
))
|
||||||
|
|
||||||
|
# hidden edges
|
||||||
|
for (x0, y0, x1, y1) in edges_hid:
|
||||||
|
dwg.add(dwg.line(
|
||||||
|
(x0 + margin, y0 + margin),
|
||||||
|
(x1 + margin, y1 + margin),
|
||||||
|
stroke="black",
|
||||||
|
stroke_width=max(1.0, 0.75 * stroke),
|
||||||
|
stroke_dasharray=[6, 6],
|
||||||
|
opacity=0.85
|
||||||
|
))
|
||||||
|
|
||||||
|
dwg.save()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- Cameras ----------------
|
||||||
|
|
||||||
|
def make_camera(view: str, device, dist=2.7):
|
||||||
|
"""
|
||||||
|
look_at_view_transform spherical params for engineering views.
|
||||||
|
"""
|
||||||
|
if view == "top":
|
||||||
|
# from +Z to origin (view -Z)
|
||||||
|
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
|
||||||
|
elif view == "left":
|
||||||
|
# from -X to origin (view +X)
|
||||||
|
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
|
||||||
|
elif view == "front":
|
||||||
|
# from +Y to origin (view -Y)
|
||||||
|
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
|
||||||
|
else:
|
||||||
|
raise ValueError("view must be one of ['top','left','front']")
|
||||||
|
return OrthographicCameras(R=R, T=T, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- Combine 3 SVGs ----------------
|
||||||
|
|
||||||
|
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
|
||||||
|
"""
|
||||||
|
横向拼接 3 张 SVG,并额外导出 PNG。
|
||||||
|
"""
|
||||||
|
out_dir = os.path.dirname(output_svg)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
view1 = st.fromfile(view1_path)
|
||||||
|
view2 = st.fromfile(view2_path)
|
||||||
|
view3 = st.fromfile(view3_path)
|
||||||
|
|
||||||
|
r1 = view1.getroot()
|
||||||
|
r2 = view2.getroot()
|
||||||
|
r3 = view3.getroot()
|
||||||
|
|
||||||
|
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
|
||||||
|
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
|
||||||
|
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
|
||||||
|
|
||||||
|
max_w = max(w1, w2, w3)
|
||||||
|
max_h = max(h1, h2, h3)
|
||||||
|
|
||||||
|
combined_width = max_w * 3 + 40
|
||||||
|
combined_height = max_h + 20
|
||||||
|
|
||||||
|
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
|
||||||
|
|
||||||
|
x_offset = 10
|
||||||
|
y_offset = 10
|
||||||
|
|
||||||
|
r1.moveto(x_offset, y_offset)
|
||||||
|
x_offset += max_w + 20
|
||||||
|
r2.moveto(x_offset, y_offset)
|
||||||
|
x_offset += max_w + 20
|
||||||
|
r3.moveto(x_offset, y_offset)
|
||||||
|
|
||||||
|
combined.append([r1, r2, r3])
|
||||||
|
combined.save(output_svg)
|
||||||
|
|
||||||
|
# SVG -> PNG
|
||||||
|
cairosvg.svg2png(url=output_svg, write_to=output_image)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- Main ----------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step")
|
||||||
|
ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output")
|
||||||
|
|
||||||
|
ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)")
|
||||||
|
ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差")
|
||||||
|
ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)")
|
||||||
|
ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)")
|
||||||
|
ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长")
|
||||||
|
|
||||||
|
ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG,不合成")
|
||||||
|
ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg)")
|
||||||
|
ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png)")
|
||||||
|
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
device = autodevice()
|
||||||
|
print(f"[Device] {device}")
|
||||||
|
|
||||||
|
# STEP -> mesh
|
||||||
|
shape = read_step_solid(args.step_path)
|
||||||
|
V_np, F_np = triangulate(shape, deflection=args.defl)
|
||||||
|
V_np = normalize_mesh_np(V_np, unit_scale=args.scale)
|
||||||
|
|
||||||
|
mesh = build_mesh(V_np, F_np, device)
|
||||||
|
verts = mesh.verts_packed()
|
||||||
|
faces = mesh.faces_packed()
|
||||||
|
fn = compute_face_normals(verts, faces)
|
||||||
|
feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang)
|
||||||
|
|
||||||
|
view_paths = {}
|
||||||
|
for name in ["top", "left", "front"]:
|
||||||
|
print(f"[View] {name}")
|
||||||
|
cam = make_camera(name, device)
|
||||||
|
alpha = render_silhouette(mesh, cam, image_size=args.res)
|
||||||
|
_, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1)
|
||||||
|
|
||||||
|
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4)
|
||||||
|
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
|
||||||
|
|
||||||
|
out_svg = os.path.join(args.out_dir, f"{name}.svg")
|
||||||
|
svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid,
|
||||||
|
stroke=args.stroke, margin=10, fill="none")
|
||||||
|
view_paths[name] = out_svg
|
||||||
|
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
|
||||||
|
|
||||||
|
if args.no_combine:
|
||||||
|
print("[Done] 3 views exported (no combine).")
|
||||||
|
return
|
||||||
|
|
||||||
|
combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg")
|
||||||
|
combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png")
|
||||||
|
|
||||||
|
# 注意:这里组合顺序按你第二段示例:front / top / left
|
||||||
|
combine_svgs(
|
||||||
|
view1_path=view_paths["front"],
|
||||||
|
view2_path=view_paths["top"],
|
||||||
|
view3_path=view_paths["left"],
|
||||||
|
output_svg=combined_svg,
|
||||||
|
output_image=combined_png
|
||||||
|
)
|
||||||
|
print(f"[Combined] SVG: {combined_svg}")
|
||||||
|
print(f"[Combined] PNG: {combined_png}")
|
||||||
|
print("[Done] All outputs exported.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
403
app.py
Normal file
403
app.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from gradio_litmodel3d import LitModel3D
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from PIL import Image
|
||||||
|
from trellis.pipelines import TrellisImageTo3DPipeline
|
||||||
|
from trellis.representations import Gaussian, MeshExtractResult
|
||||||
|
from trellis.utils import render_utils, postprocessing_utils
|
||||||
|
|
||||||
|
|
||||||
|
MAX_SEED = np.iinfo(np.int32).max
|
||||||
|
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
||||||
|
os.makedirs(TMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def start_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def end_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
shutil.rmtree(user_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: The preprocessed image.
|
||||||
|
"""
|
||||||
|
processed_image = pipeline.preprocess_image(image)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
|
||||||
|
"""
|
||||||
|
Preprocess a list of input images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (List[Tuple[Image.Image, str]]): The input images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Image.Image]: The preprocessed images.
|
||||||
|
"""
|
||||||
|
images = [image[0] for image in images]
|
||||||
|
processed_images = [pipeline.preprocess_image(image) for image in images]
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
|
||||||
|
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
||||||
|
return {
|
||||||
|
'gaussian': {
|
||||||
|
**gs.init_params,
|
||||||
|
'_xyz': gs._xyz.cpu().numpy(),
|
||||||
|
'_features_dc': gs._features_dc.cpu().numpy(),
|
||||||
|
'_scaling': gs._scaling.cpu().numpy(),
|
||||||
|
'_rotation': gs._rotation.cpu().numpy(),
|
||||||
|
'_opacity': gs._opacity.cpu().numpy(),
|
||||||
|
},
|
||||||
|
'mesh': {
|
||||||
|
'vertices': mesh.vertices.cpu().numpy(),
|
||||||
|
'faces': mesh.faces.cpu().numpy(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
||||||
|
gs = Gaussian(
|
||||||
|
aabb=state['gaussian']['aabb'],
|
||||||
|
sh_degree=state['gaussian']['sh_degree'],
|
||||||
|
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
||||||
|
scaling_bias=state['gaussian']['scaling_bias'],
|
||||||
|
opacity_bias=state['gaussian']['opacity_bias'],
|
||||||
|
scaling_activation=state['gaussian']['scaling_activation'],
|
||||||
|
)
|
||||||
|
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
||||||
|
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
||||||
|
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
||||||
|
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
||||||
|
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
||||||
|
|
||||||
|
mesh = edict(
|
||||||
|
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
||||||
|
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
||||||
|
)
|
||||||
|
|
||||||
|
return gs, mesh
|
||||||
|
|
||||||
|
|
||||||
|
def get_seed(randomize_seed: bool, seed: int) -> int:
|
||||||
|
"""
|
||||||
|
Get the random seed.
|
||||||
|
"""
|
||||||
|
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_3d(
|
||||||
|
image: Image.Image,
|
||||||
|
multiimages: List[Tuple[Image.Image, str]],
|
||||||
|
is_multiimage: bool,
|
||||||
|
seed: int,
|
||||||
|
ss_guidance_strength: float,
|
||||||
|
ss_sampling_steps: int,
|
||||||
|
slat_guidance_strength: float,
|
||||||
|
slat_sampling_steps: int,
|
||||||
|
multiimage_algo: Literal["multidiffusion", "stochastic"],
|
||||||
|
req: gr.Request,
|
||||||
|
) -> Tuple[dict, str]:
|
||||||
|
"""
|
||||||
|
Convert an image to a 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The input image.
|
||||||
|
multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
|
||||||
|
is_multiimage (bool): Whether is in multi-image mode.
|
||||||
|
seed (int): The random seed.
|
||||||
|
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
||||||
|
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
||||||
|
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
||||||
|
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
||||||
|
multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The information of the generated 3D model.
|
||||||
|
str: The path to the video of the 3D model.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
if not is_multiimage:
|
||||||
|
outputs = pipeline.run(
|
||||||
|
image,
|
||||||
|
seed=seed,
|
||||||
|
formats=["gaussian", "mesh"],
|
||||||
|
preprocess_image=False,
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": ss_sampling_steps,
|
||||||
|
"cfg_strength": ss_guidance_strength,
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": slat_sampling_steps,
|
||||||
|
"cfg_strength": slat_guidance_strength,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = pipeline.run_multi_image(
|
||||||
|
[image[0] for image in multiimages],
|
||||||
|
seed=seed,
|
||||||
|
formats=["gaussian", "mesh"],
|
||||||
|
preprocess_image=False,
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": ss_sampling_steps,
|
||||||
|
"cfg_strength": ss_guidance_strength,
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": slat_sampling_steps,
|
||||||
|
"cfg_strength": slat_guidance_strength,
|
||||||
|
},
|
||||||
|
mode=multiimage_algo,
|
||||||
|
)
|
||||||
|
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
||||||
|
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
||||||
|
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
||||||
|
video_path = os.path.join(user_dir, 'sample.mp4')
|
||||||
|
imageio.mimsave(video_path, video, fps=15)
|
||||||
|
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return state, video_path
|
||||||
|
|
||||||
|
|
||||||
|
def extract_glb(
|
||||||
|
state: dict,
|
||||||
|
mesh_simplify: float,
|
||||||
|
texture_size: int,
|
||||||
|
req: gr.Request,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract a GLB file from the 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The state of the generated 3D model.
|
||||||
|
mesh_simplify (float): The mesh simplification factor.
|
||||||
|
texture_size (int): The texture resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the extracted GLB file.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
gs, mesh = unpack_state(state)
|
||||||
|
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
||||||
|
glb_path = os.path.join(user_dir, 'sample.glb')
|
||||||
|
glb.export(glb_path)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return glb_path, glb_path
|
||||||
|
|
||||||
|
|
||||||
|
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract a Gaussian file from the 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The state of the generated 3D model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the extracted Gaussian file.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
gs, _ = unpack_state(state)
|
||||||
|
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
||||||
|
gs.save_ply(gaussian_path)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return gaussian_path, gaussian_path
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_multi_example() -> List[Image.Image]:
|
||||||
|
multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
|
||||||
|
images = []
|
||||||
|
for case in multi_case:
|
||||||
|
_images = []
|
||||||
|
for i in range(1, 4):
|
||||||
|
img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
|
||||||
|
W, H = img.size
|
||||||
|
img = img.resize((int(W / H * 512), 512))
|
||||||
|
_images.append(np.array(img))
|
||||||
|
images.append(Image.fromarray(np.concatenate(_images, axis=1)))
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def split_image(image: Image.Image) -> List[Image.Image]:
|
||||||
|
"""
|
||||||
|
Split an image into multiple views.
|
||||||
|
"""
|
||||||
|
image = np.array(image)
|
||||||
|
alpha = image[..., 3]
|
||||||
|
alpha = np.any(alpha>0, axis=0)
|
||||||
|
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
|
||||||
|
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
|
||||||
|
images = []
|
||||||
|
for s, e in zip(start_pos, end_pos):
|
||||||
|
images.append(Image.fromarray(image[:, s:e+1]))
|
||||||
|
return [preprocess_image(image) for image in images]
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
||||||
|
gr.Markdown("""
|
||||||
|
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
||||||
|
* Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
|
||||||
|
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Tabs() as input_tabs:
|
||||||
|
with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
|
||||||
|
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
|
||||||
|
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
|
||||||
|
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
|
||||||
|
gr.Markdown("""
|
||||||
|
Input different views of the object in separate images.
|
||||||
|
|
||||||
|
*NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Accordion(label="Generation Settings", open=False):
|
||||||
|
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
||||||
|
gr.Markdown("Stage 1: Sparse Structure Generation")
|
||||||
|
with gr.Row():
|
||||||
|
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
||||||
|
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
||||||
|
gr.Markdown("Stage 2: Structured Latent Generation")
|
||||||
|
with gr.Row():
|
||||||
|
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
||||||
|
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
||||||
|
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
|
||||||
|
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
|
||||||
|
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
||||||
|
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
||||||
|
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
||||||
|
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
||||||
|
gr.Markdown("""
|
||||||
|
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
||||||
|
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
||||||
|
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
||||||
|
|
||||||
|
is_multiimage = gr.State(False)
|
||||||
|
output_buf = gr.State()
|
||||||
|
|
||||||
|
# Example images at the bottom of the page
|
||||||
|
with gr.Row() as single_image_example:
|
||||||
|
examples = gr.Examples(
|
||||||
|
examples=[
|
||||||
|
f'assets/example_image/{image}'
|
||||||
|
for image in os.listdir("assets/example_image")
|
||||||
|
],
|
||||||
|
inputs=[image_prompt],
|
||||||
|
fn=preprocess_image,
|
||||||
|
outputs=[image_prompt],
|
||||||
|
run_on_click=True,
|
||||||
|
examples_per_page=64,
|
||||||
|
)
|
||||||
|
with gr.Row(visible=False) as multiimage_example:
|
||||||
|
examples_multi = gr.Examples(
|
||||||
|
examples=prepare_multi_example(),
|
||||||
|
inputs=[image_prompt],
|
||||||
|
fn=split_image,
|
||||||
|
outputs=[multiimage_prompt],
|
||||||
|
run_on_click=True,
|
||||||
|
examples_per_page=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handlers
|
||||||
|
demo.load(start_session)
|
||||||
|
demo.unload(end_session)
|
||||||
|
|
||||||
|
single_image_input_tab.select(
|
||||||
|
lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
|
||||||
|
outputs=[is_multiimage, single_image_example, multiimage_example]
|
||||||
|
)
|
||||||
|
multiimage_input_tab.select(
|
||||||
|
lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
|
||||||
|
outputs=[is_multiimage, single_image_example, multiimage_example]
|
||||||
|
)
|
||||||
|
|
||||||
|
image_prompt.upload(
|
||||||
|
preprocess_image,
|
||||||
|
inputs=[image_prompt],
|
||||||
|
outputs=[image_prompt],
|
||||||
|
)
|
||||||
|
multiimage_prompt.upload(
|
||||||
|
preprocess_images,
|
||||||
|
inputs=[multiimage_prompt],
|
||||||
|
outputs=[multiimage_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
get_seed,
|
||||||
|
inputs=[randomize_seed, seed],
|
||||||
|
outputs=[seed],
|
||||||
|
).then(
|
||||||
|
image_to_3d,
|
||||||
|
inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
|
||||||
|
outputs=[output_buf, video_output],
|
||||||
|
).then(
|
||||||
|
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
||||||
|
outputs=[extract_glb_btn, extract_gs_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
video_output.clear(
|
||||||
|
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
||||||
|
outputs=[extract_glb_btn, extract_gs_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_glb_btn.click(
|
||||||
|
extract_glb,
|
||||||
|
inputs=[output_buf, mesh_simplify, texture_size],
|
||||||
|
outputs=[model_output, download_glb],
|
||||||
|
).then(
|
||||||
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[download_glb],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_gs_btn.click(
|
||||||
|
extract_gaussian,
|
||||||
|
inputs=[output_buf],
|
||||||
|
outputs=[model_output, download_gs],
|
||||||
|
).then(
|
||||||
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[download_gs],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_output.clear(
|
||||||
|
lambda: gr.Button(interactive=False),
|
||||||
|
outputs=[download_glb],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Launch the Gradio app
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
|
||||||
|
pipeline.cuda()
|
||||||
|
demo.launch()
|
||||||
266
app_text.py
Normal file
266
app_text.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from gradio_litmodel3d import LitModel3D
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from trellis.pipelines import TrellisTextTo3DPipeline
|
||||||
|
from trellis.representations import Gaussian, MeshExtractResult
|
||||||
|
from trellis.utils import render_utils, postprocessing_utils
|
||||||
|
|
||||||
|
|
||||||
|
MAX_SEED = np.iinfo(np.int32).max
|
||||||
|
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
||||||
|
os.makedirs(TMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def start_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def end_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
shutil.rmtree(user_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
||||||
|
return {
|
||||||
|
'gaussian': {
|
||||||
|
**gs.init_params,
|
||||||
|
'_xyz': gs._xyz.cpu().numpy(),
|
||||||
|
'_features_dc': gs._features_dc.cpu().numpy(),
|
||||||
|
'_scaling': gs._scaling.cpu().numpy(),
|
||||||
|
'_rotation': gs._rotation.cpu().numpy(),
|
||||||
|
'_opacity': gs._opacity.cpu().numpy(),
|
||||||
|
},
|
||||||
|
'mesh': {
|
||||||
|
'vertices': mesh.vertices.cpu().numpy(),
|
||||||
|
'faces': mesh.faces.cpu().numpy(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
|
||||||
|
gs = Gaussian(
|
||||||
|
aabb=state['gaussian']['aabb'],
|
||||||
|
sh_degree=state['gaussian']['sh_degree'],
|
||||||
|
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
|
||||||
|
scaling_bias=state['gaussian']['scaling_bias'],
|
||||||
|
opacity_bias=state['gaussian']['opacity_bias'],
|
||||||
|
scaling_activation=state['gaussian']['scaling_activation'],
|
||||||
|
)
|
||||||
|
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
|
||||||
|
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
|
||||||
|
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
|
||||||
|
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
|
||||||
|
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
|
||||||
|
|
||||||
|
mesh = edict(
|
||||||
|
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
|
||||||
|
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
|
||||||
|
)
|
||||||
|
|
||||||
|
return gs, mesh
|
||||||
|
|
||||||
|
|
||||||
|
def get_seed(randomize_seed: bool, seed: int) -> int:
|
||||||
|
"""
|
||||||
|
Get the random seed.
|
||||||
|
"""
|
||||||
|
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_3d(
|
||||||
|
prompt: str,
|
||||||
|
seed: int,
|
||||||
|
ss_guidance_strength: float,
|
||||||
|
ss_sampling_steps: int,
|
||||||
|
slat_guidance_strength: float,
|
||||||
|
slat_sampling_steps: int,
|
||||||
|
req: gr.Request,
|
||||||
|
) -> Tuple[dict, str]:
|
||||||
|
"""
|
||||||
|
Convert an text prompt to a 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The text prompt.
|
||||||
|
seed (int): The random seed.
|
||||||
|
ss_guidance_strength (float): The guidance strength for sparse structure generation.
|
||||||
|
ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
|
||||||
|
slat_guidance_strength (float): The guidance strength for structured latent generation.
|
||||||
|
slat_sampling_steps (int): The number of sampling steps for structured latent generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The information of the generated 3D model.
|
||||||
|
str: The path to the video of the 3D model.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
outputs = pipeline.run(
|
||||||
|
prompt,
|
||||||
|
seed=seed,
|
||||||
|
formats=["gaussian", "mesh"],
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": ss_sampling_steps,
|
||||||
|
"cfg_strength": ss_guidance_strength,
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": slat_sampling_steps,
|
||||||
|
"cfg_strength": slat_guidance_strength,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
||||||
|
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
||||||
|
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
||||||
|
video_path = os.path.join(user_dir, 'sample.mp4')
|
||||||
|
imageio.mimsave(video_path, video, fps=15)
|
||||||
|
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return state, video_path
|
||||||
|
|
||||||
|
|
||||||
|
def extract_glb(
|
||||||
|
state: dict,
|
||||||
|
mesh_simplify: float,
|
||||||
|
texture_size: int,
|
||||||
|
req: gr.Request,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract a GLB file from the 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The state of the generated 3D model.
|
||||||
|
mesh_simplify (float): The mesh simplification factor.
|
||||||
|
texture_size (int): The texture resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the extracted GLB file.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
gs, mesh = unpack_state(state)
|
||||||
|
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
||||||
|
glb_path = os.path.join(user_dir, 'sample.glb')
|
||||||
|
glb.export(glb_path)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return glb_path, glb_path
|
||||||
|
|
||||||
|
|
||||||
|
def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract a Gaussian file from the 3D model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The state of the generated 3D model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the extracted Gaussian file.
|
||||||
|
"""
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
gs, _ = unpack_state(state)
|
||||||
|
gaussian_path = os.path.join(user_dir, 'sample.ply')
|
||||||
|
gs.save_ply(gaussian_path)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return gaussian_path, gaussian_path
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
||||||
|
gr.Markdown("""
|
||||||
|
## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
||||||
|
* Type a text prompt and click "Generate" to create a 3D asset.
|
||||||
|
* If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
text_prompt = gr.Textbox(label="Text Prompt", lines=5)
|
||||||
|
|
||||||
|
with gr.Accordion(label="Generation Settings", open=False):
|
||||||
|
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
||||||
|
gr.Markdown("Stage 1: Sparse Structure Generation")
|
||||||
|
with gr.Row():
|
||||||
|
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
||||||
|
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
||||||
|
gr.Markdown("Stage 2: Structured Latent Generation")
|
||||||
|
with gr.Row():
|
||||||
|
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
||||||
|
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
|
||||||
|
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
|
||||||
|
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
||||||
|
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
||||||
|
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
|
||||||
|
extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
|
||||||
|
gr.Markdown("""
|
||||||
|
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
||||||
|
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
||||||
|
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
||||||
|
|
||||||
|
output_buf = gr.State()
|
||||||
|
|
||||||
|
# Handlers
|
||||||
|
demo.load(start_session)
|
||||||
|
demo.unload(end_session)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
get_seed,
|
||||||
|
inputs=[randomize_seed, seed],
|
||||||
|
outputs=[seed],
|
||||||
|
).then(
|
||||||
|
text_to_3d,
|
||||||
|
inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
||||||
|
outputs=[output_buf, video_output],
|
||||||
|
).then(
|
||||||
|
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
||||||
|
outputs=[extract_glb_btn, extract_gs_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
video_output.clear(
|
||||||
|
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
||||||
|
outputs=[extract_glb_btn, extract_gs_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_glb_btn.click(
|
||||||
|
extract_glb,
|
||||||
|
inputs=[output_buf, mesh_simplify, texture_size],
|
||||||
|
outputs=[model_output, download_glb],
|
||||||
|
).then(
|
||||||
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[download_glb],
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_gs_btn.click(
|
||||||
|
extract_gaussian,
|
||||||
|
inputs=[output_buf],
|
||||||
|
outputs=[model_output, download_gs],
|
||||||
|
).then(
|
||||||
|
lambda: gr.Button(interactive=True),
|
||||||
|
outputs=[download_gs],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_output.clear(
|
||||||
|
lambda: gr.Button(interactive=False),
|
||||||
|
outputs=[download_glb],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Launch the Gradio app
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge")
|
||||||
|
pipeline.cuda()
|
||||||
|
demo.launch()
|
||||||
48
blender_glb_to_obj.py
Normal file
48
blender_glb_to_obj.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import bpy
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
argv = sys.argv
|
||||||
|
argv = argv[argv.index("--") + 1:]
|
||||||
|
|
||||||
|
glb_input = argv[0]
|
||||||
|
obj_output = argv[1]
|
||||||
|
|
||||||
|
|
||||||
|
def clean():
|
||||||
|
bpy.ops.object.select_all(action='SELECT')
|
||||||
|
bpy.ops.object.delete(use_global=False)
|
||||||
|
|
||||||
|
|
||||||
|
clean()
|
||||||
|
|
||||||
|
bpy.ops.import_scene.gltf(filepath=glb_input)
|
||||||
|
|
||||||
|
meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"]
|
||||||
|
|
||||||
|
if not meshes:
|
||||||
|
raise RuntimeError("No mesh objects found in GLB/GLTF.")
|
||||||
|
|
||||||
|
bpy.ops.object.select_all(action='DESELECT')
|
||||||
|
|
||||||
|
for o in meshes:
|
||||||
|
o.select_set(True)
|
||||||
|
|
||||||
|
bpy.context.view_layer.objects.active = meshes[0]
|
||||||
|
|
||||||
|
if len(meshes) > 1:
|
||||||
|
bpy.ops.object.join()
|
||||||
|
|
||||||
|
obj = bpy.context.view_layer.objects.active
|
||||||
|
obj.select_set(True)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(obj_output), exist_ok=True)
|
||||||
|
|
||||||
|
bpy.ops.wm.obj_export(
|
||||||
|
filepath=obj_output,
|
||||||
|
export_selected_objects=True,
|
||||||
|
export_normals=True,
|
||||||
|
export_uv=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("OBJ exported:", obj_output)
|
||||||
7
client.py
Normal file
7
client.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# This file is auto-generated by LitServe.
|
||||||
|
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
|
||||||
|
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
|
||||||
285
dataset_toolkits/build_metadata.py
Normal file
285
dataset_toolkits/build_metadata.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import importlib
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import utils3d
|
||||||
|
|
||||||
|
def get_first_directory(path):
|
||||||
|
with os.scandir(path) as it:
|
||||||
|
for entry in it:
|
||||||
|
if entry.is_dir():
|
||||||
|
return entry.name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def need_process(key):
|
||||||
|
return key in opt.field or opt.field == ['all']
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--output_dir', type=str, required=True,
|
||||||
|
help='Directory to save the metadata')
|
||||||
|
parser.add_argument('--field', type=str, default='all',
|
||||||
|
help='Fields to process, separated by commas')
|
||||||
|
parser.add_argument('--from_file', action='store_true',
|
||||||
|
help='Build metadata from file instead of from records of processings.' +
|
||||||
|
'Useful when some processing fail to generate records but file already exists.')
|
||||||
|
dataset_utils.add_args(parser)
|
||||||
|
opt = parser.parse_args(sys.argv[2:])
|
||||||
|
opt = edict(vars(opt))
|
||||||
|
|
||||||
|
os.makedirs(opt.output_dir, exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
|
||||||
|
|
||||||
|
opt.field = opt.field.split(',')
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
|
||||||
|
# get file list
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||||
|
print('Loading previous metadata...')
|
||||||
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
else:
|
||||||
|
metadata = dataset_utils.get_metadata(**opt)
|
||||||
|
metadata.set_index('sha256', inplace=True)
|
||||||
|
|
||||||
|
# merge downloaded
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
if 'local_path' in metadata.columns:
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
else:
|
||||||
|
metadata = metadata.join(df, on='sha256', how='left')
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# detect models
|
||||||
|
image_models = []
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'features')):
|
||||||
|
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
|
||||||
|
latent_models = []
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
|
||||||
|
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
|
||||||
|
ss_latent_models = []
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
|
||||||
|
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
|
||||||
|
print(f'Image models: {image_models}')
|
||||||
|
print(f'Latent models: {latent_models}')
|
||||||
|
print(f'Sparse Structure latent models: {ss_latent_models}')
|
||||||
|
|
||||||
|
if 'rendered' not in metadata.columns:
|
||||||
|
metadata['rendered'] = [False] * len(metadata)
|
||||||
|
if 'voxelized' not in metadata.columns:
|
||||||
|
metadata['voxelized'] = [False] * len(metadata)
|
||||||
|
if 'num_voxels' not in metadata.columns:
|
||||||
|
metadata['num_voxels'] = [0] * len(metadata)
|
||||||
|
if 'cond_rendered' not in metadata.columns:
|
||||||
|
metadata['cond_rendered'] = [False] * len(metadata)
|
||||||
|
for model in image_models:
|
||||||
|
if f'feature_{model}' not in metadata.columns:
|
||||||
|
metadata[f'feature_{model}'] = [False] * len(metadata)
|
||||||
|
for model in latent_models:
|
||||||
|
if f'latent_{model}' not in metadata.columns:
|
||||||
|
metadata[f'latent_{model}'] = [False] * len(metadata)
|
||||||
|
for model in ss_latent_models:
|
||||||
|
if f'ss_latent_{model}' not in metadata.columns:
|
||||||
|
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
|
||||||
|
|
||||||
|
# merge rendered
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge aesthetic scores
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('aesthetic_scores_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge voxelized
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge cond_rendered
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge features
|
||||||
|
for model in image_models:
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge latents
|
||||||
|
for model in latent_models:
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# merge sparse structure latents
|
||||||
|
for model in ss_latent_models:
|
||||||
|
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
|
||||||
|
df_parts = []
|
||||||
|
for f in df_files:
|
||||||
|
try:
|
||||||
|
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(df_parts) > 0:
|
||||||
|
df = pd.concat(df_parts)
|
||||||
|
df.set_index('sha256', inplace=True)
|
||||||
|
metadata.update(df, overwrite=True)
|
||||||
|
for f in df_files:
|
||||||
|
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||||
|
|
||||||
|
# build metadata from files
|
||||||
|
if opt.from_file:
|
||||||
|
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||||
|
tqdm(total=len(metadata), desc="Building metadata") as pbar:
|
||||||
|
def worker(sha256):
|
||||||
|
try:
|
||||||
|
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
||||||
|
metadata.loc[sha256, 'rendered'] = True
|
||||||
|
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
||||||
|
try:
|
||||||
|
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
||||||
|
metadata.loc[sha256, 'voxelized'] = True
|
||||||
|
metadata.loc[sha256, 'num_voxels'] = len(pts)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
||||||
|
metadata.loc[sha256, 'cond_rendered'] = True
|
||||||
|
for model in image_models:
|
||||||
|
if need_process(f'feature_{model}') and \
|
||||||
|
metadata.loc[sha256, f'feature_{model}'] == False and \
|
||||||
|
metadata.loc[sha256, 'rendered'] == True and \
|
||||||
|
metadata.loc[sha256, 'voxelized'] == True and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
|
||||||
|
metadata.loc[sha256, f'feature_{model}'] = True
|
||||||
|
for model in latent_models:
|
||||||
|
if need_process(f'latent_{model}') and \
|
||||||
|
metadata.loc[sha256, f'latent_{model}'] == False and \
|
||||||
|
metadata.loc[sha256, 'rendered'] == True and \
|
||||||
|
metadata.loc[sha256, 'voxelized'] == True and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
|
||||||
|
metadata.loc[sha256, f'latent_{model}'] = True
|
||||||
|
for model in ss_latent_models:
|
||||||
|
if need_process(f'ss_latent_{model}') and \
|
||||||
|
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
|
||||||
|
metadata.loc[sha256, 'voxelized'] == True and \
|
||||||
|
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
|
||||||
|
metadata.loc[sha256, f'ss_latent_{model}'] = True
|
||||||
|
pbar.update()
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error processing {sha256}: {e}')
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
executor.map(worker, metadata.index)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
|
||||||
|
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
|
||||||
|
f.write('Statistics:\n')
|
||||||
|
f.write(f' - Number of assets: {len(metadata)}\n')
|
||||||
|
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
||||||
|
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
|
||||||
|
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
|
||||||
|
if len(image_models) != 0:
|
||||||
|
f.write(f' - Number of assets with image features extracted:\n')
|
||||||
|
for model in image_models:
|
||||||
|
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
|
||||||
|
if len(latent_models) != 0:
|
||||||
|
f.write(f' - Number of assets with latents extracted:\n')
|
||||||
|
for model in latent_models:
|
||||||
|
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
|
||||||
|
if len(ss_latent_models) != 0:
|
||||||
|
f.write(f' - Number of assets with sparse structure latents extracted:\n')
|
||||||
|
for model in ss_latent_models:
|
||||||
|
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
|
||||||
|
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
|
||||||
|
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
|
||||||
|
|
||||||
|
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
|
||||||
|
print(f.read())
|
||||||
102
dataset_toolkits/calculate_aesthetic_scores.py
Normal file
102
dataset_toolkits/calculate_aesthetic_scores.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from PIL import Image
|
||||||
|
import open_clip
|
||||||
|
from os.path import expanduser
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
|
||||||
|
def get_aesthetic_model(clip_model="vit_l_14"):
|
||||||
|
"""load the aethetic model"""
|
||||||
|
home = expanduser("~")
|
||||||
|
cache_folder = home + "/.cache/emb_reader"
|
||||||
|
path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
|
||||||
|
if not os.path.exists(path_to_model):
|
||||||
|
os.makedirs(cache_folder, exist_ok=True)
|
||||||
|
url_model = (
|
||||||
|
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
|
||||||
|
)
|
||||||
|
urlretrieve(url_model, path_to_model)
|
||||||
|
if clip_model == "vit_l_14":
|
||||||
|
m = nn.Linear(768, 1)
|
||||||
|
elif clip_model == "vit_b_32":
|
||||||
|
m = nn.Linear(512, 1)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
s = torch.load(path_to_model)
|
||||||
|
m.load_state_dict(s)
|
||||||
|
m.eval()
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--clip_model", type=str, default="vit_l_14")
|
||||||
|
parser.add_argument("--output_dir", type=str, required=True)
|
||||||
|
parser.add_argument("--rank", type=int, default=0)
|
||||||
|
parser.add_argument("--world_size", type=int, default=1)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
amodel = get_aesthetic_model(clip_model="vit_l_14")
|
||||||
|
amodel.eval()
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
|
||||||
|
model = model.cuda()
|
||||||
|
amodel = amodel.cuda()
|
||||||
|
|
||||||
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
metadata = metadata[metadata['snapshotted'] == 1]
|
||||||
|
sha256s = metadata['sha256'].values
|
||||||
|
|
||||||
|
# filter out objects that are already calculated
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'aesthetic_scores.csv')):
|
||||||
|
with open(os.path.join(opt.output_dir, 'aesthetic_scores.csv'), 'r') as f:
|
||||||
|
old_metadata = pd.read_csv(f)
|
||||||
|
sha256s = list(set(sha256s) - set(old_metadata['sha256'].values))
|
||||||
|
|
||||||
|
sha256s = sorted(sha256s)
|
||||||
|
sha256s = sha256s[len(sha256s) * opt.rank // opt.world_size: len(sha256s) * (opt.rank + 1) // opt.world_size]
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||||
|
finished = Queue(maxsize=128)
|
||||||
|
|
||||||
|
def load_image(sha256):
|
||||||
|
try:
|
||||||
|
files = os.listdir(os.path.join(opt.output_dir, 'snapshots', sha256))
|
||||||
|
files = [f for f in files if f.endswith('.png')]
|
||||||
|
processed = []
|
||||||
|
for file in files:
|
||||||
|
image = Image.open(os.path.join(opt.output_dir, 'snapshots', sha256, file))
|
||||||
|
processed.append(preprocess(image))
|
||||||
|
processed = torch.stack(processed, dim=0)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
processed = None
|
||||||
|
finished.put((sha256, processed))
|
||||||
|
|
||||||
|
executor.map(load_image, sha256s)
|
||||||
|
for _ in tqdm(range(len(sha256s)), desc='Calculating aesthetic scores'):
|
||||||
|
sha256, processed = finished.get()
|
||||||
|
if processed is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = model.encode_image(processed.cuda())
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
aesthetic_score = amodel(image_features).cpu()
|
||||||
|
rows.append(pd.DataFrame({
|
||||||
|
'sha256': [sha256],
|
||||||
|
'mean': [aesthetic_score.mean().item()],
|
||||||
|
'std': [aesthetic_score.std().item()],
|
||||||
|
'min': [aesthetic_score.min().item()],
|
||||||
|
'max': [aesthetic_score.max().item()],
|
||||||
|
'median': [aesthetic_score.median().item()]
|
||||||
|
}))
|
||||||
|
|
||||||
|
with open(os.path.join(opt.output_dir, f'aesthetic_scores_{opt.rank}.csv'), 'w') as f:
|
||||||
|
pd.concat(rows).to_csv(f, index=False)
|
||||||
97
dataset_toolkits/datasets/3D-FUTURE.py
Normal file
97
dataset_toolkits/datasets/3D-FUTURE.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
import zipfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
from utils import get_file_hash
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(parser: argparse.ArgumentParser):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata(**kwargs):
|
||||||
|
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def download(metadata, output_dir, **kwargs):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
|
||||||
|
print("\033[93m")
|
||||||
|
print("3D-FUTURE have to be downloaded manually")
|
||||||
|
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
|
||||||
|
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
|
||||||
|
print("\033[0m")
|
||||||
|
raise FileNotFoundError("3D-FUTURE-model.zip not found")
|
||||||
|
|
||||||
|
downloaded = {}
|
||||||
|
metadata = metadata.set_index("file_identifier")
|
||||||
|
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
|
||||||
|
all_names = zip_ref.namelist()
|
||||||
|
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
|
||||||
|
instances = list(filter(lambda x: x in metadata.index, instances))
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||||
|
tqdm(total=len(instances), desc="Extracting") as pbar:
|
||||||
|
def worker(instance: str) -> str:
|
||||||
|
try:
|
||||||
|
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
|
||||||
|
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
|
||||||
|
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
|
||||||
|
pbar.update()
|
||||||
|
return sha256
|
||||||
|
except Exception as e:
|
||||||
|
pbar.update()
|
||||||
|
print(f"Error extracting for {instance}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
sha256s = executor.map(worker, instances)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
for k, sha256 in zip(instances, sha256s):
|
||||||
|
if sha256 is not None:
|
||||||
|
if sha256 == metadata.loc[k, "sha256"]:
|
||||||
|
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
|
||||||
|
else:
|
||||||
|
print(f"Error downloading {k}: sha256s do not match")
|
||||||
|
|
||||||
|
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||||
|
|
||||||
|
|
||||||
|
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# load metadata
|
||||||
|
metadata = metadata.to_dict('records')
|
||||||
|
|
||||||
|
# processing objects
|
||||||
|
records = []
|
||||||
|
max_workers = max_workers or os.cpu_count()
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||||
|
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||||
|
def worker(metadatum):
|
||||||
|
try:
|
||||||
|
local_path = metadatum['local_path']
|
||||||
|
sha256 = metadatum['sha256']
|
||||||
|
file = os.path.join(output_dir, local_path)
|
||||||
|
record = func(file, sha256)
|
||||||
|
if record is not None:
|
||||||
|
records.append(record)
|
||||||
|
pbar.update()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing object {sha256}: {e}")
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
executor.map(worker, metadata)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
except:
|
||||||
|
print("Error happened during processing.")
|
||||||
|
|
||||||
|
return pd.DataFrame.from_records(records)
|
||||||
96
dataset_toolkits/datasets/ABO.py
Normal file
96
dataset_toolkits/datasets/ABO.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
import tarfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
from utils import get_file_hash
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(parser: argparse.ArgumentParser):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata(**kwargs):
|
||||||
|
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def download(metadata, output_dir, **kwargs):
|
||||||
|
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||||
|
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
|
||||||
|
except:
|
||||||
|
print("\033[93m")
|
||||||
|
print("Error downloading ABO dataset. Please check your internet connection and try again.")
|
||||||
|
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
|
||||||
|
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
|
||||||
|
print("\033[0m")
|
||||||
|
raise FileNotFoundError("Error downloading ABO dataset")
|
||||||
|
|
||||||
|
downloaded = {}
|
||||||
|
metadata = metadata.set_index("file_identifier")
|
||||||
|
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor, \
|
||||||
|
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
||||||
|
def worker(instance: str) -> str:
|
||||||
|
try:
|
||||||
|
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
|
||||||
|
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
|
||||||
|
pbar.update()
|
||||||
|
return sha256
|
||||||
|
except Exception as e:
|
||||||
|
pbar.update()
|
||||||
|
print(f"Error extracting for {instance}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
sha256s = executor.map(worker, metadata.index)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
for k, sha256 in zip(metadata.index, sha256s):
|
||||||
|
if sha256 is not None:
|
||||||
|
if sha256 == metadata.loc[k, "sha256"]:
|
||||||
|
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
|
||||||
|
else:
|
||||||
|
print(f"Error downloading {k}: sha256s do not match")
|
||||||
|
|
||||||
|
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||||
|
|
||||||
|
|
||||||
|
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# load metadata
|
||||||
|
metadata = metadata.to_dict('records')
|
||||||
|
|
||||||
|
# processing objects
|
||||||
|
records = []
|
||||||
|
max_workers = max_workers or os.cpu_count()
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||||
|
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||||
|
def worker(metadatum):
|
||||||
|
try:
|
||||||
|
local_path = metadatum['local_path']
|
||||||
|
sha256 = metadatum['sha256']
|
||||||
|
file = os.path.join(output_dir, local_path)
|
||||||
|
record = func(file, sha256)
|
||||||
|
if record is not None:
|
||||||
|
records.append(record)
|
||||||
|
pbar.update()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing object {sha256}: {e}")
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
executor.map(worker, metadata)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
except:
|
||||||
|
print("Error happened during processing.")
|
||||||
|
|
||||||
|
return pd.DataFrame.from_records(records)
|
||||||
6
trellis/__init__.py
Executable file
6
trellis/__init__.py
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
from . import models
|
||||||
|
from . import modules
|
||||||
|
from . import pipelines
|
||||||
|
from . import renderers
|
||||||
|
from . import representations
|
||||||
|
from . import utils
|
||||||
58
trellis/datasets/__init__.py
Normal file
58
trellis/datasets/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
__attributes = {
|
||||||
|
'SparseStructure': 'sparse_structure',
|
||||||
|
|
||||||
|
'SparseFeat2Render': 'sparse_feat2render',
|
||||||
|
'SLat2Render':'structured_latent2render',
|
||||||
|
'Slat2RenderGeo':'structured_latent2render',
|
||||||
|
|
||||||
|
'SparseStructureLatent': 'sparse_structure_latent',
|
||||||
|
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
|
||||||
|
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
|
||||||
|
|
||||||
|
'SLat': 'structured_latent',
|
||||||
|
'TextConditionedSLat': 'structured_latent',
|
||||||
|
'ImageConditionedSLat': 'structured_latent',
|
||||||
|
}
|
||||||
|
|
||||||
|
__submodules = []
|
||||||
|
|
||||||
|
__all__ = list(__attributes.keys()) + __submodules
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name not in globals():
|
||||||
|
if name in __attributes:
|
||||||
|
module_name = __attributes[name]
|
||||||
|
module = importlib.import_module(f".{module_name}", __name__)
|
||||||
|
globals()[name] = getattr(module, name)
|
||||||
|
elif name in __submodules:
|
||||||
|
module = importlib.import_module(f".{name}", __name__)
|
||||||
|
globals()[name] = module
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
|
# For Pylance
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from .sparse_structure import SparseStructure
|
||||||
|
|
||||||
|
from .sparse_feat2render import SparseFeat2Render
|
||||||
|
from .structured_latent2render import (
|
||||||
|
SLat2Render,
|
||||||
|
Slat2RenderGeo,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .sparse_structure_latent import (
|
||||||
|
SparseStructureLatent,
|
||||||
|
TextConditionedSparseStructureLatent,
|
||||||
|
ImageConditionedSparseStructureLatent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .structured_latent import (
|
||||||
|
SLat,
|
||||||
|
TextConditionedSLat,
|
||||||
|
ImageConditionedSLat,
|
||||||
|
)
|
||||||
|
|
||||||
96
trellis/models/__init__.py
Normal file
96
trellis/models/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
__attributes = {
|
||||||
|
'SparseStructureEncoder': 'sparse_structure_vae',
|
||||||
|
'SparseStructureDecoder': 'sparse_structure_vae',
|
||||||
|
|
||||||
|
'SparseStructureFlowModel': 'sparse_structure_flow',
|
||||||
|
|
||||||
|
'SLatEncoder': 'structured_latent_vae',
|
||||||
|
'SLatGaussianDecoder': 'structured_latent_vae',
|
||||||
|
'SLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||||
|
'SLatMeshDecoder': 'structured_latent_vae',
|
||||||
|
'ElasticSLatEncoder': 'structured_latent_vae',
|
||||||
|
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
|
||||||
|
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||||
|
'ElasticSLatMeshDecoder': 'structured_latent_vae',
|
||||||
|
|
||||||
|
'SLatFlowModel': 'structured_latent_flow',
|
||||||
|
'ElasticSLatFlowModel': 'structured_latent_flow',
|
||||||
|
}
|
||||||
|
|
||||||
|
__submodules = []
|
||||||
|
|
||||||
|
__all__ = list(__attributes.keys()) + __submodules
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name not in globals():
|
||||||
|
if name in __attributes:
|
||||||
|
module_name = __attributes[name]
|
||||||
|
module = importlib.import_module(f".{module_name}", __name__)
|
||||||
|
globals()[name] = getattr(module, name)
|
||||||
|
elif name in __submodules:
|
||||||
|
module = importlib.import_module(f".{name}", __name__)
|
||||||
|
globals()[name] = module
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
|
def from_pretrained(path: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Load a model from a pretrained checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
|
||||||
|
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
|
||||||
|
**kwargs: Additional arguments for the model constructor.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
|
||||||
|
|
||||||
|
if is_local:
|
||||||
|
config_file = f"{path}.json"
|
||||||
|
model_file = f"{path}.safetensors"
|
||||||
|
else:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
path_parts = path.split('/')
|
||||||
|
repo_id = f'{path_parts[0]}/{path_parts[1]}'
|
||||||
|
model_name = '/'.join(path_parts[2:])
|
||||||
|
config_file = hf_hub_download(repo_id, f"{model_name}.json")
|
||||||
|
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
|
||||||
|
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
model = __getattr__(config['name'])(**config['args'], **kwargs)
|
||||||
|
model.load_state_dict(load_file(model_file))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# For Pylance
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from .sparse_structure_vae import (
|
||||||
|
SparseStructureEncoder,
|
||||||
|
SparseStructureDecoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .sparse_structure_flow import SparseStructureFlowModel
|
||||||
|
|
||||||
|
from .structured_latent_vae import (
|
||||||
|
SLatEncoder,
|
||||||
|
SLatGaussianDecoder,
|
||||||
|
SLatRadianceFieldDecoder,
|
||||||
|
SLatMeshDecoder,
|
||||||
|
ElasticSLatEncoder,
|
||||||
|
ElasticSLatGaussianDecoder,
|
||||||
|
ElasticSLatRadianceFieldDecoder,
|
||||||
|
ElasticSLatMeshDecoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .structured_latent_flow import (
|
||||||
|
SLatFlowModel,
|
||||||
|
ElasticSLatFlowModel,
|
||||||
|
)
|
||||||
4
trellis/models/structured_latent_vae/__init__.py
Normal file
4
trellis/models/structured_latent_vae/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .encoder import SLatEncoder, ElasticSLatEncoder
|
||||||
|
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
|
||||||
|
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
|
||||||
|
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
|
||||||
117
trellis/models/structured_latent_vae/base.py
Normal file
117
trellis/models/structured_latent_vae/base.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
||||||
|
from ...modules import sparse as sp
|
||||||
|
from ...modules.transformer import AbsolutePositionEmbedder
|
||||||
|
from ...modules.sparse.transformer import SparseTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
def block_attn_config(self):
|
||||||
|
"""
|
||||||
|
Return the attention configuration of the model.
|
||||||
|
"""
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
if self.attn_mode == "shift_window":
|
||||||
|
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
||||||
|
elif self.attn_mode == "shift_sequence":
|
||||||
|
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
||||||
|
elif self.attn_mode == "shift_order":
|
||||||
|
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
||||||
|
elif self.attn_mode == "full":
|
||||||
|
yield "full", None, None, None, None
|
||||||
|
elif self.attn_mode == "swin":
|
||||||
|
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
||||||
|
|
||||||
|
|
||||||
|
class SparseTransformerBase(nn.Module):
|
||||||
|
"""
|
||||||
|
Sparse Transformer without output layers.
|
||||||
|
Serve as the base class for encoder and decoder.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
model_channels: int,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: Optional[int] = None,
|
||||||
|
num_head_channels: Optional[int] = 64,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
pe_mode: Literal["ape", "rope"] = "ape",
|
||||||
|
use_fp16: bool = False,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_heads = num_heads or model_channels // num_head_channels
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.pe_mode = pe_mode
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||||
|
|
||||||
|
if pe_mode == "ape":
|
||||||
|
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
||||||
|
|
||||||
|
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
SparseTransformerBlock(
|
||||||
|
model_channels,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_sequence=shift_sequence,
|
||||||
|
shift_window=shift_window,
|
||||||
|
serialize_mode=serialize_mode,
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
use_rope=(pe_mode == "rope"),
|
||||||
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
|
)
|
||||||
|
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
Return the device of the model.
|
||||||
|
"""
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def convert_to_fp16(self) -> None:
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float16.
|
||||||
|
"""
|
||||||
|
self.blocks.apply(convert_module_to_f16)
|
||||||
|
|
||||||
|
def convert_to_fp32(self) -> None:
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float32.
|
||||||
|
"""
|
||||||
|
self.blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
|
def initialize_weights(self) -> None:
|
||||||
|
# Initialize transformer layers:
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
||||||
|
h = self.input_layer(x)
|
||||||
|
if self.pe_mode == "ape":
|
||||||
|
h = h + self.pos_embedder(x.coords[:, 1:])
|
||||||
|
h = h.type(self.dtype)
|
||||||
|
for block in self.blocks:
|
||||||
|
h = block(h)
|
||||||
|
return h
|
||||||
36
trellis/modules/attention/__init__.py
Executable file
36
trellis/modules/attention/__init__.py
Executable file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import *
|
||||||
|
|
||||||
|
BACKEND = 'flash_attn'
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
def __from_env():
|
||||||
|
import os
|
||||||
|
|
||||||
|
global BACKEND
|
||||||
|
global DEBUG
|
||||||
|
|
||||||
|
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
||||||
|
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
||||||
|
|
||||||
|
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
||||||
|
BACKEND = env_attn_backend
|
||||||
|
if env_sttn_debug is not None:
|
||||||
|
DEBUG = env_sttn_debug == '1'
|
||||||
|
|
||||||
|
print(f"[ATTENTION] Using backend: {BACKEND}")
|
||||||
|
|
||||||
|
|
||||||
|
__from_env()
|
||||||
|
|
||||||
|
|
||||||
|
def set_backend(backend: Literal['xformers', 'flash_attn']):
|
||||||
|
global BACKEND
|
||||||
|
BACKEND = backend
|
||||||
|
|
||||||
|
def set_debug(debug: bool):
|
||||||
|
global DEBUG
|
||||||
|
DEBUG = debug
|
||||||
|
|
||||||
|
|
||||||
|
from .full_attn import *
|
||||||
|
from .modules import *
|
||||||
102
trellis/modules/sparse/__init__.py
Executable file
102
trellis/modules/sparse/__init__.py
Executable file
@@ -0,0 +1,102 @@
|
|||||||
|
from typing import *
|
||||||
|
|
||||||
|
BACKEND = 'spconv'
|
||||||
|
DEBUG = False
|
||||||
|
ATTN = 'flash_attn'
|
||||||
|
|
||||||
|
def __from_env():
|
||||||
|
import os
|
||||||
|
|
||||||
|
global BACKEND
|
||||||
|
global DEBUG
|
||||||
|
global ATTN
|
||||||
|
|
||||||
|
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
||||||
|
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
||||||
|
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
||||||
|
if env_sparse_attn is None:
|
||||||
|
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
||||||
|
|
||||||
|
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
||||||
|
BACKEND = env_sparse_backend
|
||||||
|
if env_sparse_debug is not None:
|
||||||
|
DEBUG = env_sparse_debug == '1'
|
||||||
|
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
||||||
|
ATTN = env_sparse_attn
|
||||||
|
|
||||||
|
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
||||||
|
|
||||||
|
|
||||||
|
__from_env()
|
||||||
|
|
||||||
|
|
||||||
|
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
||||||
|
global BACKEND
|
||||||
|
BACKEND = backend
|
||||||
|
|
||||||
|
def set_debug(debug: bool):
|
||||||
|
global DEBUG
|
||||||
|
DEBUG = debug
|
||||||
|
|
||||||
|
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
||||||
|
global ATTN
|
||||||
|
ATTN = attn
|
||||||
|
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
__attributes = {
|
||||||
|
'SparseTensor': 'basic',
|
||||||
|
'sparse_batch_broadcast': 'basic',
|
||||||
|
'sparse_batch_op': 'basic',
|
||||||
|
'sparse_cat': 'basic',
|
||||||
|
'sparse_unbind': 'basic',
|
||||||
|
'SparseGroupNorm': 'norm',
|
||||||
|
'SparseLayerNorm': 'norm',
|
||||||
|
'SparseGroupNorm32': 'norm',
|
||||||
|
'SparseLayerNorm32': 'norm',
|
||||||
|
'SparseReLU': 'nonlinearity',
|
||||||
|
'SparseSiLU': 'nonlinearity',
|
||||||
|
'SparseGELU': 'nonlinearity',
|
||||||
|
'SparseActivation': 'nonlinearity',
|
||||||
|
'SparseLinear': 'linear',
|
||||||
|
'sparse_scaled_dot_product_attention': 'attention',
|
||||||
|
'SerializeMode': 'attention',
|
||||||
|
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
||||||
|
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
||||||
|
'SparseMultiHeadAttention': 'attention',
|
||||||
|
'SparseConv3d': 'conv',
|
||||||
|
'SparseInverseConv3d': 'conv',
|
||||||
|
'SparseDownsample': 'spatial',
|
||||||
|
'SparseUpsample': 'spatial',
|
||||||
|
'SparseSubdivide' : 'spatial'
|
||||||
|
}
|
||||||
|
|
||||||
|
__submodules = ['transformer']
|
||||||
|
|
||||||
|
__all__ = list(__attributes.keys()) + __submodules
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name not in globals():
|
||||||
|
if name in __attributes:
|
||||||
|
module_name = __attributes[name]
|
||||||
|
module = importlib.import_module(f".{module_name}", __name__)
|
||||||
|
globals()[name] = getattr(module, name)
|
||||||
|
elif name in __submodules:
|
||||||
|
module = importlib.import_module(f".{name}", __name__)
|
||||||
|
globals()[name] = module
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
|
# For Pylance
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from .basic import *
|
||||||
|
from .norm import *
|
||||||
|
from .nonlinearity import *
|
||||||
|
from .linear import *
|
||||||
|
from .attention import *
|
||||||
|
from .conv import *
|
||||||
|
from .spatial import *
|
||||||
|
import transformer
|
||||||
4
trellis/modules/sparse/attention/__init__.py
Executable file
4
trellis/modules/sparse/attention/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
from .full_attn import *
|
||||||
|
from .serialized_attn import *
|
||||||
|
from .windowed_attn import *
|
||||||
|
from .modules import *
|
||||||
459
trellis/modules/sparse/basic.py
Executable file
459
trellis/modules/sparse/basic.py
Executable file
@@ -0,0 +1,459 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import BACKEND, DEBUG
|
||||||
|
SparseTensorData = None # Lazy import
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SparseTensor',
|
||||||
|
'sparse_batch_broadcast',
|
||||||
|
'sparse_batch_op',
|
||||||
|
'sparse_cat',
|
||||||
|
'sparse_unbind',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SparseTensor:
|
||||||
|
"""
|
||||||
|
Sparse tensor with support for both torchsparse and spconv backends.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- feats (torch.Tensor): Features of the sparse tensor.
|
||||||
|
- coords (torch.Tensor): Coordinates of the sparse tensor.
|
||||||
|
- shape (torch.Size): Shape of the sparse tensor.
|
||||||
|
- layout (List[slice]): Layout of the sparse tensor for each batch
|
||||||
|
- data (SparseTensorData): Sparse tensor data used for convolusion
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
- Data corresponding to a same batch should be contiguous.
|
||||||
|
- Coords should be in [0, 1023]
|
||||||
|
"""
|
||||||
|
@overload
|
||||||
|
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Lazy import of sparse tensor backend
|
||||||
|
global SparseTensorData
|
||||||
|
if SparseTensorData is None:
|
||||||
|
import importlib
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
||||||
|
|
||||||
|
method_id = 0
|
||||||
|
if len(args) != 0:
|
||||||
|
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
||||||
|
else:
|
||||||
|
method_id = 1 if 'data' in kwargs else 0
|
||||||
|
|
||||||
|
if method_id == 0:
|
||||||
|
feats, coords, shape, layout = args + (None,) * (4 - len(args))
|
||||||
|
if 'feats' in kwargs:
|
||||||
|
feats = kwargs['feats']
|
||||||
|
del kwargs['feats']
|
||||||
|
if 'coords' in kwargs:
|
||||||
|
coords = kwargs['coords']
|
||||||
|
del kwargs['coords']
|
||||||
|
if 'shape' in kwargs:
|
||||||
|
shape = kwargs['shape']
|
||||||
|
del kwargs['shape']
|
||||||
|
if 'layout' in kwargs:
|
||||||
|
layout = kwargs['layout']
|
||||||
|
del kwargs['layout']
|
||||||
|
|
||||||
|
if shape is None:
|
||||||
|
shape = self.__cal_shape(feats, coords)
|
||||||
|
if layout is None:
|
||||||
|
layout = self.__cal_layout(coords, shape[0])
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
self.data = SparseTensorData(feats, coords, **kwargs)
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
spatial_shape = list(coords.max(0)[0] + 1)[1:]
|
||||||
|
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
|
||||||
|
self.data._features = feats
|
||||||
|
elif method_id == 1:
|
||||||
|
data, shape, layout = args + (None,) * (3 - len(args))
|
||||||
|
if 'data' in kwargs:
|
||||||
|
data = kwargs['data']
|
||||||
|
del kwargs['data']
|
||||||
|
if 'shape' in kwargs:
|
||||||
|
shape = kwargs['shape']
|
||||||
|
del kwargs['shape']
|
||||||
|
if 'layout' in kwargs:
|
||||||
|
layout = kwargs['layout']
|
||||||
|
del kwargs['layout']
|
||||||
|
|
||||||
|
self.data = data
|
||||||
|
if shape is None:
|
||||||
|
shape = self.__cal_shape(self.feats, self.coords)
|
||||||
|
if layout is None:
|
||||||
|
layout = self.__cal_layout(self.coords, shape[0])
|
||||||
|
|
||||||
|
self._shape = shape
|
||||||
|
self._layout = layout
|
||||||
|
self._scale = kwargs.get('scale', (1, 1, 1))
|
||||||
|
self._spatial_cache = kwargs.get('spatial_cache', {})
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
try:
|
||||||
|
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
|
||||||
|
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
|
||||||
|
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
|
||||||
|
for i in range(self.shape[0]):
|
||||||
|
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
|
||||||
|
except Exception as e:
|
||||||
|
print('Debugging information:')
|
||||||
|
print(f"- Shape: {self.shape}")
|
||||||
|
print(f"- Layout: {self.layout}")
|
||||||
|
print(f"- Scale: {self._scale}")
|
||||||
|
print(f"- Coords: {self.coords}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __cal_shape(self, feats, coords):
|
||||||
|
shape = []
|
||||||
|
shape.append(coords[:, 0].max().item() + 1)
|
||||||
|
shape.extend([*feats.shape[1:]])
|
||||||
|
return torch.Size(shape)
|
||||||
|
|
||||||
|
def __cal_layout(self, coords, batch_size):
|
||||||
|
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
|
||||||
|
offset = torch.cumsum(seq_len, dim=0)
|
||||||
|
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
|
||||||
|
return layout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> torch.Size:
|
||||||
|
return self._shape
|
||||||
|
|
||||||
|
def dim(self) -> int:
|
||||||
|
return len(self.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout(self) -> List[slice]:
|
||||||
|
return self._layout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feats(self) -> torch.Tensor:
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
return self.data.F
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
return self.data.features
|
||||||
|
|
||||||
|
@feats.setter
|
||||||
|
def feats(self, value: torch.Tensor):
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
self.data.F = value
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
self.data.features = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coords(self) -> torch.Tensor:
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
return self.data.C
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
return self.data.indices
|
||||||
|
|
||||||
|
@coords.setter
|
||||||
|
def coords(self, value: torch.Tensor):
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
self.data.C = value
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
self.data.indices = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.feats.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.feats.device
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs) -> 'SparseTensor':
|
||||||
|
device = None
|
||||||
|
dtype = None
|
||||||
|
if len(args) == 2:
|
||||||
|
device, dtype = args
|
||||||
|
elif len(args) == 1:
|
||||||
|
if isinstance(args[0], torch.dtype):
|
||||||
|
dtype = args[0]
|
||||||
|
else:
|
||||||
|
device = args[0]
|
||||||
|
if 'dtype' in kwargs:
|
||||||
|
assert dtype is None, "to() received multiple values for argument 'dtype'"
|
||||||
|
dtype = kwargs['dtype']
|
||||||
|
if 'device' in kwargs:
|
||||||
|
assert device is None, "to() received multiple values for argument 'device'"
|
||||||
|
device = kwargs['device']
|
||||||
|
|
||||||
|
new_feats = self.feats.to(device=device, dtype=dtype)
|
||||||
|
new_coords = self.coords.to(device=device)
|
||||||
|
return self.replace(new_feats, new_coords)
|
||||||
|
|
||||||
|
def type(self, dtype):
|
||||||
|
new_feats = self.feats.type(dtype)
|
||||||
|
return self.replace(new_feats)
|
||||||
|
|
||||||
|
def cpu(self) -> 'SparseTensor':
|
||||||
|
new_feats = self.feats.cpu()
|
||||||
|
new_coords = self.coords.cpu()
|
||||||
|
return self.replace(new_feats, new_coords)
|
||||||
|
|
||||||
|
def cuda(self) -> 'SparseTensor':
|
||||||
|
new_feats = self.feats.cuda()
|
||||||
|
new_coords = self.coords.cuda()
|
||||||
|
return self.replace(new_feats, new_coords)
|
||||||
|
|
||||||
|
def half(self) -> 'SparseTensor':
|
||||||
|
new_feats = self.feats.half()
|
||||||
|
return self.replace(new_feats)
|
||||||
|
|
||||||
|
def float(self) -> 'SparseTensor':
|
||||||
|
new_feats = self.feats.float()
|
||||||
|
return self.replace(new_feats)
|
||||||
|
|
||||||
|
def detach(self) -> 'SparseTensor':
|
||||||
|
new_coords = self.coords.detach()
|
||||||
|
new_feats = self.feats.detach()
|
||||||
|
return self.replace(new_feats, new_coords)
|
||||||
|
|
||||||
|
def dense(self) -> torch.Tensor:
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
return self.data.dense()
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
return self.data.dense()
|
||||||
|
|
||||||
|
def reshape(self, *shape) -> 'SparseTensor':
|
||||||
|
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
||||||
|
return self.replace(new_feats)
|
||||||
|
|
||||||
|
def unbind(self, dim: int) -> List['SparseTensor']:
|
||||||
|
return sparse_unbind(self, dim)
|
||||||
|
|
||||||
|
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
||||||
|
new_shape = [self.shape[0]]
|
||||||
|
new_shape.extend(feats.shape[1:])
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
new_data = SparseTensorData(
|
||||||
|
feats=feats,
|
||||||
|
coords=self.data.coords if coords is None else coords,
|
||||||
|
stride=self.data.stride,
|
||||||
|
spatial_range=self.data.spatial_range,
|
||||||
|
)
|
||||||
|
new_data._caches = self.data._caches
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
new_data = SparseTensorData(
|
||||||
|
self.data.features.reshape(self.data.features.shape[0], -1),
|
||||||
|
self.data.indices,
|
||||||
|
self.data.spatial_shape,
|
||||||
|
self.data.batch_size,
|
||||||
|
self.data.grid,
|
||||||
|
self.data.voxel_num,
|
||||||
|
self.data.indice_dict
|
||||||
|
)
|
||||||
|
new_data._features = feats
|
||||||
|
new_data.benchmark = self.data.benchmark
|
||||||
|
new_data.benchmark_record = self.data.benchmark_record
|
||||||
|
new_data.thrust_allocator = self.data.thrust_allocator
|
||||||
|
new_data._timer = self.data._timer
|
||||||
|
new_data.force_algo = self.data.force_algo
|
||||||
|
new_data.int8_scale = self.data.int8_scale
|
||||||
|
if coords is not None:
|
||||||
|
new_data.indices = coords
|
||||||
|
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
||||||
|
N, C = dim
|
||||||
|
x = torch.arange(aabb[0], aabb[3] + 1)
|
||||||
|
y = torch.arange(aabb[1], aabb[4] + 1)
|
||||||
|
z = torch.arange(aabb[2], aabb[5] + 1)
|
||||||
|
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
||||||
|
coords = torch.cat([
|
||||||
|
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
||||||
|
coords.repeat(N, 1),
|
||||||
|
], dim=1).to(dtype=torch.int32, device=device)
|
||||||
|
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
||||||
|
return SparseTensor(feats=feats, coords=coords)
|
||||||
|
|
||||||
|
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
||||||
|
new_cache = {}
|
||||||
|
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
|
||||||
|
if k in self._spatial_cache:
|
||||||
|
new_cache[k] = self._spatial_cache[k]
|
||||||
|
if k in other._spatial_cache:
|
||||||
|
if k not in new_cache:
|
||||||
|
new_cache[k] = other._spatial_cache[k]
|
||||||
|
else:
|
||||||
|
new_cache[k].update(other._spatial_cache[k])
|
||||||
|
return new_cache
|
||||||
|
|
||||||
|
def __neg__(self) -> 'SparseTensor':
|
||||||
|
return self.replace(-self.feats)
|
||||||
|
|
||||||
|
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
|
||||||
|
if isinstance(other, torch.Tensor):
|
||||||
|
try:
|
||||||
|
other = torch.broadcast_to(other, self.shape)
|
||||||
|
other = sparse_batch_broadcast(self, other)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if isinstance(other, SparseTensor):
|
||||||
|
other = other.feats
|
||||||
|
new_feats = op(self.feats, other)
|
||||||
|
new_tensor = self.replace(new_feats)
|
||||||
|
if isinstance(other, SparseTensor):
|
||||||
|
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.add)
|
||||||
|
|
||||||
|
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.add)
|
||||||
|
|
||||||
|
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.sub)
|
||||||
|
|
||||||
|
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
|
||||||
|
|
||||||
|
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.mul)
|
||||||
|
|
||||||
|
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.mul)
|
||||||
|
|
||||||
|
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, torch.div)
|
||||||
|
|
||||||
|
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||||
|
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if isinstance(idx, int):
|
||||||
|
idx = [idx]
|
||||||
|
elif isinstance(idx, slice):
|
||||||
|
idx = range(*idx.indices(self.shape[0]))
|
||||||
|
elif isinstance(idx, torch.Tensor):
|
||||||
|
if idx.dtype == torch.bool:
|
||||||
|
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
||||||
|
idx = idx.nonzero().squeeze(1)
|
||||||
|
elif idx.dtype in [torch.int32, torch.int64]:
|
||||||
|
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown index type: {idx.dtype}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown index type: {type(idx)}")
|
||||||
|
|
||||||
|
coords = []
|
||||||
|
feats = []
|
||||||
|
for new_idx, old_idx in enumerate(idx):
|
||||||
|
coords.append(self.coords[self.layout[old_idx]].clone())
|
||||||
|
coords[-1][:, 0] = new_idx
|
||||||
|
feats.append(self.feats[self.layout[old_idx]])
|
||||||
|
coords = torch.cat(coords, dim=0).contiguous()
|
||||||
|
feats = torch.cat(feats, dim=0).contiguous()
|
||||||
|
return SparseTensor(feats=feats, coords=coords)
|
||||||
|
|
||||||
|
def register_spatial_cache(self, key, value) -> None:
|
||||||
|
"""
|
||||||
|
Register a spatial cache.
|
||||||
|
The spatial cache can be any thing you want to cache.
|
||||||
|
The registery and retrieval of the cache is based on current scale.
|
||||||
|
"""
|
||||||
|
scale_key = str(self._scale)
|
||||||
|
if scale_key not in self._spatial_cache:
|
||||||
|
self._spatial_cache[scale_key] = {}
|
||||||
|
self._spatial_cache[scale_key][key] = value
|
||||||
|
|
||||||
|
def get_spatial_cache(self, key=None):
|
||||||
|
"""
|
||||||
|
Get a spatial cache.
|
||||||
|
"""
|
||||||
|
scale_key = str(self._scale)
|
||||||
|
cur_scale_cache = self._spatial_cache.get(scale_key, {})
|
||||||
|
if key is None:
|
||||||
|
return cur_scale_cache
|
||||||
|
return cur_scale_cache.get(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): 1D tensor to broadcast.
|
||||||
|
target (SparseTensor): Sparse tensor to broadcast to.
|
||||||
|
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||||
|
"""
|
||||||
|
coords, feats = input.coords, input.feats
|
||||||
|
broadcasted = torch.zeros_like(feats)
|
||||||
|
for k in range(input.shape[0]):
|
||||||
|
broadcasted[input.layout[k]] = other[k]
|
||||||
|
return broadcasted
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): 1D tensor to broadcast.
|
||||||
|
target (SparseTensor): Sparse tensor to broadcast to.
|
||||||
|
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||||
|
"""
|
||||||
|
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Concatenate a list of sparse tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
|
||||||
|
"""
|
||||||
|
if dim == 0:
|
||||||
|
start = 0
|
||||||
|
coords = []
|
||||||
|
for input in inputs:
|
||||||
|
coords.append(input.coords.clone())
|
||||||
|
coords[-1][:, 0] += start
|
||||||
|
start += input.shape[0]
|
||||||
|
coords = torch.cat(coords, dim=0)
|
||||||
|
feats = torch.cat([input.feats for input in inputs], dim=0)
|
||||||
|
output = SparseTensor(
|
||||||
|
coords=coords,
|
||||||
|
feats=feats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
feats = torch.cat([input.feats for input in inputs], dim=dim)
|
||||||
|
output = inputs[0].replace(feats)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
||||||
|
"""
|
||||||
|
Unbind a sparse tensor along a dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (SparseTensor): Sparse tensor to unbind.
|
||||||
|
dim (int): Dimension to unbind.
|
||||||
|
"""
|
||||||
|
if dim == 0:
|
||||||
|
return [input[i] for i in range(input.shape[0])]
|
||||||
|
else:
|
||||||
|
feats = input.feats.unbind(dim)
|
||||||
|
return [input.replace(f) for f in feats]
|
||||||
21
trellis/modules/sparse/conv/__init__.py
Executable file
21
trellis/modules/sparse/conv/__init__.py
Executable file
@@ -0,0 +1,21 @@
|
|||||||
|
from .. import BACKEND
|
||||||
|
|
||||||
|
|
||||||
|
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
|
||||||
|
|
||||||
|
def __from_env():
|
||||||
|
import os
|
||||||
|
|
||||||
|
global SPCONV_ALGO
|
||||||
|
env_spconv_algo = os.environ.get('SPCONV_ALGO')
|
||||||
|
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
|
||||||
|
SPCONV_ALGO = env_spconv_algo
|
||||||
|
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
|
||||||
|
|
||||||
|
|
||||||
|
__from_env()
|
||||||
|
|
||||||
|
if BACKEND == 'torchsparse':
|
||||||
|
from .conv_torchsparse import *
|
||||||
|
elif BACKEND == 'spconv':
|
||||||
|
from .conv_spconv import *
|
||||||
2
trellis/modules/sparse/transformer/__init__.py
Normal file
2
trellis/modules/sparse/transformer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .blocks import *
|
||||||
|
from .modulated import *
|
||||||
151
trellis/modules/sparse/transformer/blocks.py
Normal file
151
trellis/modules/sparse/transformer/blocks.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..basic import SparseTensor
|
||||||
|
from ..linear import SparseLinear
|
||||||
|
from ..nonlinearity import SparseGELU
|
||||||
|
from ..attention import SparseMultiHeadAttention, SerializeMode
|
||||||
|
from ...norm import LayerNorm32
|
||||||
|
|
||||||
|
|
||||||
|
class SparseFeedForwardNet(nn.Module):
|
||||||
|
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
SparseLinear(channels, int(channels * mlp_ratio)),
|
||||||
|
SparseGELU(approximate="tanh"),
|
||||||
|
SparseLinear(int(channels * mlp_ratio), channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Sparse Transformer block (MSA + FFN).
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_sequence: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
serialize_mode: Optional[SerializeMode] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
ln_affine: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_sequence=shift_sequence,
|
||||||
|
shift_window=shift_window,
|
||||||
|
serialize_mode=serialize_mode,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.mlp = SparseFeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = self.attn(h)
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm2(x.feats))
|
||||||
|
h = self.mlp(h)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseTransformerCrossBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Sparse Transformer cross-attention block (MSA + MCA + FFN).
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
ctx_channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_sequence: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
serialize_mode: Optional[SerializeMode] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
ln_affine: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.self_attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="self",
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_sequence=shift_sequence,
|
||||||
|
shift_window=shift_window,
|
||||||
|
serialize_mode=serialize_mode,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.cross_attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
ctx_channels=ctx_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="cross",
|
||||||
|
attn_mode="full",
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
self.mlp = SparseFeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = self.self_attn(h)
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm2(x.feats))
|
||||||
|
h = self.cross_attn(h, context)
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm3(x.feats))
|
||||||
|
h = self.mlp(h)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor, context: torch.Tensor):
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, context)
|
||||||
2
trellis/modules/transformer/__init__.py
Normal file
2
trellis/modules/transformer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .blocks import *
|
||||||
|
from .modulated import *
|
||||||
182
trellis/modules/transformer/blocks.py
Normal file
182
trellis/modules/transformer/blocks.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..attention import MultiHeadAttention
|
||||||
|
from ..norm import LayerNorm32
|
||||||
|
|
||||||
|
|
||||||
|
class AbsolutePositionEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds spatial positions into vector representations.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels: int, in_channels: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.freq_dim = channels // in_channels // 2
|
||||||
|
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||||
|
self.freqs = 1.0 / (10000 ** self.freqs)
|
||||||
|
|
||||||
|
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create sinusoidal position embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: a 1-D Tensor of N indices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
self.freqs = self.freqs.to(x.device)
|
||||||
|
out = torch.outer(x, self.freqs)
|
||||||
|
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (N, D) tensor of spatial positions
|
||||||
|
"""
|
||||||
|
N, D = x.shape
|
||||||
|
assert D == self.in_channels, "Input dimension must match number of input channels"
|
||||||
|
embed = self._sin_cos_embedding(x.reshape(-1))
|
||||||
|
embed = embed.reshape(N, -1)
|
||||||
|
if embed.shape[1] < self.channels:
|
||||||
|
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardNet(nn.Module):
|
||||||
|
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(int(channels * mlp_ratio), channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer block (MSA + FFN).
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[int] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
ln_affine: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_window=shift_window,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.mlp = FeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = self.attn(h)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm2(x)
|
||||||
|
h = self.mlp(h)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerCrossBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer cross-attention block (MSA + MCA + FFN).
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
ctx_channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
ln_affine: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||||
|
self.self_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="self",
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_window=shift_window,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
ctx_channels=ctx_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="cross",
|
||||||
|
attn_mode="full",
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
self.mlp = FeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor, context: torch.Tensor):
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = self.self_attn(h)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm2(x)
|
||||||
|
h = self.cross_attn(h, context)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm3(x)
|
||||||
|
h = self.mlp(h)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, context: torch.Tensor):
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, context)
|
||||||
|
|
||||||
25
trellis/pipelines/__init__.py
Normal file
25
trellis/pipelines/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from . import samplers
|
||||||
|
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
||||||
|
from .trellis_text_to_3d import TrellisTextTo3DPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def from_pretrained(path: str):
|
||||||
|
"""
|
||||||
|
Load a pipeline from a model folder or a Hugging Face model hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the model. Can be either local path or a Hugging Face model name.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||||
|
|
||||||
|
if is_local:
|
||||||
|
config_file = f"{path}/pipeline.json"
|
||||||
|
else:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
config_file = hf_hub_download(path, "pipeline.json")
|
||||||
|
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return globals()[config['name']].from_pretrained(path)
|
||||||
68
trellis/pipelines/base.py
Normal file
68
trellis/pipelines/base.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .. import models
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
A base class for pipelines.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
models: dict[str, nn.Module] = None,
|
||||||
|
):
|
||||||
|
if models is None:
|
||||||
|
return
|
||||||
|
self.models = models
|
||||||
|
for model in self.models.values():
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path: str) -> "Pipeline":
|
||||||
|
"""
|
||||||
|
Load a pretrained model.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||||
|
|
||||||
|
if is_local:
|
||||||
|
config_file = f"{path}/pipeline.json"
|
||||||
|
else:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
config_file = hf_hub_download(path, "pipeline.json")
|
||||||
|
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
args = json.load(f)['args']
|
||||||
|
|
||||||
|
_models = {}
|
||||||
|
for k, v in args['models'].items():
|
||||||
|
try:
|
||||||
|
_models[k] = models.from_pretrained(f"{path}/{v}")
|
||||||
|
except:
|
||||||
|
_models[k] = models.from_pretrained(v)
|
||||||
|
|
||||||
|
new_pipeline = Pipeline(_models)
|
||||||
|
new_pipeline._pretrained_args = args
|
||||||
|
return new_pipeline
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
for model in self.models.values():
|
||||||
|
if hasattr(model, 'device'):
|
||||||
|
return model.device
|
||||||
|
for model in self.models.values():
|
||||||
|
if hasattr(model, 'parameters'):
|
||||||
|
return next(model.parameters()).device
|
||||||
|
raise RuntimeError("No device found.")
|
||||||
|
|
||||||
|
def to(self, device: torch.device) -> None:
|
||||||
|
for model in self.models.values():
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
def cuda(self) -> None:
|
||||||
|
self.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
def cpu(self) -> None:
|
||||||
|
self.to(torch.device("cpu"))
|
||||||
2
trellis/pipelines/samplers/__init__.py
Executable file
2
trellis/pipelines/samplers/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
from .base import Sampler
|
||||||
|
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
|
||||||
20
trellis/pipelines/samplers/base.py
Normal file
20
trellis/pipelines/samplers/base.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import *
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(ABC):
|
||||||
|
"""
|
||||||
|
A base class for samplers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sample from a model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import *
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierFreeGuidanceSamplerMixin:
|
||||||
|
"""
|
||||||
|
A mixin class for samplers that apply classifier-free guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
|
||||||
|
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||||
|
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||||
|
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||||
31
trellis/renderers/__init__.py
Executable file
31
trellis/renderers/__init__.py
Executable file
@@ -0,0 +1,31 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
__attributes = {
|
||||||
|
'OctreeRenderer': 'octree_renderer',
|
||||||
|
'GaussianRenderer': 'gaussian_render',
|
||||||
|
'MeshRenderer': 'mesh_renderer',
|
||||||
|
}
|
||||||
|
|
||||||
|
__submodules = []
|
||||||
|
|
||||||
|
__all__ = list(__attributes.keys()) + __submodules
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name not in globals():
|
||||||
|
if name in __attributes:
|
||||||
|
module_name = __attributes[name]
|
||||||
|
module = importlib.import_module(f".{module_name}", __name__)
|
||||||
|
globals()[name] = getattr(module, name)
|
||||||
|
elif name in __submodules:
|
||||||
|
module = importlib.import_module(f".{name}", __name__)
|
||||||
|
globals()[name] = module
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
|
# For Pylance
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from .octree_renderer import OctreeRenderer
|
||||||
|
from .gaussian_render import GaussianRenderer
|
||||||
|
from .mesh_renderer import MeshRenderer
|
||||||
4
trellis/representations/__init__.py
Executable file
4
trellis/representations/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
from .radiance_field import Strivec
|
||||||
|
from .octree import DfsOctree as Octree
|
||||||
|
from .gaussian import Gaussian
|
||||||
|
from .mesh import MeshExtractResult
|
||||||
1
trellis/representations/gaussian/__init__.py
Executable file
1
trellis/representations/gaussian/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
from .gaussian_model import Gaussian
|
||||||
1
trellis/representations/mesh/__init__.py
Normal file
1
trellis/representations/mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
|
||||||
1
trellis/representations/octree/__init__.py
Executable file
1
trellis/representations/octree/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
from .octree_dfs import DfsOctree
|
||||||
1
trellis/representations/radiance_field/__init__.py
Executable file
1
trellis/representations/radiance_field/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
from .strivec import Strivec
|
||||||
63
trellis/trainers/__init__.py
Normal file
63
trellis/trainers/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
__attributes = {
|
||||||
|
'BasicTrainer': 'basic',
|
||||||
|
|
||||||
|
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
|
||||||
|
|
||||||
|
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
|
||||||
|
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
|
||||||
|
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
|
||||||
|
|
||||||
|
'FlowMatchingTrainer': 'flow_matching.flow_matching',
|
||||||
|
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||||
|
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||||
|
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||||
|
|
||||||
|
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
|
||||||
|
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||||
|
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||||
|
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||||
|
}
|
||||||
|
|
||||||
|
__submodules = []
|
||||||
|
|
||||||
|
__all__ = list(__attributes.keys()) + __submodules
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
if name not in globals():
|
||||||
|
if name in __attributes:
|
||||||
|
module_name = __attributes[name]
|
||||||
|
module = importlib.import_module(f".{module_name}", __name__)
|
||||||
|
globals()[name] = getattr(module, name)
|
||||||
|
elif name in __submodules:
|
||||||
|
module = importlib.import_module(f".{name}", __name__)
|
||||||
|
globals()[name] = module
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||||
|
return globals()[name]
|
||||||
|
|
||||||
|
|
||||||
|
# For Pylance
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from .basic import BasicTrainer
|
||||||
|
|
||||||
|
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
|
||||||
|
|
||||||
|
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
|
||||||
|
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
|
||||||
|
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
|
||||||
|
|
||||||
|
from .flow_matching.flow_matching import (
|
||||||
|
FlowMatchingTrainer,
|
||||||
|
FlowMatchingCFGTrainer,
|
||||||
|
TextConditionedFlowMatchingCFGTrainer,
|
||||||
|
ImageConditionedFlowMatchingCFGTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .flow_matching.sparse_flow_matching import (
|
||||||
|
SparseFlowMatchingTrainer,
|
||||||
|
SparseFlowMatchingCFGTrainer,
|
||||||
|
TextConditionedSparseFlowMatchingCFGTrainer,
|
||||||
|
ImageConditionedSparseFlowMatchingCFGTrainer,
|
||||||
|
)
|
||||||
451
trellis/trainers/base.py
Normal file
451
trellis/trainers/base.py
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torchvision import utils
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .utils import *
|
||||||
|
from ..utils.general_utils import *
|
||||||
|
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
"""
|
||||||
|
Base class for training.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
models,
|
||||||
|
dataset,
|
||||||
|
*,
|
||||||
|
output_dir,
|
||||||
|
load_dir,
|
||||||
|
step,
|
||||||
|
max_steps,
|
||||||
|
batch_size=None,
|
||||||
|
batch_size_per_gpu=None,
|
||||||
|
batch_split=None,
|
||||||
|
optimizer={},
|
||||||
|
lr_scheduler=None,
|
||||||
|
elastic=None,
|
||||||
|
grad_clip=None,
|
||||||
|
ema_rate=0.9999,
|
||||||
|
fp16_mode='inflat_all',
|
||||||
|
fp16_scale_growth=1e-3,
|
||||||
|
finetune_ckpt=None,
|
||||||
|
log_param_stats=False,
|
||||||
|
prefetch_data=True,
|
||||||
|
i_print=1000,
|
||||||
|
i_log=500,
|
||||||
|
i_sample=10000,
|
||||||
|
i_save=10000,
|
||||||
|
i_ddpcheck=10000,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
|
||||||
|
|
||||||
|
self.models = models
|
||||||
|
self.dataset = dataset
|
||||||
|
self.batch_split = batch_split if batch_split is not None else 1
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self.optimizer_config = optimizer
|
||||||
|
self.lr_scheduler_config = lr_scheduler
|
||||||
|
self.elastic_controller_config = elastic
|
||||||
|
self.grad_clip = grad_clip
|
||||||
|
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
|
||||||
|
self.fp16_mode = fp16_mode
|
||||||
|
self.fp16_scale_growth = fp16_scale_growth
|
||||||
|
self.log_param_stats = log_param_stats
|
||||||
|
self.prefetch_data = prefetch_data
|
||||||
|
if self.prefetch_data:
|
||||||
|
self._data_prefetched = None
|
||||||
|
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.i_print = i_print
|
||||||
|
self.i_log = i_log
|
||||||
|
self.i_sample = i_sample
|
||||||
|
self.i_save = i_save
|
||||||
|
self.i_ddpcheck = i_ddpcheck
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
# Multi-GPU params
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
self.local_rank = dist.get_rank() % torch.cuda.device_count()
|
||||||
|
self.is_master = self.rank == 0
|
||||||
|
else:
|
||||||
|
# Single-GPU params
|
||||||
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
self.local_rank = 0
|
||||||
|
self.is_master = True
|
||||||
|
|
||||||
|
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
|
||||||
|
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
|
||||||
|
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
|
||||||
|
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
|
||||||
|
|
||||||
|
self.init_models_and_more(**kwargs)
|
||||||
|
self.prepare_dataloader(**kwargs)
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
self.step = 0
|
||||||
|
if load_dir is not None and step is not None:
|
||||||
|
self.load(load_dir, step)
|
||||||
|
elif finetune_ckpt is not None:
|
||||||
|
self.finetune_from(finetune_ckpt)
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
|
||||||
|
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.check_ddp()
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
print('\n\nTrainer initialized.')
|
||||||
|
print(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
for _, model in self.models.items():
|
||||||
|
if hasattr(model, 'device'):
|
||||||
|
return model.device
|
||||||
|
return next(list(self.models.values())[0].parameters()).device
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_models_and_more(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize models and more.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_dataloader(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Prepare dataloader.
|
||||||
|
"""
|
||||||
|
self.data_sampler = ResumableSampler(
|
||||||
|
self.dataset,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
self.dataloader = DataLoader(
|
||||||
|
self.dataset,
|
||||||
|
batch_size=self.batch_size_per_gpu,
|
||||||
|
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
persistent_workers=True,
|
||||||
|
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||||
|
sampler=self.data_sampler,
|
||||||
|
)
|
||||||
|
self.data_iterator = cycle(self.dataloader)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, load_dir, step=0):
|
||||||
|
"""
|
||||||
|
Load a checkpoint.
|
||||||
|
Should be called by all processes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self):
|
||||||
|
"""
|
||||||
|
Save a checkpoint.
|
||||||
|
Should be called only by the rank 0 process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def finetune_from(self, finetune_ckpt):
|
||||||
|
"""
|
||||||
|
Finetune from a checkpoint.
|
||||||
|
Should be called by all processes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Run a snapshot of the model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def visualize_sample(self, sample):
|
||||||
|
"""
|
||||||
|
Convert a sample to an image.
|
||||||
|
"""
|
||||||
|
if hasattr(self.dataset, 'visualize_sample'):
|
||||||
|
return self.dataset.visualize_sample(sample)
|
||||||
|
else:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def snapshot_dataset(self, num_samples=100):
|
||||||
|
"""
|
||||||
|
Sample images from the dataset.
|
||||||
|
"""
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
self.dataset,
|
||||||
|
batch_size=num_samples,
|
||||||
|
num_workers=0,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||||
|
)
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
data = recursive_to_device(data, self.device)
|
||||||
|
vis = self.visualize_sample(data)
|
||||||
|
if isinstance(vis, dict):
|
||||||
|
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
|
||||||
|
else:
|
||||||
|
save_cfg = [('dataset', vis)]
|
||||||
|
for name, image in save_cfg:
|
||||||
|
utils.save_image(
|
||||||
|
image,
|
||||||
|
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
|
||||||
|
nrow=int(np.sqrt(num_samples)),
|
||||||
|
normalize=True,
|
||||||
|
value_range=self.dataset.value_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
|
||||||
|
"""
|
||||||
|
Sample images from the model.
|
||||||
|
NOTE: This function should be called by all processes.
|
||||||
|
"""
|
||||||
|
if self.is_master:
|
||||||
|
print(f'\nSampling {num_samples} images...', end='')
|
||||||
|
|
||||||
|
if suffix is None:
|
||||||
|
suffix = f'step{self.step:07d}'
|
||||||
|
|
||||||
|
# Assign tasks
|
||||||
|
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
|
||||||
|
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
|
||||||
|
|
||||||
|
# Preprocess images
|
||||||
|
for key in list(samples.keys()):
|
||||||
|
if samples[key]['type'] == 'sample':
|
||||||
|
vis = self.visualize_sample(samples[key]['value'])
|
||||||
|
if isinstance(vis, dict):
|
||||||
|
for k, v in vis.items():
|
||||||
|
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
|
||||||
|
del samples[key]
|
||||||
|
else:
|
||||||
|
samples[key] = {'value': vis, 'type': 'image'}
|
||||||
|
|
||||||
|
# Gather results
|
||||||
|
if self.world_size > 1:
|
||||||
|
for key in samples.keys():
|
||||||
|
samples[key]['value'] = samples[key]['value'].contiguous()
|
||||||
|
if self.is_master:
|
||||||
|
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
|
||||||
|
else:
|
||||||
|
all_images = []
|
||||||
|
dist.gather(samples[key]['value'], all_images, dst=0)
|
||||||
|
if self.is_master:
|
||||||
|
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
|
||||||
|
|
||||||
|
# Save images
|
||||||
|
if self.is_master:
|
||||||
|
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
|
||||||
|
for key in samples.keys():
|
||||||
|
if samples[key]['type'] == 'image':
|
||||||
|
utils.save_image(
|
||||||
|
samples[key]['value'],
|
||||||
|
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
||||||
|
nrow=int(np.sqrt(num_samples)),
|
||||||
|
normalize=True,
|
||||||
|
value_range=self.dataset.value_range,
|
||||||
|
)
|
||||||
|
elif samples[key]['type'] == 'number':
|
||||||
|
min = samples[key]['value'].min()
|
||||||
|
max = samples[key]['value'].max()
|
||||||
|
images = (samples[key]['value'] - min) / (max - min)
|
||||||
|
images = utils.make_grid(
|
||||||
|
images,
|
||||||
|
nrow=int(np.sqrt(num_samples)),
|
||||||
|
normalize=False,
|
||||||
|
)
|
||||||
|
save_image_with_notes(
|
||||||
|
images,
|
||||||
|
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
||||||
|
notes=f'{key} min: {min}, max: {max}',
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
print(' Done.')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_ema(self):
|
||||||
|
"""
|
||||||
|
Update exponential moving average.
|
||||||
|
Should only be called by the rank 0 process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check_ddp(self):
|
||||||
|
"""
|
||||||
|
Check if DDP is working properly.
|
||||||
|
Should be called by all process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def training_losses(**mb_data):
|
||||||
|
"""
|
||||||
|
Compute training losses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
"""
|
||||||
|
Load data.
|
||||||
|
"""
|
||||||
|
if self.prefetch_data:
|
||||||
|
if self._data_prefetched is None:
|
||||||
|
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||||
|
data = self._data_prefetched
|
||||||
|
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if self.batch_split == 1:
|
||||||
|
data_list = [data]
|
||||||
|
else:
|
||||||
|
batch_size = list(data.values())[0].shape[0]
|
||||||
|
data_list = [
|
||||||
|
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
|
||||||
|
for i in range(self.batch_split)
|
||||||
|
]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
data_list = data
|
||||||
|
else:
|
||||||
|
raise ValueError('Data must be a dict or a list of dicts.')
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_step(self, data_list):
|
||||||
|
"""
|
||||||
|
Run a training step.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Run training.
|
||||||
|
"""
|
||||||
|
if self.is_master:
|
||||||
|
print('\nStarting training...')
|
||||||
|
self.snapshot_dataset()
|
||||||
|
if self.step == 0:
|
||||||
|
self.snapshot(suffix='init')
|
||||||
|
else: # resume
|
||||||
|
self.snapshot(suffix=f'resume_step{self.step:07d}')
|
||||||
|
|
||||||
|
log = []
|
||||||
|
time_last_print = 0.0
|
||||||
|
time_elapsed = 0.0
|
||||||
|
while self.step < self.max_steps:
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
data_list = self.load_data()
|
||||||
|
step_log = self.run_step(data_list)
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
time_elapsed += time_end - time_start
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
# Print progress
|
||||||
|
if self.is_master and self.step % self.i_print == 0:
|
||||||
|
speed = self.i_print / (time_elapsed - time_last_print) * 3600
|
||||||
|
columns = [
|
||||||
|
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
|
||||||
|
f'Elapsed: {time_elapsed / 3600:.2f} h',
|
||||||
|
f'Speed: {speed:.2f} steps/h',
|
||||||
|
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
|
||||||
|
]
|
||||||
|
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
|
||||||
|
time_last_print = time_elapsed
|
||||||
|
|
||||||
|
# Check ddp
|
||||||
|
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
|
||||||
|
self.check_ddp()
|
||||||
|
|
||||||
|
# Sample images
|
||||||
|
if self.step % self.i_sample == 0:
|
||||||
|
self.snapshot()
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
log.append((self.step, {}))
|
||||||
|
|
||||||
|
# Log time
|
||||||
|
log[-1][1]['time'] = {
|
||||||
|
'step': time_end - time_start,
|
||||||
|
'elapsed': time_elapsed,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log losses
|
||||||
|
if step_log is not None:
|
||||||
|
log[-1][1].update(step_log)
|
||||||
|
|
||||||
|
# Log scale
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
log[-1][1]['scale'] = self.scaler.get_scale()
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
log[-1][1]['log_scale'] = self.log_scale
|
||||||
|
|
||||||
|
# Save log
|
||||||
|
if self.step % self.i_log == 0:
|
||||||
|
## save to log file
|
||||||
|
log_str = '\n'.join([
|
||||||
|
f'{step}: {json.dumps(log)}' for step, log in log
|
||||||
|
])
|
||||||
|
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
|
||||||
|
log_file.write(log_str + '\n')
|
||||||
|
|
||||||
|
# show with mlflow
|
||||||
|
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
|
||||||
|
log_show = dict_reduce(log_show, lambda x: np.mean(x))
|
||||||
|
log_show = dict_flatten(log_show, sep='/')
|
||||||
|
for key, value in log_show.items():
|
||||||
|
self.writer.add_scalar(key, value, self.step)
|
||||||
|
log = []
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if self.step % self.i_save == 0:
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
self.snapshot(suffix='final')
|
||||||
|
self.writer.close()
|
||||||
|
print('Training finished.')
|
||||||
|
|
||||||
|
def profile(self, wait=2, warmup=3, active=5):
|
||||||
|
"""
|
||||||
|
Profile the training loop.
|
||||||
|
"""
|
||||||
|
with torch.profiler.profile(
|
||||||
|
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as prof:
|
||||||
|
for _ in range(wait + warmup + active):
|
||||||
|
self.run_step()
|
||||||
|
prof.step()
|
||||||
|
|
||||||
438
trellis/trainers/basic.py
Normal file
438
trellis/trainers/basic.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
import os
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .utils import *
|
||||||
|
from .base import Trainer
|
||||||
|
from ..utils.general_utils import *
|
||||||
|
from ..utils.dist_utils import *
|
||||||
|
from ..utils import grad_clip_utils, elastic_utils
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTrainer(Trainer):
|
||||||
|
"""
|
||||||
|
Trainer for basic training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models (dict[str, nn.Module]): Models to train.
|
||||||
|
dataset (torch.utils.data.Dataset): Dataset.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
load_dir (str): Load directory.
|
||||||
|
step (int): Step to load.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||||
|
batch_split (int): Split batch with gradient accumulation.
|
||||||
|
max_steps (int): Max steps.
|
||||||
|
optimizer (dict): Optimizer config.
|
||||||
|
lr_scheduler (dict): Learning rate scheduler config.
|
||||||
|
elastic (dict): Elastic memory management config.
|
||||||
|
grad_clip (float or dict): Gradient clip config.
|
||||||
|
ema_rate (float or list): Exponential moving average rates.
|
||||||
|
fp16_mode (str): FP16 mode.
|
||||||
|
- None: No FP16.
|
||||||
|
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||||
|
- 'amp': Automatic mixed precision.
|
||||||
|
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||||
|
finetune_ckpt (dict): Finetune checkpoint.
|
||||||
|
log_param_stats (bool): Log parameter stats.
|
||||||
|
i_print (int): Print interval.
|
||||||
|
i_log (int): Log interval.
|
||||||
|
i_sample (int): Sample interval.
|
||||||
|
i_save (int): Save interval.
|
||||||
|
i_ddpcheck (int): DDP check interval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
lines = []
|
||||||
|
lines.append(self.__class__.__name__)
|
||||||
|
lines.append(f' - Models:')
|
||||||
|
for name, model in self.models.items():
|
||||||
|
lines.append(f' - {name}: {model.__class__.__name__}')
|
||||||
|
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
|
||||||
|
lines.append(f' - Dataloader:')
|
||||||
|
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
|
||||||
|
lines.append(f' - Num workers: {self.dataloader.num_workers}')
|
||||||
|
lines.append(f' - Number of steps: {self.max_steps}')
|
||||||
|
lines.append(f' - Number of GPUs: {self.world_size}')
|
||||||
|
lines.append(f' - Batch size: {self.batch_size}')
|
||||||
|
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
|
||||||
|
lines.append(f' - Batch split: {self.batch_split}')
|
||||||
|
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
|
||||||
|
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
|
||||||
|
if self.lr_scheduler_config is not None:
|
||||||
|
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
|
||||||
|
lines.append(f' - EMA rate: {self.ema_rate}')
|
||||||
|
lines.append(f' - FP16 mode: {self.fp16_mode}')
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
def init_models_and_more(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize models and more.
|
||||||
|
"""
|
||||||
|
if self.world_size > 1:
|
||||||
|
# Prepare distributed data parallel
|
||||||
|
self.training_models = {
|
||||||
|
name: DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[self.local_rank],
|
||||||
|
output_device=self.local_rank,
|
||||||
|
bucket_cap_mb=128,
|
||||||
|
find_unused_parameters=False
|
||||||
|
)
|
||||||
|
for name, model in self.models.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.training_models = self.models
|
||||||
|
|
||||||
|
# Build master params
|
||||||
|
self.model_params = sum(
|
||||||
|
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
|
||||||
|
, [])
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
self.master_params = self.model_params
|
||||||
|
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
self.master_params = make_master_params(self.model_params)
|
||||||
|
self.fp16_scale_growth = self.fp16_scale_growth
|
||||||
|
self.log_scale = 20.0
|
||||||
|
elif self.fp16_mode is None:
|
||||||
|
self.master_params = self.model_params
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
|
||||||
|
|
||||||
|
# Build EMA params
|
||||||
|
if self.is_master:
|
||||||
|
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
if hasattr(torch.optim, self.optimizer_config['name']):
|
||||||
|
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
|
||||||
|
else:
|
||||||
|
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
|
||||||
|
|
||||||
|
# Initalize learning rate scheduler
|
||||||
|
if self.lr_scheduler_config is not None:
|
||||||
|
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
|
||||||
|
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
|
||||||
|
|
||||||
|
# Initialize elastic memory controller
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
|
||||||
|
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
|
||||||
|
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
|
||||||
|
for model in self.models.values():
|
||||||
|
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
|
||||||
|
model.register_memory_controller(self.elastic_controller)
|
||||||
|
|
||||||
|
# Initialize gradient clipper
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
if isinstance(self.grad_clip, (float, int)):
|
||||||
|
self.grad_clip = float(self.grad_clip)
|
||||||
|
else:
|
||||||
|
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
|
||||||
|
|
||||||
|
def _master_params_to_state_dicts(self, master_params):
|
||||||
|
"""
|
||||||
|
Convert master params to dict of state_dicts.
|
||||||
|
"""
|
||||||
|
if self.fp16_mode == 'inflat_all':
|
||||||
|
master_params = unflatten_master_params(self.model_params, master_params)
|
||||||
|
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
|
||||||
|
master_params_names = sum(
|
||||||
|
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
||||||
|
, [])
|
||||||
|
for i, (model_name, param_name) in enumerate(master_params_names):
|
||||||
|
state_dicts[model_name][param_name] = master_params[i]
|
||||||
|
return state_dicts
|
||||||
|
|
||||||
|
def _state_dicts_to_master_params(self, master_params, state_dicts):
|
||||||
|
"""
|
||||||
|
Convert a state_dict to master params.
|
||||||
|
"""
|
||||||
|
master_params_names = sum(
|
||||||
|
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
||||||
|
, [])
|
||||||
|
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
|
||||||
|
if self.fp16_mode == 'inflat_all':
|
||||||
|
model_params_to_master_params(params, master_params)
|
||||||
|
else:
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
master_params[i].data.copy_(param.data)
|
||||||
|
|
||||||
|
def load(self, load_dir, step=0):
|
||||||
|
"""
|
||||||
|
Load a checkpoint.
|
||||||
|
Should be called by all processes.
|
||||||
|
"""
|
||||||
|
if self.is_master:
|
||||||
|
print(f'\nLoading checkpoint from step {step}...', end='')
|
||||||
|
|
||||||
|
model_ckpts = {}
|
||||||
|
for name, model in self.models.items():
|
||||||
|
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
|
||||||
|
model_ckpts[name] = model_ckpt
|
||||||
|
model.load_state_dict(model_ckpt)
|
||||||
|
if self.fp16_mode == 'inflat_all':
|
||||||
|
model.convert_to_fp16()
|
||||||
|
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
||||||
|
del model_ckpts
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
for i, ema_rate in enumerate(self.ema_rate):
|
||||||
|
ema_ckpts = {}
|
||||||
|
for name, model in self.models.items():
|
||||||
|
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
|
||||||
|
ema_ckpts[name] = ema_ckpt
|
||||||
|
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
|
||||||
|
del ema_ckpts
|
||||||
|
|
||||||
|
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
|
||||||
|
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
|
||||||
|
self.step = misc_ckpt['step']
|
||||||
|
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
self.scaler.load_state_dict(misc_ckpt['scaler'])
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
self.log_scale = misc_ckpt['log_scale']
|
||||||
|
if self.lr_scheduler_config is not None:
|
||||||
|
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
|
||||||
|
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
||||||
|
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
|
||||||
|
del misc_ckpt
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
dist.barrier()
|
||||||
|
if self.is_master:
|
||||||
|
print(' Done.')
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.check_ddp()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""
|
||||||
|
Save a checkpoint.
|
||||||
|
Should be called only by the rank 0 process.
|
||||||
|
"""
|
||||||
|
assert self.is_master, 'save() should be called only by the rank 0 process.'
|
||||||
|
print(f'\nSaving checkpoint at step {self.step}...', end='')
|
||||||
|
|
||||||
|
model_ckpts = self._master_params_to_state_dicts(self.master_params)
|
||||||
|
for name, model_ckpt in model_ckpts.items():
|
||||||
|
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
|
||||||
|
|
||||||
|
for i, ema_rate in enumerate(self.ema_rate):
|
||||||
|
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
|
||||||
|
for name, ema_ckpt in ema_ckpts.items():
|
||||||
|
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
|
||||||
|
|
||||||
|
misc_ckpt = {
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
'step': self.step,
|
||||||
|
'data_sampler': self.data_sampler.state_dict(),
|
||||||
|
}
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
misc_ckpt['scaler'] = self.scaler.state_dict()
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
misc_ckpt['log_scale'] = self.log_scale
|
||||||
|
if self.lr_scheduler_config is not None:
|
||||||
|
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
|
||||||
|
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
||||||
|
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
|
||||||
|
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
|
||||||
|
print(' Done.')
|
||||||
|
|
||||||
|
def finetune_from(self, finetune_ckpt):
|
||||||
|
"""
|
||||||
|
Finetune from a checkpoint.
|
||||||
|
Should be called by all processes.
|
||||||
|
"""
|
||||||
|
if self.is_master:
|
||||||
|
print('\nFinetuning from:')
|
||||||
|
for name, path in finetune_ckpt.items():
|
||||||
|
print(f' - {name}: {path}')
|
||||||
|
|
||||||
|
model_ckpts = {}
|
||||||
|
for name, model in self.models.items():
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
if name in finetune_ckpt:
|
||||||
|
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
|
||||||
|
for k, v in model_ckpt.items():
|
||||||
|
if model_ckpt[k].shape != model_state_dict[k].shape:
|
||||||
|
if self.is_master:
|
||||||
|
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
|
||||||
|
model_ckpt[k] = model_state_dict[k]
|
||||||
|
model_ckpts[name] = model_ckpt
|
||||||
|
model.load_state_dict(model_ckpt)
|
||||||
|
if self.fp16_mode == 'inflat_all':
|
||||||
|
model.convert_to_fp16()
|
||||||
|
else:
|
||||||
|
if self.is_master:
|
||||||
|
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
|
||||||
|
model_ckpts[name] = model_state_dict
|
||||||
|
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
||||||
|
del model_ckpts
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
dist.barrier()
|
||||||
|
if self.is_master:
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.check_ddp()
|
||||||
|
|
||||||
|
def update_ema(self):
|
||||||
|
"""
|
||||||
|
Update exponential moving average.
|
||||||
|
Should only be called by the rank 0 process.
|
||||||
|
"""
|
||||||
|
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
|
||||||
|
for i, ema_rate in enumerate(self.ema_rate):
|
||||||
|
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
|
||||||
|
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
|
||||||
|
|
||||||
|
def check_ddp(self):
|
||||||
|
"""
|
||||||
|
Check if DDP is working properly.
|
||||||
|
Should be called by all process.
|
||||||
|
"""
|
||||||
|
if self.is_master:
|
||||||
|
print('\nPerforming DDP check...')
|
||||||
|
|
||||||
|
if self.is_master:
|
||||||
|
print('Checking if parameters are consistent across processes...')
|
||||||
|
dist.barrier()
|
||||||
|
try:
|
||||||
|
for p in self.master_params:
|
||||||
|
# split to avoid OOM
|
||||||
|
for i in range(0, p.numel(), 10000000):
|
||||||
|
sub_size = min(10000000, p.numel() - i)
|
||||||
|
sub_p = p.detach().view(-1)[i:i+sub_size]
|
||||||
|
# gather from all processes
|
||||||
|
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
|
||||||
|
dist.all_gather(sub_p_gather, sub_p)
|
||||||
|
# check if equal
|
||||||
|
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
|
||||||
|
except AssertionError as e:
|
||||||
|
if self.is_master:
|
||||||
|
print(f'\n\033[91mError: {e}\033[0m')
|
||||||
|
print('DDP check failed.')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
if self.is_master:
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
def run_step(self, data_list):
|
||||||
|
"""
|
||||||
|
Run a training step.
|
||||||
|
"""
|
||||||
|
step_log = {'loss': {}, 'status': {}}
|
||||||
|
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
|
||||||
|
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
|
||||||
|
|
||||||
|
# Train
|
||||||
|
losses = []
|
||||||
|
statuses = []
|
||||||
|
elastic_controller_logs = []
|
||||||
|
zero_grad(self.model_params)
|
||||||
|
for i, mb_data in enumerate(data_list):
|
||||||
|
## sync at the end of each batch split
|
||||||
|
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
|
||||||
|
with nested_contexts(*sync_contexts), elastic_controller_context():
|
||||||
|
with amp_context():
|
||||||
|
loss, status = self.training_losses(**mb_data)
|
||||||
|
l = loss['loss'] / len(data_list)
|
||||||
|
## backward
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
self.scaler.scale(l).backward()
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
scaled_l = l * (2 ** self.log_scale)
|
||||||
|
scaled_l.backward()
|
||||||
|
else:
|
||||||
|
l.backward()
|
||||||
|
## log
|
||||||
|
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
||||||
|
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
elastic_controller_logs.append(self.elastic_controller.log())
|
||||||
|
## gradient clip
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
model_grads_to_master_grads(self.model_params, self.master_params)
|
||||||
|
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
||||||
|
if isinstance(self.grad_clip, float):
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
|
||||||
|
else:
|
||||||
|
grad_norm = self.grad_clip(self.master_params)
|
||||||
|
if torch.isfinite(grad_norm):
|
||||||
|
statuses[-1]['grad_norm'] = grad_norm.item()
|
||||||
|
## step
|
||||||
|
if self.fp16_mode == 'amp':
|
||||||
|
prev_scale = self.scaler.get_scale()
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
elif self.fp16_mode == 'inflat_all':
|
||||||
|
prev_scale = 2 ** self.log_scale
|
||||||
|
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
||||||
|
if self.grad_clip is None:
|
||||||
|
model_grads_to_master_grads(self.model_params, self.master_params)
|
||||||
|
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
||||||
|
self.optimizer.step()
|
||||||
|
master_params_to_model_params(self.model_params, self.master_params)
|
||||||
|
self.log_scale += self.fp16_scale_growth
|
||||||
|
else:
|
||||||
|
self.log_scale -= 1
|
||||||
|
else:
|
||||||
|
prev_scale = 1.0
|
||||||
|
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
||||||
|
self.optimizer.step()
|
||||||
|
else:
|
||||||
|
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
||||||
|
## adjust learning rate
|
||||||
|
if self.lr_scheduler_config is not None:
|
||||||
|
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
|
||||||
|
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
|
||||||
|
if self.elastic_controller_config is not None:
|
||||||
|
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
|
||||||
|
|
||||||
|
# Check grad and norm of each param
|
||||||
|
if self.log_param_stats:
|
||||||
|
param_norms = {}
|
||||||
|
param_grads = {}
|
||||||
|
for name, param in self.backbone.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param_norms[name] = param.norm().item()
|
||||||
|
if param.grad is not None and torch.isfinite(param.grad).all():
|
||||||
|
param_grads[name] = param.grad.norm().item() / prev_scale
|
||||||
|
step_log['param_norms'] = param_norms
|
||||||
|
step_log['param_grads'] = param_grads
|
||||||
|
|
||||||
|
# Update exponential moving average
|
||||||
|
if self.is_master:
|
||||||
|
self.update_ema()
|
||||||
|
|
||||||
|
return step_log
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from ....utils.general_utils import dict_foreach
|
||||||
|
from ....pipelines import samplers
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierFreeGuidanceMixin:
|
||||||
|
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.p_uncond = p_uncond
|
||||||
|
|
||||||
|
def get_cond(self, cond, neg_cond=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the conditioning data.
|
||||||
|
"""
|
||||||
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||||
|
|
||||||
|
if self.p_uncond > 0:
|
||||||
|
# randomly drop the class label
|
||||||
|
def get_batch_size(cond):
|
||||||
|
if isinstance(cond, torch.Tensor):
|
||||||
|
return cond.shape[0]
|
||||||
|
elif isinstance(cond, list):
|
||||||
|
return len(cond)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||||
|
|
||||||
|
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
||||||
|
B = get_batch_size(ref_cond)
|
||||||
|
|
||||||
|
def select(cond, neg_cond, mask):
|
||||||
|
if isinstance(cond, torch.Tensor):
|
||||||
|
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
||||||
|
return torch.where(mask, neg_cond, cond)
|
||||||
|
elif isinstance(cond, list):
|
||||||
|
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||||
|
|
||||||
|
mask = list(np.random.rand(B) < self.p_uncond)
|
||||||
|
if not isinstance(cond, dict):
|
||||||
|
cond = select(cond, neg_cond, mask)
|
||||||
|
else:
|
||||||
|
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
||||||
|
|
||||||
|
return cond
|
||||||
|
|
||||||
|
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the conditioning data for inference.
|
||||||
|
"""
|
||||||
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||||
|
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
||||||
|
|
||||||
|
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
||||||
|
"""
|
||||||
|
Get the sampler for the diffusion process.
|
||||||
|
"""
|
||||||
|
return samplers.FlowEulerCfgSampler(self.sigma_min)
|
||||||
0
trellis/utils/__init__.py
Executable file
0
trellis/utils/__init__.py
Executable file
Reference in New Issue
Block a user