diff --git a/dataset_toolkits/datasets/HSSD.py b/dataset_toolkits/datasets/HSSD.py new file mode 100644 index 0000000..465e6a1 --- /dev/null +++ b/dataset_toolkits/datasets/HSSD.py @@ -0,0 +1,103 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import huggingface_hub +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # check login + try: + huggingface_hub.whoami() + except: + print("\033[93m") + print("Haven't logged in to the Hugging Face Hub.") + print("Visit https://huggingface.co/settings/tokens to get a token.") + print("\033[0m") + huggingface_hub.login() + + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset") + except: + print("\033[93m") + print("Error downloading HSSD dataset.") + print("Check if you have access to the HSSD dataset.") + print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information") + print("\033[0m") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Downloading") as pbar: + def worker(instance: str) -> str: + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/glb2svg.py b/glb2svg.py new file mode 100644 index 0000000..e3f7e7e --- /dev/null +++ b/glb2svg.py @@ -0,0 +1,555 @@ +import math +import secrets +from collections import defaultdict +from datetime import datetime + +import os +import subprocess +from pathlib import Path +import logging +from typing import Optional + +import imageio +import numpy as np +import torch +from PIL import Image + +# ---------- PyTorch3D ---------- +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + RasterizationSettings, MeshRasterizer, + MeshRenderer, BlendParams, SoftSilhouetteShader, + look_at_view_transform, OrthographicCameras +) +from pytorch3d.renderer.mesh.rasterizer import Fragments + +# ---------- SVG / Image ---------- +import svgwrite +from skimage import measure + +# ---------- pythonocc (OCC) ---------- +from OCC.Core.STEPControl import STEPControl_Reader +from OCC.Core.TopAbs import TopAbs_FACE +from OCC.Core.TopExp import TopExp_Explorer +from OCC.Core.BRep import BRep_Tool +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh +from OCC.Core.TopLoc import TopLoc_Location +from OCC.Core.ShapeFix import ShapeFix_Shape + +# ---------- Combine SVG ---------- +import svgutils.transform as st +from svgutils.compose import Unit +import cairosvg + +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +logger = logging.getLogger(__name__) + +"""single image to 3D""" + + +def generate_unique_name(original_name: str) -> str: + stem, ext = os.path.splitext(original_name) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + random_part = secrets.token_hex(4) # 8位随机十六进制 + return f"{stem}_{timestamp}_{random_part}{ext}" + + +def image_to_3d(image: list[str], out_dir: str = "trellis_out", single_image: bool = False, seed: int = 1, + steps_sparse: int = 12, cfg_sparse: float = 7.5, steps_slat: int = 12, + cfg_slat: float = 3.0, simplify: float = 0.95, texture_size: int = 1024, + save_video: bool = False, fps: int = 30, video_gs_name: str = "sample_gs.mp4", video_rf_name: str = "sample_rf.mp4", video_mesh_name: str = "sample_mesh.mp4"): + os.makedirs(out_dir, exist_ok=True) + # Optional env + os.environ.setdefault("SPCONV_ALGO", "native") + pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") + pipeline.cuda() + + if single_image: + image = Image.open(image[0]) + outputs = pipeline.run( + image, + seed=seed, + sparse_structure_sampler_params={ + "steps": steps_sparse, + "cfg_strength": cfg_sparse, + }, + slat_sampler_params={ + "steps": steps_slat, + "cfg_strength": cfg_slat, + }, + ) + else: + images = [Image.open(p) for p in image] + outputs = pipeline.run_multi_image( + images, + seed=seed, + sparse_structure_sampler_params={ + "steps": steps_sparse, + "cfg_strength": cfg_sparse, + }, + slat_sampler_params={ + "steps": steps_slat, + "cfg_strength": cfg_slat, + }, + ) + + if save_video: + video = render_utils.render_video(outputs["gaussian"][0])["color"] + imageio.mimsave(os.path.join(out_dir, video_gs_name), video, fps=fps) + + video = render_utils.render_video(outputs["radiance_field"][0])["color"] + imageio.mimsave(os.path.join(out_dir, video_rf_name), video, fps=fps) + + video = render_utils.render_video(outputs["mesh"][0])["normal"] + imageio.mimsave(os.path.join(out_dir, video_mesh_name), video, fps=fps) + + glb_path = os.path.join(out_dir, generate_unique_name("sample.glb")) + ply_path = os.path.join(out_dir, generate_unique_name("sample.ply")) + + glb = postprocessing_utils.to_glb( + outputs["gaussian"][0], + outputs["mesh"][0], + simplify=simplify, + texture_size=texture_size, + ) + + # 先写到临时路径 + glb.export(glb_path) + outputs["gaussian"][0].save_ply(ply_path) + + return glb_path, ply_path + + +""" +glb_to_obj +""" + + +def glb_to_obj(glb_input, obj_output_dir="obj_output"): + obj_path = os.path.join(obj_output_dir, generate_unique_name("sample.obj")) + os.makedirs(obj_output_dir, exist_ok=True) + + cmd = [ + "blender", + "-b", + "-P", + "0_glb_to_obj.py", + "--", + "--input", + glb_input, + "--output", + obj_path + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + print(result) + print("\n") + if result.returncode != 0: + raise RuntimeError( + f"Blender failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + + return obj_path + + +""" +obj_to_step +""" + + +def obj_to_step( + input_obj: str | Path, + output_dir: str | Path, + script_path: str | Path = "1_obj_to_step.py", + timeout: int = 120, # 秒,建议根据实际转换耗时调整 + freecadcmd_path: Optional[str] = None +) -> bool: + """ + 把 OBJ 转 STEP 的 FreeCAD 命令封装成函数 + 返回 True 表示成功,False 表示失败 + """ + input_obj = Path(input_obj).resolve() + output_dir = Path(output_dir).resolve() + script_path = Path(script_path).resolve() + + print(f"input_obj : {input_obj}") + print(f"output_dir : {output_dir}") + print(f"script_path : {script_path}") + + if not input_obj.exists(): + raise FileNotFoundError(f"输入文件不存在: {input_obj}") + if not script_path.exists(): + raise FileNotFoundError(f"转换脚本不存在: {script_path}") + + # 构造 -c 参数(完全复刻你原来的命令) + python_code = f'''import sys +sys.argv = ["{script_path.name}", "{input_obj}", "{output_dir}"] +exec(open("{script_path}", "r", encoding="utf-8").read())''' + + cmd = [ + freecadcmd_path or "freecadcmd", # 如果不在 PATH 中,可传入绝对路径 + "-c", + python_code + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, # 自动抛出 CalledProcessError + timeout=timeout, + cwd=script_path.parent # 让脚本能找到相对路径文件 + ) + + logger.info(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}") + print(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}") + if result.stdout: + logger.debug(result.stdout) + return result.stdout.split()[-1] + + except subprocess.TimeoutExpired: + logger.error(f"FreeCAD 转换超时(>{timeout}s): {input_obj}") + return False + except subprocess.CalledProcessError as e: + logger.error(f"FreeCAD 执行失败: {e.stderr}") + return False + except Exception as e: + logger.exception("未知错误") + return False + + +""" +step to svg +""" + + +def autodevice(): + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +def read_step_solid(path: str): + reader = STEPControl_Reader() + stat = reader.ReadFile(path) + if stat != 1: + raise RuntimeError(f"STEP 读取失败: {path}") + reader.TransferRoots() + shape = reader.OneShape() + fixer = ShapeFix_Shape(shape) + fixer.Perform() + return fixer.Shape() + + +def triangulate(shape, deflection=0.5, angle=0.5): + """OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))""" + BRepMesh_IncrementalMesh(shape, deflection, False, angle, True) + verts, faces, v_off = [], [], 0 + exp = TopExp_Explorer(shape, TopAbs_FACE) + while exp.More(): + face = exp.Current() + loc = TopLoc_Location() + tri = BRep_Tool.Triangulation(face, loc) + if tri is not None: + nb_nodes = tri.NbNodes() + has_nodes_arr = hasattr(tri, "Nodes") + for i in range(1, nb_nodes + 1): + p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i) + p = p.Transformed(loc.Transformation()) + verts.append([p.X(), p.Y(), p.Z()]) + nb_tris = tri.NbTriangles() + has_tris_arr = hasattr(tri, "Triangles") + for i in range(1, nb_tris + 1): + t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i) + a, b, c = t.Get() + faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1]) + v_off += nb_nodes + exp.Next() + if not verts or not faces: + raise RuntimeError("三角化为空,尝试减小 --defl") + return np.asarray(verts, np.float32), np.asarray(faces, np.int64) + + +def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0): + c = verts.mean(axis=0, keepdims=True) + v0 = verts - c + s = np.max(np.abs(v0)) + s = max(s, 1e-12) + return v0 / s * unit_scale + + +def build_mesh(V_np, F_np, device): + V = torch.from_numpy(V_np).to(device) + F = torch.from_numpy(F_np).to(device) + return Meshes(verts=[V], faces=[F]) + + +def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor): + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + n = torch.cross(v1 - v0, v2 - v0, dim=1) + return torch.nn.functional.normalize(n, dim=1, eps=1e-12) + + +def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float): + """二面角>=angle_deg 视为特征棱;边界边必取。""" + thr = math.cos(math.radians(max(0.0, angle_deg))) + e2f = defaultdict(list) + F = faces.shape[0] + for f in range(F): + i, j, k = faces[f].tolist() + for a, b in ((i, j), (j, k), (k, i)): + e = (a, b) if a < b else (b, a) + e2f[e].append(f) + feat = [] + for e, fl in e2f.items(): + if len(fl) == 1: + feat.append(e) + elif len(fl) == 2: + n0, n1 = face_normals[fl[0]], face_normals[fl[1]] + cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item() + if cosv <= thr: + feat.append(e) + return feat + + +def make_camera(view: str, device, dist=2.7): + """ + look_at_view_transform spherical params for engineering views. + """ + if view == "top": + # from +Z to origin (view -Z) + R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device) + elif view == "left": + # from -X to origin (view +X) + R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device) + elif view == "front": + # from +Y to origin (view -Y) + R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device) + else: + raise ValueError("view must be one of ['top','left','front']") + return OrthographicCameras(R=R, T=T, device=device) + + +def render_silhouette(mesh: Meshes, cameras, image_size=1600): + rs = RasterizationSettings( + image_size=image_size, + blur_radius=1e-6, + faces_per_pixel=50, + cull_backfaces=True + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs), + shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4)) + ) + with torch.no_grad(): + img = renderer(mesh, cameras=cameras) # (1,H,W,4) + return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0) + + +def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1): + rs = RasterizationSettings( + image_size=image_size, + blur_radius=0.0, + faces_per_pixel=faces_per_pixel, + cull_backfaces=True + ) + rast = MeshRasterizer(cameras=cameras, raster_settings=rs) + with torch.no_grad(): + frags: Fragments = rast(mesh, cameras=cameras) + return frags, frags.zbuf[0, ..., 0] + + +def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int): + with torch.no_grad(): + scr = cameras.transform_points_screen( + points_world[None, ...], + image_size=((image_size, image_size),) + ) + return scr[0] + + +def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4): + H = W = image_size + vis, hid = [], [] + scr = project_points_to_screen(cameras, verts_world, image_size) + zmin = zbuf.detach().cpu().numpy() + for i, j in edges_idx: + p0, p1 = scr[i], scr[j] + m = 0.5 * (p0 + p1) + x = int(np.clip(m[0].item(), 0, W - 1)) + y = int(np.clip(m[1].item(), 0, H - 1)) + z_proj = m[2].item() + z_ref = zmin[y, x] + seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item()) + if np.isfinite(z_ref) and (z_proj <= z_ref + eps): + vis.append(seg) + else: + hid.append(seg) + return vis, hid + + +def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1): + contours = measure.find_contours(alpha, level=threshold) + polys = [] + for cnt in contours: + if len(cnt) < 8: + continue + cnt = cnt[::max(1, step)] + polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y) + return polys + + +def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"): + dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin)) + dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none")) + + # silhouette fill (optional) + for poly in silhouettes: + if len(poly) < 3: + continue + path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"] + for i in range(1, len(poly)): + path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}") + path.append("Z") + dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none")) + + # visible edges + for (x0, y0, x1, y1) in edges_vis: + dwg.add(dwg.line( + (x0 + margin, y0 + margin), + (x1 + margin, y1 + margin), + stroke="black", + stroke_width=stroke + )) + + # hidden edges + for (x0, y0, x1, y1) in edges_hid: + dwg.add(dwg.line( + (x0 + margin, y0 + margin), + (x1 + margin, y1 + margin), + stroke="black", + stroke_width=max(1.0, 0.75 * stroke), + stroke_dasharray=[6, 6], + opacity=0.85 + )) + + dwg.save() + + +def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image): + """ + 横向拼接 3 张 SVG,并额外导出 PNG。 + """ + out_dir = os.path.dirname(output_svg) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + view1 = st.fromfile(view1_path) + view2 = st.fromfile(view2_path) + view3 = st.fromfile(view3_path) + + r1 = view1.getroot() + r2 = view2.getroot() + r3 = view3.getroot() + + w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()] + w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()] + w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()] + + max_w = max(w1, w2, w3) + max_h = max(h1, h2, h3) + + combined_width = max_w * 3 + 40 + combined_height = max_h + 20 + + combined = st.SVGFigure(Unit(combined_width), Unit(combined_height)) + + x_offset = 10 + y_offset = 10 + + r1.moveto(x_offset, y_offset) + x_offset += max_w + 20 + r2.moveto(x_offset, y_offset) + x_offset += max_w + 20 + r3.moveto(x_offset, y_offset) + + combined.append([r1, r2, r3]) + combined.save(output_svg) + + # SVG -> PNG + cairosvg.svg2png(url=output_svg, write_to=output_image) + + +def step_to_svg(step_path, out_dir, + res: int = 1400, defl: float = 1.5, + ang: float = 65.0, stroke: float = 1.3, + scale: float = 1.0, no_combine: bool = False, + combined_svg: str = None, combined_png: str = None): + os.makedirs(out_dir, exist_ok=True) + device = autodevice() + print(f"[Device] {device}") + + # STEP -> mesh + shape = read_step_solid(step_path) + V_np, F_np = triangulate(shape, deflection=defl) + V_np = normalize_mesh_np(V_np, unit_scale=scale) + + mesh = build_mesh(V_np, F_np, device) + verts = mesh.verts_packed() + faces = mesh.faces_packed() + fn = compute_face_normals(verts, faces) + feat_edges = extract_feature_edges(faces, fn, angle_deg=ang) + + view_paths = {} + for name in ["top", "left", "front"]: + print(f"[View] {name}") + cam = make_camera(name, device) + alpha = render_silhouette(mesh, cam, image_size=res) + _, zbuf = raster_fragments(mesh, cam, image_size=res, faces_per_pixel=1) + + e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, res, eps=1e-4) + polys = trace_silhouettes(alpha, threshold=0.5, step=1) + + out_svg = os.path.join(out_dir, f"{name}.svg") + svg_from_view(out_svg, res, res, polys, e_vis, e_hid, + stroke=stroke, margin=10, fill="none") + view_paths[name] = out_svg + print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})") + + if no_combine: + print("[Done] 3 views exported (no combine).") + return None + + combined_svg = combined_svg or os.path.join(out_dir, "combined_views.svg") + combined_png = combined_png or os.path.join(out_dir, "combined_views.png") + + # 注意:这里组合顺序按你第二段示例:front / top / left + combine_svgs( + view1_path=view_paths["front"], + view2_path=view_paths["top"], + view3_path=view_paths["left"], + output_svg=combined_svg, + output_image=combined_png + ) + return combined_svg, combined_png + + +if __name__ == "__main__": + # step -1 + glb_result, ply_result = image_to_3d(image=["assets/whiteboard_exported_image.png"]) + print(f"glb_result : {glb_result}", f"ply_result : {ply_result}") + + # step 0 + obj_result = glb_to_obj(glb_result) + print(f"obj_result : {obj_result}") + + # step 1 + step_result = obj_to_step(input_obj=obj_result, output_dir="step_output", script_path="1_obj_to_step.py") + + # step 2 + out_dir = "svg_output" + combined_svg, combined_png = step_to_svg(step_path=step_result, out_dir=out_dir) + print(f"combined_svg : {combined_svg}", f"combined_png : {combined_png}") diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000..7074a4d --- /dev/null +++ b/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,15 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py new file mode 100644 index 0000000..57108e3 --- /dev/null +++ b/trellis/renderers/gaussian_render.py @@ -0,0 +1,231 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +from .sh_utils import eval_sh +import torch.nn.functional as F +from easydict import EasyDict as edict + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'GaussianRasterizer' not in globals(): + from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp + ) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict({"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii}) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) + + if ssaa > 1: + render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret['render'] + }) + return ret diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py new file mode 100644 index 0000000..541c082 --- /dev/null +++ b/trellis/representations/gaussian/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/trellis/trainers/flow_matching/mixins/image_conditioned.py b/trellis/trainers/flow_matching/mixins/image_conditioned.py new file mode 100644 index 0000000..be18d64 --- /dev/null +++ b/trellis/trainers/flow_matching/mixins/image_conditioned.py @@ -0,0 +1,93 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +import numpy as np +from PIL import Image + +from ....utils import dist_utils + + +class ImageConditionedMixin: + """ + Mixin for image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_name = image_cond_model + self.image_cond_model = None # the model is init lazily + + @staticmethod + def prepare_for_training(image_cond_model: str, **kwargs): + """ + Prepare for training. + """ + if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'): + super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs) + # download the model + torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True) + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True) + dinov2_model.eval().cuda() + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.image_cond_model = { + 'model': dinov2_model, + 'transform': transform, + } + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + if self.image_cond_model is None: + self._init_image_cond_model() + image = self.image_cond_model['transform'](image).cuda() + features = self.image_cond_model['model'](image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + return {'image': {'value': cond, 'type': 'image'}} diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py new file mode 100644 index 0000000..c0d765f --- /dev/null +++ b/trellis/utils/general_utils.py @@ -0,0 +1,202 @@ +import re +import numpy as np +import cv2 +import torch +import contextlib + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +# Context utils +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx()) + yield + + +# Image utils +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + if images[0].ndim == 2: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) + else: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis/utils/grad_clip_utils.py b/trellis/utils/grad_clip_utils.py new file mode 100644 index 0000000..990a435 --- /dev/null +++ b/trellis/utils/grad_clip_utils.py @@ -0,0 +1,81 @@ +from typing import * +import torch +import numpy as np +import torch.utils + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + + def __repr__(self): + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + + def log(self): + return { + 'max_norm': self._max_norm, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + max_norm = self._max_norm if self._max_norm is not None else float('inf') + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) + + if torch.isfinite(grad_norm): + self._grad_norm[self._buffer_ptr] = grad_norm + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + if self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + + return grad_norm \ No newline at end of file