1
This commit is contained in:
52
dataset_toolkits/download.py
Normal file
52
dataset_toolkits/download.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import copy
|
||||
import sys
|
||||
import importlib
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
dataset_utils.add_args(parser)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
|
||||
os.makedirs(opt.output_dir, exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||
raise ValueError('metadata.csv not found')
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
if opt.instances is None:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if 'local_path' in metadata.columns:
|
||||
metadata = metadata[metadata['local_path'].isna()]
|
||||
else:
|
||||
if os.path.exists(opt.instances):
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
else:
|
||||
instances = opt.instances.split(',')
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
|
||||
print(f'Processing {len(metadata)} objects...')
|
||||
|
||||
# process objects
|
||||
downloaded = dataset_utils.download(metadata, **opt)
|
||||
downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)
|
||||
127
dataset_toolkits/encode_latent.py
Normal file
127
dataset_toolkits/encode_latent.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
import copy
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
import trellis.models as models
|
||||
import trellis.modules.sparse as sp
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
|
||||
help='Feature model')
|
||||
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
|
||||
help='Pretrained encoder model')
|
||||
parser.add_argument('--model_root', type=str, default='results',
|
||||
help='Root directory of models')
|
||||
parser.add_argument('--enc_model', type=str, default=None,
|
||||
help='Encoder model. if specified, use this model instead of pretrained model')
|
||||
parser.add_argument('--ckpt', type=str, default=None,
|
||||
help='Checkpoint to load')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
|
||||
if opt.enc_model is None:
|
||||
latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
|
||||
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
||||
else:
|
||||
latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
|
||||
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
||||
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
||||
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
||||
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
||||
encoder.eval()
|
||||
print(f'Loaded model from {ckpt_path}')
|
||||
|
||||
os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
else:
|
||||
raise ValueError('metadata.csv not found')
|
||||
if opt.instances is not None:
|
||||
with open(opt.instances, 'r') as f:
|
||||
sha256s = [line.strip() for line in f]
|
||||
metadata = metadata[metadata['sha256'].isin(sha256s)]
|
||||
else:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
|
||||
if f'latent_{latent_name}' in metadata.columns:
|
||||
metadata = metadata[metadata[f'latent_{latent_name}'] == False]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
for sha256 in copy.copy(sha256s):
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
|
||||
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
||||
sha256s.remove(sha256)
|
||||
|
||||
# encode latents
|
||||
load_queue = Queue(maxsize=4)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
||||
def loader(sha256):
|
||||
try:
|
||||
feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
|
||||
load_queue.put((sha256, feats))
|
||||
except Exception as e:
|
||||
print(f"Error loading features for {sha256}: {e}")
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack):
|
||||
save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
||||
sha256, feats = load_queue.get()
|
||||
feats = sp.SparseTensor(
|
||||
feats = torch.from_numpy(feats['patchtokens']).float(),
|
||||
coords = torch.cat([
|
||||
torch.zeros(feats['patchtokens'].shape[0], 1).int(),
|
||||
torch.from_numpy(feats['indices']).int(),
|
||||
], dim=1),
|
||||
).cuda()
|
||||
latent = encoder(feats, sample_posterior=False)
|
||||
assert torch.isfinite(latent.feats).all(), "Non-finite latent"
|
||||
pack = {
|
||||
'feats': latent.feats.cpu().numpy().astype(np.float32),
|
||||
'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
||||
}
|
||||
saver_executor.submit(saver, sha256, pack)
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
|
||||
128
dataset_toolkits/encode_ss_latent.py
Normal file
128
dataset_toolkits/encode_ss_latent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
import copy
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import utils3d
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
import trellis.models as models
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_voxels(instance):
|
||||
position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
|
||||
coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
|
||||
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
|
||||
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
||||
return ss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
|
||||
help='Pretrained encoder model')
|
||||
parser.add_argument('--model_root', type=str, default='results',
|
||||
help='Root directory of models')
|
||||
parser.add_argument('--enc_model', type=str, default=None,
|
||||
help='Encoder model. if specified, use this model instead of pretrained model')
|
||||
parser.add_argument('--ckpt', type=str, default=None,
|
||||
help='Checkpoint to load')
|
||||
parser.add_argument('--resolution', type=int, default=64,
|
||||
help='Resolution')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
|
||||
if opt.enc_model is None:
|
||||
latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
|
||||
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
||||
else:
|
||||
latent_name = f'{opt.enc_model}_{opt.ckpt}'
|
||||
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
||||
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
||||
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
||||
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
||||
encoder.eval()
|
||||
print(f'Loaded model from {ckpt_path}')
|
||||
|
||||
os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
|
||||
|
||||
# get file list
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
else:
|
||||
raise ValueError('metadata.csv not found')
|
||||
if opt.instances is not None:
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
else:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
metadata = metadata[metadata['voxelized'] == True]
|
||||
if f'ss_latent_{latent_name}' in metadata.columns:
|
||||
metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
for sha256 in copy.copy(sha256s):
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
|
||||
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
||||
sha256s.remove(sha256)
|
||||
|
||||
# encode latents
|
||||
load_queue = Queue(maxsize=4)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
||||
def loader(sha256):
|
||||
try:
|
||||
ss = get_voxels(sha256)[None].float()
|
||||
load_queue.put((sha256, ss))
|
||||
except Exception as e:
|
||||
print(f"Error loading features for {sha256}: {e}")
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack):
|
||||
save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
||||
sha256, ss = load_queue.get()
|
||||
ss = ss.cuda().float()
|
||||
latent = encoder(ss, sample_posterior=False)
|
||||
assert torch.isfinite(latent).all(), "Non-finite latent"
|
||||
pack = {
|
||||
'mean': latent[0].cpu().numpy(),
|
||||
}
|
||||
saver_executor.submit(saver, sha256, pack)
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)
|
||||
179
dataset_toolkits/extract_feature.py
Normal file
179
dataset_toolkits/extract_feature.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import copy
|
||||
import sys
|
||||
import json
|
||||
import importlib
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import utils3d
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_data(frames, sha256):
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
def worker(view):
|
||||
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
except:
|
||||
print(f"Error loading image {image_path}")
|
||||
return None
|
||||
image = image.resize((518, 518), Image.Resampling.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255
|
||||
image = image[:, :, :3] * image[:, :, 3:]
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
||||
|
||||
c2w = torch.tensor(view['transform_matrix'])
|
||||
c2w[:3, 1:3] *= -1
|
||||
extrinsics = torch.inverse(c2w)
|
||||
fov = view['camera_angle_x']
|
||||
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
||||
|
||||
return {
|
||||
'image': image,
|
||||
'extrinsics': extrinsics,
|
||||
'intrinsics': intrinsics
|
||||
}
|
||||
|
||||
datas = executor.map(worker, frames)
|
||||
for data in datas:
|
||||
if data is not None:
|
||||
yield data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||
help='Filter objects with aesthetic score lower than this value')
|
||||
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
|
||||
help='Feature extraction model')
|
||||
parser.add_argument('--instances', type=str, default=None,
|
||||
help='Instances to process')
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--rank', type=int, default=0)
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
opt = edict(vars(opt))
|
||||
|
||||
feature_name = opt.model
|
||||
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
|
||||
|
||||
# load model
|
||||
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
|
||||
dinov2_model.eval().cuda()
|
||||
transform = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
n_patch = 518 // 14
|
||||
|
||||
# get file list
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
else:
|
||||
raise ValueError('metadata.csv not found')
|
||||
if opt.instances is not None:
|
||||
with open(opt.instances, 'r') as f:
|
||||
instances = f.read().splitlines()
|
||||
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||
else:
|
||||
if opt.filter_low_aesthetic_score is not None:
|
||||
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||
if f'feature_{feature_name}' in metadata.columns:
|
||||
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
|
||||
metadata = metadata[metadata['voxelized'] == True]
|
||||
metadata = metadata[metadata['rendered'] == True]
|
||||
|
||||
start = len(metadata) * opt.rank // opt.world_size
|
||||
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||
metadata = metadata[start:end]
|
||||
records = []
|
||||
|
||||
# filter out objects that are already processed
|
||||
sha256s = list(metadata['sha256'].values)
|
||||
for sha256 in copy.copy(sha256s):
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
|
||||
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
||||
sha256s.remove(sha256)
|
||||
|
||||
# extract features
|
||||
load_queue = Queue(maxsize=4)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
|
||||
ThreadPoolExecutor(max_workers=8) as saver_executor:
|
||||
def loader(sha256):
|
||||
try:
|
||||
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
|
||||
metadata = json.load(f)
|
||||
frames = metadata['frames']
|
||||
data = []
|
||||
for datum in get_data(frames, sha256):
|
||||
datum['image'] = transform(datum['image'])
|
||||
data.append(datum)
|
||||
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
||||
load_queue.put((sha256, data, positions))
|
||||
except Exception as e:
|
||||
print(f"Error loading data for {sha256}: {e}")
|
||||
|
||||
loader_executor.map(loader, sha256s)
|
||||
|
||||
def saver(sha256, pack, patchtokens, uv):
|
||||
pack['patchtokens'] = F.grid_sample(
|
||||
patchtokens,
|
||||
uv.unsqueeze(1),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
).squeeze(2).permute(0, 2, 1).cpu().numpy()
|
||||
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
|
||||
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
|
||||
np.savez_compressed(save_path, **pack)
|
||||
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
||||
|
||||
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
|
||||
sha256, data, positions = load_queue.get()
|
||||
positions = torch.from_numpy(positions).float().cuda()
|
||||
indices = ((positions + 0.5) * 64).long()
|
||||
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
|
||||
n_views = len(data)
|
||||
N = positions.shape[0]
|
||||
pack = {
|
||||
'indices': indices.cpu().numpy().astype(np.uint8),
|
||||
}
|
||||
patchtokens_lst = []
|
||||
uv_lst = []
|
||||
for i in range(0, n_views, opt.batch_size):
|
||||
batch_data = data[i:i+opt.batch_size]
|
||||
bs = len(batch_data)
|
||||
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
|
||||
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
|
||||
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
|
||||
features = dinov2_model(batch_images, is_training=True)
|
||||
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
|
||||
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
|
||||
patchtokens_lst.append(patchtokens)
|
||||
uv_lst.append(uv)
|
||||
patchtokens = torch.cat(patchtokens_lst, dim=0)
|
||||
uv = torch.cat(uv_lst, dim=0)
|
||||
|
||||
# save features
|
||||
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
|
||||
|
||||
saver_executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
records = pd.DataFrame.from_records(records)
|
||||
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
|
||||
|
||||
Reference in New Issue
Block a user