1
This commit is contained in:
103
dataset_toolkits/datasets/HSSD.py
Normal file
103
dataset_toolkits/datasets/HSSD.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import tarfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import huggingface_hub
|
||||
from utils import get_file_hash
|
||||
|
||||
|
||||
def add_args(parser: argparse.ArgumentParser):
|
||||
pass
|
||||
|
||||
|
||||
def get_metadata(**kwargs):
|
||||
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
|
||||
return metadata
|
||||
|
||||
|
||||
def download(metadata, output_dir, **kwargs):
|
||||
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||
|
||||
# check login
|
||||
try:
|
||||
huggingface_hub.whoami()
|
||||
except:
|
||||
print("\033[93m")
|
||||
print("Haven't logged in to the Hugging Face Hub.")
|
||||
print("Visit https://huggingface.co/settings/tokens to get a token.")
|
||||
print("\033[0m")
|
||||
huggingface_hub.login()
|
||||
|
||||
try:
|
||||
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
|
||||
except:
|
||||
print("\033[93m")
|
||||
print("Error downloading HSSD dataset.")
|
||||
print("Check if you have access to the HSSD dataset.")
|
||||
print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
|
||||
print("\033[0m")
|
||||
|
||||
downloaded = {}
|
||||
metadata = metadata.set_index("file_identifier")
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc="Downloading") as pbar:
|
||||
def worker(instance: str) -> str:
|
||||
try:
|
||||
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
|
||||
sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
|
||||
pbar.update()
|
||||
return sha256
|
||||
except Exception as e:
|
||||
pbar.update()
|
||||
print(f"Error extracting for {instance}: {e}")
|
||||
return None
|
||||
|
||||
sha256s = executor.map(worker, metadata.index)
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
for k, sha256 in zip(metadata.index, sha256s):
|
||||
if sha256 is not None:
|
||||
if sha256 == metadata.loc[k, "sha256"]:
|
||||
downloaded[sha256] = os.path.join('raw', k)
|
||||
else:
|
||||
print(f"Error downloading {k}: sha256s do not match")
|
||||
|
||||
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||
|
||||
|
||||
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
# load metadata
|
||||
metadata = metadata.to_dict('records')
|
||||
|
||||
# processing objects
|
||||
records = []
|
||||
max_workers = max_workers or os.cpu_count()
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||
def worker(metadatum):
|
||||
try:
|
||||
local_path = metadatum['local_path']
|
||||
sha256 = metadatum['sha256']
|
||||
file = os.path.join(output_dir, local_path)
|
||||
record = func(file, sha256)
|
||||
if record is not None:
|
||||
records.append(record)
|
||||
pbar.update()
|
||||
except Exception as e:
|
||||
print(f"Error processing object {sha256}: {e}")
|
||||
pbar.update()
|
||||
|
||||
executor.map(worker, metadata)
|
||||
executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
return pd.DataFrame.from_records(records)
|
||||
555
glb2svg.py
Normal file
555
glb2svg.py
Normal file
@@ -0,0 +1,555 @@
|
||||
import math
|
||||
import secrets
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# ---------- PyTorch3D ----------
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.renderer import (
|
||||
RasterizationSettings, MeshRasterizer,
|
||||
MeshRenderer, BlendParams, SoftSilhouetteShader,
|
||||
look_at_view_transform, OrthographicCameras
|
||||
)
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
|
||||
# ---------- SVG / Image ----------
|
||||
import svgwrite
|
||||
from skimage import measure
|
||||
|
||||
# ---------- pythonocc (OCC) ----------
|
||||
from OCC.Core.STEPControl import STEPControl_Reader
|
||||
from OCC.Core.TopAbs import TopAbs_FACE
|
||||
from OCC.Core.TopExp import TopExp_Explorer
|
||||
from OCC.Core.BRep import BRep_Tool
|
||||
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
||||
from OCC.Core.TopLoc import TopLoc_Location
|
||||
from OCC.Core.ShapeFix import ShapeFix_Shape
|
||||
|
||||
# ---------- Combine SVG ----------
|
||||
import svgutils.transform as st
|
||||
from svgutils.compose import Unit
|
||||
import cairosvg
|
||||
|
||||
from trellis.pipelines import TrellisImageTo3DPipeline
|
||||
from trellis.utils import render_utils, postprocessing_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""single image to 3D"""
|
||||
|
||||
|
||||
def generate_unique_name(original_name: str) -> str:
|
||||
stem, ext = os.path.splitext(original_name)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
random_part = secrets.token_hex(4) # 8位随机十六进制
|
||||
return f"{stem}_{timestamp}_{random_part}{ext}"
|
||||
|
||||
|
||||
def image_to_3d(image: list[str], out_dir: str = "trellis_out", single_image: bool = False, seed: int = 1,
|
||||
steps_sparse: int = 12, cfg_sparse: float = 7.5, steps_slat: int = 12,
|
||||
cfg_slat: float = 3.0, simplify: float = 0.95, texture_size: int = 1024,
|
||||
save_video: bool = False, fps: int = 30, video_gs_name: str = "sample_gs.mp4", video_rf_name: str = "sample_rf.mp4", video_mesh_name: str = "sample_mesh.mp4"):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
# Optional env
|
||||
os.environ.setdefault("SPCONV_ALGO", "native")
|
||||
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
|
||||
pipeline.cuda()
|
||||
|
||||
if single_image:
|
||||
image = Image.open(image[0])
|
||||
outputs = pipeline.run(
|
||||
image,
|
||||
seed=seed,
|
||||
sparse_structure_sampler_params={
|
||||
"steps": steps_sparse,
|
||||
"cfg_strength": cfg_sparse,
|
||||
},
|
||||
slat_sampler_params={
|
||||
"steps": steps_slat,
|
||||
"cfg_strength": cfg_slat,
|
||||
},
|
||||
)
|
||||
else:
|
||||
images = [Image.open(p) for p in image]
|
||||
outputs = pipeline.run_multi_image(
|
||||
images,
|
||||
seed=seed,
|
||||
sparse_structure_sampler_params={
|
||||
"steps": steps_sparse,
|
||||
"cfg_strength": cfg_sparse,
|
||||
},
|
||||
slat_sampler_params={
|
||||
"steps": steps_slat,
|
||||
"cfg_strength": cfg_slat,
|
||||
},
|
||||
)
|
||||
|
||||
if save_video:
|
||||
video = render_utils.render_video(outputs["gaussian"][0])["color"]
|
||||
imageio.mimsave(os.path.join(out_dir, video_gs_name), video, fps=fps)
|
||||
|
||||
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
|
||||
imageio.mimsave(os.path.join(out_dir, video_rf_name), video, fps=fps)
|
||||
|
||||
video = render_utils.render_video(outputs["mesh"][0])["normal"]
|
||||
imageio.mimsave(os.path.join(out_dir, video_mesh_name), video, fps=fps)
|
||||
|
||||
glb_path = os.path.join(out_dir, generate_unique_name("sample.glb"))
|
||||
ply_path = os.path.join(out_dir, generate_unique_name("sample.ply"))
|
||||
|
||||
glb = postprocessing_utils.to_glb(
|
||||
outputs["gaussian"][0],
|
||||
outputs["mesh"][0],
|
||||
simplify=simplify,
|
||||
texture_size=texture_size,
|
||||
)
|
||||
|
||||
# 先写到临时路径
|
||||
glb.export(glb_path)
|
||||
outputs["gaussian"][0].save_ply(ply_path)
|
||||
|
||||
return glb_path, ply_path
|
||||
|
||||
|
||||
"""
|
||||
glb_to_obj
|
||||
"""
|
||||
|
||||
|
||||
def glb_to_obj(glb_input, obj_output_dir="obj_output"):
|
||||
obj_path = os.path.join(obj_output_dir, generate_unique_name("sample.obj"))
|
||||
os.makedirs(obj_output_dir, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"blender",
|
||||
"-b",
|
||||
"-P",
|
||||
"0_glb_to_obj.py",
|
||||
"--",
|
||||
"--input",
|
||||
glb_input,
|
||||
"--output",
|
||||
obj_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
print(result)
|
||||
print("\n")
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Blender failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
||||
)
|
||||
|
||||
return obj_path
|
||||
|
||||
|
||||
"""
|
||||
obj_to_step
|
||||
"""
|
||||
|
||||
|
||||
def obj_to_step(
|
||||
input_obj: str | Path,
|
||||
output_dir: str | Path,
|
||||
script_path: str | Path = "1_obj_to_step.py",
|
||||
timeout: int = 120, # 秒,建议根据实际转换耗时调整
|
||||
freecadcmd_path: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
把 OBJ 转 STEP 的 FreeCAD 命令封装成函数
|
||||
返回 True 表示成功,False 表示失败
|
||||
"""
|
||||
input_obj = Path(input_obj).resolve()
|
||||
output_dir = Path(output_dir).resolve()
|
||||
script_path = Path(script_path).resolve()
|
||||
|
||||
print(f"input_obj : {input_obj}")
|
||||
print(f"output_dir : {output_dir}")
|
||||
print(f"script_path : {script_path}")
|
||||
|
||||
if not input_obj.exists():
|
||||
raise FileNotFoundError(f"输入文件不存在: {input_obj}")
|
||||
if not script_path.exists():
|
||||
raise FileNotFoundError(f"转换脚本不存在: {script_path}")
|
||||
|
||||
# 构造 -c 参数(完全复刻你原来的命令)
|
||||
python_code = f'''import sys
|
||||
sys.argv = ["{script_path.name}", "{input_obj}", "{output_dir}"]
|
||||
exec(open("{script_path}", "r", encoding="utf-8").read())'''
|
||||
|
||||
cmd = [
|
||||
freecadcmd_path or "freecadcmd", # 如果不在 PATH 中,可传入绝对路径
|
||||
"-c",
|
||||
python_code
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True, # 自动抛出 CalledProcessError
|
||||
timeout=timeout,
|
||||
cwd=script_path.parent # 让脚本能找到相对路径文件
|
||||
)
|
||||
|
||||
logger.info(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}")
|
||||
print(f"FreeCAD 转换成功: {input_obj.name} → {output_dir}")
|
||||
if result.stdout:
|
||||
logger.debug(result.stdout)
|
||||
return result.stdout.split()[-1]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"FreeCAD 转换超时(>{timeout}s): {input_obj}")
|
||||
return False
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FreeCAD 执行失败: {e.stderr}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("未知错误")
|
||||
return False
|
||||
|
||||
|
||||
"""
|
||||
step to svg
|
||||
"""
|
||||
|
||||
|
||||
def autodevice():
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def read_step_solid(path: str):
|
||||
reader = STEPControl_Reader()
|
||||
stat = reader.ReadFile(path)
|
||||
if stat != 1:
|
||||
raise RuntimeError(f"STEP 读取失败: {path}")
|
||||
reader.TransferRoots()
|
||||
shape = reader.OneShape()
|
||||
fixer = ShapeFix_Shape(shape)
|
||||
fixer.Perform()
|
||||
return fixer.Shape()
|
||||
|
||||
|
||||
def triangulate(shape, deflection=0.5, angle=0.5):
|
||||
"""OCC 三角化:返回 (V(np.float32 N×3), F(np.int64 M×3))"""
|
||||
BRepMesh_IncrementalMesh(shape, deflection, False, angle, True)
|
||||
verts, faces, v_off = [], [], 0
|
||||
exp = TopExp_Explorer(shape, TopAbs_FACE)
|
||||
while exp.More():
|
||||
face = exp.Current()
|
||||
loc = TopLoc_Location()
|
||||
tri = BRep_Tool.Triangulation(face, loc)
|
||||
if tri is not None:
|
||||
nb_nodes = tri.NbNodes()
|
||||
has_nodes_arr = hasattr(tri, "Nodes")
|
||||
for i in range(1, nb_nodes + 1):
|
||||
p = tri.Nodes().Value(i) if has_nodes_arr else tri.Node(i)
|
||||
p = p.Transformed(loc.Transformation())
|
||||
verts.append([p.X(), p.Y(), p.Z()])
|
||||
nb_tris = tri.NbTriangles()
|
||||
has_tris_arr = hasattr(tri, "Triangles")
|
||||
for i in range(1, nb_tris + 1):
|
||||
t = tri.Triangles().Value(i) if has_tris_arr else tri.Triangle(i)
|
||||
a, b, c = t.Get()
|
||||
faces.append([v_off + a - 1, v_off + b - 1, v_off + c - 1])
|
||||
v_off += nb_nodes
|
||||
exp.Next()
|
||||
if not verts or not faces:
|
||||
raise RuntimeError("三角化为空,尝试减小 --defl")
|
||||
return np.asarray(verts, np.float32), np.asarray(faces, np.int64)
|
||||
|
||||
|
||||
def normalize_mesh_np(verts: np.ndarray, unit_scale=1.0):
|
||||
c = verts.mean(axis=0, keepdims=True)
|
||||
v0 = verts - c
|
||||
s = np.max(np.abs(v0))
|
||||
s = max(s, 1e-12)
|
||||
return v0 / s * unit_scale
|
||||
|
||||
|
||||
def build_mesh(V_np, F_np, device):
|
||||
V = torch.from_numpy(V_np).to(device)
|
||||
F = torch.from_numpy(F_np).to(device)
|
||||
return Meshes(verts=[V], faces=[F])
|
||||
|
||||
|
||||
def compute_face_normals(verts: torch.Tensor, faces: torch.Tensor):
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
n = torch.cross(v1 - v0, v2 - v0, dim=1)
|
||||
return torch.nn.functional.normalize(n, dim=1, eps=1e-12)
|
||||
|
||||
|
||||
def extract_feature_edges(faces: torch.Tensor, face_normals: torch.Tensor, angle_deg: float):
|
||||
"""二面角>=angle_deg 视为特征棱;边界边必取。"""
|
||||
thr = math.cos(math.radians(max(0.0, angle_deg)))
|
||||
e2f = defaultdict(list)
|
||||
F = faces.shape[0]
|
||||
for f in range(F):
|
||||
i, j, k = faces[f].tolist()
|
||||
for a, b in ((i, j), (j, k), (k, i)):
|
||||
e = (a, b) if a < b else (b, a)
|
||||
e2f[e].append(f)
|
||||
feat = []
|
||||
for e, fl in e2f.items():
|
||||
if len(fl) == 1:
|
||||
feat.append(e)
|
||||
elif len(fl) == 2:
|
||||
n0, n1 = face_normals[fl[0]], face_normals[fl[1]]
|
||||
cosv = torch.clamp((n0 * n1).sum(), -1.0, 1.0).item()
|
||||
if cosv <= thr:
|
||||
feat.append(e)
|
||||
return feat
|
||||
|
||||
|
||||
def make_camera(view: str, device, dist=2.7):
|
||||
"""
|
||||
look_at_view_transform spherical params for engineering views.
|
||||
"""
|
||||
if view == "top":
|
||||
# from +Z to origin (view -Z)
|
||||
R, T = look_at_view_transform(dist=dist, elev=90.0, azim=180.0, device=device)
|
||||
elif view == "left":
|
||||
# from -X to origin (view +X)
|
||||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=270.0, device=device)
|
||||
elif view == "front":
|
||||
# from +Y to origin (view -Y)
|
||||
R, T = look_at_view_transform(dist=dist, elev=0.0, azim=180.0, device=device)
|
||||
else:
|
||||
raise ValueError("view must be one of ['top','left','front']")
|
||||
return OrthographicCameras(R=R, T=T, device=device)
|
||||
|
||||
|
||||
def render_silhouette(mesh: Meshes, cameras, image_size=1600):
|
||||
rs = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=1e-6,
|
||||
faces_per_pixel=50,
|
||||
cull_backfaces=True
|
||||
)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rs),
|
||||
shader=SoftSilhouetteShader(blend_params=BlendParams(sigma=1e-4, gamma=1e-4))
|
||||
)
|
||||
with torch.no_grad():
|
||||
img = renderer(mesh, cameras=cameras) # (1,H,W,4)
|
||||
return np.clip(img[0, ..., 3].detach().cpu().numpy(), 0.0, 1.0)
|
||||
|
||||
|
||||
def raster_fragments(mesh: Meshes, cameras, image_size=1600, faces_per_pixel=1):
|
||||
rs = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=faces_per_pixel,
|
||||
cull_backfaces=True
|
||||
)
|
||||
rast = MeshRasterizer(cameras=cameras, raster_settings=rs)
|
||||
with torch.no_grad():
|
||||
frags: Fragments = rast(mesh, cameras=cameras)
|
||||
return frags, frags.zbuf[0, ..., 0]
|
||||
|
||||
|
||||
def project_points_to_screen(cameras, points_world: torch.Tensor, image_size: int):
|
||||
with torch.no_grad():
|
||||
scr = cameras.transform_points_screen(
|
||||
points_world[None, ...],
|
||||
image_size=((image_size, image_size),)
|
||||
)
|
||||
return scr[0]
|
||||
|
||||
|
||||
def edge_visibility_split(edges_idx, verts_world, cameras, zbuf, image_size, eps=1e-4):
|
||||
H = W = image_size
|
||||
vis, hid = [], []
|
||||
scr = project_points_to_screen(cameras, verts_world, image_size)
|
||||
zmin = zbuf.detach().cpu().numpy()
|
||||
for i, j in edges_idx:
|
||||
p0, p1 = scr[i], scr[j]
|
||||
m = 0.5 * (p0 + p1)
|
||||
x = int(np.clip(m[0].item(), 0, W - 1))
|
||||
y = int(np.clip(m[1].item(), 0, H - 1))
|
||||
z_proj = m[2].item()
|
||||
z_ref = zmin[y, x]
|
||||
seg = (p0[0].item(), p0[1].item(), p1[0].item(), p1[1].item())
|
||||
if np.isfinite(z_ref) and (z_proj <= z_ref + eps):
|
||||
vis.append(seg)
|
||||
else:
|
||||
hid.append(seg)
|
||||
return vis, hid
|
||||
|
||||
|
||||
def trace_silhouettes(alpha: np.ndarray, threshold=0.5, step=1):
|
||||
contours = measure.find_contours(alpha, level=threshold)
|
||||
polys = []
|
||||
for cnt in contours:
|
||||
if len(cnt) < 8:
|
||||
continue
|
||||
cnt = cnt[::max(1, step)]
|
||||
polys.append(np.stack([cnt[:, 1], cnt[:, 0]], axis=1)) # (x,y)
|
||||
return polys
|
||||
|
||||
|
||||
def svg_from_view(out_svg, W, H, silhouettes, edges_vis, edges_hid, stroke=2.0, margin=10, fill="none"):
|
||||
dwg = svgwrite.Drawing(out_svg, size=(W + 2 * margin, H + 2 * margin))
|
||||
dwg.add(dwg.rect(insert=(0, 0), size=(W + 2 * margin, H + 2 * margin), fill="none"))
|
||||
|
||||
# silhouette fill (optional)
|
||||
for poly in silhouettes:
|
||||
if len(poly) < 3:
|
||||
continue
|
||||
path = [f"M {poly[0, 0] + margin:.2f} {poly[0, 1] + margin:.2f}"]
|
||||
for i in range(1, len(poly)):
|
||||
path.append(f"L {poly[i, 0] + margin:.2f} {poly[i, 1] + margin:.2f}")
|
||||
path.append("Z")
|
||||
dwg.add(dwg.path(" ".join(path), fill=fill, stroke="none"))
|
||||
|
||||
# visible edges
|
||||
for (x0, y0, x1, y1) in edges_vis:
|
||||
dwg.add(dwg.line(
|
||||
(x0 + margin, y0 + margin),
|
||||
(x1 + margin, y1 + margin),
|
||||
stroke="black",
|
||||
stroke_width=stroke
|
||||
))
|
||||
|
||||
# hidden edges
|
||||
for (x0, y0, x1, y1) in edges_hid:
|
||||
dwg.add(dwg.line(
|
||||
(x0 + margin, y0 + margin),
|
||||
(x1 + margin, y1 + margin),
|
||||
stroke="black",
|
||||
stroke_width=max(1.0, 0.75 * stroke),
|
||||
stroke_dasharray=[6, 6],
|
||||
opacity=0.85
|
||||
))
|
||||
|
||||
dwg.save()
|
||||
|
||||
|
||||
def combine_svgs(view1_path, view2_path, view3_path, output_svg, output_image):
|
||||
"""
|
||||
横向拼接 3 张 SVG,并额外导出 PNG。
|
||||
"""
|
||||
out_dir = os.path.dirname(output_svg)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
view1 = st.fromfile(view1_path)
|
||||
view2 = st.fromfile(view2_path)
|
||||
view3 = st.fromfile(view3_path)
|
||||
|
||||
r1 = view1.getroot()
|
||||
r2 = view2.getroot()
|
||||
r3 = view3.getroot()
|
||||
|
||||
w1, h1 = [float(x.replace("px", "")) for x in view1.get_size()]
|
||||
w2, h2 = [float(x.replace("px", "")) for x in view2.get_size()]
|
||||
w3, h3 = [float(x.replace("px", "")) for x in view3.get_size()]
|
||||
|
||||
max_w = max(w1, w2, w3)
|
||||
max_h = max(h1, h2, h3)
|
||||
|
||||
combined_width = max_w * 3 + 40
|
||||
combined_height = max_h + 20
|
||||
|
||||
combined = st.SVGFigure(Unit(combined_width), Unit(combined_height))
|
||||
|
||||
x_offset = 10
|
||||
y_offset = 10
|
||||
|
||||
r1.moveto(x_offset, y_offset)
|
||||
x_offset += max_w + 20
|
||||
r2.moveto(x_offset, y_offset)
|
||||
x_offset += max_w + 20
|
||||
r3.moveto(x_offset, y_offset)
|
||||
|
||||
combined.append([r1, r2, r3])
|
||||
combined.save(output_svg)
|
||||
|
||||
# SVG -> PNG
|
||||
cairosvg.svg2png(url=output_svg, write_to=output_image)
|
||||
|
||||
|
||||
def step_to_svg(step_path, out_dir,
|
||||
res: int = 1400, defl: float = 1.5,
|
||||
ang: float = 65.0, stroke: float = 1.3,
|
||||
scale: float = 1.0, no_combine: bool = False,
|
||||
combined_svg: str = None, combined_png: str = None):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
device = autodevice()
|
||||
print(f"[Device] {device}")
|
||||
|
||||
# STEP -> mesh
|
||||
shape = read_step_solid(step_path)
|
||||
V_np, F_np = triangulate(shape, deflection=defl)
|
||||
V_np = normalize_mesh_np(V_np, unit_scale=scale)
|
||||
|
||||
mesh = build_mesh(V_np, F_np, device)
|
||||
verts = mesh.verts_packed()
|
||||
faces = mesh.faces_packed()
|
||||
fn = compute_face_normals(verts, faces)
|
||||
feat_edges = extract_feature_edges(faces, fn, angle_deg=ang)
|
||||
|
||||
view_paths = {}
|
||||
for name in ["top", "left", "front"]:
|
||||
print(f"[View] {name}")
|
||||
cam = make_camera(name, device)
|
||||
alpha = render_silhouette(mesh, cam, image_size=res)
|
||||
_, zbuf = raster_fragments(mesh, cam, image_size=res, faces_per_pixel=1)
|
||||
|
||||
e_vis, e_hid = edge_visibility_split(feat_edges, verts, cam, zbuf, res, eps=1e-4)
|
||||
polys = trace_silhouettes(alpha, threshold=0.5, step=1)
|
||||
|
||||
out_svg = os.path.join(out_dir, f"{name}.svg")
|
||||
svg_from_view(out_svg, res, res, polys, e_vis, e_hid,
|
||||
stroke=stroke, margin=10, fill="none")
|
||||
view_paths[name] = out_svg
|
||||
print(f" -> {out_svg} (edges: vis={len(e_vis)}, hid={len(e_hid)}, polys={len(polys)})")
|
||||
|
||||
if no_combine:
|
||||
print("[Done] 3 views exported (no combine).")
|
||||
return None
|
||||
|
||||
combined_svg = combined_svg or os.path.join(out_dir, "combined_views.svg")
|
||||
combined_png = combined_png or os.path.join(out_dir, "combined_views.png")
|
||||
|
||||
# 注意:这里组合顺序按你第二段示例:front / top / left
|
||||
combine_svgs(
|
||||
view1_path=view_paths["front"],
|
||||
view2_path=view_paths["top"],
|
||||
view3_path=view_paths["left"],
|
||||
output_svg=combined_svg,
|
||||
output_image=combined_png
|
||||
)
|
||||
return combined_svg, combined_png
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# step -1
|
||||
glb_result, ply_result = image_to_3d(image=["assets/whiteboard_exported_image.png"])
|
||||
print(f"glb_result : {glb_result}", f"ply_result : {ply_result}")
|
||||
|
||||
# step 0
|
||||
obj_result = glb_to_obj(glb_result)
|
||||
print(f"obj_result : {obj_result}")
|
||||
|
||||
# step 1
|
||||
step_result = obj_to_step(input_obj=obj_result, output_dir="step_output", script_path="1_obj_to_step.py")
|
||||
|
||||
# step 2
|
||||
out_dir = "svg_output"
|
||||
combined_svg, combined_png = step_to_svg(step_path=step_result, out_dir=out_dir)
|
||||
print(f"combined_svg : {combined_svg}", f"combined_png : {combined_png}")
|
||||
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class GuidanceIntervalSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance with interval.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
|
||||
if cfg_interval[0] <= t <= cfg_interval[1]:
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
else:
|
||||
return super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
231
trellis/renderers/gaussian_render.py
Normal file
231
trellis/renderers/gaussian_render.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#
|
||||
# Copyright (C) 2023, Inria
|
||||
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
||||
# All rights reserved.
|
||||
#
|
||||
# This software is free for non-commercial, research and evaluation use
|
||||
# under the terms of the LICENSE.md file.
|
||||
#
|
||||
# For inquiries contact george.drettakis@inria.fr
|
||||
#
|
||||
|
||||
import torch
|
||||
import math
|
||||
from easydict import EasyDict as edict
|
||||
import numpy as np
|
||||
from ..representations.gaussian import Gaussian
|
||||
from .sh_utils import eval_sh
|
||||
import torch.nn.functional as F
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
def intrinsics_to_projection(
|
||||
intrinsics: torch.Tensor,
|
||||
near: float,
|
||||
far: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
OpenCV intrinsics to OpenGL perspective matrix
|
||||
|
||||
Args:
|
||||
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
|
||||
near (float): near plane to clip
|
||||
far (float): far plane to clip
|
||||
Returns:
|
||||
(torch.Tensor): [4, 4] OpenGL perspective matrix
|
||||
"""
|
||||
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
|
||||
ret[0, 0] = 2 * fx
|
||||
ret[1, 1] = 2 * fy
|
||||
ret[0, 2] = 2 * cx - 1
|
||||
ret[1, 2] = - 2 * cy + 1
|
||||
ret[2, 2] = far / (far - near)
|
||||
ret[2, 3] = near * far / (near - far)
|
||||
ret[3, 2] = 1.
|
||||
return ret
|
||||
|
||||
|
||||
def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
||||
"""
|
||||
Render the scene.
|
||||
|
||||
Background tensor (bg_color) must be on GPU!
|
||||
"""
|
||||
# lazy import
|
||||
if 'GaussianRasterizer' not in globals():
|
||||
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
|
||||
|
||||
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
||||
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
||||
try:
|
||||
screenspace_points.retain_grad()
|
||||
except:
|
||||
pass
|
||||
# Set up rasterization configuration
|
||||
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
||||
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
||||
|
||||
kernel_size = pipe.kernel_size
|
||||
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
|
||||
|
||||
raster_settings = GaussianRasterizationSettings(
|
||||
image_height=int(viewpoint_camera.image_height),
|
||||
image_width=int(viewpoint_camera.image_width),
|
||||
tanfovx=tanfovx,
|
||||
tanfovy=tanfovy,
|
||||
kernel_size=kernel_size,
|
||||
subpixel_offset=subpixel_offset,
|
||||
bg=bg_color,
|
||||
scale_modifier=scaling_modifier,
|
||||
viewmatrix=viewpoint_camera.world_view_transform,
|
||||
projmatrix=viewpoint_camera.full_proj_transform,
|
||||
sh_degree=pc.active_sh_degree,
|
||||
campos=viewpoint_camera.camera_center,
|
||||
prefiltered=False,
|
||||
debug=pipe.debug
|
||||
)
|
||||
|
||||
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
||||
|
||||
means3D = pc.get_xyz
|
||||
means2D = screenspace_points
|
||||
opacity = pc.get_opacity
|
||||
|
||||
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
||||
# scaling / rotation by the rasterizer.
|
||||
scales = None
|
||||
rotations = None
|
||||
cov3D_precomp = None
|
||||
if pipe.compute_cov3D_python:
|
||||
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
||||
else:
|
||||
scales = pc.get_scaling
|
||||
rotations = pc.get_rotation
|
||||
|
||||
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
||||
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
||||
shs = None
|
||||
colors_precomp = None
|
||||
if override_color is None:
|
||||
if pipe.convert_SHs_python:
|
||||
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
|
||||
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
|
||||
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
|
||||
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
||||
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
||||
else:
|
||||
shs = pc.get_features
|
||||
else:
|
||||
colors_precomp = override_color
|
||||
|
||||
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
||||
rendered_image, radii = rasterizer(
|
||||
means3D = means3D,
|
||||
means2D = means2D,
|
||||
shs = shs,
|
||||
colors_precomp = colors_precomp,
|
||||
opacities = opacity,
|
||||
scales = scales,
|
||||
rotations = rotations,
|
||||
cov3D_precomp = cov3D_precomp
|
||||
)
|
||||
|
||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||
# They will be excluded from value updates used in the splitting criteria.
|
||||
return edict({"render": rendered_image,
|
||||
"viewspace_points": screenspace_points,
|
||||
"visibility_filter" : radii > 0,
|
||||
"radii": radii})
|
||||
|
||||
|
||||
class GaussianRenderer:
|
||||
"""
|
||||
Renderer for the Voxel representation.
|
||||
|
||||
Args:
|
||||
rendering_options (dict): Rendering options.
|
||||
"""
|
||||
|
||||
def __init__(self, rendering_options={}) -> None:
|
||||
self.pipe = edict({
|
||||
"kernel_size": 0.1,
|
||||
"convert_SHs_python": False,
|
||||
"compute_cov3D_python": False,
|
||||
"scale_modifier": 1.0,
|
||||
"debug": False
|
||||
})
|
||||
self.rendering_options = edict({
|
||||
"resolution": None,
|
||||
"near": None,
|
||||
"far": None,
|
||||
"ssaa": 1,
|
||||
"bg_color": 'random',
|
||||
})
|
||||
self.rendering_options.update(rendering_options)
|
||||
self.bg_color = None
|
||||
|
||||
def render(
|
||||
self,
|
||||
gausssian: Gaussian,
|
||||
extrinsics: torch.Tensor,
|
||||
intrinsics: torch.Tensor,
|
||||
colors_overwrite: torch.Tensor = None
|
||||
) -> edict:
|
||||
"""
|
||||
Render the gausssian.
|
||||
|
||||
Args:
|
||||
gaussian : gaussianmodule
|
||||
extrinsics (torch.Tensor): (4, 4) camera extrinsics
|
||||
intrinsics (torch.Tensor): (3, 3) camera intrinsics
|
||||
colors_overwrite (torch.Tensor): (N, 3) override color
|
||||
|
||||
Returns:
|
||||
edict containing:
|
||||
color (torch.Tensor): (3, H, W) rendered color image
|
||||
"""
|
||||
resolution = self.rendering_options["resolution"]
|
||||
near = self.rendering_options["near"]
|
||||
far = self.rendering_options["far"]
|
||||
ssaa = self.rendering_options["ssaa"]
|
||||
|
||||
if self.rendering_options["bg_color"] == 'random':
|
||||
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
|
||||
if np.random.rand() < 0.5:
|
||||
self.bg_color += 1
|
||||
else:
|
||||
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
|
||||
|
||||
view = extrinsics
|
||||
perspective = intrinsics_to_projection(intrinsics, near, far)
|
||||
camera = torch.inverse(view)[:3, 3]
|
||||
focalx = intrinsics[0, 0]
|
||||
focaly = intrinsics[1, 1]
|
||||
fovx = 2 * torch.atan(0.5 / focalx)
|
||||
fovy = 2 * torch.atan(0.5 / focaly)
|
||||
|
||||
camera_dict = edict({
|
||||
"image_height": resolution * ssaa,
|
||||
"image_width": resolution * ssaa,
|
||||
"FoVx": fovx,
|
||||
"FoVy": fovy,
|
||||
"znear": near,
|
||||
"zfar": far,
|
||||
"world_view_transform": view.T.contiguous(),
|
||||
"projection_matrix": perspective.T.contiguous(),
|
||||
"full_proj_transform": (perspective @ view).T.contiguous(),
|
||||
"camera_center": camera
|
||||
})
|
||||
|
||||
# Render
|
||||
render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
|
||||
|
||||
if ssaa > 1:
|
||||
render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||
|
||||
ret = edict({
|
||||
'color': render_ret['render']
|
||||
})
|
||||
return ret
|
||||
133
trellis/representations/gaussian/general_utils.py
Normal file
133
trellis/representations/gaussian/general_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#
|
||||
# Copyright (C) 2023, Inria
|
||||
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
||||
# All rights reserved.
|
||||
#
|
||||
# This software is free for non-commercial, research and evaluation use
|
||||
# under the terms of the LICENSE.md file.
|
||||
#
|
||||
# For inquiries contact george.drettakis@inria.fr
|
||||
#
|
||||
|
||||
import torch
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def inverse_sigmoid(x):
|
||||
return torch.log(x/(1-x))
|
||||
|
||||
def PILtoTorch(pil_image, resolution):
|
||||
resized_image_PIL = pil_image.resize(resolution)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
||||
if len(resized_image.shape) == 3:
|
||||
return resized_image.permute(2, 0, 1)
|
||||
else:
|
||||
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
||||
|
||||
def get_expon_lr_func(
|
||||
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
||||
):
|
||||
"""
|
||||
Copied from Plenoxels
|
||||
|
||||
Continuous learning rate decay function. Adapted from JaxNeRF
|
||||
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
||||
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
||||
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
||||
function of lr_delay_mult, such that the initial learning rate is
|
||||
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
||||
to the normal learning rate when steps>lr_delay_steps.
|
||||
:param conf: config subtree 'lr' or similar
|
||||
:param max_steps: int, the number of steps during optimization.
|
||||
:return HoF which takes step as input
|
||||
"""
|
||||
|
||||
def helper(step):
|
||||
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
||||
# Disable this parameter
|
||||
return 0.0
|
||||
if lr_delay_steps > 0:
|
||||
# A kind of reverse cosine decay.
|
||||
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
||||
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
||||
)
|
||||
else:
|
||||
delay_rate = 1.0
|
||||
t = np.clip(step / max_steps, 0, 1)
|
||||
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
||||
return delay_rate * log_lerp
|
||||
|
||||
return helper
|
||||
|
||||
def strip_lowerdiag(L):
|
||||
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
||||
|
||||
uncertainty[:, 0] = L[:, 0, 0]
|
||||
uncertainty[:, 1] = L[:, 0, 1]
|
||||
uncertainty[:, 2] = L[:, 0, 2]
|
||||
uncertainty[:, 3] = L[:, 1, 1]
|
||||
uncertainty[:, 4] = L[:, 1, 2]
|
||||
uncertainty[:, 5] = L[:, 2, 2]
|
||||
return uncertainty
|
||||
|
||||
def strip_symmetric(sym):
|
||||
return strip_lowerdiag(sym)
|
||||
|
||||
def build_rotation(r):
|
||||
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
||||
|
||||
q = r / norm[:, None]
|
||||
|
||||
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
||||
|
||||
r = q[:, 0]
|
||||
x = q[:, 1]
|
||||
y = q[:, 2]
|
||||
z = q[:, 3]
|
||||
|
||||
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
||||
R[:, 0, 1] = 2 * (x*y - r*z)
|
||||
R[:, 0, 2] = 2 * (x*z + r*y)
|
||||
R[:, 1, 0] = 2 * (x*y + r*z)
|
||||
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
||||
R[:, 1, 2] = 2 * (y*z - r*x)
|
||||
R[:, 2, 0] = 2 * (x*z - r*y)
|
||||
R[:, 2, 1] = 2 * (y*z + r*x)
|
||||
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
||||
return R
|
||||
|
||||
def build_scaling_rotation(s, r):
|
||||
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
||||
R = build_rotation(r)
|
||||
|
||||
L[:,0,0] = s[:,0]
|
||||
L[:,1,1] = s[:,1]
|
||||
L[:,2,2] = s[:,2]
|
||||
|
||||
L = R @ L
|
||||
return L
|
||||
|
||||
def safe_state(silent):
|
||||
old_f = sys.stdout
|
||||
class F:
|
||||
def __init__(self, silent):
|
||||
self.silent = silent
|
||||
|
||||
def write(self, x):
|
||||
if not self.silent:
|
||||
if x.endswith("\n"):
|
||||
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
||||
else:
|
||||
old_f.write(x)
|
||||
|
||||
def flush(self):
|
||||
old_f.flush()
|
||||
|
||||
sys.stdout = F(silent)
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.set_device(torch.device("cuda:0"))
|
||||
93
trellis/trainers/flow_matching/mixins/image_conditioned.py
Normal file
93
trellis/trainers/flow_matching/mixins/image_conditioned.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ....utils import dist_utils
|
||||
|
||||
|
||||
class ImageConditionedMixin:
|
||||
"""
|
||||
Mixin for image-conditioned models.
|
||||
|
||||
Args:
|
||||
image_cond_model: The image conditioning model.
|
||||
"""
|
||||
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.image_cond_model_name = image_cond_model
|
||||
self.image_cond_model = None # the model is init lazily
|
||||
|
||||
@staticmethod
|
||||
def prepare_for_training(image_cond_model: str, **kwargs):
|
||||
"""
|
||||
Prepare for training.
|
||||
"""
|
||||
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
|
||||
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
|
||||
# download the model
|
||||
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
|
||||
|
||||
def _init_image_cond_model(self):
|
||||
"""
|
||||
Initialize the image conditioning model.
|
||||
"""
|
||||
with dist_utils.local_master_first():
|
||||
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
|
||||
dinov2_model.eval().cuda()
|
||||
transform = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
self.image_cond_model = {
|
||||
'model': dinov2_model,
|
||||
'transform': transform,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
|
||||
"""
|
||||
Encode the image.
|
||||
"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
|
||||
elif isinstance(image, list):
|
||||
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
|
||||
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
||||
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
||||
image = torch.stack(image).cuda()
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of image: {type(image)}")
|
||||
|
||||
if self.image_cond_model is None:
|
||||
self._init_image_cond_model()
|
||||
image = self.image_cond_model['transform'](image).cuda()
|
||||
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
|
||||
patchtokens = F.layer_norm(features, features.shape[-1:])
|
||||
return patchtokens
|
||||
|
||||
def get_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
cond = self.encode_image(cond)
|
||||
kwargs['neg_cond'] = torch.zeros_like(cond)
|
||||
cond = super().get_cond(cond, **kwargs)
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
cond = self.encode_image(cond)
|
||||
kwargs['neg_cond'] = torch.zeros_like(cond)
|
||||
cond = super().get_inference_cond(cond, **kwargs)
|
||||
return cond
|
||||
|
||||
def vis_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Visualize the conditioning data.
|
||||
"""
|
||||
return {'image': {'value': cond, 'type': 'image'}}
|
||||
202
trellis/utils/general_utils.py
Normal file
202
trellis/utils/general_utils.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
|
||||
# Dictionary utils
|
||||
def _dict_merge(dicta, dictb, prefix=''):
|
||||
"""
|
||||
Merge two dictionaries.
|
||||
"""
|
||||
assert isinstance(dicta, dict), 'input must be a dictionary'
|
||||
assert isinstance(dictb, dict), 'input must be a dictionary'
|
||||
dict_ = {}
|
||||
all_keys = set(dicta.keys()).union(set(dictb.keys()))
|
||||
for key in all_keys:
|
||||
if key in dicta.keys() and key in dictb.keys():
|
||||
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
|
||||
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
|
||||
else:
|
||||
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
|
||||
elif key in dicta.keys():
|
||||
dict_[key] = dicta[key]
|
||||
else:
|
||||
dict_[key] = dictb[key]
|
||||
return dict_
|
||||
|
||||
|
||||
def dict_merge(dicta, dictb):
|
||||
"""
|
||||
Merge two dictionaries.
|
||||
"""
|
||||
return _dict_merge(dicta, dictb, prefix='')
|
||||
|
||||
|
||||
def dict_foreach(dic, func, special_func={}):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
dic[key] = dict_foreach(dic[key], func)
|
||||
else:
|
||||
if key in special_func.keys():
|
||||
dic[key] = special_func[key](dic[key])
|
||||
else:
|
||||
dic[key] = func(dic[key])
|
||||
return dic
|
||||
|
||||
|
||||
def dict_reduce(dicts, func, special_func={}):
|
||||
"""
|
||||
Reduce a list of dictionaries. Leaf values must be scalars.
|
||||
"""
|
||||
assert isinstance(dicts, list), 'input must be a list of dictionaries'
|
||||
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
|
||||
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
|
||||
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
|
||||
reduced_dict = {}
|
||||
for key in all_keys:
|
||||
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
|
||||
if isinstance(vlist[0], dict):
|
||||
reduced_dict[key] = dict_reduce(vlist, func, special_func)
|
||||
else:
|
||||
if key in special_func.keys():
|
||||
reduced_dict[key] = special_func[key](vlist)
|
||||
else:
|
||||
reduced_dict[key] = func(vlist)
|
||||
return reduced_dict
|
||||
|
||||
|
||||
def dict_any(dic, func):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
if dict_any(dic[key], func):
|
||||
return True
|
||||
else:
|
||||
if func(dic[key]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def dict_all(dic, func):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
if not dict_all(dic[key], func):
|
||||
return False
|
||||
else:
|
||||
if not func(dic[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def dict_flatten(dic, sep='.'):
|
||||
"""
|
||||
Flatten a nested dictionary into a dictionary with no nested dictionaries.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
flat_dict = {}
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
sub_dict = dict_flatten(dic[key], sep=sep)
|
||||
for sub_key in sub_dict.keys():
|
||||
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
|
||||
else:
|
||||
flat_dict[key] = dic[key]
|
||||
return flat_dict
|
||||
|
||||
|
||||
# Context utils
|
||||
@contextlib.contextmanager
|
||||
def nested_contexts(*contexts):
|
||||
with contextlib.ExitStack() as stack:
|
||||
for ctx in contexts:
|
||||
stack.enter_context(ctx())
|
||||
yield
|
||||
|
||||
|
||||
# Image utils
|
||||
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
|
||||
num_images = len(images)
|
||||
if nrow is None and ncol is None:
|
||||
if aspect_ratio is not None:
|
||||
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
|
||||
else:
|
||||
nrow = int(np.sqrt(num_images))
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
elif nrow is None and ncol is not None:
|
||||
nrow = (num_images + ncol - 1) // ncol
|
||||
elif nrow is not None and ncol is None:
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
else:
|
||||
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
|
||||
|
||||
if images[0].ndim == 2:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
|
||||
else:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
|
||||
for i, img in enumerate(images):
|
||||
row = i // ncol
|
||||
col = i % ncol
|
||||
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
|
||||
return grid
|
||||
|
||||
|
||||
def notes_on_image(img, notes=None):
|
||||
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
if notes is not None:
|
||||
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def save_image_with_notes(img, path, notes=None):
|
||||
"""
|
||||
Save an image with notes.
|
||||
"""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy().transpose(1, 2, 0)
|
||||
if img.dtype == np.float32 or img.dtype == np.float64:
|
||||
img = np.clip(img * 255, 0, 255).astype(np.uint8)
|
||||
img = notes_on_image(img, notes)
|
||||
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
|
||||
# debug utils
|
||||
|
||||
def atol(x, y):
|
||||
"""
|
||||
Absolute tolerance.
|
||||
"""
|
||||
return torch.abs(x - y)
|
||||
|
||||
|
||||
def rtol(x, y):
|
||||
"""
|
||||
Relative tolerance.
|
||||
"""
|
||||
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
|
||||
|
||||
|
||||
# print utils
|
||||
def indent(s, n=4):
|
||||
"""
|
||||
Indent a string.
|
||||
"""
|
||||
lines = s.split('\n')
|
||||
for i in range(1, len(lines)):
|
||||
lines[i] = ' ' * n + lines[i]
|
||||
return '\n'.join(lines)
|
||||
|
||||
81
trellis/utils/grad_clip_utils.py
Normal file
81
trellis/utils/grad_clip_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils
|
||||
|
||||
|
||||
class AdaptiveGradClipper:
|
||||
"""
|
||||
Adaptive gradient clipping for training.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_norm=None,
|
||||
clip_percentile=95.0,
|
||||
buffer_size=1000,
|
||||
):
|
||||
self.max_norm = max_norm
|
||||
self.clip_percentile = clip_percentile
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._max_norm = max_norm
|
||||
self._buffer_ptr = 0
|
||||
self._buffer_length = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'grad_norm': self._grad_norm,
|
||||
'max_norm': self._max_norm,
|
||||
'buffer_ptr': self._buffer_ptr,
|
||||
'buffer_length': self._buffer_length,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._grad_norm = state_dict['grad_norm']
|
||||
self._max_norm = state_dict['max_norm']
|
||||
self._buffer_ptr = state_dict['buffer_ptr']
|
||||
self._buffer_length = state_dict['buffer_length']
|
||||
|
||||
def log(self):
|
||||
return {
|
||||
'max_norm': self._max_norm,
|
||||
}
|
||||
|
||||
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
||||
"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over all gradients together, as if they were
|
||||
concatenated into a single vector. Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
||||
|
||||
if torch.isfinite(grad_norm):
|
||||
self._grad_norm[self._buffer_ptr] = grad_norm
|
||||
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
||||
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
||||
if self._buffer_length == self.buffer_size:
|
||||
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
||||
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
||||
|
||||
return grad_norm
|
||||
Reference in New Issue
Block a user