117 lines
4.0 KiB
Python
117 lines
4.0 KiB
Python
import asyncio
|
||
from google.cloud import storage
|
||
from google.oauth2 import service_account
|
||
from vertexai.generative_models import Part
|
||
import os
|
||
import mimetypes
|
||
from typing import List
|
||
from google.cloud import aiplatform
|
||
|
||
# 替换为您的 GCS Bucket 名称和目标文件夹
|
||
GCS_BUCKET_NAME = "aida-test-vertex-ai-bucket"
|
||
GCS_BLOB_FOLDER = "user_456/"
|
||
KEY_FILE_PATH = "/workspace/lc_stylist_agent/request.json"
|
||
# 1. 从 JSON 文件加载凭证
|
||
aiplatform.init(
|
||
project='aida-461108',
|
||
location='us-central1'
|
||
)
|
||
|
||
|
||
class MyVertexAIClient:
|
||
def __init__(self):
|
||
try:
|
||
self.credentials = service_account.Credentials.from_service_account_file(KEY_FILE_PATH)
|
||
except Exception as e:
|
||
# 这里的异常处理应根据实际情况调整
|
||
raise RuntimeError(f"Failed to load credentials from file {KEY_FILE_PATH}: {e}")
|
||
# 初始化 GCS 客户端 (可以在 setup 或 __init__ 中完成)
|
||
self.gcs_client = storage.Client(
|
||
project=self.credentials.project_id,
|
||
credentials=self.credentials
|
||
)
|
||
self.uploaded_gcs_uris: List[str] = [] # 用于记录和后续清理
|
||
|
||
# --- 辅助方法:安全地执行同步 GCS 上传 ---
|
||
def _upload_to_gcs_sync(self, bucket_name: str, blob_name: str, local_file_path: str) -> str:
|
||
"""同步方法:将文件上传到 GCS 并返回 GCS URI。"""
|
||
bucket = self.gcs_client.bucket(bucket_name)
|
||
blob = bucket.blob(blob_name)
|
||
blob.upload_from_filename(local_file_path)
|
||
|
||
gcs_uri = f"gs://{bucket_name}/{blob_name}"
|
||
return gcs_uri
|
||
|
||
# --- 目标方法:改写 gemini_client.aio.files.upload ---
|
||
async def upload_file_to_vertex_part(self, merged_image_path: str) -> Part:
|
||
"""
|
||
[改写目标] 异步地将本地文件上传到 GCS,并返回 Vertex AI Part 对象。
|
||
|
||
:param merged_image_path: 本地文件的路径。
|
||
:return: 包含 GCS URI 的 Vertex AI Part 对象。
|
||
"""
|
||
# 1. 确定 GCS 路径
|
||
file_name = os.path.basename(merged_image_path)
|
||
# 使用唯一标识符以防冲突,这里简化为文件名
|
||
blob_name = GCS_BLOB_FOLDER + file_name
|
||
|
||
# 2. 异步执行 GCS 上传 (将同步 I/O 放入线程池)
|
||
loop = asyncio.get_event_loop()
|
||
gcs_uri = await loop.run_in_executor(
|
||
None, # 使用默认线程池
|
||
self._upload_to_gcs_sync,
|
||
GCS_BUCKET_NAME,
|
||
blob_name,
|
||
merged_image_path
|
||
)
|
||
|
||
# 3. 记录 GCS URI 以供后续清理
|
||
self.uploaded_gcs_uris.append(gcs_uri)
|
||
|
||
# 4. 确定 MIME 类型
|
||
mime_type, _ = mimetypes.guess_type(merged_image_path)
|
||
if not mime_type:
|
||
mime_type = 'application/octet-stream' # 提供默认值
|
||
|
||
# 5. 创建 Vertex AI Part 对象
|
||
# 这是 Vertex AI SDK 传递文件的标准方式
|
||
vertex_part = Part.from_uri(
|
||
uri=gcs_uri,
|
||
mime_type=mime_type
|
||
)
|
||
|
||
return vertex_part
|
||
|
||
def _list_clear_object(self, bucket_name: str, prefix: str, delimiter=None):
|
||
blobs = self.gcs_client.list_blobs(
|
||
bucket_name, prefix=prefix, delimiter=delimiter
|
||
)
|
||
for blob in blobs:
|
||
blob.delete()
|
||
print(f" ✅ 已删除: {blob.name}")
|
||
|
||
|
||
# --- 原始代码的调用改写 ---
|
||
|
||
# 假设 self.vertex_client 是 MyVertexAIClient 的实例
|
||
# original: myfile = await self.gemini_client.aio.files.upload(file=merged_image_path)
|
||
# original: content_parts.append(myfile)
|
||
|
||
# 改写后:
|
||
merged_image_path = "/workspace/lc_stylist_agent/app/core/data/outfit_output/35e8626c-943f-4a3f-a0cc-1280d8bcf84d.jpg"
|
||
|
||
|
||
async def run():
|
||
vertex_ai_server = MyVertexAIClient()
|
||
# file_name = os.path.basename(merged_image_path)
|
||
# blob_name = GCS_BLOB_FOLDER + file_name
|
||
|
||
# vertex_part = vertex_ai_server._upload_to_gcs_sync(GCS_BUCKET_NAME, blob_name, merged_image_path)
|
||
# print(vertex_part)
|
||
|
||
vertex_ai_server._list_clear_object(GCS_BUCKET_NAME, "user_456")
|
||
|
||
|
||
# content_parts.append(vertex_part)
|
||
asyncio.run(run())
|