475 lines
14 KiB
Python
475 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import re
|
||
import sys
|
||
import json
|
||
import time
|
||
import argparse
|
||
import mimetypes
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterator, Tuple
|
||
from urllib.parse import urlparse
|
||
|
||
import requests
|
||
|
||
BASE_URL = "https://api.tripo3d.ai/v2/openapi"
|
||
|
||
|
||
class TripoAPIError(RuntimeError):
|
||
pass
|
||
|
||
|
||
def build_parser():
|
||
p = argparse.ArgumentParser("Tripo3D CLI: single image -> 3D")
|
||
|
||
# I/O
|
||
p.add_argument("-i", "--image", required=True, help="Input image path")
|
||
p.add_argument("-o", "--out_dir", default="tripo_outputs", help="Output directory")
|
||
|
||
# Auth
|
||
p.add_argument(
|
||
"--api_key",
|
||
default=os.getenv("TRIPO_API_KEY", "tcli_50ecbff125084d4db958b1863ec082e6"),
|
||
help="Tripo API key, or set env TRIPO_API_KEY",
|
||
)
|
||
|
||
# Model
|
||
p.add_argument(
|
||
"--model_version",
|
||
type=str,
|
||
default="v3.1-20260211",
|
||
help="Model version, e.g. P1-20260311 / v3.1-20260211 / v3.0-20250812 / v2.5-20250123",
|
||
)
|
||
|
||
# Network / polling
|
||
p.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval (seconds)")
|
||
p.add_argument("--poll_timeout", type=float, default=1800.0, help="Max polling time (seconds)")
|
||
p.add_argument("--request_timeout", type=float, default=120.0, help="HTTP request timeout (seconds)")
|
||
|
||
# Generation options
|
||
p.add_argument("--texture", dest="texture", action="store_true", default=True)
|
||
p.add_argument("--no-texture", dest="texture", action="store_false")
|
||
|
||
p.add_argument("--pbr", dest="pbr", action="store_true", default=True)
|
||
p.add_argument("--no-pbr", dest="pbr", action="store_false")
|
||
|
||
p.add_argument(
|
||
"--texture_quality",
|
||
type=str,
|
||
default="standard",
|
||
choices=["standard", "detailed"],
|
||
help="Texture quality",
|
||
)
|
||
p.add_argument(
|
||
"--texture_alignment",
|
||
type=str,
|
||
default="original_image",
|
||
choices=["original_image", "geometry"],
|
||
help="Texture alignment mode",
|
||
)
|
||
p.add_argument(
|
||
"--orientation",
|
||
type=str,
|
||
default="default",
|
||
choices=["default", "align_image"],
|
||
help="Orientation mode",
|
||
)
|
||
|
||
# Optional params
|
||
p.add_argument("--face_limit", type=int, default=None)
|
||
p.add_argument("--model_seed", type=int, default=None)
|
||
p.add_argument("--texture_seed", type=int, default=None)
|
||
p.add_argument("--auto_size", type=str, default=None)
|
||
p.add_argument("--quad", type=str, default=None)
|
||
p.add_argument("--compress", type=str, default=None)
|
||
p.add_argument("--generate_parts", type=str, default=None)
|
||
p.add_argument("--smart_low_poly", type=str, default=None)
|
||
|
||
# Save / download toggles
|
||
p.add_argument("--download_outputs", dest="download_outputs", action="store_true", default=True)
|
||
p.add_argument("--no-download_outputs", dest="download_outputs", action="store_false")
|
||
|
||
p.add_argument("--save_task_json", dest="save_task_json", action="store_true", default=True)
|
||
p.add_argument("--no-save_task_json", dest="save_task_json", action="store_false")
|
||
|
||
p.add_argument("--print_payload", dest="print_payload", action="store_true", default=False)
|
||
p.add_argument("--print_output", dest="print_output", action="store_true", default=True)
|
||
p.add_argument("--no-print_output", dest="print_output", action="store_false")
|
||
|
||
return p
|
||
|
||
|
||
def guess_mime_type(file_path: Path) -> str:
|
||
mime, _ = mimetypes.guess_type(str(file_path))
|
||
return mime or "application/octet-stream"
|
||
|
||
|
||
def safe_filename(name: str) -> str:
|
||
name = re.sub(r'[\\/:*?"<>|]+', "_", name)
|
||
name = re.sub(r"\s+", "_", name).strip("._")
|
||
return name or "file"
|
||
|
||
|
||
def extract_error_message(payload: Any) -> str:
|
||
if isinstance(payload, dict):
|
||
for key in ("message", "error", "error_message", "detail"):
|
||
if payload.get(key):
|
||
return str(payload[key])
|
||
|
||
data = payload.get("data")
|
||
if isinstance(data, dict):
|
||
for key in ("message", "error", "error_message", "detail"):
|
||
if data.get(key):
|
||
return str(data[key])
|
||
|
||
return json.dumps(payload, ensure_ascii=False)[:800]
|
||
|
||
return str(payload)[:800]
|
||
|
||
|
||
def request_json(
|
||
session: requests.Session,
|
||
method: str,
|
||
endpoint: str,
|
||
request_timeout: float,
|
||
**kwargs,
|
||
) -> Dict[str, Any]:
|
||
url = f"{BASE_URL}{endpoint}"
|
||
|
||
try:
|
||
resp = session.request(method=method, url=url, timeout=request_timeout, **kwargs)
|
||
except requests.RequestException as e:
|
||
raise TripoAPIError(f"请求失败: {method} {url} | {e}") from e
|
||
|
||
if not resp.ok:
|
||
try:
|
||
err_payload = resp.json()
|
||
except Exception:
|
||
err_payload = resp.text
|
||
raise TripoAPIError(
|
||
f"HTTP {resp.status_code} | {method} {url} | {extract_error_message(err_payload)}"
|
||
)
|
||
|
||
try:
|
||
payload = resp.json()
|
||
except Exception as e:
|
||
raise TripoAPIError(
|
||
f"响应不是合法 JSON: {method} {url}\n原始响应前 500 字符:\n{resp.text[:500]}"
|
||
) from e
|
||
|
||
return payload
|
||
|
||
|
||
def create_session(api_key: str) -> requests.Session:
|
||
session = requests.Session()
|
||
session.headers.update({
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Accept": "application/json",
|
||
})
|
||
return session
|
||
|
||
|
||
def upload_image(session: requests.Session, image_path: Path, request_timeout: float) -> str:
|
||
if not image_path.exists():
|
||
raise FileNotFoundError(f"找不到图片: {image_path}")
|
||
|
||
with image_path.open("rb") as f:
|
||
files = {
|
||
"file": (image_path.name, f, guess_mime_type(image_path))
|
||
}
|
||
payload = request_json(
|
||
session,
|
||
"POST",
|
||
"/upload",
|
||
request_timeout=request_timeout,
|
||
files=files,
|
||
)
|
||
|
||
data = payload.get("data") or {}
|
||
file_token = data.get("image_token")
|
||
|
||
if not file_token:
|
||
raise TripoAPIError(f"上传成功但未返回 image_token: {json.dumps(payload, ensure_ascii=False)}")
|
||
|
||
return file_token
|
||
|
||
|
||
def build_generation_payload(args, file_token: str, image_path: Path) -> Dict[str, Any]:
|
||
file_ext = image_path.suffix.lower().lstrip(".") or "png"
|
||
|
||
payload: Dict[str, Any] = {
|
||
"type": "image_to_model",
|
||
"model_version": args.model_version,
|
||
"file": {
|
||
"type": file_ext,
|
||
"file_token": file_token,
|
||
},
|
||
"texture": args.texture,
|
||
"pbr": args.pbr,
|
||
"texture_quality": args.texture_quality,
|
||
"texture_alignment": args.texture_alignment,
|
||
"orientation": args.orientation,
|
||
}
|
||
|
||
optional_fields = [
|
||
"face_limit",
|
||
"model_seed",
|
||
"texture_seed",
|
||
"auto_size",
|
||
"quad",
|
||
"compress",
|
||
"generate_parts",
|
||
"smart_low_poly",
|
||
]
|
||
|
||
for key in optional_fields:
|
||
value = getattr(args, key)
|
||
if value is not None:
|
||
payload[key] = value
|
||
|
||
return payload
|
||
|
||
|
||
def create_task(session: requests.Session, payload: Dict[str, Any], request_timeout: float) -> str:
|
||
resp = request_json(
|
||
session,
|
||
"POST",
|
||
"/task",
|
||
request_timeout=request_timeout,
|
||
json=payload,
|
||
)
|
||
data = resp.get("data") or {}
|
||
task_id = data.get("task_id")
|
||
|
||
if not task_id:
|
||
raise TripoAPIError(f"提交任务成功但未返回 task_id: {json.dumps(resp, ensure_ascii=False)}")
|
||
|
||
return task_id
|
||
|
||
|
||
def get_task(session: requests.Session, task_id: str, request_timeout: float) -> Dict[str, Any]:
|
||
return request_json(
|
||
session,
|
||
"GET",
|
||
f"/task/{task_id}",
|
||
request_timeout=request_timeout,
|
||
)
|
||
|
||
|
||
def poll_task(
|
||
session: requests.Session,
|
||
task_id: str,
|
||
poll_interval: float,
|
||
poll_timeout: float,
|
||
request_timeout: float,
|
||
) -> Dict[str, Any]:
|
||
start = time.perf_counter()
|
||
last_line = ""
|
||
|
||
while True:
|
||
resp = get_task(session, task_id, request_timeout=request_timeout)
|
||
data = resp.get("data") or {}
|
||
|
||
status = str(data.get("status", "unknown")).lower()
|
||
progress = data.get("progress", 0)
|
||
elapsed = time.perf_counter() - start
|
||
|
||
line = f"\r[状态] {status:<10} | [进度] {progress:>3}% | [已等待] {elapsed:>7.1f}s"
|
||
if line != last_line:
|
||
sys.stdout.write(line)
|
||
sys.stdout.flush()
|
||
last_line = line
|
||
|
||
if status == "success":
|
||
sys.stdout.write("\n")
|
||
sys.stdout.flush()
|
||
return resp
|
||
|
||
if status == "failed":
|
||
sys.stdout.write("\n")
|
||
sys.stdout.flush()
|
||
error_message = data.get("error_message") or extract_error_message(resp)
|
||
raise TripoAPIError(f"任务失败 | task_id={task_id} | {error_message}")
|
||
|
||
if elapsed > poll_timeout:
|
||
sys.stdout.write("\n")
|
||
sys.stdout.flush()
|
||
raise TimeoutError(f"轮询超时: 已等待 {elapsed:.1f}s,task_id={task_id}")
|
||
|
||
time.sleep(poll_interval)
|
||
|
||
|
||
def iter_urls(obj: Any, prefix: str = "output") -> Iterator[Tuple[str, str]]:
|
||
if isinstance(obj, dict):
|
||
for k, v in obj.items():
|
||
yield from iter_urls(v, f"{prefix}.{k}")
|
||
elif isinstance(obj, list):
|
||
for i, v in enumerate(obj):
|
||
yield from iter_urls(v, f"{prefix}[{i}]")
|
||
elif isinstance(obj, str) and obj.startswith(("http://", "https://")):
|
||
yield prefix, obj
|
||
|
||
|
||
def infer_extension_from_url(url: str) -> str:
|
||
path = urlparse(url).path
|
||
ext = Path(path).suffix
|
||
return ext if ext else ".bin"
|
||
|
||
|
||
def unique_path(path: Path) -> Path:
|
||
if not path.exists():
|
||
return path
|
||
|
||
stem = path.stem
|
||
suffix = path.suffix
|
||
parent = path.parent
|
||
i = 1
|
||
while True:
|
||
candidate = parent / f"{stem}_{i}{suffix}"
|
||
if not candidate.exists():
|
||
return candidate
|
||
i += 1
|
||
|
||
|
||
def download_file(session: requests.Session, url: str, save_path: Path, request_timeout: float) -> None:
|
||
try:
|
||
with session.get(url, stream=True, timeout=request_timeout) as resp:
|
||
resp.raise_for_status()
|
||
with save_path.open("wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=1024 * 1024):
|
||
if chunk:
|
||
f.write(chunk)
|
||
except requests.RequestException as e:
|
||
raise TripoAPIError(f"下载失败: {url} | {e}") from e
|
||
|
||
|
||
def save_outputs(
|
||
session: requests.Session,
|
||
task_resp: Dict[str, Any],
|
||
out_dir: Path,
|
||
request_timeout: float,
|
||
save_task_json: bool = True,
|
||
download_outputs: bool = True,
|
||
) -> None:
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
data = task_resp.get("data") or {}
|
||
task_id = data.get("task_id", "unknown_task")
|
||
output = data.get("output") or {}
|
||
|
||
if save_task_json:
|
||
meta_path = out_dir / f"{safe_filename(task_id)}.json"
|
||
with meta_path.open("w", encoding="utf-8") as f:
|
||
json.dump(task_resp, f, ensure_ascii=False, indent=2)
|
||
|
||
if not output:
|
||
print("⚠️ 任务成功,但 output 为空。")
|
||
return
|
||
|
||
if not download_outputs:
|
||
print("ℹ️ 已跳过下载,仅保存任务响应。")
|
||
return
|
||
|
||
url_items = list(iter_urls(output))
|
||
if not url_items:
|
||
print("⚠️ output 中没有找到可下载 URL。")
|
||
return
|
||
|
||
print("\n📥 开始下载输出文件...")
|
||
for logical_key, url in url_items:
|
||
short_key = logical_key.replace("output.", "")
|
||
ext = infer_extension_from_url(url)
|
||
filename = safe_filename(short_key) + ext
|
||
save_path = unique_path(out_dir / filename)
|
||
|
||
print(f" - {short_key} -> {save_path}")
|
||
download_file(session, url, save_path, request_timeout=request_timeout)
|
||
|
||
|
||
def main():
|
||
args = build_parser().parse_args()
|
||
|
||
if not args.api_key:
|
||
raise ValueError("请提供 --api_key 或设置环境变量 TRIPO_API_KEY")
|
||
|
||
image_path = Path(args.image)
|
||
out_dir = Path(args.out_dir)
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
session = create_session(args.api_key)
|
||
|
||
print(f"🚀 启动测试 | 模型: {args.model_version}")
|
||
print(f"🖼️ 输入图片: {image_path}")
|
||
print(f"📁 输出目录: {out_dir.resolve()}")
|
||
|
||
start_wall_time = time.perf_counter()
|
||
|
||
# 1) 上传
|
||
print("\n[1/4] 上传图片...")
|
||
upload_start = time.perf_counter()
|
||
file_token = upload_image(session, image_path, request_timeout=args.request_timeout)
|
||
upload_end = time.perf_counter()
|
||
print(f"✅ 上传完成 | file_token: {file_token}")
|
||
print(f"⏱️ 上传耗时: {upload_end - upload_start:.2f}s")
|
||
|
||
# 2) 提交任务
|
||
print("\n[2/4] 提交 image_to_model 任务...")
|
||
payload = build_generation_payload(args, file_token, image_path)
|
||
|
||
if args.print_payload:
|
||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||
|
||
task_id = create_task(session, payload, request_timeout=args.request_timeout)
|
||
print(f"✅ 任务提交成功 | task_id: {task_id}")
|
||
|
||
# 3) 轮询任务
|
||
print("\n[3/4] 轮询任务状态...")
|
||
gen_start = time.perf_counter()
|
||
task_resp = poll_task(
|
||
session,
|
||
task_id,
|
||
poll_interval=args.poll_interval,
|
||
poll_timeout=args.poll_timeout,
|
||
request_timeout=args.request_timeout,
|
||
)
|
||
gen_end = time.perf_counter()
|
||
|
||
data = task_resp.get("data") or {}
|
||
output = data.get("output") or {}
|
||
total_end = time.perf_counter()
|
||
|
||
print("\n🎉 生成成功")
|
||
print("=" * 60)
|
||
print(f"task_id : {task_id}")
|
||
print(f"纯生成耗时 : {gen_end - gen_start:.2f}s")
|
||
print(f"总流程耗时 : {total_end - start_wall_time:.2f}s")
|
||
print(f"最终 status : {data.get('status')}")
|
||
print(f"output keys : {list(output.keys()) if isinstance(output, dict) else type(output)}")
|
||
print("=" * 60)
|
||
|
||
if args.print_output:
|
||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||
|
||
# 4) 下载输出
|
||
print("\n[4/4] 保存结果...")
|
||
save_outputs(
|
||
session,
|
||
task_resp,
|
||
out_dir=out_dir,
|
||
request_timeout=args.request_timeout,
|
||
save_task_json=args.save_task_json,
|
||
download_outputs=args.download_outputs,
|
||
)
|
||
|
||
print("\n✅ 全部完成。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except Exception as e:
|
||
print(f"\n❌ 程序终止: {e}")
|
||
sys.exit(1)
|