1
This commit is contained in:
285
dataset_toolkits/build_metadata.py
Normal file
285
dataset_toolkits/build_metadata.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import importlib
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import utils3d
|
||||
|
||||
def get_first_directory(path):
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir():
|
||||
return entry.name
|
||||
return None
|
||||
|
||||
def need_process(key):
|
||||
return key in opt.field or opt.field == ['all']
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
help='Directory to save the metadata')
|
||||
parser.add_argument('--field', type=str, default='all',
|
||||
help='Fields to process, separated by commas')
|
||||
parser.add_argument('--from_file', action='store_true',
|
||||
help='Build metadata from file instead of from records of processings.' +
|
||||
'Useful when some processing fail to generate records but file already exists.')
|
||||
dataset_utils.add_args(parser)
|
||||
opt = parser.parse_args(sys.argv[2:])
|
||||
opt = edict(vars(opt))
|
||||
|
||||
os.makedirs(opt.output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
|
||||
|
||||
opt.field = opt.field.split(',')
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
# get file list
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||
print('Loading previous metadata...')
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
else:
|
||||
metadata = dataset_utils.get_metadata(**opt)
|
||||
metadata.set_index('sha256', inplace=True)
|
||||
|
||||
# merge downloaded
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
if 'local_path' in metadata.columns:
|
||||
metadata.update(df, overwrite=True)
|
||||
else:
|
||||
metadata = metadata.join(df, on='sha256', how='left')
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# detect models
|
||||
image_models = []
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'features')):
|
||||
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
|
||||
latent_models = []
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
|
||||
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
|
||||
ss_latent_models = []
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
|
||||
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
|
||||
print(f'Image models: {image_models}')
|
||||
print(f'Latent models: {latent_models}')
|
||||
print(f'Sparse Structure latent models: {ss_latent_models}')
|
||||
|
||||
if 'rendered' not in metadata.columns:
|
||||
metadata['rendered'] = [False] * len(metadata)
|
||||
if 'voxelized' not in metadata.columns:
|
||||
metadata['voxelized'] = [False] * len(metadata)
|
||||
if 'num_voxels' not in metadata.columns:
|
||||
metadata['num_voxels'] = [0] * len(metadata)
|
||||
if 'cond_rendered' not in metadata.columns:
|
||||
metadata['cond_rendered'] = [False] * len(metadata)
|
||||
for model in image_models:
|
||||
if f'feature_{model}' not in metadata.columns:
|
||||
metadata[f'feature_{model}'] = [False] * len(metadata)
|
||||
for model in latent_models:
|
||||
if f'latent_{model}' not in metadata.columns:
|
||||
metadata[f'latent_{model}'] = [False] * len(metadata)
|
||||
for model in ss_latent_models:
|
||||
if f'ss_latent_{model}' not in metadata.columns:
|
||||
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
|
||||
|
||||
# merge rendered
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge aesthetic scores
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('aesthetic_scores_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge voxelized
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge cond_rendered
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge features
|
||||
for model in image_models:
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge latents
|
||||
for model in latent_models:
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# merge sparse structure latents
|
||||
for model in ss_latent_models:
|
||||
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
|
||||
df_parts = []
|
||||
for f in df_files:
|
||||
try:
|
||||
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
||||
except:
|
||||
pass
|
||||
if len(df_parts) > 0:
|
||||
df = pd.concat(df_parts)
|
||||
df.set_index('sha256', inplace=True)
|
||||
metadata.update(df, overwrite=True)
|
||||
for f in df_files:
|
||||
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
||||
|
||||
# build metadata from files
|
||||
if opt.from_file:
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(metadata), desc="Building metadata") as pbar:
|
||||
def worker(sha256):
|
||||
try:
|
||||
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
||||
metadata.loc[sha256, 'rendered'] = True
|
||||
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
||||
try:
|
||||
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
||||
metadata.loc[sha256, 'voxelized'] = True
|
||||
metadata.loc[sha256, 'num_voxels'] = len(pts)
|
||||
except Exception as e:
|
||||
pass
|
||||
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
||||
metadata.loc[sha256, 'cond_rendered'] = True
|
||||
for model in image_models:
|
||||
if need_process(f'feature_{model}') and \
|
||||
metadata.loc[sha256, f'feature_{model}'] == False and \
|
||||
metadata.loc[sha256, 'rendered'] == True and \
|
||||
metadata.loc[sha256, 'voxelized'] == True and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
|
||||
metadata.loc[sha256, f'feature_{model}'] = True
|
||||
for model in latent_models:
|
||||
if need_process(f'latent_{model}') and \
|
||||
metadata.loc[sha256, f'latent_{model}'] == False and \
|
||||
metadata.loc[sha256, 'rendered'] == True and \
|
||||
metadata.loc[sha256, 'voxelized'] == True and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
|
||||
metadata.loc[sha256, f'latent_{model}'] = True
|
||||
for model in ss_latent_models:
|
||||
if need_process(f'ss_latent_{model}') and \
|
||||
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
|
||||
metadata.loc[sha256, 'voxelized'] == True and \
|
||||
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
|
||||
metadata.loc[sha256, f'ss_latent_{model}'] = True
|
||||
pbar.update()
|
||||
except Exception as e:
|
||||
print(f'Error processing {sha256}: {e}')
|
||||
pbar.update()
|
||||
|
||||
executor.map(worker, metadata.index)
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
# statistics
|
||||
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
|
||||
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
|
||||
f.write('Statistics:\n')
|
||||
f.write(f' - Number of assets: {len(metadata)}\n')
|
||||
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
||||
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
|
||||
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
|
||||
if len(image_models) != 0:
|
||||
f.write(f' - Number of assets with image features extracted:\n')
|
||||
for model in image_models:
|
||||
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
|
||||
if len(latent_models) != 0:
|
||||
f.write(f' - Number of assets with latents extracted:\n')
|
||||
for model in latent_models:
|
||||
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
|
||||
if len(ss_latent_models) != 0:
|
||||
f.write(f' - Number of assets with sparse structure latents extracted:\n')
|
||||
for model in ss_latent_models:
|
||||
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
|
||||
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
|
||||
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
|
||||
|
||||
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
|
||||
print(f.read())
|
||||
102
dataset_toolkits/calculate_aesthetic_scores.py
Normal file
102
dataset_toolkits/calculate_aesthetic_scores.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
import open_clip
|
||||
from os.path import expanduser
|
||||
from urllib.request import urlretrieve
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
|
||||
|
||||
def get_aesthetic_model(clip_model="vit_l_14"):
|
||||
"""load the aethetic model"""
|
||||
home = expanduser("~")
|
||||
cache_folder = home + "/.cache/emb_reader"
|
||||
path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
|
||||
if not os.path.exists(path_to_model):
|
||||
os.makedirs(cache_folder, exist_ok=True)
|
||||
url_model = (
|
||||
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
|
||||
)
|
||||
urlretrieve(url_model, path_to_model)
|
||||
if clip_model == "vit_l_14":
|
||||
m = nn.Linear(768, 1)
|
||||
elif clip_model == "vit_b_32":
|
||||
m = nn.Linear(512, 1)
|
||||
else:
|
||||
raise ValueError()
|
||||
s = torch.load(path_to_model)
|
||||
m.load_state_dict(s)
|
||||
m.eval()
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--clip_model", type=str, default="vit_l_14")
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument("--rank", type=int, default=0)
|
||||
parser.add_argument("--world_size", type=int, default=1)
|
||||
opt = parser.parse_args()
|
||||
|
||||
amodel = get_aesthetic_model(clip_model="vit_l_14")
|
||||
amodel.eval()
|
||||
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
|
||||
model = model.cuda()
|
||||
amodel = amodel.cuda()
|
||||
|
||||
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||
metadata = metadata[metadata['snapshotted'] == 1]
|
||||
sha256s = metadata['sha256'].values
|
||||
|
||||
# filter out objects that are already calculated
|
||||
if os.path.exists(os.path.join(opt.output_dir, 'aesthetic_scores.csv')):
|
||||
with open(os.path.join(opt.output_dir, 'aesthetic_scores.csv'), 'r') as f:
|
||||
old_metadata = pd.read_csv(f)
|
||||
sha256s = list(set(sha256s) - set(old_metadata['sha256'].values))
|
||||
|
||||
sha256s = sorted(sha256s)
|
||||
sha256s = sha256s[len(sha256s) * opt.rank // opt.world_size: len(sha256s) * (opt.rank + 1) // opt.world_size]
|
||||
|
||||
rows = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
||||
finished = Queue(maxsize=128)
|
||||
|
||||
def load_image(sha256):
|
||||
try:
|
||||
files = os.listdir(os.path.join(opt.output_dir, 'snapshots', sha256))
|
||||
files = [f for f in files if f.endswith('.png')]
|
||||
processed = []
|
||||
for file in files:
|
||||
image = Image.open(os.path.join(opt.output_dir, 'snapshots', sha256, file))
|
||||
processed.append(preprocess(image))
|
||||
processed = torch.stack(processed, dim=0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
processed = None
|
||||
finished.put((sha256, processed))
|
||||
|
||||
executor.map(load_image, sha256s)
|
||||
for _ in tqdm(range(len(sha256s)), desc='Calculating aesthetic scores'):
|
||||
sha256, processed = finished.get()
|
||||
if processed is not None:
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(processed.cuda())
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
aesthetic_score = amodel(image_features).cpu()
|
||||
rows.append(pd.DataFrame({
|
||||
'sha256': [sha256],
|
||||
'mean': [aesthetic_score.mean().item()],
|
||||
'std': [aesthetic_score.std().item()],
|
||||
'min': [aesthetic_score.min().item()],
|
||||
'max': [aesthetic_score.max().item()],
|
||||
'median': [aesthetic_score.median().item()]
|
||||
}))
|
||||
|
||||
with open(os.path.join(opt.output_dir, f'aesthetic_scores_{opt.rank}.csv'), 'w') as f:
|
||||
pd.concat(rows).to_csv(f, index=False)
|
||||
97
dataset_toolkits/datasets/3D-FUTURE.py
Normal file
97
dataset_toolkits/datasets/3D-FUTURE.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from utils import get_file_hash
|
||||
|
||||
|
||||
def add_args(parser: argparse.ArgumentParser):
|
||||
pass
|
||||
|
||||
|
||||
def get_metadata(**kwargs):
|
||||
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
|
||||
return metadata
|
||||
|
||||
|
||||
def download(metadata, output_dir, **kwargs):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
|
||||
print("\033[93m")
|
||||
print("3D-FUTURE have to be downloaded manually")
|
||||
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
|
||||
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
|
||||
print("\033[0m")
|
||||
raise FileNotFoundError("3D-FUTURE-model.zip not found")
|
||||
|
||||
downloaded = {}
|
||||
metadata = metadata.set_index("file_identifier")
|
||||
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
|
||||
all_names = zip_ref.namelist()
|
||||
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
|
||||
instances = list(filter(lambda x: x in metadata.index, instances))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
||||
tqdm(total=len(instances), desc="Extracting") as pbar:
|
||||
def worker(instance: str) -> str:
|
||||
try:
|
||||
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
|
||||
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
|
||||
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
|
||||
pbar.update()
|
||||
return sha256
|
||||
except Exception as e:
|
||||
pbar.update()
|
||||
print(f"Error extracting for {instance}: {e}")
|
||||
return None
|
||||
|
||||
sha256s = executor.map(worker, instances)
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
for k, sha256 in zip(instances, sha256s):
|
||||
if sha256 is not None:
|
||||
if sha256 == metadata.loc[k, "sha256"]:
|
||||
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
|
||||
else:
|
||||
print(f"Error downloading {k}: sha256s do not match")
|
||||
|
||||
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||
|
||||
|
||||
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
# load metadata
|
||||
metadata = metadata.to_dict('records')
|
||||
|
||||
# processing objects
|
||||
records = []
|
||||
max_workers = max_workers or os.cpu_count()
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||
def worker(metadatum):
|
||||
try:
|
||||
local_path = metadatum['local_path']
|
||||
sha256 = metadatum['sha256']
|
||||
file = os.path.join(output_dir, local_path)
|
||||
record = func(file, sha256)
|
||||
if record is not None:
|
||||
records.append(record)
|
||||
pbar.update()
|
||||
except Exception as e:
|
||||
print(f"Error processing object {sha256}: {e}")
|
||||
pbar.update()
|
||||
|
||||
executor.map(worker, metadata)
|
||||
executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
return pd.DataFrame.from_records(records)
|
||||
96
dataset_toolkits/datasets/ABO.py
Normal file
96
dataset_toolkits/datasets/ABO.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import tarfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from utils import get_file_hash
|
||||
|
||||
|
||||
def add_args(parser: argparse.ArgumentParser):
|
||||
pass
|
||||
|
||||
|
||||
def get_metadata(**kwargs):
|
||||
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
|
||||
return metadata
|
||||
|
||||
|
||||
def download(metadata, output_dir, **kwargs):
|
||||
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
|
||||
try:
|
||||
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
|
||||
except:
|
||||
print("\033[93m")
|
||||
print("Error downloading ABO dataset. Please check your internet connection and try again.")
|
||||
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
|
||||
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
|
||||
print("\033[0m")
|
||||
raise FileNotFoundError("Error downloading ABO dataset")
|
||||
|
||||
downloaded = {}
|
||||
metadata = metadata.set_index("file_identifier")
|
||||
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor, \
|
||||
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
||||
def worker(instance: str) -> str:
|
||||
try:
|
||||
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
|
||||
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
|
||||
pbar.update()
|
||||
return sha256
|
||||
except Exception as e:
|
||||
pbar.update()
|
||||
print(f"Error extracting for {instance}: {e}")
|
||||
return None
|
||||
|
||||
sha256s = executor.map(worker, metadata.index)
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
for k, sha256 in zip(metadata.index, sha256s):
|
||||
if sha256 is not None:
|
||||
if sha256 == metadata.loc[k, "sha256"]:
|
||||
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
|
||||
else:
|
||||
print(f"Error downloading {k}: sha256s do not match")
|
||||
|
||||
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||
|
||||
|
||||
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
# load metadata
|
||||
metadata = metadata.to_dict('records')
|
||||
|
||||
# processing objects
|
||||
records = []
|
||||
max_workers = max_workers or os.cpu_count()
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||
def worker(metadatum):
|
||||
try:
|
||||
local_path = metadatum['local_path']
|
||||
sha256 = metadatum['sha256']
|
||||
file = os.path.join(output_dir, local_path)
|
||||
record = func(file, sha256)
|
||||
if record is not None:
|
||||
records.append(record)
|
||||
pbar.update()
|
||||
except Exception as e:
|
||||
print(f"Error processing object {sha256}: {e}")
|
||||
pbar.update()
|
||||
|
||||
executor.map(worker, metadata)
|
||||
executor.shutdown(wait=True)
|
||||
except:
|
||||
print("Error happened during processing.")
|
||||
|
||||
return pd.DataFrame.from_records(records)
|
||||
Reference in New Issue
Block a user