#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import re import sys import json import time import argparse import mimetypes from pathlib import Path from typing import Any, Dict, Iterator, Tuple from urllib.parse import urlparse import requests BASE_URL = "https://api.tripo3d.ai/v2/openapi" class TripoAPIError(RuntimeError): pass def build_parser(): p = argparse.ArgumentParser("Tripo3D CLI: single image -> 3D") # I/O p.add_argument("-i", "--image", required=True, help="Input image path") p.add_argument("-o", "--out_dir", default="tripo_outputs", help="Output directory") # Auth p.add_argument( "--api_key", default=os.getenv("TRIPO_API_KEY", "tcli_50ecbff125084d4db958b1863ec082e6"), help="Tripo API key, or set env TRIPO_API_KEY", ) # Model p.add_argument( "--model_version", type=str, default="v3.1-20260211", help="Model version, e.g. P1-20260311 / v3.1-20260211 / v3.0-20250812 / v2.5-20250123", ) # Network / polling p.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval (seconds)") p.add_argument("--poll_timeout", type=float, default=1800.0, help="Max polling time (seconds)") p.add_argument("--request_timeout", type=float, default=120.0, help="HTTP request timeout (seconds)") # Generation options p.add_argument("--texture", dest="texture", action="store_true", default=True) p.add_argument("--no-texture", dest="texture", action="store_false") p.add_argument("--pbr", dest="pbr", action="store_true", default=True) p.add_argument("--no-pbr", dest="pbr", action="store_false") p.add_argument( "--texture_quality", type=str, default="standard", choices=["standard", "detailed"], help="Texture quality", ) p.add_argument( "--texture_alignment", type=str, default="original_image", choices=["original_image", "geometry"], help="Texture alignment mode", ) p.add_argument( "--orientation", type=str, default="default", choices=["default", "align_image"], help="Orientation mode", ) # Optional params p.add_argument("--face_limit", type=int, default=None) p.add_argument("--model_seed", type=int, default=None) p.add_argument("--texture_seed", type=int, default=None) p.add_argument("--auto_size", type=str, default=None) p.add_argument("--quad", type=str, default=None) p.add_argument("--compress", type=str, default=None) p.add_argument("--generate_parts", type=str, default=None) p.add_argument("--smart_low_poly", type=str, default=None) # Save / download toggles p.add_argument("--download_outputs", dest="download_outputs", action="store_true", default=True) p.add_argument("--no-download_outputs", dest="download_outputs", action="store_false") p.add_argument("--save_task_json", dest="save_task_json", action="store_true", default=True) p.add_argument("--no-save_task_json", dest="save_task_json", action="store_false") p.add_argument("--print_payload", dest="print_payload", action="store_true", default=False) p.add_argument("--print_output", dest="print_output", action="store_true", default=True) p.add_argument("--no-print_output", dest="print_output", action="store_false") return p def guess_mime_type(file_path: Path) -> str: mime, _ = mimetypes.guess_type(str(file_path)) return mime or "application/octet-stream" def safe_filename(name: str) -> str: name = re.sub(r'[\\/:*?"<>|]+', "_", name) name = re.sub(r"\s+", "_", name).strip("._") return name or "file" def extract_error_message(payload: Any) -> str: if isinstance(payload, dict): for key in ("message", "error", "error_message", "detail"): if payload.get(key): return str(payload[key]) data = payload.get("data") if isinstance(data, dict): for key in ("message", "error", "error_message", "detail"): if data.get(key): return str(data[key]) return json.dumps(payload, ensure_ascii=False)[:800] return str(payload)[:800] def request_json( session: requests.Session, method: str, endpoint: str, request_timeout: float, **kwargs, ) -> Dict[str, Any]: url = f"{BASE_URL}{endpoint}" try: resp = session.request(method=method, url=url, timeout=request_timeout, **kwargs) except requests.RequestException as e: raise TripoAPIError(f"请求失败: {method} {url} | {e}") from e if not resp.ok: try: err_payload = resp.json() except Exception: err_payload = resp.text raise TripoAPIError( f"HTTP {resp.status_code} | {method} {url} | {extract_error_message(err_payload)}" ) try: payload = resp.json() except Exception as e: raise TripoAPIError( f"响应不是合法 JSON: {method} {url}\n原始响应前 500 字符:\n{resp.text[:500]}" ) from e return payload def create_session(api_key: str) -> requests.Session: session = requests.Session() session.headers.update({ "Authorization": f"Bearer {api_key}", "Accept": "application/json", }) return session def upload_image(session: requests.Session, image_path: Path, request_timeout: float) -> str: if not image_path.exists(): raise FileNotFoundError(f"找不到图片: {image_path}") with image_path.open("rb") as f: files = { "file": (image_path.name, f, guess_mime_type(image_path)) } payload = request_json( session, "POST", "/upload", request_timeout=request_timeout, files=files, ) data = payload.get("data") or {} file_token = data.get("image_token") if not file_token: raise TripoAPIError(f"上传成功但未返回 image_token: {json.dumps(payload, ensure_ascii=False)}") return file_token def build_generation_payload(args, file_token: str, image_path: Path) -> Dict[str, Any]: file_ext = image_path.suffix.lower().lstrip(".") or "png" payload: Dict[str, Any] = { "type": "image_to_model", "model_version": args.model_version, "file": { "type": file_ext, "file_token": file_token, }, "texture": args.texture, "pbr": args.pbr, "texture_quality": args.texture_quality, "texture_alignment": args.texture_alignment, "orientation": args.orientation, } optional_fields = [ "face_limit", "model_seed", "texture_seed", "auto_size", "quad", "compress", "generate_parts", "smart_low_poly", ] for key in optional_fields: value = getattr(args, key) if value is not None: payload[key] = value return payload def create_task(session: requests.Session, payload: Dict[str, Any], request_timeout: float) -> str: resp = request_json( session, "POST", "/task", request_timeout=request_timeout, json=payload, ) data = resp.get("data") or {} task_id = data.get("task_id") if not task_id: raise TripoAPIError(f"提交任务成功但未返回 task_id: {json.dumps(resp, ensure_ascii=False)}") return task_id def get_task(session: requests.Session, task_id: str, request_timeout: float) -> Dict[str, Any]: return request_json( session, "GET", f"/task/{task_id}", request_timeout=request_timeout, ) def poll_task( session: requests.Session, task_id: str, poll_interval: float, poll_timeout: float, request_timeout: float, ) -> Dict[str, Any]: start = time.perf_counter() last_line = "" while True: resp = get_task(session, task_id, request_timeout=request_timeout) data = resp.get("data") or {} status = str(data.get("status", "unknown")).lower() progress = data.get("progress", 0) elapsed = time.perf_counter() - start line = f"\r[状态] {status:<10} | [进度] {progress:>3}% | [已等待] {elapsed:>7.1f}s" if line != last_line: sys.stdout.write(line) sys.stdout.flush() last_line = line if status == "success": sys.stdout.write("\n") sys.stdout.flush() return resp if status == "failed": sys.stdout.write("\n") sys.stdout.flush() error_message = data.get("error_message") or extract_error_message(resp) raise TripoAPIError(f"任务失败 | task_id={task_id} | {error_message}") if elapsed > poll_timeout: sys.stdout.write("\n") sys.stdout.flush() raise TimeoutError(f"轮询超时: 已等待 {elapsed:.1f}s,task_id={task_id}") time.sleep(poll_interval) def iter_urls(obj: Any, prefix: str = "output") -> Iterator[Tuple[str, str]]: if isinstance(obj, dict): for k, v in obj.items(): yield from iter_urls(v, f"{prefix}.{k}") elif isinstance(obj, list): for i, v in enumerate(obj): yield from iter_urls(v, f"{prefix}[{i}]") elif isinstance(obj, str) and obj.startswith(("http://", "https://")): yield prefix, obj def infer_extension_from_url(url: str) -> str: path = urlparse(url).path ext = Path(path).suffix return ext if ext else ".bin" def unique_path(path: Path) -> Path: if not path.exists(): return path stem = path.stem suffix = path.suffix parent = path.parent i = 1 while True: candidate = parent / f"{stem}_{i}{suffix}" if not candidate.exists(): return candidate i += 1 def download_file(session: requests.Session, url: str, save_path: Path, request_timeout: float) -> None: try: with session.get(url, stream=True, timeout=request_timeout) as resp: resp.raise_for_status() with save_path.open("wb") as f: for chunk in resp.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) except requests.RequestException as e: raise TripoAPIError(f"下载失败: {url} | {e}") from e def save_outputs( session: requests.Session, task_resp: Dict[str, Any], out_dir: Path, request_timeout: float, save_task_json: bool = True, download_outputs: bool = True, ) -> None: out_dir.mkdir(parents=True, exist_ok=True) data = task_resp.get("data") or {} task_id = data.get("task_id", "unknown_task") output = data.get("output") or {} if save_task_json: meta_path = out_dir / f"{safe_filename(task_id)}.json" with meta_path.open("w", encoding="utf-8") as f: json.dump(task_resp, f, ensure_ascii=False, indent=2) if not output: print("⚠️ 任务成功,但 output 为空。") return if not download_outputs: print("ℹ️ 已跳过下载,仅保存任务响应。") return url_items = list(iter_urls(output)) if not url_items: print("⚠️ output 中没有找到可下载 URL。") return print("\n📥 开始下载输出文件...") for logical_key, url in url_items: short_key = logical_key.replace("output.", "") ext = infer_extension_from_url(url) filename = safe_filename(short_key) + ext save_path = unique_path(out_dir / filename) print(f" - {short_key} -> {save_path}") download_file(session, url, save_path, request_timeout=request_timeout) def main(): args = build_parser().parse_args() if not args.api_key: raise ValueError("请提供 --api_key 或设置环境变量 TRIPO_API_KEY") image_path = Path(args.image) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) session = create_session(args.api_key) print(f"🚀 启动测试 | 模型: {args.model_version}") print(f"🖼️ 输入图片: {image_path}") print(f"📁 输出目录: {out_dir.resolve()}") start_wall_time = time.perf_counter() # 1) 上传 print("\n[1/4] 上传图片...") upload_start = time.perf_counter() file_token = upload_image(session, image_path, request_timeout=args.request_timeout) upload_end = time.perf_counter() print(f"✅ 上传完成 | file_token: {file_token}") print(f"⏱️ 上传耗时: {upload_end - upload_start:.2f}s") # 2) 提交任务 print("\n[2/4] 提交 image_to_model 任务...") payload = build_generation_payload(args, file_token, image_path) if args.print_payload: print(json.dumps(payload, ensure_ascii=False, indent=2)) task_id = create_task(session, payload, request_timeout=args.request_timeout) print(f"✅ 任务提交成功 | task_id: {task_id}") # 3) 轮询任务 print("\n[3/4] 轮询任务状态...") gen_start = time.perf_counter() task_resp = poll_task( session, task_id, poll_interval=args.poll_interval, poll_timeout=args.poll_timeout, request_timeout=args.request_timeout, ) gen_end = time.perf_counter() data = task_resp.get("data") or {} output = data.get("output") or {} total_end = time.perf_counter() print("\n🎉 生成成功") print("=" * 60) print(f"task_id : {task_id}") print(f"纯生成耗时 : {gen_end - gen_start:.2f}s") print(f"总流程耗时 : {total_end - start_wall_time:.2f}s") print(f"最终 status : {data.get('status')}") print(f"output keys : {list(output.keys()) if isinstance(output, dict) else type(output)}") print("=" * 60) if args.print_output: print(json.dumps(output, ensure_ascii=False, indent=2)) # 4) 下载输出 print("\n[4/4] 保存结果...") save_outputs( session, task_resp, out_dir=out_dir, request_timeout=args.request_timeout, save_task_json=args.save_task_json, download_outputs=args.download_outputs, ) print("\n✅ 全部完成。") if __name__ == "__main__": try: main() except Exception as e: print(f"\n❌ 程序终止: {e}") sys.exit(1)