Files
FiDA_Python/src/server/canvas_generate_3D/triop3d_api.py
2026-04-13 12:11:34 +08:00

475 lines
14 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
import sys
import json
import time
import argparse
import mimetypes
from pathlib import Path
from typing import Any, Dict, Iterator, Tuple
from urllib.parse import urlparse
import requests
BASE_URL = "https://api.tripo3d.ai/v2/openapi"
class TripoAPIError(RuntimeError):
pass
def build_parser():
p = argparse.ArgumentParser("Tripo3D CLI: single image -> 3D")
# I/O
p.add_argument("-i", "--image", required=True, help="Input image path")
p.add_argument("-o", "--out_dir", default="tripo_outputs", help="Output directory")
# Auth
p.add_argument(
"--api_key",
default=os.getenv("TRIPO_API_KEY", "tcli_50ecbff125084d4db958b1863ec082e6"),
help="Tripo API key, or set env TRIPO_API_KEY",
)
# Model
p.add_argument(
"--model_version",
type=str,
default="v3.1-20260211",
help="Model version, e.g. P1-20260311 / v3.1-20260211 / v3.0-20250812 / v2.5-20250123",
)
# Network / polling
p.add_argument("--poll_interval", type=float, default=2.0, help="Polling interval (seconds)")
p.add_argument("--poll_timeout", type=float, default=1800.0, help="Max polling time (seconds)")
p.add_argument("--request_timeout", type=float, default=120.0, help="HTTP request timeout (seconds)")
# Generation options
p.add_argument("--texture", dest="texture", action="store_true", default=True)
p.add_argument("--no-texture", dest="texture", action="store_false")
p.add_argument("--pbr", dest="pbr", action="store_true", default=True)
p.add_argument("--no-pbr", dest="pbr", action="store_false")
p.add_argument(
"--texture_quality",
type=str,
default="standard",
choices=["standard", "detailed"],
help="Texture quality",
)
p.add_argument(
"--texture_alignment",
type=str,
default="original_image",
choices=["original_image", "geometry"],
help="Texture alignment mode",
)
p.add_argument(
"--orientation",
type=str,
default="default",
choices=["default", "align_image"],
help="Orientation mode",
)
# Optional params
p.add_argument("--face_limit", type=int, default=None)
p.add_argument("--model_seed", type=int, default=None)
p.add_argument("--texture_seed", type=int, default=None)
p.add_argument("--auto_size", type=str, default=None)
p.add_argument("--quad", type=str, default=None)
p.add_argument("--compress", type=str, default=None)
p.add_argument("--generate_parts", type=str, default=None)
p.add_argument("--smart_low_poly", type=str, default=None)
# Save / download toggles
p.add_argument("--download_outputs", dest="download_outputs", action="store_true", default=True)
p.add_argument("--no-download_outputs", dest="download_outputs", action="store_false")
p.add_argument("--save_task_json", dest="save_task_json", action="store_true", default=True)
p.add_argument("--no-save_task_json", dest="save_task_json", action="store_false")
p.add_argument("--print_payload", dest="print_payload", action="store_true", default=False)
p.add_argument("--print_output", dest="print_output", action="store_true", default=True)
p.add_argument("--no-print_output", dest="print_output", action="store_false")
return p
def guess_mime_type(file_path: Path) -> str:
mime, _ = mimetypes.guess_type(str(file_path))
return mime or "application/octet-stream"
def safe_filename(name: str) -> str:
name = re.sub(r'[\\/:*?"<>|]+', "_", name)
name = re.sub(r"\s+", "_", name).strip("._")
return name or "file"
def extract_error_message(payload: Any) -> str:
if isinstance(payload, dict):
for key in ("message", "error", "error_message", "detail"):
if payload.get(key):
return str(payload[key])
data = payload.get("data")
if isinstance(data, dict):
for key in ("message", "error", "error_message", "detail"):
if data.get(key):
return str(data[key])
return json.dumps(payload, ensure_ascii=False)[:800]
return str(payload)[:800]
def request_json(
session: requests.Session,
method: str,
endpoint: str,
request_timeout: float,
**kwargs,
) -> Dict[str, Any]:
url = f"{BASE_URL}{endpoint}"
try:
resp = session.request(method=method, url=url, timeout=request_timeout, **kwargs)
except requests.RequestException as e:
raise TripoAPIError(f"请求失败: {method} {url} | {e}") from e
if not resp.ok:
try:
err_payload = resp.json()
except Exception:
err_payload = resp.text
raise TripoAPIError(
f"HTTP {resp.status_code} | {method} {url} | {extract_error_message(err_payload)}"
)
try:
payload = resp.json()
except Exception as e:
raise TripoAPIError(
f"响应不是合法 JSON: {method} {url}\n原始响应前 500 字符:\n{resp.text[:500]}"
) from e
return payload
def create_session(api_key: str) -> requests.Session:
session = requests.Session()
session.headers.update({
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
})
return session
def upload_image(session: requests.Session, image_path: Path, request_timeout: float) -> str:
if not image_path.exists():
raise FileNotFoundError(f"找不到图片: {image_path}")
with image_path.open("rb") as f:
files = {
"file": (image_path.name, f, guess_mime_type(image_path))
}
payload = request_json(
session,
"POST",
"/upload",
request_timeout=request_timeout,
files=files,
)
data = payload.get("data") or {}
file_token = data.get("image_token")
if not file_token:
raise TripoAPIError(f"上传成功但未返回 image_token: {json.dumps(payload, ensure_ascii=False)}")
return file_token
def build_generation_payload(args, file_token: str, image_path: Path) -> Dict[str, Any]:
file_ext = image_path.suffix.lower().lstrip(".") or "png"
payload: Dict[str, Any] = {
"type": "image_to_model",
"model_version": args.model_version,
"file": {
"type": file_ext,
"file_token": file_token,
},
"texture": args.texture,
"pbr": args.pbr,
"texture_quality": args.texture_quality,
"texture_alignment": args.texture_alignment,
"orientation": args.orientation,
}
optional_fields = [
"face_limit",
"model_seed",
"texture_seed",
"auto_size",
"quad",
"compress",
"generate_parts",
"smart_low_poly",
]
for key in optional_fields:
value = getattr(args, key)
if value is not None:
payload[key] = value
return payload
def create_task(session: requests.Session, payload: Dict[str, Any], request_timeout: float) -> str:
resp = request_json(
session,
"POST",
"/task",
request_timeout=request_timeout,
json=payload,
)
data = resp.get("data") or {}
task_id = data.get("task_id")
if not task_id:
raise TripoAPIError(f"提交任务成功但未返回 task_id: {json.dumps(resp, ensure_ascii=False)}")
return task_id
def get_task(session: requests.Session, task_id: str, request_timeout: float) -> Dict[str, Any]:
return request_json(
session,
"GET",
f"/task/{task_id}",
request_timeout=request_timeout,
)
def poll_task(
session: requests.Session,
task_id: str,
poll_interval: float,
poll_timeout: float,
request_timeout: float,
) -> Dict[str, Any]:
start = time.perf_counter()
last_line = ""
while True:
resp = get_task(session, task_id, request_timeout=request_timeout)
data = resp.get("data") or {}
status = str(data.get("status", "unknown")).lower()
progress = data.get("progress", 0)
elapsed = time.perf_counter() - start
line = f"\r[状态] {status:<10} | [进度] {progress:>3}% | [已等待] {elapsed:>7.1f}s"
if line != last_line:
sys.stdout.write(line)
sys.stdout.flush()
last_line = line
if status == "success":
sys.stdout.write("\n")
sys.stdout.flush()
return resp
if status == "failed":
sys.stdout.write("\n")
sys.stdout.flush()
error_message = data.get("error_message") or extract_error_message(resp)
raise TripoAPIError(f"任务失败 | task_id={task_id} | {error_message}")
if elapsed > poll_timeout:
sys.stdout.write("\n")
sys.stdout.flush()
raise TimeoutError(f"轮询超时: 已等待 {elapsed:.1f}stask_id={task_id}")
time.sleep(poll_interval)
def iter_urls(obj: Any, prefix: str = "output") -> Iterator[Tuple[str, str]]:
if isinstance(obj, dict):
for k, v in obj.items():
yield from iter_urls(v, f"{prefix}.{k}")
elif isinstance(obj, list):
for i, v in enumerate(obj):
yield from iter_urls(v, f"{prefix}[{i}]")
elif isinstance(obj, str) and obj.startswith(("http://", "https://")):
yield prefix, obj
def infer_extension_from_url(url: str) -> str:
path = urlparse(url).path
ext = Path(path).suffix
return ext if ext else ".bin"
def unique_path(path: Path) -> Path:
if not path.exists():
return path
stem = path.stem
suffix = path.suffix
parent = path.parent
i = 1
while True:
candidate = parent / f"{stem}_{i}{suffix}"
if not candidate.exists():
return candidate
i += 1
def download_file(session: requests.Session, url: str, save_path: Path, request_timeout: float) -> None:
try:
with session.get(url, stream=True, timeout=request_timeout) as resp:
resp.raise_for_status()
with save_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
except requests.RequestException as e:
raise TripoAPIError(f"下载失败: {url} | {e}") from e
def save_outputs(
session: requests.Session,
task_resp: Dict[str, Any],
out_dir: Path,
request_timeout: float,
save_task_json: bool = True,
download_outputs: bool = True,
) -> None:
out_dir.mkdir(parents=True, exist_ok=True)
data = task_resp.get("data") or {}
task_id = data.get("task_id", "unknown_task")
output = data.get("output") or {}
if save_task_json:
meta_path = out_dir / f"{safe_filename(task_id)}.json"
with meta_path.open("w", encoding="utf-8") as f:
json.dump(task_resp, f, ensure_ascii=False, indent=2)
if not output:
print("⚠️ 任务成功,但 output 为空。")
return
if not download_outputs:
print(" 已跳过下载,仅保存任务响应。")
return
url_items = list(iter_urls(output))
if not url_items:
print("⚠️ output 中没有找到可下载 URL。")
return
print("\n📥 开始下载输出文件...")
for logical_key, url in url_items:
short_key = logical_key.replace("output.", "")
ext = infer_extension_from_url(url)
filename = safe_filename(short_key) + ext
save_path = unique_path(out_dir / filename)
print(f" - {short_key} -> {save_path}")
download_file(session, url, save_path, request_timeout=request_timeout)
def main():
args = build_parser().parse_args()
if not args.api_key:
raise ValueError("请提供 --api_key 或设置环境变量 TRIPO_API_KEY")
image_path = Path(args.image)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
session = create_session(args.api_key)
print(f"🚀 启动测试 | 模型: {args.model_version}")
print(f"🖼️ 输入图片: {image_path}")
print(f"📁 输出目录: {out_dir.resolve()}")
start_wall_time = time.perf_counter()
# 1) 上传
print("\n[1/4] 上传图片...")
upload_start = time.perf_counter()
file_token = upload_image(session, image_path, request_timeout=args.request_timeout)
upload_end = time.perf_counter()
print(f"✅ 上传完成 | file_token: {file_token}")
print(f"⏱️ 上传耗时: {upload_end - upload_start:.2f}s")
# 2) 提交任务
print("\n[2/4] 提交 image_to_model 任务...")
payload = build_generation_payload(args, file_token, image_path)
if args.print_payload:
print(json.dumps(payload, ensure_ascii=False, indent=2))
task_id = create_task(session, payload, request_timeout=args.request_timeout)
print(f"✅ 任务提交成功 | task_id: {task_id}")
# 3) 轮询任务
print("\n[3/4] 轮询任务状态...")
gen_start = time.perf_counter()
task_resp = poll_task(
session,
task_id,
poll_interval=args.poll_interval,
poll_timeout=args.poll_timeout,
request_timeout=args.request_timeout,
)
gen_end = time.perf_counter()
data = task_resp.get("data") or {}
output = data.get("output") or {}
total_end = time.perf_counter()
print("\n🎉 生成成功")
print("=" * 60)
print(f"task_id : {task_id}")
print(f"纯生成耗时 : {gen_end - gen_start:.2f}s")
print(f"总流程耗时 : {total_end - start_wall_time:.2f}s")
print(f"最终 status : {data.get('status')}")
print(f"output keys : {list(output.keys()) if isinstance(output, dict) else type(output)}")
print("=" * 60)
if args.print_output:
print(json.dumps(output, ensure_ascii=False, indent=2))
# 4) 下载输出
print("\n[4/4] 保存结果...")
save_outputs(
session,
task_resp,
out_dir=out_dir,
request_timeout=args.request_timeout,
save_task_json=args.save_task_json,
download_outputs=args.download_outputs,
)
print("\n✅ 全部完成。")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"\n❌ 程序终止: {e}")
sys.exit(1)