This commit is contained in:
zcr
2026-03-17 11:38:02 +08:00
parent 046be2c797
commit 0571f65793
8 changed files with 1413 additions and 0 deletions

View File

@@ -0,0 +1,103 @@
import os
import re
import argparse
import tarfile
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import pandas as pd
import huggingface_hub
from utils import get_file_hash
def add_args(parser: argparse.ArgumentParser):
pass
def get_metadata(**kwargs):
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
return metadata
def download(metadata, output_dir, **kwargs):
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
# check login
try:
huggingface_hub.whoami()
except:
print("\033[93m")
print("Haven't logged in to the Hugging Face Hub.")
print("Visit https://huggingface.co/settings/tokens to get a token.")
print("\033[0m")
huggingface_hub.login()
try:
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
except:
print("\033[93m")
print("Error downloading HSSD dataset.")
print("Check if you have access to the HSSD dataset.")
print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
print("\033[0m")
downloaded = {}
metadata = metadata.set_index("file_identifier")
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Downloading") as pbar:
def worker(instance: str) -> str:
try:
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
pbar.update()
return sha256
except Exception as e:
pbar.update()
print(f"Error extracting for {instance}: {e}")
return None
sha256s = executor.map(worker, metadata.index)
executor.shutdown(wait=True)
for k, sha256 in zip(metadata.index, sha256s):
if sha256 is not None:
if sha256 == metadata.loc[k, "sha256"]:
downloaded[sha256] = os.path.join('raw', k)
else:
print(f"Error downloading {k}: sha256s do not match")
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
# load metadata
metadata = metadata.to_dict('records')
# processing objects
records = []
max_workers = max_workers or os.cpu_count()
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
tqdm(total=len(metadata), desc=desc) as pbar:
def worker(metadatum):
try:
local_path = metadatum['local_path']
sha256 = metadatum['sha256']
file = os.path.join(output_dir, local_path)
record = func(file, sha256)
if record is not None:
records.append(record)
pbar.update()
except Exception as e:
print(f"Error processing object {sha256}: {e}")
pbar.update()
executor.map(worker, metadata)
executor.shutdown(wait=True)
except:
print("Error happened during processing.")
return pd.DataFrame.from_records(records)

555
glb2svg.py Normal file
View File

@@ -0,0 +1,555 @@
import math
import secrets
from collections import defaultdict
from datetime import datetime
import os
import subprocess
from pathlib import Path
import logging
from typing import Optional
import imageio
import numpy as np
import torch
from PIL import Image
# ---------- PyTorch3D ----------
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
RasterizationSettings, MeshRasterizer,
MeshRenderer, BlendParams, SoftSilhouetteShader,
look_at_view_transform, OrthographicCameras
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
# ---------- SVG / Image ----------
import svgwrite
from skimage import measure
# ---------- pythonocc (OCC) ----------
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.BRep import BRep_Tool
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.ShapeFix import ShapeFix_Shape
# ---------- Combine SVG ----------
import svgutils.transform as st
from svgutils.compose import Unit
import cairosvg
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
logger = logging.getLogger(__name__)
"""single image to 3D"""
def generate_unique_name(original_name: str) -> str:
stem, ext = os.path.splitext(original_name)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_part = secrets.token_hex(4) # 8位随机十六进制
return f"{stem}_{timestamp}_{random_part}{ext}"
def image_to_3d(image: list[str], out_dir: str = "trellis_out", single_image: bool = False, seed: int = 1,
steps_sparse: int = 12, cfg_sparse: float = 7.5, steps_slat: int = 12,
cfg_slat: float = 3.0, simplify: float = 0.95, texture_size: int = 1024,
save_video: bool = False, fps: int = 30, video_gs_name: str = "sample_gs.mp4", video_rf_name: str = "sample_rf.mp4", video_mesh_name: str = "sample_mesh.mp4"):
os.makedirs(out_dir, exist_ok=True)
# Optional env
os.environ.setdefault("SPCONV_ALGO", "native")
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
if single_image:
image = Image.open(image[0])
outputs = pipeline.run(
image,
seed=seed,
sparse_structure_sampler_params={
"steps": steps_sparse,
"cfg_strength": cfg_sparse,
},
slat_sampler_params={
"steps": steps_slat,
"cfg_strength": cfg_slat,
},
)
else:
images = [Image.open(p) for p in image]
outputs = pipeline.run_multi_image(
images,
seed=seed,
sparse_structure_sampler_params={
"steps": steps_sparse,
"cfg_strength": cfg_sparse,
},
slat_sampler_params={
"steps": steps_slat,
"cfg_strength": cfg_slat,
},
)
if save_video:
video = render_utils.render_video(outputs["gaussian"][0])["color"]
imageio.mimsave(os.path.join(out_dir, video_gs_name), video, fps=fps)
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
imageio.mimsave(os.path.join(out_dir, video_rf_name), video, fps=fps)
video = render_utils.render_video(outputs["mesh"][0])["normal"]
imageio.mimsave(os.path.join(out_dir, video_mesh_name), video, fps=fps)
glb_path = os.path.join(out_dir, generate_unique_name("sample.glb"))
ply_path = os.path.join(out_dir, generate_unique_name("sample.ply"))
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=simplify,
texture_size=texture_size,
)
# 先写到临时路径
glb.export(glb_path)
outputs["gaussian"][0].save_ply(ply_path)
return glb_path, ply_path
"""
glb_to_obj
"""
def glb_to_obj(glb_input, obj_output_dir="obj_output"):
obj_path = os.path.join(obj_output_dir, generate_unique_name("sample.obj"))
os.makedirs(obj_output_dir, exist_ok=True)
cmd = [
"blender",
"-b",
"-P",
"0_glb_to_obj.py",
"--",
"--input",
glb_input,
"--output",
obj_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result)
print("\n")
if result.returncode != 0:
raise RuntimeError(
f"Blender failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
return obj_path
"""
obj_to_step
"""
def obj_to_step(
input_obj: str | Path,
output_dir: str | Path,
script_path: str | Path = "1_obj_to_step.py",
timeout: int = 120, # 秒,建议根据实际转换耗时调整
freecadcmd_path: Optional[str] = None
) -> bool:
"""
把 OBJ 转 STEP 的 FreeCAD 命令封装成函数
返回 True 表示成功False 表示失败
"""
input_obj = Path(input_obj).resolve()
output_dir = Path(output_dir).resolve()
script_path = Path(script_path).resolve()
print(f"input_obj : {input_obj}")
print(f"output_dir : {output_dir}")
print(f"script_path : {script_path}")
if not input_obj.exists():
raise FileNotFoundError(f"输入文件不存在: {input_obj}")
if not script_path.exists():
raise FileNotFoundError(f"转换脚本不存在: {script_path}")
# 构造 -c 参数(完全复刻你原来的命令)
python_code = f'''import sys
sys.argv = ["{script_path.name}", "{input_obj}", "{output_dir}"]
exec(open("{script_path}", "r", encoding="utf-8").read())'''
cmd = [
freecadcmd_path or "freecadcmd", # 如果不在 PATH 中,可传入绝对路径
"-c",
python_code
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True, # 自动抛出 CalledProcessError
timeout=timeout,
cwd=script_path.parent # 让脚本能找到相对路径文件
)
logger.info(f"FreeCAD 转换成功: {input_obj.name}{output_dir}")
print(f"FreeCAD 转换成功: {input_obj.name}{output_dir}")
if result.stdout:
logger.debug(result.stdout)
return result.stdout.split()[-1]
except subprocess.TimeoutExpired:
logger.error(f"FreeCAD 转换超时(>{timeout}s: {input_obj}")
return False
except subprocess.CalledProcessError as e:
logger.error(f"FreeCAD 执行失败: {e.stderr}")
return False
except Exception as e:
logger.exception("未知错误")
return False
"""
step to svg
"""
def autodevice():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def read_step_solid(path: str):
reader = STEPControl_Reader()
stat = reader.ReadFile(path)
if stat != 1:
raise RuntimeError(f"STEP 读取失败: {path}")
reader.TransferRoots()
shape = reader.OneShape()
fixer = ShapeFix_Shape(shape)
fixer.Perform()
return fixer.Shape()
def triangulate(shape, deflection=0.5, angle=0.5):
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
verts, faces, v_off = [], [], 0
exp = TopExp_Explorer(shape, TopAbs_FACE)
while exp.More():
face = exp.Current()
loc = TopLoc_Location()
tri = BRep_Tool.Triangulation(face, loc)
if tri is not None:
nb_nodes = tri.NbNodes()
has_nodes_arr = hasattr(tri, "Nodes")
for i in range(1, nb_nodes + 1):
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
p = p.Transformed(loc.Transformation())
verts.append([p.X(), p.Y(), p.Z()])
nb_tris = tri.NbTriangles()
has_tris_arr = hasattr(tri, "Triangles")
for i in range(1, nb_tris + 1):
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
a, b, c = t.Get()
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
v_off += nb_nodes
exp.Next()
if not verts or not faces:
raise RuntimeError("三角化为空,尝试减小 --defl")
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
c = verts.mean(axis=0, keepdims=True)
v0 = verts - c
s = np.max(np.abs(v0))
s = max(s, 1e-12)
return v0 / s * unit_scale
def build_mesh(V_np, F_np, device):
V = torch.from_numpy(V_np).to(device)
F = torch.from_numpy(F_np).to(device)
return Meshes(verts=[V], faces=[F])
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=1)
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
thr = math.cos(math.radians(max(0.0, angle_deg)))
e2f = defaultdict(list)
F = faces.shape[0]
for f in range(F):
i, j, k = faces[f].tolist()
for a, b in ((i, j), (j, k), (k, i)):
e = (a, b) if a < b else (b, a)
e2f[e].append(f)
feat = []
for e, fl in e2f.items():
if len(fl) == 1:
feat.append(e)
elif len(fl) == 2:
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
if cosv <= thr:
feat.append(e)
return feat
def make_camera(view: str, device, dist=2.7):
"""
look_at_view_transform spherical params for engineering views.
"""
if view == "top":
# from +Z to origin (view -Z)
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
elif view == "left":
# from -X to origin (view +X)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
elif view == "front":
# from +Y to origin (view -Y)
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
else:
raise ValueError("view must be one of ['top','left','front']")
return OrthographicCameras(R=R, T=T, device=device)
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=1e-6,
faces_per_pixel=50,
cull_backfaces=True
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
)
with torch.no_grad():
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
rs = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
cull_backfaces=True
)
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
with torch.no_grad():
frags: Fragments = rast(mesh, cameras=cameras)
return frags, frags.zbuf[0, ..., 0]
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
with torch.no_grad():
scr = cameras.transform_points_screen(
points_world[None, ...],
image_size=((image_size, image_size),)
)
return scr[0]
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
H = W = image_size
vis, hid = [], []
scr = project_points_to_screen(cameras, verts_world, image_size)
zmin = zbuf.detach().cpu().numpy()
for i, j in edges_idx:
p0, p1 = scr[i], scr[j]
m = 0.5 * (p0 + p1)
x = int(np.clip(m[0].item(), 0, W - 1))
y = int(np.clip(m[1].item(), 0, H - 1))
z_proj = m[2].item()
z_ref = zmin[y, x]
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
vis.append(seg)
else:
hid.append(seg)
return vis, hid
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
contours = measure.find_contours(alpha, level=threshold)
polys = []
for cnt in contours:
if len(cnt) < 8:
continue
cnt = cnt[::max(1, step)]
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
return polys
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
# silhouette fill (optional)
for poly in silhouettes:
if len(poly) < 3:
continue
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
for i in range(1, len(poly)):
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
path.append("Z")
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
# visible edges
for (x0, y0, x1, y1) in edges_vis:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=stroke
))
# hidden edges
for (x0, y0, x1, y1) in edges_hid:
dwg.add(dwg.line(
(x0 + margin, y0 + margin),
(x1 + margin, y1 + margin),
stroke="black",
stroke_width=max(1.0, 0.75 * stroke),
stroke_dasharray=[6, 6],
opacity=0.85
))
dwg.save()
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
"""
横向拼接 3 张 SVG并额外导出 PNG。
"""
out_dir = os.path.dirname(output_svg)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
view1 = st.fromfile(view1_path)
view2 = st.fromfile(view2_path)
view3 = st.fromfile(view3_path)
r1 = view1.getroot()
r2 = view2.getroot()
r3 = view3.getroot()
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
max_w = max(w1, w2, w3)
max_h = max(h1, h2, h3)
combined_width = max_w * 3 + 40
combined_height = max_h + 20
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
x_offset = 10
y_offset = 10
r1.moveto(x_offset, y_offset)
x_offset += max_w + 20
r2.moveto(x_offset, y_offset)
x_offset += max_w + 20
r3.moveto(x_offset, y_offset)
combined.append([r1, r2, r3])
combined.save(output_svg)
# SVG -> PNG
cairosvg.svg2png(url=output_svg, write_to=output_image)
def step_to_svg(step_path, out_dir,
res: int = 1400, defl: float = 1.5,
ang: float = 65.0, stroke: float = 1.3,
scale: float = 1.0, no_combine: bool = False,
combined_svg: str = None, combined_png: str = None):
os.makedirs(out_dir, exist_ok=True)
device = autodevice()
print(f"[Device] {device}")
# STEP -> mesh
shape = read_step_solid(step_path)
V_np, F_np = triangulate(shape, deflection=defl)
V_np = normalize_mesh_np(V_np, unit_scale=scale)
mesh = build_mesh(V_np, F_np, device)
verts = mesh.verts_packed()
faces = mesh.faces_packed()
fn = compute_face_normals(verts, faces)
feat_edges = extract_feature_edges(faces, fn, angle_deg=ang)
view_paths = {}
for name in ["top", "left", "front"]:
print(f"[View] {name}")
cam = make_camera(name, device)
alpha = render_silhouette(mesh, cam, image_size=res)
_, zbuf = raster_fragments(mesh, cam, image_size=res, faces_per_pixel=1)
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, res, eps=1e-4)
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
out_svg = os.path.join(out_dir, f"{name}.svg")
svg_from_view(out_svg, res, res, polys, e_vis, e_hid,
stroke=stroke, margin=10, fill="none")
view_paths[name] = out_svg
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
if no_combine:
print("[Done] 3 views exported (no combine).")
return None
combined_svg = combined_svg or os.path.join(out_dir, "combined_views.svg")
combined_png = combined_png or os.path.join(out_dir, "combined_views.png")
# 注意这里组合顺序按你第二段示例front / top / left
combine_svgs(
view1_path=view_paths["front"],
view2_path=view_paths["top"],
view3_path=view_paths["left"],
output_svg=combined_svg,
output_image=combined_png
)
return combined_svg, combined_png
if __name__ == "__main__":
# step -1
glb_result, ply_result = image_to_3d(image=["assets/whiteboard_exported_image.png"])
print(f"glb_result : {glb_result}", f"ply_result : {ply_result}")
# step 0
obj_result = glb_to_obj(glb_result)
print(f"obj_result : {obj_result}")
# step 1
step_result = obj_to_step(input_obj=obj_result, output_dir="step_output", script_path="1_obj_to_step.py")
# step 2
out_dir = "svg_output"
combined_svg, combined_png = step_to_svg(step_path=step_result, out_dir=out_dir)
print(f"combined_svg : {combined_svg}", f"combined_png : {combined_png}")

View File

@@ -0,0 +1,15 @@
from typing import *
class GuidanceIntervalSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance with interval.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
return super()._inference_model(model, x_t, t, cond, **kwargs)

View File

@@ -0,0 +1,231 @@
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import math
from easydict import EasyDict as edict
import numpy as np
from ..representations.gaussian import Gaussian
from .sh_utils import eval_sh
import torch.nn.functional as F
from easydict import EasyDict as edict
def intrinsics_to_projection(
intrinsics: torch.Tensor,
near: float,
far: float,
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[0, 0] = 2 * fx
ret[1, 1] = 2 * fy
ret[0, 2] = 2 * cx - 1
ret[1, 2] = - 2 * cy + 1
ret[2, 2] = far / (far - near)
ret[2, 3] = near * far / (near - far)
ret[3, 2] = 1.
return ret
def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
# lazy import
if 'GaussianRasterizer' not in globals():
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
kernel_size = pipe.kernel_size
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
kernel_size=kernel_size,
subpixel_offset=subpixel_offset,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp
)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return edict({"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii})
class GaussianRenderer:
"""
Renderer for the Voxel representation.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}) -> None:
self.pipe = edict({
"kernel_size": 0.1,
"convert_SHs_python": False,
"compute_cov3D_python": False,
"scale_modifier": 1.0,
"debug": False
})
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"bg_color": 'random',
})
self.rendering_options.update(rendering_options)
self.bg_color = None
def render(
self,
gausssian: Gaussian,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None
) -> edict:
"""
Render the gausssian.
Args:
gaussian : gaussianmodule
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
colors_overwrite (torch.Tensor): (N, 3) override color
Returns:
edict containing:
color (torch.Tensor): (3, H, W) rendered color image
"""
resolution = self.rendering_options["resolution"]
near = self.rendering_options["near"]
far = self.rendering_options["far"]
ssaa = self.rendering_options["ssaa"]
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
view = extrinsics
perspective = intrinsics_to_projection(intrinsics, near, far)
camera = torch.inverse(view)[:3, 3]
focalx = intrinsics[0, 0]
focaly = intrinsics[1, 1]
fovx = 2 * torch.atan(0.5 / focalx)
fovy = 2 * torch.atan(0.5 / focaly)
camera_dict = edict({
"image_height": resolution * ssaa,
"image_width": resolution * ssaa,
"FoVx": fovx,
"FoVy": fovy,
"znear": near,
"zfar": far,
"world_view_transform": view.T.contiguous(),
"projection_matrix": perspective.T.contiguous(),
"full_proj_transform": (perspective @ view).T.contiguous(),
"camera_center": camera
})
# Render
render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
if ssaa > 1:
render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
ret = edict({
'color': render_ret['render']
})
return ret

View File

@@ -0,0 +1,133 @@
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import sys
from datetime import datetime
import numpy as np
import random
def inverse_sigmoid(x):
return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1)
else:
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
"""
Copied from Plenoxels
Continuous learning rate decay function. Adapted from JaxNeRF
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
is log-linearly interpolated elsewhere (equivalent to exponential decay).
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
function of lr_delay_mult, such that the initial learning rate is
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
to the normal learning rate when steps>lr_delay_steps.
:param conf: config subtree 'lr' or similar
:param max_steps: int, the number of steps during optimization.
:return HoF which takes step as input
"""
def helper(step):
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
# Disable this parameter
return 0.0
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
)
else:
delay_rate = 1.0
t = np.clip(step / max_steps, 0, 1)
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
return delay_rate * log_lerp
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
R[:, 0, 1] = 2 * (x*y - r*z)
R[:, 0, 2] = 2 * (x*z + r*y)
R[:, 1, 0] = 2 * (x*y + r*z)
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
R[:, 1, 2] = 2 * (y*z - r*x)
R[:, 2, 0] = 2 * (x*z - r*y)
R[:, 2, 1] = 2 * (y*z + r*x)
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
L[:,1,1] = s[:,1]
L[:,2,2] = s[:,2]
L = R @ L
return L
def safe_state(silent):
old_f = sys.stdout
class F:
def __init__(self, silent):
self.silent = silent
def write(self, x):
if not self.silent:
if x.endswith("\n"):
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
else:
old_f.write(x)
def flush(self):
old_f.flush()
sys.stdout = F(silent)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))

View File

@@ -0,0 +1,93 @@
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from ....utils import dist_utils
class ImageConditionedMixin:
"""
Mixin for image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
super().__init__(*args, **kwargs)
self.image_cond_model_name = image_cond_model
self.image_cond_model = None # the model is init lazily
@staticmethod
def prepare_for_training(image_cond_model: str, **kwargs):
"""
Prepare for training.
"""
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
# download the model
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
def _init_image_cond_model(self):
"""
Initialize the image conditioning model.
"""
with dist_utils.local_master_first():
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
dinov2_model.eval().cuda()
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model = {
'model': dinov2_model,
'transform': transform,
}
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Encode the image.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
if self.image_cond_model is None:
self._init_image_cond_model()
image = self.image_cond_model['transform'](image).cuda()
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_inference_cond(cond, **kwargs)
return cond
def vis_cond(self, cond, **kwargs):
"""
Visualize the conditioning data.
"""
return {'image': {'value': cond, 'type': 'image'}}

View File

@@ -0,0 +1,202 @@
import re
import numpy as np
import cv2
import torch
import contextlib
# Dictionary utils
def _dict_merge(dicta, dictb, prefix=''):
"""
Merge two dictionaries.
"""
assert isinstance(dicta, dict), 'input must be a dictionary'
assert isinstance(dictb, dict), 'input must be a dictionary'
dict_ = {}
all_keys = set(dicta.keys()).union(set(dictb.keys()))
for key in all_keys:
if key in dicta.keys() and key in dictb.keys():
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
else:
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
elif key in dicta.keys():
dict_[key] = dicta[key]
else:
dict_[key] = dictb[key]
return dict_
def dict_merge(dicta, dictb):
"""
Merge two dictionaries.
"""
return _dict_merge(dicta, dictb, prefix='')
def dict_foreach(dic, func, special_func={}):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
dic[key] = dict_foreach(dic[key], func)
else:
if key in special_func.keys():
dic[key] = special_func[key](dic[key])
else:
dic[key] = func(dic[key])
return dic
def dict_reduce(dicts, func, special_func={}):
"""
Reduce a list of dictionaries. Leaf values must be scalars.
"""
assert isinstance(dicts, list), 'input must be a list of dictionaries'
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
reduced_dict = {}
for key in all_keys:
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
if isinstance(vlist[0], dict):
reduced_dict[key] = dict_reduce(vlist, func, special_func)
else:
if key in special_func.keys():
reduced_dict[key] = special_func[key](vlist)
else:
reduced_dict[key] = func(vlist)
return reduced_dict
def dict_any(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if dict_any(dic[key], func):
return True
else:
if func(dic[key]):
return True
return False
def dict_all(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if not dict_all(dic[key], func):
return False
else:
if not func(dic[key]):
return False
return True
def dict_flatten(dic, sep='.'):
"""
Flatten a nested dictionary into a dictionary with no nested dictionaries.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
flat_dict = {}
for key in dic.keys():
if isinstance(dic[key], dict):
sub_dict = dict_flatten(dic[key], sep=sep)
for sub_key in sub_dict.keys():
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
else:
flat_dict[key] = dic[key]
return flat_dict
# Context utils
@contextlib.contextmanager
def nested_contexts(*contexts):
with contextlib.ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx())
yield
# Image utils
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
if images[0].ndim == 2:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
else:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# debug utils
def atol(x, y):
"""
Absolute tolerance.
"""
return torch.abs(x - y)
def rtol(x, y):
"""
Relative tolerance.
"""
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
# print utils
def indent(s, n=4):
"""
Indent a string.
"""
lines = s.split('\n')
for i in range(1, len(lines)):
lines[i] = ' ' * n + lines[i]
return '\n'.join(lines)

View File

@@ -0,0 +1,81 @@
from typing import *
import torch
import numpy as np
import torch.utils
class AdaptiveGradClipper:
"""
Adaptive gradient clipping for training.
"""
def __init__(
self,
max_norm=None,
clip_percentile=95.0,
buffer_size=1000,
):
self.max_norm = max_norm
self.clip_percentile = clip_percentile
self.buffer_size = buffer_size
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
self._max_norm = max_norm
self._buffer_ptr = 0
self._buffer_length = 0
def __repr__(self):
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
def state_dict(self):
return {
'grad_norm': self._grad_norm,
'max_norm': self._max_norm,
'buffer_ptr': self._buffer_ptr,
'buffer_length': self._buffer_length,
}
def load_state_dict(self, state_dict):
self._grad_norm = state_dict['grad_norm']
self._max_norm = state_dict['max_norm']
self._buffer_ptr = state_dict['buffer_ptr']
self._buffer_length = state_dict['buffer_length']
def log(self):
return {
'max_norm': self._max_norm,
}
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
"""Clip the gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
max_norm = self._max_norm if self._max_norm is not None else float('inf')
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
if torch.isfinite(grad_norm):
self._grad_norm[self._buffer_ptr] = grad_norm
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
if self._buffer_length == self.buffer_size:
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
return grad_norm