117 lines
4.0 KiB
Python
117 lines
4.0 KiB
Python
|
|
import asyncio
|
|||
|
|
from google.cloud import storage
|
|||
|
|
from google.oauth2 import service_account
|
|||
|
|
from vertexai.generative_models import Part
|
|||
|
|
import os
|
|||
|
|
import mimetypes
|
|||
|
|
from typing import List
|
|||
|
|
from google.cloud import aiplatform
|
|||
|
|
|
|||
|
|
# 替换为您的 GCS Bucket 名称和目标文件夹
|
|||
|
|
GCS_BUCKET_NAME = "aida-test-vertex-ai-bucket"
|
|||
|
|
GCS_BLOB_FOLDER = "user_456/"
|
|||
|
|
KEY_FILE_PATH = "/workspace/lc_stylist_agent/request.json"
|
|||
|
|
# 1. 从 JSON 文件加载凭证
|
|||
|
|
aiplatform.init(
|
|||
|
|
project='aida-461108',
|
|||
|
|
location='us-central1'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MyVertexAIClient:
|
|||
|
|
def __init__(self):
|
|||
|
|
try:
|
|||
|
|
self.credentials = service_account.Credentials.from_service_account_file(KEY_FILE_PATH)
|
|||
|
|
except Exception as e:
|
|||
|
|
# 这里的异常处理应根据实际情况调整
|
|||
|
|
raise RuntimeError(f"Failed to load credentials from file {KEY_FILE_PATH}: {e}")
|
|||
|
|
# 初始化 GCS 客户端 (可以在 setup 或 __init__ 中完成)
|
|||
|
|
self.gcs_client = storage.Client(
|
|||
|
|
project=self.credentials.project_id,
|
|||
|
|
credentials=self.credentials
|
|||
|
|
)
|
|||
|
|
self.uploaded_gcs_uris: List[str] = [] # 用于记录和后续清理
|
|||
|
|
|
|||
|
|
# --- 辅助方法:安全地执行同步 GCS 上传 ---
|
|||
|
|
def _upload_to_gcs_sync(self, bucket_name: str, blob_name: str, local_file_path: str) -> str:
|
|||
|
|
"""同步方法:将文件上传到 GCS 并返回 GCS URI。"""
|
|||
|
|
bucket = self.gcs_client.bucket(bucket_name)
|
|||
|
|
blob = bucket.blob(blob_name)
|
|||
|
|
blob.upload_from_filename(local_file_path)
|
|||
|
|
|
|||
|
|
gcs_uri = f"gs://{bucket_name}/{blob_name}"
|
|||
|
|
return gcs_uri
|
|||
|
|
|
|||
|
|
# --- 目标方法:改写 gemini_client.aio.files.upload ---
|
|||
|
|
async def upload_file_to_vertex_part(self, merged_image_path: str) -> Part:
|
|||
|
|
"""
|
|||
|
|
[改写目标] 异步地将本地文件上传到 GCS,并返回 Vertex AI Part 对象。
|
|||
|
|
|
|||
|
|
:param merged_image_path: 本地文件的路径。
|
|||
|
|
:return: 包含 GCS URI 的 Vertex AI Part 对象。
|
|||
|
|
"""
|
|||
|
|
# 1. 确定 GCS 路径
|
|||
|
|
file_name = os.path.basename(merged_image_path)
|
|||
|
|
# 使用唯一标识符以防冲突,这里简化为文件名
|
|||
|
|
blob_name = GCS_BLOB_FOLDER + file_name
|
|||
|
|
|
|||
|
|
# 2. 异步执行 GCS 上传 (将同步 I/O 放入线程池)
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
gcs_uri = await loop.run_in_executor(
|
|||
|
|
None, # 使用默认线程池
|
|||
|
|
self._upload_to_gcs_sync,
|
|||
|
|
GCS_BUCKET_NAME,
|
|||
|
|
blob_name,
|
|||
|
|
merged_image_path
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 3. 记录 GCS URI 以供后续清理
|
|||
|
|
self.uploaded_gcs_uris.append(gcs_uri)
|
|||
|
|
|
|||
|
|
# 4. 确定 MIME 类型
|
|||
|
|
mime_type, _ = mimetypes.guess_type(merged_image_path)
|
|||
|
|
if not mime_type:
|
|||
|
|
mime_type = 'application/octet-stream' # 提供默认值
|
|||
|
|
|
|||
|
|
# 5. 创建 Vertex AI Part 对象
|
|||
|
|
# 这是 Vertex AI SDK 传递文件的标准方式
|
|||
|
|
vertex_part = Part.from_uri(
|
|||
|
|
uri=gcs_uri,
|
|||
|
|
mime_type=mime_type
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return vertex_part
|
|||
|
|
|
|||
|
|
def _list_clear_object(self, bucket_name: str, prefix: str, delimiter=None):
|
|||
|
|
blobs = self.gcs_client.list_blobs(
|
|||
|
|
bucket_name, prefix=prefix, delimiter=delimiter
|
|||
|
|
)
|
|||
|
|
for blob in blobs:
|
|||
|
|
blob.delete()
|
|||
|
|
print(f" ✅ 已删除: {blob.name}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# --- 原始代码的调用改写 ---
|
|||
|
|
|
|||
|
|
# 假设 self.vertex_client 是 MyVertexAIClient 的实例
|
|||
|
|
# original: myfile = await self.gemini_client.aio.files.upload(file=merged_image_path)
|
|||
|
|
# original: content_parts.append(myfile)
|
|||
|
|
|
|||
|
|
# 改写后:
|
|||
|
|
merged_image_path = "/workspace/lc_stylist_agent/app/core/data/outfit_output/35e8626c-943f-4a3f-a0cc-1280d8bcf84d.jpg"
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def run():
|
|||
|
|
vertex_ai_server = MyVertexAIClient()
|
|||
|
|
# file_name = os.path.basename(merged_image_path)
|
|||
|
|
# blob_name = GCS_BLOB_FOLDER + file_name
|
|||
|
|
|
|||
|
|
# vertex_part = vertex_ai_server._upload_to_gcs_sync(GCS_BUCKET_NAME, blob_name, merged_image_path)
|
|||
|
|
# print(vertex_part)
|
|||
|
|
|
|||
|
|
vertex_ai_server._list_clear_object(GCS_BUCKET_NAME, "user_456")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# content_parts.append(vertex_part)
|
|||
|
|
asyncio.run(run())
|