commit 59570f8812732e1e483d41dc4f51ea43f644fad3 Author: zcr Date: Tue Mar 17 11:28:52 2026 +0800 1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3e9d1c1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,437 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + + +*.glb +*.ply +*.mtl +*.step +*.mp4 +*.svg +*.png +*.ipynb +*.zip + +# 模型文件(重点忽略) +*.pth +*.ckpt +*.bin +*.pt +*.h5 + +# 压缩包 +*.zip +*.rar +*.7z +*.tar.gz + +# 图片/视频(非必要则忽略,必要则后续用LFS) +*.png +*.svg +*.jpg +*.jpeg +*.mp4 +*.avi + +# 日志/缓存/临时文件 +*.log +*.tmp +__pycache__/ +*.pyc +.DS_Store diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..dd2b5eb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "trellis/representations/mesh/flexicubes"] + path = trellis/representations/mesh/flexicubes + url = https://github.com/MaxtirError/FlexiCubes.git diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..b6b1ecf --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,10 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 已忽略包含查询文件的默认文件夹 +/queries/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ diff --git a/0_glb_to_obj.py b/0_glb_to_obj.py new file mode 100644 index 0000000..b3d585d --- /dev/null +++ b/0_glb_to_obj.py @@ -0,0 +1,68 @@ +import bpy +import sys +import os +import argparse + +glb_input = "trellis_out/sample.ply" +obj_output = "obj_output/sample.obj" + + +def parse_args(): + argv = sys.argv + argv = argv[argv.index("--") + 1:] if "--" in argv else [] + p = argparse.ArgumentParser() + p.add_argument("-i", "--input", default=True) + p.add_argument("-o", "--output", required=True) + return p.parse_args(argv) + + +def clean(): + bpy.ops.object.select_all(action='SELECT') + bpy.ops.object.delete(use_global=False) + + +def main(): + args = parse_args() + clean() + + bpy.ops.import_scene.gltf(filepath=args.input) # 支持 .glb/.gltf + + # 只保留 mesh,并合并成一个(符合你“单资产”假设) + meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"] + if not meshes: + raise RuntimeError("No mesh objects found in GLB/GLTF.") + bpy.ops.object.select_all(action='DESELECT') + for o in meshes: + o.select_set(True) + bpy.context.view_layer.objects.active = meshes[0] + if len(meshes) > 1: + bpy.ops.object.join() + + obj = bpy.context.view_layer.objects.active + obj.select_set(True) + + out_dir = os.path.dirname(args.output) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + # bpy.ops.wm.obj_export( + # filepath=args.output, + # export_selected_objects=True, + # export_normals=True, + # export_uv=True, + # ) + + bpy.ops.export_scene.obj( + filepath=args.output, + use_selection=True, # 只导出选中的对象 + use_normals=True, + use_uvs=True, + use_materials=True, # 根据需要改(你原来没导出材质) + axis_forward='-Z', # glTF 通常是 -Z forward,Y up,可根据需要调整 + axis_up='Y', + # keep_vertex_order=True, # 可选:保持顶点顺序 + ) + + +if __name__ == "__main__": + main() diff --git a/0_remesh.py b/0_remesh.py new file mode 100644 index 0000000..877498a --- /dev/null +++ b/0_remesh.py @@ -0,0 +1,260 @@ +import bpy +import sys +import os +import argparse +import tempfile + + +def clean_scene(): + bpy.ops.object.select_all(action='SELECT') + bpy.ops.object.delete(use_global=False) + + +def _parse_args_blender(argv): + """Blender 参数在 `--` 之后。""" + if "--" in argv: + return argv[argv.index("--") + 1:] + return [] + + +def _is_glb_gltf(path: str) -> bool: + ext = os.path.splitext(path)[1].lower() + return ext in [".glb", ".gltf"] + + +def _is_obj(path: str) -> bool: + return os.path.splitext(path)[1].lower() == ".obj" + + +def import_mesh_any(input_path: str): + """ + 导入 OBJ 或 GLB/GLTF,并将场景中所有 MESH 合并为一个 active mesh 对象返回。 + """ + clean_scene() + + if _is_glb_gltf(input_path): + print(f"[import] GLB/GLTF: {input_path}") + bpy.ops.import_scene.gltf(filepath=input_path) + elif _is_obj(input_path): + print(f"[import] OBJ: {input_path}") + bpy.ops.wm.obj_import(filepath=input_path) + else: + raise RuntimeError(f"Unsupported input format: {input_path}") + + meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"] + if not meshes: + raise RuntimeError("No mesh objects found after import.") + + bpy.ops.object.select_all(action='DESELECT') + for o in meshes: + o.select_set(True) + bpy.context.view_layer.objects.active = meshes[0] + + if len(meshes) > 1: + bpy.ops.object.join() + + obj = bpy.context.view_layer.objects.active + obj.select_set(True) + return obj + + +def export_obj_selected(output_obj: str): + out_dir = os.path.dirname(output_obj) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + bpy.ops.wm.obj_export( + filepath=output_obj, + export_selected_objects=True, + export_normals=True, + export_uv=True, + ) + + +def fix_mesh(obj): + """修复网格使其更符合 QuadriFlow 的输入要求""" + bpy.context.view_layer.objects.active = obj + obj.select_set(True) + + bpy.ops.object.mode_set(mode='EDIT') + bpy.ops.mesh.select_all(action='SELECT') + + bpy.ops.mesh.remove_doubles(threshold=0.001) + bpy.ops.mesh.delete_loose() + + bpy.ops.mesh.select_all(action='DESELECT') + bpy.ops.mesh.select_non_manifold() + + selected_count = sum(1 for v in obj.data.vertices if v.select) + if selected_count > 0: + print(f"[fix_mesh] found {selected_count} non-manifold verts, deleting...") + bpy.ops.mesh.delete(type='VERT') + + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.normals_make_consistent(inside=False) + + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.fill_holes(sides=0) + + bpy.ops.mesh.remove_doubles(threshold=0.001) + bpy.ops.mesh.dissolve_degenerate(threshold=0.0001) + + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.normals_make_consistent(inside=False) + + bpy.ops.object.mode_set(mode='OBJECT') + + +def quadriflow_remesh_obj( + input_obj: str, + output_obj: str, + face_count: int = 8000, + use_mesh_symmetry: bool = False, + use_preserve_sharp: bool = False, + use_preserve_boundary: bool = True, + use_voxel_preprocess: bool = True, + voxel_size: float = 0.008, +): + """只处理 OBJ 输入(你的原流程)""" + obj = import_mesh_any(input_obj) # 这里会走 OBJ import + + print(f"[import] object={obj.name}, verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}") + print("[mesh] fixing...") + fix_mesh(obj) + + if use_voxel_preprocess: + print(f"[voxel] preprocess voxel_size={voxel_size} ...") + voxel_mod = obj.modifiers.new(name="Voxel", type='REMESH') + voxel_mod.mode = 'VOXEL' + voxel_mod.voxel_size = voxel_size + voxel_mod.use_smooth_shade = True + voxel_mod.use_remove_disconnected = False + bpy.ops.object.modifier_apply(modifier=voxel_mod.name) + + print(f"[voxel] after: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}") + print("[mesh] fixing after voxel...") + fix_mesh(obj) + print(f"[mesh] after fix2: verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}") + + print(f"[quadriflow] target_faces={face_count} ...") + bpy.context.view_layer.objects.active = obj + obj.select_set(True) + + try: + bpy.ops.object.quadriflow_remesh( + use_mesh_symmetry=use_mesh_symmetry, + use_preserve_sharp=use_preserve_sharp, + use_preserve_boundary=use_preserve_boundary, + smooth_normals=False, + mode='FACES', + target_faces=face_count, + seed=0 + ) + print("[quadriflow] done.") + except Exception as e: + print(f"[quadriflow] error: {e} (continue to export)") + + print(f"[export] {output_obj}") + export_obj_selected(output_obj) + print(f"[done] verts={len(obj.data.vertices)}, faces={len(obj.data.polygons)}") + + +def glb_to_tmp_obj(input_glb: str) -> str: + """ + 在 Blender 内把 GLB/GLTF 导入后导出为临时 OBJ(按你的要求“先转OBJ再走后续”)。 + """ + obj = import_mesh_any(input_glb) # 会走 gltf import + tmp_obj = tempfile.mktemp(suffix=".obj") + print(f"[glb2obj] export tmp obj -> {tmp_obj}") + export_obj_selected(tmp_obj) + return tmp_obj + + +def run_pipeline( + input_path: str, + output_obj: str, + face_count: int, + mesh_symmetry: bool, + preserve_sharp: bool, + preserve_boundary: bool, + no_voxel_preprocess: bool, + voxel_size: float, +): + """ + 统一入口: + - 输入 OBJ:直接 remesh + - 输入 GLB/GLTF:先转临时 OBJ,再 remesh + """ + tmp_obj = None + try: + if _is_glb_gltf(input_path): + tmp_obj = glb_to_tmp_obj(input_path) + quadriflow_remesh_obj( + input_obj=tmp_obj, + output_obj=output_obj, + face_count=face_count, + use_mesh_symmetry=mesh_symmetry, + use_preserve_sharp=preserve_sharp, + use_preserve_boundary=preserve_boundary, + use_voxel_preprocess=(not no_voxel_preprocess), + voxel_size=voxel_size, + ) + elif _is_obj(input_path): + quadriflow_remesh_obj( + input_obj=input_path, + output_obj=output_obj, + face_count=face_count, + use_mesh_symmetry=mesh_symmetry, + use_preserve_sharp=preserve_sharp, + use_preserve_boundary=preserve_boundary, + use_voxel_preprocess=(not no_voxel_preprocess), + voxel_size=voxel_size, + ) + else: + raise RuntimeError(f"Unsupported input format: {input_path}") + finally: + if tmp_obj and os.path.exists(tmp_obj): + try: + os.remove(tmp_obj) + print(f"[cleanup] removed tmp obj: {tmp_obj}") + except Exception: + pass + + +def main(): + parser = argparse.ArgumentParser(description="GLB/GLTF/OBJ -> (optional tmp OBJ) -> QuadriFlow remesh -> OBJ") + parser.add_argument("-i", "--input", default="trellis_out/sample.glb", help="Input file path (.obj/.glb/.gltf).") + parser.add_argument("-o", "--output", default="output/test.obj", help="Output OBJ file path.") + parser.add_argument("--face_count", type=int, default=8000, help="Target quad face count.") + parser.add_argument("--mesh_symmetry", action="store_true", help="Enable mesh symmetry.") + parser.add_argument("--preserve_sharp", action="store_true", help="Preserve sharp edges.") + parser.add_argument("--preserve_boundary", action="store_true", help="Preserve boundary edges.") + parser.add_argument("--no_voxel_preprocess", action="store_true", help="Disable voxel preprocess.") + parser.add_argument("--voxel_size", type=float, default=0.008, help="Voxel size for preprocess (smaller=denser).") + + args = parser.parse_args(_parse_args_blender(sys.argv)) + + if not os.path.exists(args.input): + raise FileNotFoundError(f"Input not found: {args.input}") + + run_pipeline( + input_path=args.input, + output_obj=args.output, + face_count=args.face_count, + mesh_symmetry=args.mesh_symmetry, + preserve_sharp=args.preserve_sharp, + preserve_boundary=args.preserve_boundary, + no_voxel_preprocess=args.no_voxel_preprocess, + voxel_size=args.voxel_size, + ) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"[fatal] {e}") + import traceback + + traceback.print_exc() + raise diff --git a/1_obj_to_step.py b/1_obj_to_step.py new file mode 100644 index 0000000..1976947 --- /dev/null +++ b/1_obj_to_step.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Fast OBJ/GLB/GLTF -> STEP converter for FreeCAD (headless-friendly) + +参数说明: + tol : 网格→形状拟合公差,越大越快(默认 0.5) + sew_tol : Sewing 缝合公差(环境缺少 Sewing 会自动跳过),默认 0.05 + solid : 1 尝试固化为 Solid(耗时大),0 仅导出壳(默认 0) + split : 1 使用 trimesh 分块清理(若环境有 trimesh),0 关闭(默认 1) + max_comp : 分块后最多保留的组件数(默认 64) + +注意: + - 若 FreeCAD 发行版缺少 MeshPart.meshToShape 或 Part.Sewing,脚本会自动回退/跳过。 + - 若导出 STEP 仍失败,会兜底导出 .brep,方便后续用 pythonocc/OCP 再转 STEP。 +""" +import secrets +import sys, os, time, tempfile +from datetime import datetime + +import FreeCAD as App +import Part, Mesh + +# 尝试加载 MeshPart(某些构建可能没有) +try: + import MeshPart + + HAS_MESHPART = True +except Exception: + MeshPart = None + HAS_MESHPART = False + +# 可选:trimesh 用于更稳的分块/清理与 GLB/GLTF 转临时 OBJ +try: + import trimesh + + HAS_TRIMESH = True +except Exception: + HAS_TRIMESH = False + + +def log(msg): + try: + App.Console.PrintMessage(str(msg) + "\n") + except Exception: + print(msg, flush=True) + + +# ---------------------- 网格清理/转换核心 ---------------------- + +def _freecad_mesh_quick_clean(fc_mesh): + """FreeCAD 轻量清理:不大改拓扑,提升稳健性。""" + for fn in ( + "removeDuplicatedFacets", + "removeDuplicatedPoints", + "removeDegeneratedFacets", + "removeNonManifoldEdges", + ): + try: + getattr(fc_mesh, fn)() + except Exception: + pass + return fc_mesh + + +def _mesh_to_shape_fast(fc_mesh, tol=0.5, sew_tol=0.05, to_solid=False): + """ + Mesh -> B-Rep(更稳健版本): + - 若有 MeshPart.meshToShape 优先使用(共面聚合,通常更快) + - 否则退回 Part.Shape.makeShapeFromMesh + - Sewing 若不可用则跳过(部分构建缺少 Part.Sewing) + - 仅在 to_solid=True 且闭壳时尝试 Solidify(耗时较大) + """ + # 1) mesh -> shape + shape = None + if HAS_MESHPART and hasattr(MeshPart, "meshToShape"): + try: + shape = MeshPart.meshToShape(fc_mesh, tol) + if isinstance(shape, tuple): # 某些版本返回 (shape, mapping) + shape = shape[0] + except Exception as e: + log(f"meshToShape 失败,退回 makeShapeFromMesh: {e}") + + if shape is None: + shape = Part.Shape() + shape.makeShapeFromMesh(fc_mesh.Topology, tol) + + # 2) sewing(若 Part.Sewing 不存在则跳过) + try: + SewingCls = getattr(Part, "Sewing", None) + if SewingCls is not None: + sew = SewingCls() + ok = False + for setter in ("tolerance", "SetTolerance"): + try: + if setter == "tolerance": + sew.tolerance = sew_tol + else: + sew.SetTolerance(sew_tol) + ok = True + break + except Exception: + continue + + sew.add(shape) + sew.perform() + + new_shape = None + if hasattr(sew, "SewedShape"): + new_shape = sew.SewedShape + try: + if hasattr(new_shape, "isNull") and new_shape.isNull(): + new_shape = None + except Exception: + pass + if new_shape is None and hasattr(sew, "sewedShape"): + try: + new_shape = sew.sewedShape() + except Exception: + pass + if new_shape is None and hasattr(sew, "sewShape"): + try: + new_shape = sew.sewShape() + except Exception: + pass + + if new_shape is not None: + shape = new_shape + else: + log("跳过 Sewing:当前 FreeCAD Part 模块缺少 Sewing 绑定") + except Exception as e: + log(f"Sewing 失败,继续原 shape: {e}") + + # 3) 可选固化 + if to_solid: + try: + if not getattr(shape, "Solids", []): + shell = Part.Shell(shape.Faces) + if hasattr(shell, "isClosed") and shell.isClosed(): + shape = Part.makeSolid(shell) + except Exception as e: + log(f"Solidify 失败,保留壳: {e}") + + # 4) 轻量拓扑精简 + try: + shape = shape.removeSplitter() + except Exception: + pass + try: + if hasattr(shape, "fixTolerance"): + shape.fixTolerance(sew_tol) + except Exception: + pass + + return shape + + +def _export_shapes_to_step(shapes, step_file_path): + """ + 稳健导出: + - 确保输出目录存在 + - 过滤无效 shape + - 尝试合并 Compound(有时更稳定) + - Part.export 失败则 Import.export;再不行兜底写 .brep + """ + out_dir = os.path.dirname(step_file_path) or "." + os.makedirs(out_dir, exist_ok=True) + + valid = [] + for s in shapes: + try: + if hasattr(s, "isNull") and s.isNull(): + continue + if hasattr(s, "Faces") and len(s.Faces) == 0: + continue + valid.append(s) + except Exception: + continue + if not valid: + raise RuntimeError("没有可导出的有效 shape(可能网格太脏或全部构造失败)") + + try: + compound = Part.makeCompound(valid) + except Exception: + compound = valid[0] + + doc = App.newDocument("Obj2StepFast") + try: + if hasattr(doc, "suppressRecompute"): + doc.suppressRecompute() + + obj = doc.addObject("Part::Feature", "Compound") + obj.Shape = compound + + if hasattr(doc, "recompute"): + doc.recompute() + + try: + Part.export([obj], step_file_path) + return + except Exception as e1: + log(f"Part.export 失败:{e1}") + + try: + import Import + Import.export([obj], step_file_path) + return + except Exception as e2: + log(f"Import.export 也失败:{e2}") + + brep_path = os.path.splitext(step_file_path)[0] + ".brep" + try: + obj.Shape.exportBrep(brep_path) + raise RuntimeError( + f"STEP 导出失败,已兜底写入 BREP:{brep_path}\n" + f"可用 pythonocc/OCP 将 BREP 转为 STEP。" + ) + except Exception as e3: + raise RuntimeError(f"STEP 与 BREP 均导出失败:{e3}") + finally: + App.closeDocument(doc.Name) + + +# ---------------------- 顶层转换逻辑 ---------------------- + +def _glb_gltf_to_tmp_obj_if_needed(in_path): + ext = os.path.splitext(in_path)[1].lower() + if ext in (".glb", ".gltf"): + if not HAS_TRIMESH: + raise RuntimeError("输入为 GLB/GLTF,但 FreeCAD Python 环境未安装 trimesh。请安装后再试。") + mesh = trimesh.load(in_path, force="mesh") + if mesh.is_empty: + raise RuntimeError(f"GLB/GLTF 读取失败或为空:{in_path}") + tmp = tempfile.mktemp(suffix=".obj") + mesh.export(tmp) + return tmp, True + return in_path, False + + +# 新增:生成带时间戳+随机串的唯一文件名(避免覆盖) +def generate_unique_step_filename(base_name: str, output_dir: str) -> str: + """ + 生成格式示例: + model_20260316_131245_8f4a2c1d.step + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + random_part = secrets.token_hex(4) # 8 位随机十六进制 + filename = f"{base_name}_{timestamp}_{random_part}.step" + return os.path.join(output_dir, filename) + + +def obj_to_step_fast(input_path, + step_file_path=None, + tol=0.5, + sew_tol=0.05, + to_solid=False, + split_components=True, + max_components=64): + t0 = time.perf_counter() + if not os.path.exists(input_path): + raise FileNotFoundError(input_path) + + base = os.path.splitext(os.path.basename(input_path))[0] + + # ── 决定最终输出路径 ── + if step_file_path is None: + # 没有指定路径 → 用 output_dir 或 input 同目录 + 唯一文件名 + out_dir = os.path.dirname(input_path) + step_file_path = generate_unique_step_filename(base, out_dir) + elif os.path.isdir(step_file_path): + # 用户传的是目录 → 在该目录生成唯一文件名 + step_file_path = generate_unique_step_filename(base, step_file_path) + else: + # 用户明确指定了 .step / .stp 文件路径 → 直接使用(不加随机后缀) + if not step_file_path.lower().endswith((".step", ".stp")): + step_file_path += ".step" + # 这里不做额外唯一性处理,尊重用户意图(可能想覆盖) + + log(f"目标 STEP 文件:{step_file_path}") + + # 下面部分保持原样 ........................................ + shapes = [] + tmp_created_paths = [] + + src_for_freecad, is_tmp = _glb_gltf_to_tmp_obj_if_needed(input_path) + if is_tmp: + tmp_created_paths.append(src_for_freecad) + + if HAS_TRIMESH and split_components: + m = None + try: + m = trimesh.load(src_for_freecad, force="mesh") + except Exception as e: + log(f"trimesh 读取失败(转纯 FreeCAD 路径):{e}") + m = None + + if m is not None and not m.is_empty: + try: + m.remove_duplicate_faces() + m.remove_degenerate_faces() + m.remove_unreferenced_vertices() + m.fix_normals() + except Exception: + pass + + parts = m.split(only_watertight=False) + if len(parts) == 0: + parts = [m] + + if len(parts) > max_components: + parts = sorted(parts, key=lambda x: x.faces.shape[0], reverse=True)[:max_components] + + log(f"Trimesh分块:{len(parts)} 个组件") + + for i, pm in enumerate(parts): + tmp_obj = tempfile.mktemp(suffix=f"_{i}.obj") + pm.export(tmp_obj) + tmp_created_paths.append(tmp_obj) + + fc_mesh = Mesh.Mesh(tmp_obj) + _freecad_mesh_quick_clean(fc_mesh) + shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid) + shapes.append(shp) + else: + fc_mesh = Mesh.Mesh(src_for_freecad) + log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}") + _freecad_mesh_quick_clean(fc_mesh) + shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid) + shapes.append(shp) + else: + fc_mesh = Mesh.Mesh(src_for_freecad) + log(f"载入Mesh: {src_for_freecad}, 三角数={len(fc_mesh.Facets)}") + _freecad_mesh_quick_clean(fc_mesh) + shp = _mesh_to_shape_fast(fc_mesh, tol=tol, sew_tol=sew_tol, to_solid=to_solid) + shapes.append(shp) + + _export_shapes_to_step(shapes, step_file_path) + + for p in tmp_created_paths: + try: + os.remove(p) + except Exception: + pass + + t1 = time.perf_counter() + log(f"STEP导出: {step_file_path} (耗时 {t1 - t0:.2f}s)") + return step_file_path + + +def main(obj_file_path, + output_dir=None, + tol=0.5, + sew_tol=0.05, + to_solid=False, + split_components=True, + max_components=64): + if output_dir is None: + output_dir = os.path.dirname(obj_file_path) + + os.makedirs(output_dir, exist_ok=True) + base = os.path.splitext(os.path.basename(obj_file_path))[0] + + # 使用唯一文件名生成 step_path + step_path = generate_unique_step_filename(base, output_dir) + + return obj_to_step_fast( + obj_file_path, + step_file_path=step_path, # 这里传唯一路径 + tol=tol, + sew_tol=sew_tol, + to_solid=to_solid, + split_components=split_components, + max_components=max_components, + ) + + +# ── CLI 部分也做相应调整 ── +def _parse_cli(argv): + in_path = os.path.abspath(argv[1]) + out_arg = os.path.abspath(argv[2]) if len(argv) > 2 else None + + tol = float(argv[3]) if len(argv) > 3 else 0.5 + sew_tol = float(argv[4]) if len(argv) > 4 else 0.05 + solid = bool(int(argv[5])) if len(argv) > 5 else False + split = bool(int(argv[6])) if len(argv) > 6 else True + max_comp = int(argv[7]) if len(argv) > 7 else 64 + + return in_path, out_arg, tol, sew_tol, solid, split, max_comp + + +# ---------------------- CLI entry ---------------------- + +if __name__ == "__main__": + in_path, out_arg, tol, sew_tol, solid, split, max_comp = _parse_cli(sys.argv) + + # 判断第二个参数是目录还是 step 文件 + out_is_step = out_arg.lower().endswith((".step", ".stp")) + if (not out_is_step) or os.path.isdir(out_arg): + out_dir = out_arg + os.makedirs(out_dir, exist_ok=True) + result = main( + in_path, + output_dir=out_dir, + tol=tol, + sew_tol=sew_tol, + to_solid=solid, + split_components=split, + max_components=max_comp, + ) + else: + os.makedirs(os.path.dirname(out_arg), exist_ok=True) + result = obj_to_step_fast( + in_path, + step_file_path=out_arg, + tol=tol, + sew_tol=sew_tol, + to_solid=solid, + split_components=split, + max_components=max_comp, + ) + + log("完成!") + log(f"STEP: {result}") diff --git a/2_step_to_svg.py b/2_step_to_svg.py new file mode 100644 index 0000000..7e6be34 --- /dev/null +++ b/2_step_to_svg.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +STEP -> OCC triangulate -> PyTorch3D orthographic 3 views -> SVG (silhouette + visible/hidden feature edges) +Then combine 3 SVGs into a single SVG and a PNG. + +Views: + - top : from +Z towards origin (view direction -Z) + - left : from -X towards origin (view direction +X) + - front: from +Y towards origin (view direction -Y) + +Outputs (default in --out_dir): + - top.svg, left.svg, front.svg + - combined_views.svg + - combined_views.png +""" + +import os, sys, math, argparse, tempfile +import numpy as np +import torch +from collections import defaultdict + +# ---------- PyTorch3D ---------- +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + RasterizationSettings, MeshRasterizer, + MeshRenderer, BlendParams, SoftSilhouetteShader, + look_at_view_transform, OrthographicCameras +) +from pytorch3d.renderer.mesh.rasterizer import Fragments + +# ---------- SVG / Image ---------- +import svgwrite +from skimage import measure + +# ---------- pythonocc (OCC) ---------- +from OCC.Core.STEPControl import STEPControl_Reader +from OCC.Core.TopAbs import TopAbs_FACE +from OCC.Core.TopExp import TopExp_Explorer +from OCC.Core.BRep import BRep_Tool +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh +from OCC.Core.TopLoc import TopLoc_Location +from OCC.Core.ShapeFix import ShapeFix_Shape + +# ---------- Combine SVG ---------- +import svgutils.transform as st +from svgutils.compose import Unit +import cairosvg + + +# ---------------- Utils ---------------- + +def autodevice(): + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +def read_step_solid(path: str): + reader = STEPControl_Reader() + stat = reader.ReadFile(path) + if stat != 1: + raise RuntimeError(f"STEP 读取失败: {path}") + reader.TransferRoots() + shape = reader.OneShape() + fixer = ShapeFix_Shape(shape) + fixer.Perform() + return fixer.Shape() + + +def triangulate(shape, deflection=0.5, angle=0.5): + """OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))""" + BRepMesh_IncrementalMesh(shape, deflection, False, angle, True) + verts, faces, v_off = [], [], 0 + exp = TopExp_Explorer(shape, TopAbs_FACE) + while exp.More(): + face = exp.Current() + loc = TopLoc_Location() + tri = BRep_Tool.Triangulation(face, loc) + if tri is not None: + nb_nodes = tri.NbNodes() + has_nodes_arr = hasattr(tri, "Nodes") + for i in range(1, nb_nodes + 1): + p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i) + p = p.Transformed(loc.Transformation()) + verts.append([p.X(), p.Y(), p.Z()]) + nb_tris = tri.NbTriangles() + has_tris_arr = hasattr(tri, "Triangles") + for i in range(1, nb_tris + 1): + t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i) + a, b, c = t.Get() + faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1]) + v_off += nb_nodes + exp.Next() + if not verts or not faces: + raise RuntimeError("三角化为空,尝试减小 --defl") + return np.asarray(verts, np.float32), np.asarray(faces, np.int64) + + +def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0): + c = verts.mean(axis=0, keepdims=True) + v0 = verts - c + s = np.max(np.abs(v0)) + s = max(s, 1e-12) + return v0 / s * unit_scale + + +def build_mesh(V_np, F_np, device): + V = torch.from_numpy(V_np).to(device) + F = torch.from_numpy(F_np).to(device) + return Meshes(verts=[V], faces=[F]) + + +def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor): + v0 = verts[faces[:, 0]]; + v1 = verts[faces[:, 1]]; + v2 = verts[faces[:, 2]] + n = torch.cross(v1 - v0, v2 - v0, dim=1) + return torch.nn.functional.normalize(n, dim=1, eps=1e-12) + + +def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float): + """二面角>=angle_deg 视为特征棱;边界边必取。""" + thr = math.cos(math.radians(max(0.0, angle_deg))) + e2f = defaultdict(list) + F = faces.shape[0] + for f in range(F): + i, j, k = faces[f].tolist() + for a, b in ((i, j), (j, k), (k, i)): + e = (a, b) if a < b else (b, a) + e2f[e].append(f) + feat = [] + for e, fl in e2f.items(): + if len(fl) == 1: + feat.append(e) + elif len(fl) == 2: + n0, n1 = face_normals[fl[0]], face_normals[fl[1]] + cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item() + if cosv <= thr: + feat.append(e) + return feat + + +def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1): + rs = RasterizationSettings( + image_size=image_size, + blur_radius=0.0, + faces_per_pixel=faces_per_pixel, + cull_backfaces=True + ) + rast = MeshRasterizer(cameras=cameras, raster_settings=rs) + with torch.no_grad(): + frags: Fragments = rast(mesh, cameras=cameras) + return frags, frags.zbuf[0, ..., 0] + + +def render_silhouette(mesh: Meshes, cameras, image_size=1600): + rs = RasterizationSettings( + image_size=image_size, + blur_radius=1e-6, + faces_per_pixel=50, + cull_backfaces=True + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs), + shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4)) + ) + with torch.no_grad(): + img = renderer(mesh, cameras=cameras) # (1,H,W,4) + return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0) + + +def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int): + with torch.no_grad(): + scr = cameras.transform_points_screen( + points_world[None, ...], + image_size=((image_size, image_size),) + ) + return scr[0] # (N,3) + + +def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4): + H = W = image_size + vis, hid = [], [] + scr = project_points_to_screen(cameras, verts_world, image_size) + zmin = zbuf.detach().cpu().numpy() + for i, j in edges_idx: + p0, p1 = scr[i], scr[j] + m = 0.5 * (p0 + p1) + x = int(np.clip(m[0].item(), 0, W - 1)) + y = int(np.clip(m[1].item(), 0, H - 1)) + z_proj = m[2].item() + z_ref = zmin[y, x] + seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item()) + if np.isfinite(z_ref) and (z_proj <= z_ref + eps): + vis.append(seg) + else: + hid.append(seg) + return vis, hid + + +def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1): + contours = measure.find_contours(alpha, level=threshold) + polys = [] + for cnt in contours: + if len(cnt) < 8: + continue + cnt = cnt[::max(1, step)] + polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y) + return polys + + +def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"): + dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin)) + dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none")) + + # silhouette fill (optional) + for poly in silhouettes: + if len(poly) < 3: + continue + path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"] + for i in range(1, len(poly)): + path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}") + path.append("Z") + dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none")) + + # visible edges + for (x0, y0, x1, y1) in edges_vis: + dwg.add(dwg.line( + (x0 + margin, y0 + margin), + (x1 + margin, y1 + margin), + stroke="black", + stroke_width=stroke + )) + + # hidden edges + for (x0, y0, x1, y1) in edges_hid: + dwg.add(dwg.line( + (x0 + margin, y0 + margin), + (x1 + margin, y1 + margin), + stroke="black", + stroke_width=max(1.0, 0.75 * stroke), + stroke_dasharray=[6, 6], + opacity=0.85 + )) + + dwg.save() + + +# ---------------- Cameras ---------------- + +def make_camera(view: str, device, dist=2.7): + """ + look_at_view_transform spherical params for engineering views. + """ + if view == "top": + # from +Z to origin (view -Z) + R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device) + elif view == "left": + # from -X to origin (view +X) + R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device) + elif view == "front": + # from +Y to origin (view -Y) + R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device) + else: + raise ValueError("view must be one of ['top','left','front']") + return OrthographicCameras(R=R, T=T, device=device) + + +# ---------------- Combine 3 SVGs ---------------- + +def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image): + """ + 横向拼接 3 张 SVG,并额外导出 PNG。 + """ + out_dir = os.path.dirname(output_svg) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + view1 = st.fromfile(view1_path) + view2 = st.fromfile(view2_path) + view3 = st.fromfile(view3_path) + + r1 = view1.getroot() + r2 = view2.getroot() + r3 = view3.getroot() + + w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()] + w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()] + w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()] + + max_w = max(w1, w2, w3) + max_h = max(h1, h2, h3) + + combined_width = max_w * 3 + 40 + combined_height = max_h + 20 + + combined = st.SVGFigure(Unit(combined_width), Unit(combined_height)) + + x_offset = 10 + y_offset = 10 + + r1.moveto(x_offset, y_offset) + x_offset += max_w + 20 + r2.moveto(x_offset, y_offset) + x_offset += max_w + 20 + r3.moveto(x_offset, y_offset) + + combined.append([r1, r2, r3]) + combined.save(output_svg) + + # SVG -> PNG + cairosvg.svg2png(url=output_svg, write_to=output_image) + + +# ---------------- Main ---------------- + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--step_path", type=str, required=False, help="输入 STEP 文件", default="output/sample_clean.step") + ap.add_argument("--out_dir", type=str, required=False, help="输出目录", default="output") + + ap.add_argument("--res", type=int, default=1400, help="渲染分辨率(正方形)") + ap.add_argument("--defl", type=float, default=1.5, help="三角化线性偏差") + ap.add_argument("--ang", type=float, default=65.0, help="二面角阈值(°)") + ap.add_argument("--stroke", type=float, default=1.3, help="SVG 线宽(px)") + ap.add_argument("--scale", type=float, default=1.0, help="归一化后模型最大边长") + + ap.add_argument("--no_combine", action="store_true", help="只导出三视图 SVG,不合成") + ap.add_argument("--combined_svg", type=str, default=None, help="合成 SVG 输出路径(默认 out_dir/combined_views.svg)") + ap.add_argument("--combined_png", type=str, default=None, help="合成 PNG 输出路径(默认 out_dir/combined_views.png)") + + args = ap.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + device = autodevice() + print(f"[Device] {device}") + + # STEP -> mesh + shape = read_step_solid(args.step_path) + V_np, F_np = triangulate(shape, deflection=args.defl) + V_np = normalize_mesh_np(V_np, unit_scale=args.scale) + + mesh = build_mesh(V_np, F_np, device) + verts = mesh.verts_packed() + faces = mesh.faces_packed() + fn = compute_face_normals(verts, faces) + feat_edges = extract_feature_edges(faces, fn, angle_deg=args.ang) + + view_paths = {} + for name in ["top", "left", "front"]: + print(f"[View] {name}") + cam = make_camera(name, device) + alpha = render_silhouette(mesh, cam, image_size=args.res) + _, zbuf = raster_fragments(mesh, cam, image_size=args.res, faces_per_pixel=1) + + e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, args.res, eps=1e-4) + polys = trace_silhouettes(alpha, threshold=0.5, step=1) + + out_svg = os.path.join(args.out_dir, f"{name}.svg") + svg_from_view(out_svg, args.res, args.res, polys, e_vis, e_hid, + stroke=args.stroke, margin=10, fill="none") + view_paths[name] = out_svg + print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})") + + if args.no_combine: + print("[Done] 3 views exported (no combine).") + return + + combined_svg = args.combined_svg or os.path.join(args.out_dir, "combined_views.svg") + combined_png = args.combined_png or os.path.join(args.out_dir, "combined_views.png") + + # 注意:这里组合顺序按你第二段示例:front / top / left + combine_svgs( + view1_path=view_paths["front"], + view2_path=view_paths["top"], + view3_path=view_paths["left"], + output_svg=combined_svg, + output_image=combined_png + ) + print(f"[Combined] SVG: {combined_svg}") + print(f"[Combined] PNG: {combined_png}") + print("[Done] All outputs exported.") + + +if __name__ == "__main__": + main() diff --git a/app.py b/app.py new file mode 100644 index 0000000..bd85946 --- /dev/null +++ b/app.py @@ -0,0 +1,403 @@ +import gradio as gr +from gradio_litmodel3d import LitModel3D + +import os +import shutil +from typing import * +import torch +import numpy as np +import imageio +from easydict import EasyDict as edict +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +os.makedirs(TMP_DIR, exist_ok=True) + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: + """ + Preprocess a list of input images. + + Args: + images (List[Tuple[Image.Image, str]]): The input images. + + Returns: + List[Image.Image]: The preprocessed images. + """ + images = [image[0] for image in images] + processed_images = [pipeline.preprocess_image(image) for image in images] + return processed_images + + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def image_to_3d( + image: Image.Image, + multiimages: List[Tuple[Image.Image, str]], + is_multiimage: bool, + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + multiimage_algo: Literal["multidiffusion", "stochastic"], + req: gr.Request, +) -> Tuple[dict, str]: + """ + Convert an image to a 3D model. + + Args: + image (Image.Image): The input image. + multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode. + is_multiimage (bool): Whether is in multi-image mode. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + slat_guidance_strength (float): The guidance strength for structured latent generation. + slat_sampling_steps (int): The number of sampling steps for structured latent generation. + multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + if not is_multiimage: + outputs = pipeline.run( + image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + else: + outputs = pipeline.run_multi_image( + [image[0] for image in multiimages], + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(user_dir, 'sample.mp4') + imageio.mimsave(video_path, video, fps=15) + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + torch.cuda.empty_cache() + return state, video_path + + +def extract_glb( + state: dict, + mesh_simplify: float, + texture_size: int, + req: gr.Request, +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, mesh = unpack_state(state) + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + glb_path = os.path.join(user_dir, 'sample.glb') + glb.export(glb_path) + torch.cuda.empty_cache() + return glb_path, glb_path + + +def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: + """ + Extract a Gaussian file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + + Returns: + str: The path to the extracted Gaussian file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, _ = unpack_state(state) + gaussian_path = os.path.join(user_dir, 'sample.ply') + gs.save_ply(gaussian_path) + torch.cuda.empty_cache() + return gaussian_path, gaussian_path + + +def prepare_multi_example() -> List[Image.Image]: + multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) + images = [] + for case in multi_case: + _images = [] + for i in range(1, 4): + img = Image.open(f'assets/example_multi_image/{case}_{i}.png') + W, H = img.size + img = img.resize((int(W / H * 512), 512)) + _images.append(np.array(img)) + images.append(Image.fromarray(np.concatenate(_images, axis=1))) + return images + + +def split_image(image: Image.Image) -> List[Image.Image]: + """ + Split an image into multiple views. + """ + image = np.array(image) + alpha = image[..., 3] + alpha = np.any(alpha>0, axis=0) + start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() + end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() + images = [] + for s, e in zip(start_pos, end_pos): + images.append(Image.fromarray(image[:, s:e+1])) + return [preprocess_image(image) for image in images] + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) + * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. + * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. + """) + + with gr.Row(): + with gr.Column(): + with gr.Tabs() as input_tabs: + with gr.Tab(label="Single Image", id=0) as single_image_input_tab: + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) + with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: + multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) + gr.Markdown(""" + Input different views of the object in separate images. + + *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* + """) + + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + gr.Markdown("Stage 2: Structured Latent Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) + slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="GLB Extraction Settings", open=False): + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + + with gr.Row(): + extract_glb_btn = gr.Button("Extract GLB", interactive=False) + extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) + gr.Markdown(""" + *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* + """) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) + + with gr.Row(): + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) + + is_multiimage = gr.State(False) + output_buf = gr.State() + + # Example images at the bottom of the page + with gr.Row() as single_image_example: + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=preprocess_image, + outputs=[image_prompt], + run_on_click=True, + examples_per_page=64, + ) + with gr.Row(visible=False) as multiimage_example: + examples_multi = gr.Examples( + examples=prepare_multi_example(), + inputs=[image_prompt], + fn=split_image, + outputs=[multiimage_prompt], + run_on_click=True, + examples_per_page=8, + ) + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + single_image_input_tab.select( + lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + multiimage_input_tab.select( + lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + multiimage_prompt.upload( + preprocess_images, + inputs=[multiimage_prompt], + outputs=[multiimage_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + image_to_3d, + inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], + outputs=[output_buf, video_output], + ).then( + lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + video_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + extract_glb_btn.click( + extract_glb, + inputs=[output_buf, mesh_simplify, texture_size], + outputs=[model_output, download_glb], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_glb], + ) + + extract_gs_btn.click( + extract_gaussian, + inputs=[output_buf], + outputs=[model_output, download_gs], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_gs], + ) + + model_output.clear( + lambda: gr.Button(interactive=False), + outputs=[download_glb], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") + pipeline.cuda() + demo.launch() diff --git a/app_text.py b/app_text.py new file mode 100644 index 0000000..56b414b --- /dev/null +++ b/app_text.py @@ -0,0 +1,266 @@ +import gradio as gr +from gradio_litmodel3d import LitModel3D + +import os +import shutil +from typing import * +import torch +import numpy as np +import imageio +from easydict import EasyDict as edict +from trellis.pipelines import TrellisTextTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +os.makedirs(TMP_DIR, exist_ok=True) + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def text_to_3d( + prompt: str, + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + req: gr.Request, +) -> Tuple[dict, str]: + """ + Convert an text prompt to a 3D model. + + Args: + prompt (str): The text prompt. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + slat_guidance_strength (float): The guidance strength for structured latent generation. + slat_sampling_steps (int): The number of sampling steps for structured latent generation. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + outputs = pipeline.run( + prompt, + seed=seed, + formats=["gaussian", "mesh"], + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(user_dir, 'sample.mp4') + imageio.mimsave(video_path, video, fps=15) + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + torch.cuda.empty_cache() + return state, video_path + + +def extract_glb( + state: dict, + mesh_simplify: float, + texture_size: int, + req: gr.Request, +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, mesh = unpack_state(state) + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + glb_path = os.path.join(user_dir, 'sample.glb') + glb.export(glb_path) + torch.cuda.empty_cache() + return glb_path, glb_path + + +def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: + """ + Extract a Gaussian file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + + Returns: + str: The path to the extracted Gaussian file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, _ = unpack_state(state) + gaussian_path = os.path.join(user_dir, 'sample.ply') + gs.save_ply(gaussian_path) + torch.cuda.empty_cache() + return gaussian_path, gaussian_path + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/) + * Type a text prompt and click "Generate" to create a 3D asset. + * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. + """) + + with gr.Row(): + with gr.Column(): + text_prompt = gr.Textbox(label="Text Prompt", lines=5) + + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1) + gr.Markdown("Stage 2: Structured Latent Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1) + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="GLB Extraction Settings", open=False): + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + + with gr.Row(): + extract_glb_btn = gr.Button("Extract GLB", interactive=False) + extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) + gr.Markdown(""" + *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* + """) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) + + with gr.Row(): + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) + + output_buf = gr.State() + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + text_to_3d, + inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps], + outputs=[output_buf, video_output], + ).then( + lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + video_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + extract_glb_btn.click( + extract_glb, + inputs=[output_buf, mesh_simplify, texture_size], + outputs=[model_output, download_glb], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_glb], + ) + + extract_gs_btn.click( + extract_gaussian, + inputs=[output_buf], + outputs=[model_output, download_gs], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_gs], + ) + + model_output.clear( + lambda: gr.Button(interactive=False), + outputs=[download_glb], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge") + pipeline.cuda() + demo.launch() diff --git a/blender_glb_to_obj.py b/blender_glb_to_obj.py new file mode 100644 index 0000000..ba1c1eb --- /dev/null +++ b/blender_glb_to_obj.py @@ -0,0 +1,48 @@ +import bpy +import sys +import os + +argv = sys.argv +argv = argv[argv.index("--") + 1:] + +glb_input = argv[0] +obj_output = argv[1] + + +def clean(): + bpy.ops.object.select_all(action='SELECT') + bpy.ops.object.delete(use_global=False) + + +clean() + +bpy.ops.import_scene.gltf(filepath=glb_input) + +meshes = [o for o in bpy.context.scene.objects if o.type == "MESH"] + +if not meshes: + raise RuntimeError("No mesh objects found in GLB/GLTF.") + +bpy.ops.object.select_all(action='DESELECT') + +for o in meshes: + o.select_set(True) + +bpy.context.view_layer.objects.active = meshes[0] + +if len(meshes) > 1: + bpy.ops.object.join() + +obj = bpy.context.view_layer.objects.active +obj.select_set(True) + +os.makedirs(os.path.dirname(obj_output), exist_ok=True) + +bpy.ops.wm.obj_export( + filepath=obj_output, + export_selected_objects=True, + export_normals=True, + export_uv=True, +) + +print("OBJ exported:", obj_output) \ No newline at end of file diff --git a/client.py b/client.py new file mode 100644 index 0000000..5edcbe9 --- /dev/null +++ b/client.py @@ -0,0 +1,7 @@ +# This file is auto-generated by LitServe. +# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`. + +import requests + +response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}) +print(f"Status: {response.status_code}\nResponse:\n {response.text}") diff --git a/dataset_toolkits/build_metadata.py b/dataset_toolkits/build_metadata.py new file mode 100644 index 0000000..e2a314c --- /dev/null +++ b/dataset_toolkits/build_metadata.py @@ -0,0 +1,285 @@ +import os +import shutil +import sys +import time +import importlib +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +import utils3d + +def get_first_directory(path): + with os.scandir(path) as it: + for entry in it: + if entry.is_dir(): + return entry.name + return None + +def need_process(key): + return key in opt.field or opt.field == ['all'] + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--field', type=str, default='all', + help='Fields to process, separated by commas') + parser.add_argument('--from_file', action='store_true', + help='Build metadata from file instead of from records of processings.' + + 'Useful when some processing fail to generate records but file already exists.') + dataset_utils.add_args(parser) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True) + + opt.field = opt.field.split(',') + + timestamp = str(int(time.time())) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + print('Loading previous metadata...') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + metadata = dataset_utils.get_metadata(**opt) + metadata.set_index('sha256', inplace=True) + + # merge downloaded + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + if 'local_path' in metadata.columns: + metadata.update(df, overwrite=True) + else: + metadata = metadata.join(df, on='sha256', how='left') + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # detect models + image_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'features')): + image_models = os.listdir(os.path.join(opt.output_dir, 'features')) + latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'latents')): + latent_models = os.listdir(os.path.join(opt.output_dir, 'latents')) + ss_latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')): + ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents')) + print(f'Image models: {image_models}') + print(f'Latent models: {latent_models}') + print(f'Sparse Structure latent models: {ss_latent_models}') + + if 'rendered' not in metadata.columns: + metadata['rendered'] = [False] * len(metadata) + if 'voxelized' not in metadata.columns: + metadata['voxelized'] = [False] * len(metadata) + if 'num_voxels' not in metadata.columns: + metadata['num_voxels'] = [0] * len(metadata) + if 'cond_rendered' not in metadata.columns: + metadata['cond_rendered'] = [False] * len(metadata) + for model in image_models: + if f'feature_{model}' not in metadata.columns: + metadata[f'feature_{model}'] = [False] * len(metadata) + for model in latent_models: + if f'latent_{model}' not in metadata.columns: + metadata[f'latent_{model}'] = [False] * len(metadata) + for model in ss_latent_models: + if f'ss_latent_{model}' not in metadata.columns: + metadata[f'ss_latent_{model}'] = [False] * len(metadata) + + # merge rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge aesthetic scores + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('aesthetic_scores_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge voxelized + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge cond_rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge features + for model in image_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge latents + for model in latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge sparse structure latents + for model in ss_latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # build metadata from files + if opt.from_file: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Building metadata") as pbar: + def worker(sha256): + try: + if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + metadata.loc[sha256, 'rendered'] = True + if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): + try: + pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + metadata.loc[sha256, 'voxelized'] = True + metadata.loc[sha256, 'num_voxels'] = len(pts) + except Exception as e: + pass + if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + metadata.loc[sha256, 'cond_rendered'] = True + for model in image_models: + if need_process(f'feature_{model}') and \ + metadata.loc[sha256, f'feature_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')): + metadata.loc[sha256, f'feature_{model}'] = True + for model in latent_models: + if need_process(f'latent_{model}') and \ + metadata.loc[sha256, f'latent_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'latent_{model}'] = True + for model in ss_latent_models: + if need_process(f'ss_latent_{model}') and \ + metadata.loc[sha256, f'ss_latent_{model}'] == False and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'ss_latent_{model}'] = True + pbar.update() + except Exception as e: + print(f'Error processing {sha256}: {e}') + pbar.update() + + executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + # statistics + metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv')) + num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0 + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f: + f.write('Statistics:\n') + f.write(f' - Number of assets: {len(metadata)}\n') + f.write(f' - Number of assets downloaded: {num_downloaded}\n') + f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n') + f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n') + if len(image_models) != 0: + f.write(f' - Number of assets with image features extracted:\n') + for model in image_models: + f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n') + if len(latent_models) != 0: + f.write(f' - Number of assets with latents extracted:\n') + for model in latent_models: + f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n') + if len(ss_latent_models) != 0: + f.write(f' - Number of assets with sparse structure latents extracted:\n') + for model in ss_latent_models: + f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n') + f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n') + f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n') + + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f: + print(f.read()) \ No newline at end of file diff --git a/dataset_toolkits/calculate_aesthetic_scores.py b/dataset_toolkits/calculate_aesthetic_scores.py new file mode 100644 index 0000000..92ce30c --- /dev/null +++ b/dataset_toolkits/calculate_aesthetic_scores.py @@ -0,0 +1,102 @@ +import os +import argparse +import torch +import torch.nn as nn +from PIL import Image +import open_clip +from os.path import expanduser +from urllib.request import urlretrieve +import pandas as pd +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + + +def get_aesthetic_model(clip_model="vit_l_14"): + """load the aethetic model""" + home = expanduser("~") + cache_folder = home + "/.cache/emb_reader" + path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth" + if not os.path.exists(path_to_model): + os.makedirs(cache_folder, exist_ok=True) + url_model = ( + "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true" + ) + urlretrieve(url_model, path_to_model) + if clip_model == "vit_l_14": + m = nn.Linear(768, 1) + elif clip_model == "vit_b_32": + m = nn.Linear(512, 1) + else: + raise ValueError() + s = torch.load(path_to_model) + m.load_state_dict(s) + m.eval() + return m + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clip_model", type=str, default="vit_l_14") + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--world_size", type=int, default=1) + opt = parser.parse_args() + + amodel = get_aesthetic_model(clip_model="vit_l_14") + amodel.eval() + model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') + model = model.cuda() + amodel = amodel.cuda() + + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + metadata = metadata[metadata['snapshotted'] == 1] + sha256s = metadata['sha256'].values + + # filter out objects that are already calculated + if os.path.exists(os.path.join(opt.output_dir, 'aesthetic_scores.csv')): + with open(os.path.join(opt.output_dir, 'aesthetic_scores.csv'), 'r') as f: + old_metadata = pd.read_csv(f) + sha256s = list(set(sha256s) - set(old_metadata['sha256'].values)) + + sha256s = sorted(sha256s) + sha256s = sha256s[len(sha256s) * opt.rank // opt.world_size: len(sha256s) * (opt.rank + 1) // opt.world_size] + + rows = [] + + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + finished = Queue(maxsize=128) + + def load_image(sha256): + try: + files = os.listdir(os.path.join(opt.output_dir, 'snapshots', sha256)) + files = [f for f in files if f.endswith('.png')] + processed = [] + for file in files: + image = Image.open(os.path.join(opt.output_dir, 'snapshots', sha256, file)) + processed.append(preprocess(image)) + processed = torch.stack(processed, dim=0) + except Exception as e: + print(e) + processed = None + finished.put((sha256, processed)) + + executor.map(load_image, sha256s) + for _ in tqdm(range(len(sha256s)), desc='Calculating aesthetic scores'): + sha256, processed = finished.get() + if processed is not None: + with torch.no_grad(): + image_features = model.encode_image(processed.cuda()) + image_features /= image_features.norm(dim=-1, keepdim=True) + aesthetic_score = amodel(image_features).cpu() + rows.append(pd.DataFrame({ + 'sha256': [sha256], + 'mean': [aesthetic_score.mean().item()], + 'std': [aesthetic_score.std().item()], + 'min': [aesthetic_score.min().item()], + 'max': [aesthetic_score.max().item()], + 'median': [aesthetic_score.median().item()] + })) + + with open(os.path.join(opt.output_dir, f'aesthetic_scores_{opt.rank}.csv'), 'w') as f: + pd.concat(rows).to_csv(f, index=False) diff --git a/dataset_toolkits/datasets/3D-FUTURE.py b/dataset_toolkits/datasets/3D-FUTURE.py new file mode 100644 index 0000000..a5ccc63 --- /dev/null +++ b/dataset_toolkits/datasets/3D-FUTURE.py @@ -0,0 +1,97 @@ +import os +import re +import argparse +import zipfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')): + print("\033[93m") + print("3D-FUTURE have to be downloaded manually") + print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory") + print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information") + print("\033[0m") + raise FileNotFoundError("3D-FUTURE-model.zip not found") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref: + all_names = zip_ref.namelist() + instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)] + instances = list(filter(lambda x: x in metadata.index, instances)) + + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(instances), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names)) + zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg")) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, instances) + executor.shutdown(wait=True) + + for k, sha256 in zip(instances, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj") + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/datasets/ABO.py b/dataset_toolkits/datasets/ABO.py new file mode 100644 index 0000000..b0aba22 --- /dev/null +++ b/dataset_toolkits/datasets/ABO.py @@ -0,0 +1,96 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')): + try: + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar") + except: + print("\033[93m") + print("Error downloading ABO dataset. Please check your internet connection and try again.") + print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory") + print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information") + print("\033[0m") + raise FileNotFoundError("Error downloading ABO dataset") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar: + with ThreadPoolExecutor(max_workers=1) as executor, \ + tqdm(total=len(metadata), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw/3dmodels/original', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/trellis/__init__.py b/trellis/__init__.py new file mode 100755 index 0000000..20d240a --- /dev/null +++ b/trellis/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis/datasets/__init__.py b/trellis/datasets/__init__.py new file mode 100644 index 0000000..6798ca0 --- /dev/null +++ b/trellis/datasets/__init__.py @@ -0,0 +1,58 @@ +import importlib + +__attributes = { + 'SparseStructure': 'sparse_structure', + + 'SparseFeat2Render': 'sparse_feat2render', + 'SLat2Render':'structured_latent2render', + 'Slat2RenderGeo':'structured_latent2render', + + 'SparseStructureLatent': 'sparse_structure_latent', + 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', + 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', + + 'SLat': 'structured_latent', + 'TextConditionedSLat': 'structured_latent', + 'ImageConditionedSLat': 'structured_latent', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure import SparseStructure + + from .sparse_feat2render import SparseFeat2Render + from .structured_latent2render import ( + SLat2Render, + Slat2RenderGeo, + ) + + from .sparse_structure_latent import ( + SparseStructureLatent, + TextConditionedSparseStructureLatent, + ImageConditionedSparseStructureLatent, + ) + + from .structured_latent import ( + SLat, + TextConditionedSLat, + ImageConditionedSLat, + ) + \ No newline at end of file diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py new file mode 100644 index 0000000..ae9d17c --- /dev/null +++ b/trellis/models/__init__.py @@ -0,0 +1,96 @@ +import importlib + +__attributes = { + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + + 'SparseStructureFlowModel': 'sparse_structure_flow', + + 'SLatEncoder': 'structured_latent_vae', + 'SLatGaussianDecoder': 'structured_latent_vae', + 'SLatRadianceFieldDecoder': 'structured_latent_vae', + 'SLatMeshDecoder': 'structured_latent_vae', + 'ElasticSLatEncoder': 'structured_latent_vae', + 'ElasticSLatGaussianDecoder': 'structured_latent_vae', + 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae', + 'ElasticSLatMeshDecoder': 'structured_latent_vae', + + 'SLatFlowModel': 'structured_latent_flow', + 'ElasticSLatFlowModel': 'structured_latent_flow', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file)) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import ( + SparseStructureEncoder, + SparseStructureDecoder, + ) + + from .sparse_structure_flow import SparseStructureFlowModel + + from .structured_latent_vae import ( + SLatEncoder, + SLatGaussianDecoder, + SLatRadianceFieldDecoder, + SLatMeshDecoder, + ElasticSLatEncoder, + ElasticSLatGaussianDecoder, + ElasticSLatRadianceFieldDecoder, + ElasticSLatMeshDecoder, + ) + + from .structured_latent_flow import ( + SLatFlowModel, + ElasticSLatFlowModel, + ) diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000..4e2ac35 --- /dev/null +++ b/trellis/models/structured_latent_vae/__init__.py @@ -0,0 +1,4 @@ +from .encoder import SLatEncoder, ElasticSLatEncoder +from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder +from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder +from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py new file mode 100644 index 0000000..ab0bf6a --- /dev/null +++ b/trellis/models/structured_latent_vae/base.py @@ -0,0 +1,117 @@ +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py new file mode 100755 index 0000000..f452320 --- /dev/null +++ b/trellis/modules/attention/__init__.py @@ -0,0 +1,36 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py new file mode 100755 index 0000000..726756c --- /dev/null +++ b/trellis/modules/sparse/__init__.py @@ -0,0 +1,102 @@ +from typing import * + +BACKEND = 'spconv' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py new file mode 100755 index 0000000..32b3c2c --- /dev/null +++ b/trellis/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py new file mode 100755 index 0000000..8837f44 --- /dev/null +++ b/trellis/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py new file mode 100755 index 0000000..340a871 --- /dev/null +++ b/trellis/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000..b08b0d4 --- /dev/null +++ b/trellis/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000..9d037a4 --- /dev/null +++ b/trellis/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py new file mode 100644 index 0000000..b08b0d4 --- /dev/null +++ b/trellis/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py new file mode 100644 index 0000000..c37eb7e --- /dev/null +++ b/trellis/modules/transformer/blocks.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py new file mode 100644 index 0000000..0e33c6e --- /dev/null +++ b/trellis/pipelines/__init__.py @@ -0,0 +1,25 @@ +from . import samplers +from .trellis_image_to_3d import TrellisImageTo3DPipeline +from .trellis_text_to_3d import TrellisTextTo3DPipeline + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py new file mode 100644 index 0000000..9d214e4 --- /dev/null +++ b/trellis/pipelines/base.py @@ -0,0 +1,68 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = {} + for k, v in args['models'].items(): + try: + _models[k] = models.from_pretrained(f"{path}/{v}") + except: + _models[k] = models.from_pretrained(v) + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py new file mode 100755 index 0000000..54d412f --- /dev/null +++ b/trellis/pipelines/samplers/__init__.py @@ -0,0 +1,2 @@ +from .base import Sampler +from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py new file mode 100644 index 0000000..1966ce7 --- /dev/null +++ b/trellis/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000..5701b25 --- /dev/null +++ b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,12 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py new file mode 100755 index 0000000..0339355 --- /dev/null +++ b/trellis/renderers/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'OctreeRenderer': 'octree_renderer', + 'GaussianRenderer': 'gaussian_render', + 'MeshRenderer': 'mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .octree_renderer import OctreeRenderer + from .gaussian_render import GaussianRenderer + from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py new file mode 100755 index 0000000..549ffdb --- /dev/null +++ b/trellis/representations/__init__.py @@ -0,0 +1,4 @@ +from .radiance_field import Strivec +from .octree import DfsOctree as Octree +from .gaussian import Gaussian +from .mesh import MeshExtractResult diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py new file mode 100755 index 0000000..e3de6e1 --- /dev/null +++ b/trellis/representations/gaussian/__init__.py @@ -0,0 +1 @@ +from .gaussian_model import Gaussian \ No newline at end of file diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py new file mode 100644 index 0000000..38cf35c --- /dev/null +++ b/trellis/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py new file mode 100755 index 0000000..f66a39a --- /dev/null +++ b/trellis/representations/octree/__init__.py @@ -0,0 +1 @@ +from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py new file mode 100755 index 0000000..b72a1b7 --- /dev/null +++ b/trellis/representations/radiance_field/__init__.py @@ -0,0 +1 @@ +from .strivec import Strivec \ No newline at end of file diff --git a/trellis/trainers/__init__.py b/trellis/trainers/__init__.py new file mode 100644 index 0000000..29b54f9 --- /dev/null +++ b/trellis/trainers/__init__.py @@ -0,0 +1,63 @@ +import importlib + +__attributes = { + 'BasicTrainer': 'basic', + + 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', + + 'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian', + 'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec', + 'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec', + + 'FlowMatchingTrainer': 'flow_matching.flow_matching', + 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + + 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', + 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import BasicTrainer + + from .vae.sparse_structure_vae import SparseStructureVaeTrainer + + from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer + from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer + from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer + + from .flow_matching.flow_matching import ( + FlowMatchingTrainer, + FlowMatchingCFGTrainer, + TextConditionedFlowMatchingCFGTrainer, + ImageConditionedFlowMatchingCFGTrainer, + ) + + from .flow_matching.sparse_flow_matching import ( + SparseFlowMatchingTrainer, + SparseFlowMatchingCFGTrainer, + TextConditionedSparseFlowMatchingCFGTrainer, + ImageConditionedSparseFlowMatchingCFGTrainer, + ) diff --git a/trellis/trainers/base.py b/trellis/trainers/base.py new file mode 100644 index 0000000..15463a0 --- /dev/null +++ b/trellis/trainers/base.py @@ -0,0 +1,451 @@ +from abc import abstractmethod +import os +import time +import json + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +import numpy as np + +from torchvision import utils +from torch.utils.tensorboard import SummaryWriter + +from .utils import * +from ..utils.general_utils import * +from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler + + +class Trainer: + """ + Base class for training. + """ + def __init__(self, + models, + dataset, + *, + output_dir, + load_dir, + step, + max_steps, + batch_size=None, + batch_size_per_gpu=None, + batch_split=None, + optimizer={}, + lr_scheduler=None, + elastic=None, + grad_clip=None, + ema_rate=0.9999, + fp16_mode='inflat_all', + fp16_scale_growth=1e-3, + finetune_ckpt=None, + log_param_stats=False, + prefetch_data=True, + i_print=1000, + i_log=500, + i_sample=10000, + i_save=10000, + i_ddpcheck=10000, + **kwargs + ): + assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' + + self.models = models + self.dataset = dataset + self.batch_split = batch_split if batch_split is not None else 1 + self.max_steps = max_steps + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + self.elastic_controller_config = elastic + self.grad_clip = grad_clip + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate + self.fp16_mode = fp16_mode + self.fp16_scale_growth = fp16_scale_growth + self.log_param_stats = log_param_stats + self.prefetch_data = prefetch_data + if self.prefetch_data: + self._data_prefetched = None + + self.output_dir = output_dir + self.i_print = i_print + self.i_log = i_log + self.i_sample = i_sample + self.i_save = i_save + self.i_ddpcheck = i_ddpcheck + + if dist.is_initialized(): + # Multi-GPU params + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.local_rank = dist.get_rank() % torch.cuda.device_count() + self.is_master = self.rank == 0 + else: + # Single-GPU params + self.world_size = 1 + self.rank = 0 + self.local_rank = 0 + self.is_master = True + + self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size + self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size + assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' + assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' + + self.init_models_and_more(**kwargs) + self.prepare_dataloader(**kwargs) + + # Load checkpoint + self.step = 0 + if load_dir is not None and step is not None: + self.load(load_dir, step) + elif finetune_ckpt is not None: + self.finetune_from(finetune_ckpt) + + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) + os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) + self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs')) + + if self.world_size > 1: + self.check_ddp() + + if self.is_master: + print('\n\nTrainer initialized.') + print(self) + + @property + def device(self): + for _, model in self.models.items(): + if hasattr(model, 'device'): + return model.device + return next(list(self.models.values())[0].parameters()).device + + @abstractmethod + def init_models_and_more(self, **kwargs): + """ + Initialize models and more. + """ + pass + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = ResumableSampler( + self.dataset, + shuffle=True, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + @abstractmethod + def load(self, load_dir, step=0): + """ + Load a checkpoint. + Should be called by all processes. + """ + pass + + @abstractmethod + def save(self): + """ + Save a checkpoint. + Should be called only by the rank 0 process. + """ + pass + + @abstractmethod + def finetune_from(self, finetune_ckpt): + """ + Finetune from a checkpoint. + Should be called by all processes. + """ + pass + + @abstractmethod + def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): + """ + Run a snapshot of the model. + """ + pass + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to an image. + """ + if hasattr(self.dataset, 'visualize_sample'): + return self.dataset.visualize_sample(sample) + else: + return sample + + @torch.no_grad() + def snapshot_dataset(self, num_samples=100): + """ + Sample images from the dataset. + """ + dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=num_samples, + num_workers=0, + shuffle=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + data = next(iter(dataloader)) + data = recursive_to_device(data, self.device) + vis = self.visualize_sample(data) + if isinstance(vis, dict): + save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()] + else: + save_cfg = [('dataset', vis)] + for name, image in save_cfg: + utils.save_image( + image, + os.path.join(self.output_dir, 'samples', f'{name}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): + """ + Sample images from the model. + NOTE: This function should be called by all processes. + """ + if self.is_master: + print(f'\nSampling {num_samples} images...', end='') + + if suffix is None: + suffix = f'step{self.step:07d}' + + # Assign tasks + num_samples_per_process = int(np.ceil(num_samples / self.world_size)) + samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) + + # Preprocess images + for key in list(samples.keys()): + if samples[key]['type'] == 'sample': + vis = self.visualize_sample(samples[key]['value']) + if isinstance(vis, dict): + for k, v in vis.items(): + samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} + del samples[key] + else: + samples[key] = {'value': vis, 'type': 'image'} + + # Gather results + if self.world_size > 1: + for key in samples.keys(): + samples[key]['value'] = samples[key]['value'].contiguous() + if self.is_master: + all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] + else: + all_images = [] + dist.gather(samples[key]['value'], all_images, dst=0) + if self.is_master: + samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] + + # Save images + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) + for key in samples.keys(): + if samples[key]['type'] == 'image': + utils.save_image( + samples[key]['value'], + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + elif samples[key]['type'] == 'number': + min = samples[key]['value'].min() + max = samples[key]['value'].max() + images = (samples[key]['value'] - min) / (max - min) + images = utils.make_grid( + images, + nrow=int(np.sqrt(num_samples)), + normalize=False, + ) + save_image_with_notes( + images, + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + notes=f'{key} min: {min}, max: {max}', + ) + + if self.is_master: + print(' Done.') + + @abstractmethod + def update_ema(self): + """ + Update exponential moving average. + Should only be called by the rank 0 process. + """ + pass + + @abstractmethod + def check_ddp(self): + """ + Check if DDP is working properly. + Should be called by all process. + """ + pass + + @abstractmethod + def training_losses(**mb_data): + """ + Compute training losses. + """ + pass + + def load_data(self): + """ + Load data. + """ + if self.prefetch_data: + if self._data_prefetched is None: + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + data = self._data_prefetched + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + else: + data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + + # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu + if isinstance(data, dict): + if self.batch_split == 1: + data_list = [data] + else: + batch_size = list(data.values())[0].shape[0] + data_list = [ + {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} + for i in range(self.batch_split) + ] + elif isinstance(data, list): + data_list = data + else: + raise ValueError('Data must be a dict or a list of dicts.') + + return data_list + + @abstractmethod + def run_step(self, data_list): + """ + Run a training step. + """ + pass + + def run(self): + """ + Run training. + """ + if self.is_master: + print('\nStarting training...') + self.snapshot_dataset() + if self.step == 0: + self.snapshot(suffix='init') + else: # resume + self.snapshot(suffix=f'resume_step{self.step:07d}') + + log = [] + time_last_print = 0.0 + time_elapsed = 0.0 + while self.step < self.max_steps: + time_start = time.time() + + data_list = self.load_data() + step_log = self.run_step(data_list) + + time_end = time.time() + time_elapsed += time_end - time_start + + self.step += 1 + + # Print progress + if self.is_master and self.step % self.i_print == 0: + speed = self.i_print / (time_elapsed - time_last_print) * 3600 + columns = [ + f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', + f'Elapsed: {time_elapsed / 3600:.2f} h', + f'Speed: {speed:.2f} steps/h', + f'ETA: {(self.max_steps - self.step) / speed:.2f} h', + ] + print(' | '.join([c.ljust(25) for c in columns]), flush=True) + time_last_print = time_elapsed + + # Check ddp + if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: + self.check_ddp() + + # Sample images + if self.step % self.i_sample == 0: + self.snapshot() + + if self.is_master: + log.append((self.step, {})) + + # Log time + log[-1][1]['time'] = { + 'step': time_end - time_start, + 'elapsed': time_elapsed, + } + + # Log losses + if step_log is not None: + log[-1][1].update(step_log) + + # Log scale + if self.fp16_mode == 'amp': + log[-1][1]['scale'] = self.scaler.get_scale() + elif self.fp16_mode == 'inflat_all': + log[-1][1]['log_scale'] = self.log_scale + + # Save log + if self.step % self.i_log == 0: + ## save to log file + log_str = '\n'.join([ + f'{step}: {json.dumps(log)}' for step, log in log + ]) + with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file: + log_file.write(log_str + '\n') + + # show with mlflow + log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))] + log_show = dict_reduce(log_show, lambda x: np.mean(x)) + log_show = dict_flatten(log_show, sep='/') + for key, value in log_show.items(): + self.writer.add_scalar(key, value, self.step) + log = [] + + # Save checkpoint + if self.step % self.i_save == 0: + self.save() + + if self.is_master: + self.snapshot(suffix='final') + self.writer.close() + print('Training finished.') + + def profile(self, wait=2, warmup=3, active=5): + """ + Profile the training loop. + """ + with torch.profiler.profile( + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), + profile_memory=True, + with_stack=True, + ) as prof: + for _ in range(wait + warmup + active): + self.run_step() + prof.step() + \ No newline at end of file diff --git a/trellis/trainers/basic.py b/trellis/trainers/basic.py new file mode 100644 index 0000000..f5cf304 --- /dev/null +++ b/trellis/trainers/basic.py @@ -0,0 +1,438 @@ +import os +import copy +from functools import partial +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np + +from .utils import * +from .base import Trainer +from ..utils.general_utils import * +from ..utils.dist_utils import * +from ..utils import grad_clip_utils, elastic_utils + + +class BasicTrainer(Trainer): + """ + Trainer for basic training loop. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + """ + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Models:') + for name, model in self.models.items(): + lines.append(f' - {name}: {model.__class__.__name__}') + lines.append(f' - Dataset: {indent(str(self.dataset), 2)}') + lines.append(f' - Dataloader:') + lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}') + lines.append(f' - Num workers: {self.dataloader.num_workers}') + lines.append(f' - Number of steps: {self.max_steps}') + lines.append(f' - Number of GPUs: {self.world_size}') + lines.append(f' - Batch size: {self.batch_size}') + lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}') + lines.append(f' - Batch split: {self.batch_split}') + lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}') + lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}') + if self.lr_scheduler_config is not None: + lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}') + if self.elastic_controller_config is not None: + lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}') + if self.grad_clip is not None: + lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}') + lines.append(f' - EMA rate: {self.ema_rate}') + lines.append(f' - FP16 mode: {self.fp16_mode}') + return '\n'.join(lines) + + def init_models_and_more(self, **kwargs): + """ + Initialize models and more. + """ + if self.world_size > 1: + # Prepare distributed data parallel + self.training_models = { + name: DDP( + model, + device_ids=[self.local_rank], + output_device=self.local_rank, + bucket_cap_mb=128, + find_unused_parameters=False + ) + for name, model in self.models.items() + } + else: + self.training_models = self.models + + # Build master params + self.model_params = sum( + [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] + , []) + if self.fp16_mode == 'amp': + self.master_params = self.model_params + self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None + elif self.fp16_mode == 'inflat_all': + self.master_params = make_master_params(self.model_params) + self.fp16_scale_growth = self.fp16_scale_growth + self.log_scale = 20.0 + elif self.fp16_mode is None: + self.master_params = self.model_params + else: + raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.') + + # Build EMA params + if self.is_master: + self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] + + # Initialize optimizer + if hasattr(torch.optim, self.optimizer_config['name']): + self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) + else: + self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) + + # Initalize learning rate scheduler + if self.lr_scheduler_config is not None: + if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) + else: + self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) + + # Initialize elastic memory controller + if self.elastic_controller_config is not None: + assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ + 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' + self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) + for model in self.models.values(): + if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)): + model.register_memory_controller(self.elastic_controller) + + # Initialize gradient clipper + if self.grad_clip is not None: + if isinstance(self.grad_clip, (float, int)): + self.grad_clip = float(self.grad_clip) + else: + self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) + + def _master_params_to_state_dicts(self, master_params): + """ + Convert master params to dict of state_dicts. + """ + if self.fp16_mode == 'inflat_all': + master_params = unflatten_master_params(self.model_params, master_params) + state_dicts = {name: model.state_dict() for name, model in self.models.items()} + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + for i, (model_name, param_name) in enumerate(master_params_names): + state_dicts[model_name][param_name] = master_params[i] + return state_dicts + + def _state_dicts_to_master_params(self, master_params, state_dicts): + """ + Convert a state_dict to master params. + """ + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + params = [state_dicts[name][param_name] for name, param_name in master_params_names] + if self.fp16_mode == 'inflat_all': + model_params_to_master_params(params, master_params) + else: + for i, param in enumerate(params): + master_params[i].data.copy_(param.data) + + def load(self, load_dir, step=0): + """ + Load a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print(f'\nLoading checkpoint from step {step}...', end='') + + model_ckpts = {} + for name, model in self.models.items(): + model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + if self.fp16_mode == 'inflat_all': + model.convert_to_fp16() + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = {} + for name, model in self.models.items(): + ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) + ema_ckpts[name] = ema_ckpt + self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) + del ema_ckpts + + misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) + self.optimizer.load_state_dict(misc_ckpt['optimizer']) + self.step = misc_ckpt['step'] + self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) + if self.fp16_mode == 'amp': + self.scaler.load_state_dict(misc_ckpt['scaler']) + elif self.fp16_mode == 'inflat_all': + self.log_scale = misc_ckpt['log_scale'] + if self.lr_scheduler_config is not None: + self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) + if self.elastic_controller_config is not None: + self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) + del misc_ckpt + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print(' Done.') + + if self.world_size > 1: + self.check_ddp() + + def save(self): + """ + Save a checkpoint. + Should be called only by the rank 0 process. + """ + assert self.is_master, 'save() should be called only by the rank 0 process.' + print(f'\nSaving checkpoint at step {self.step}...', end='') + + model_ckpts = self._master_params_to_state_dicts(self.master_params) + for name, model_ckpt in model_ckpts.items(): + torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) + + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) + for name, ema_ckpt in ema_ckpts.items(): + torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) + + misc_ckpt = { + 'optimizer': self.optimizer.state_dict(), + 'step': self.step, + 'data_sampler': self.data_sampler.state_dict(), + } + if self.fp16_mode == 'amp': + misc_ckpt['scaler'] = self.scaler.state_dict() + elif self.fp16_mode == 'inflat_all': + misc_ckpt['log_scale'] = self.log_scale + if self.lr_scheduler_config is not None: + misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() + if self.elastic_controller_config is not None: + misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + misc_ckpt['grad_clip'] = self.grad_clip.state_dict() + torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) + print(' Done.') + + def finetune_from(self, finetune_ckpt): + """ + Finetune from a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print('\nFinetuning from:') + for name, path in finetune_ckpt.items(): + print(f' - {name}: {path}') + + model_ckpts = {} + for name, model in self.models.items(): + model_state_dict = model.state_dict() + if name in finetune_ckpt: + model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) + for k, v in model_ckpt.items(): + if model_ckpt[k].shape != model_state_dict[k].shape: + if self.is_master: + print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + if self.fp16_mode == 'inflat_all': + model.convert_to_fp16() + else: + if self.is_master: + print(f'Warning: {name} not found in finetune_ckpt, skipped.') + model_ckpts[name] = model_state_dict + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print('Done.') + + if self.world_size > 1: + self.check_ddp() + + def update_ema(self): + """ + Update exponential moving average. + Should only be called by the rank 0 process. + """ + assert self.is_master, 'update_ema() should be called only by the rank 0 process.' + for i, ema_rate in enumerate(self.ema_rate): + for master_param, ema_param in zip(self.master_params, self.ema_params[i]): + ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate) + + def check_ddp(self): + """ + Check if DDP is working properly. + Should be called by all process. + """ + if self.is_master: + print('\nPerforming DDP check...') + + if self.is_master: + print('Checking if parameters are consistent across processes...') + dist.barrier() + try: + for p in self.master_params: + # split to avoid OOM + for i in range(0, p.numel(), 10000000): + sub_size = min(10000000, p.numel() - i) + sub_p = p.detach().view(-1)[i:i+sub_size] + # gather from all processes + sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)] + dist.all_gather(sub_p_gather, sub_p) + # check if equal + assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes' + except AssertionError as e: + if self.is_master: + print(f'\n\033[91mError: {e}\033[0m') + print('DDP check failed.') + raise e + + dist.barrier() + if self.is_master: + print('Done.') + + def run_step(self, data_list): + """ + Run a training step. + """ + step_log = {'loss': {}, 'status': {}} + amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext + elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext + + # Train + losses = [] + statuses = [] + elastic_controller_logs = [] + zero_grad(self.model_params) + for i, mb_data in enumerate(data_list): + ## sync at the end of each batch split + sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext] + with nested_contexts(*sync_contexts), elastic_controller_context(): + with amp_context(): + loss, status = self.training_losses(**mb_data) + l = loss['loss'] / len(data_list) + ## backward + if self.fp16_mode == 'amp': + self.scaler.scale(l).backward() + elif self.fp16_mode == 'inflat_all': + scaled_l = l * (2 ** self.log_scale) + scaled_l.backward() + else: + l.backward() + ## log + losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + if self.elastic_controller_config is not None: + elastic_controller_logs.append(self.elastic_controller.log()) + ## gradient clip + if self.grad_clip is not None: + if self.fp16_mode == 'amp': + self.scaler.unscale_(self.optimizer) + elif self.fp16_mode == 'inflat_all': + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + if isinstance(self.grad_clip, float): + grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip) + else: + grad_norm = self.grad_clip(self.master_params) + if torch.isfinite(grad_norm): + statuses[-1]['grad_norm'] = grad_norm.item() + ## step + if self.fp16_mode == 'amp': + prev_scale = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + elif self.fp16_mode == 'inflat_all': + prev_scale = 2 ** self.log_scale + if not any(not p.grad.isfinite().all() for p in self.model_params): + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + self.log_scale += self.fp16_scale_growth + else: + self.log_scale -= 1 + else: + prev_scale = 1.0 + if not any(not p.grad.isfinite().all() for p in self.model_params): + self.optimizer.step() + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + ## adjust learning rate + if self.lr_scheduler_config is not None: + statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] + self.lr_scheduler.step() + + # Logs + step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x)) + step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)}) + if self.elastic_controller_config is not None: + step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) + if self.grad_clip is not None: + step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() + + # Check grad and norm of each param + if self.log_param_stats: + param_norms = {} + param_grads = {} + for name, param in self.backbone.named_parameters(): + if param.requires_grad: + param_norms[name] = param.norm().item() + if param.grad is not None and torch.isfinite(param.grad).all(): + param_grads[name] = param.grad.norm().item() / prev_scale + step_log['param_norms'] = param_norms + step_log['param_grads'] = param_grads + + # Update exponential moving average + if self.is_master: + self.update_ema() + + return step_log diff --git a/trellis/trainers/flow_matching/mixins/classifier_free_guidance.py b/trellis/trainers/flow_matching/mixins/classifier_free_guidance.py new file mode 100644 index 0000000..548e007 --- /dev/null +++ b/trellis/trainers/flow_matching/mixins/classifier_free_guidance.py @@ -0,0 +1,59 @@ +import torch +import numpy as np +from ....utils.general_utils import dict_foreach +from ....pipelines import samplers + + +class ClassifierFreeGuidanceMixin: + def __init__(self, *args, p_uncond: float = 0.1, **kwargs): + super().__init__(*args, **kwargs) + self.p_uncond = p_uncond + + def get_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + + if self.p_uncond > 0: + # randomly drop the class label + def get_batch_size(cond): + if isinstance(cond, torch.Tensor): + return cond.shape[0] + elif isinstance(cond, list): + return len(cond) + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] + B = get_batch_size(ref_cond) + + def select(cond, neg_cond, mask): + if isinstance(cond, torch.Tensor): + mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) + return torch.where(mask, neg_cond, cond) + elif isinstance(cond, list): + return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + mask = list(np.random.rand(B) < self.p_uncond) + if not isinstance(cond, dict): + cond = select(cond, neg_cond, mask) + else: + cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) + + return cond + + def get_inference_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data for inference. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + return {'cond': cond, 'neg_cond': neg_cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerCfgSampler(self.sigma_min) diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py new file mode 100755 index 0000000..e69de29