1
This commit is contained in:
28
trellis/representations/radiance_field/strivec.py
Normal file
28
trellis/representations/radiance_field/strivec.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from ..octree import DfsOctree as Octree
|
||||
|
||||
|
||||
class Strivec(Octree):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
aabb: list,
|
||||
sh_degree: int = 0,
|
||||
rank: int = 8,
|
||||
dim: int = 8,
|
||||
device: str = "cuda",
|
||||
):
|
||||
assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
|
||||
self.resolution = resolution
|
||||
depth = int(np.round(np.log2(resolution)))
|
||||
super().__init__(
|
||||
depth=depth,
|
||||
aabb=aabb,
|
||||
sh_degree=sh_degree,
|
||||
primitive="trivec",
|
||||
primitive_config={"rank": rank, "dim": dim},
|
||||
device=device,
|
||||
)
|
||||
Reference in New Issue
Block a user