push gitignore
This commit is contained in:
129
app/server/ChatbotAgent/agent_server.py
Normal file
129
app/server/ChatbotAgent/agent_server.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
import litserve as ls
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
from app.core.data_structure import Message, Role
|
||||
from app.core.llm_interface import AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.stylist_agent_server import AsyncStylistAgent
|
||||
from app.core.system_prompt import SUMMARY_PROMPT
|
||||
from app.core.vector_database import VectorDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRequestModel(BaseModel):
|
||||
user_id: str
|
||||
num_outfits: int
|
||||
stylist_path: str
|
||||
|
||||
|
||||
class LCAgent(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
self.llm = AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
|
||||
self.redis = RedisManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
|
||||
)
|
||||
self.vector_db = VectorDatabase(
|
||||
vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
collection_name=settings.COLLECTION_NAME,
|
||||
embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
)
|
||||
self.stylist_agent_kwages = {
|
||||
'local_db': self.vector_db,
|
||||
'max_len': 5,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
|
||||
async def decode_request(self, request: AgentRequestModel):
|
||||
logger.info(f"request: {request.model_dump()}")
|
||||
return request
|
||||
|
||||
async def predict(self, request):
|
||||
asyncio.create_task(self.background_run(request))
|
||||
return {"status": "Task initiated in background."}
|
||||
|
||||
async def encode_response(self, output):
|
||||
return output
|
||||
|
||||
async def background_run(self, request: AgentRequestModel):
|
||||
# 1. 根据用户ID查询对话历史,总结对话内容
|
||||
request_summary = await self.get_conversation_summary(request.user_id)
|
||||
logger.info(f"request_summary: {request_summary}")
|
||||
|
||||
# 2.根据对话总结推荐搭配
|
||||
recommendation_results = await self.recommend_outfit(request_summary=request_summary,
|
||||
stylist_name=request.stylist_path,
|
||||
start_outfit=[],
|
||||
num_outfits=request.num_outfits,
|
||||
user_id=request.user_id)
|
||||
|
||||
logger.info("\n--- Final Recommendation Results ---")
|
||||
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
|
||||
logger.info(f"✅ Outfit {i + 1} saved to: {path}")
|
||||
for error in recommendation_results.get("failed_outfits", []):
|
||||
logger.error(f"❌ {error}")
|
||||
|
||||
async def get_conversation_summary(self, user_id: str) -> str:
|
||||
"""
|
||||
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
||||
|
||||
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
||||
"""
|
||||
history_messages = self.redis.get_history(user_id)
|
||||
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
||||
# 临时调用 LLM 或使用本地逻辑生成总结
|
||||
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
|
||||
return summary
|
||||
|
||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1, user_id: str = "test"):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
Args:
|
||||
request_summary: 用户的request
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
stylist_path=stylist_name,
|
||||
start_outfit=start_outfit,
|
||||
user_id=user_id
|
||||
)
|
||||
tasks.append(task)
|
||||
print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successful_outfits = []
|
||||
failed_outfits = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 任务执行中发生异常
|
||||
failed_outfits.append(f"Failed: {result}")
|
||||
else:
|
||||
# 任务成功,result 是 run_styling_process 返回的图片路径
|
||||
successful_outfits.append(result)
|
||||
|
||||
return {
|
||||
"successful_outfits": successful_outfits,
|
||||
"failed_outfits": failed_outfits
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during concurrent recommendation: {e}")
|
||||
return {"error": str(e)}
|
||||
188
app/server/ChatbotAgent/chatbot_server.py
Normal file
188
app/server/ChatbotAgent/chatbot_server.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# server.py
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
import litserve as ls
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
from app.core.data_structure import Role, Message
|
||||
from app.core.llm_interface import AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.stylist_agent import AsyncStylistAgent
|
||||
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT
|
||||
from app.core.vector_database import VectorDatabase
|
||||
from google.genai import types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
user_id: str # 用戶ID
|
||||
user_message: str # 用戶輸入
|
||||
|
||||
|
||||
class LCChatBot(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
self.llm = AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
|
||||
self.redis = RedisManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
|
||||
)
|
||||
self.vector_db = VectorDatabase(
|
||||
vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
collection_name=settings.COLLECTION_NAME,
|
||||
embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
)
|
||||
self.stylist_agent_kwages = {
|
||||
'local_db': self.vector_db,
|
||||
'max_len': 5,
|
||||
'outfits_root': settings.OUTFIT_OUTPUT_DIR,
|
||||
'image_dir': settings.IMAGE_DIR,
|
||||
'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
self.gemini_client = genai.Client(
|
||||
vertexai=True, project='aida-461108', location='us-central1'
|
||||
)
|
||||
|
||||
async def decode_request(self, request: PredictRequest):
|
||||
return request
|
||||
|
||||
async def predict(self, request) -> AsyncGenerator[str, None]:
|
||||
# 添加用户消息到历史
|
||||
user_message = request.user_message
|
||||
user_id = request.user_id
|
||||
user_msg = Message(role=Role.USER, content=user_message)
|
||||
chat_history = self.redis.get_history(user_id)
|
||||
chat_history.append(user_msg)
|
||||
|
||||
contents = []
|
||||
|
||||
for msg in chat_history:
|
||||
gemini_role = "user" if msg.role == Role.USER else "model"
|
||||
content = types.Content(
|
||||
role=gemini_role,
|
||||
parts=[types.Part.from_text(text=msg.content)]
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
response_stream = await self.gemini_client.aio.models.generate_content_stream(
|
||||
model='gemini-2.5-flash',
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=BASIC_PROMPT,
|
||||
# temperature=0.3,
|
||||
)
|
||||
)
|
||||
full_response_text = ""
|
||||
# 异步迭代流
|
||||
async for chunk in response_stream:
|
||||
if chunk:
|
||||
logger.info(chunk.text)
|
||||
# 2. 实时 yield 文本块给 encode_response
|
||||
yield chunk.text
|
||||
|
||||
# 3. 累加文本块以保存完整的历史记录
|
||||
full_response_text += chunk.text
|
||||
|
||||
# 添加助手消息到历史
|
||||
if full_response_text:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content=full_response_text)
|
||||
else:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
||||
|
||||
self.redis.save_message(user_id, user_msg)
|
||||
self.redis.save_message(user_id, assistant_msg)
|
||||
|
||||
async def encode_response(self, output):
|
||||
# The for-loop must have async keyword here since output is an AsyncGenerator
|
||||
async for out in output:
|
||||
yield {"output": out}
|
||||
|
||||
async def process_query(self, user_id: str, user_message: str) -> str:
|
||||
"""
|
||||
处理用户的最新输入,调用 LLM, 并更新历史记录。
|
||||
"""
|
||||
|
||||
# 添加用户消息到历史
|
||||
user_msg = Message(role=Role.USER, content=user_message)
|
||||
chat_history = self.redis.get_history(user_id)
|
||||
chat_history.append(user_msg)
|
||||
|
||||
# 生成 LLM 回复
|
||||
try:
|
||||
response_text = await self.llm.generate_response(chat_history, system_prompt=BASIC_PROMPT)
|
||||
except Exception as e:
|
||||
logger("\n--- Final Recommendation Results ---")
|
||||
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
response_text = "抱歉,系统暂时无法响应,请稍后再试。"
|
||||
|
||||
# 添加助手消息到历史
|
||||
if response_text:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content=response_text)
|
||||
else:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
||||
|
||||
self.redis.save_message(user_id, user_msg)
|
||||
self.redis.save_message(user_id, assistant_msg)
|
||||
|
||||
return response_text
|
||||
|
||||
async def get_conversation_summary(self, user_id: str) -> str:
|
||||
"""
|
||||
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
||||
|
||||
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
||||
"""
|
||||
history_messages = self.redis.get_history(user_id)
|
||||
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
||||
|
||||
# 临时调用 LLM 或使用本地逻辑生成总结
|
||||
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
|
||||
|
||||
return summary
|
||||
|
||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
Args:
|
||||
request_summary: 用户的request
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
|
||||
tasks.append(task)
|
||||
logger.info(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successful_outfits = []
|
||||
failed_outfits = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 任务执行中发生异常
|
||||
failed_outfits.append(f"Failed: {result}")
|
||||
else:
|
||||
# 任务成功,result 是 run_styling_process 返回的图片路径
|
||||
successful_outfits.append(result)
|
||||
|
||||
return {
|
||||
"successful_outfits": successful_outfits,
|
||||
"failed_outfits": failed_outfits
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred during concurrent recommendation: {e}")
|
||||
return {"error": str(e)}
|
||||
125
app/server/utils/minio_client.py
Normal file
125
app/server/utils/minio_client.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.server.utils.minio_config import MINIO_ACCESS, MINIO_SECRET, MINIO_URL, MINIO_SECURE
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
class CustomRetry(urllib3.Retry):
|
||||
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||
# 调用父类的 increment 方法
|
||||
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||
# 打印重试信息
|
||||
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||
return new_retry
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||
http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 获取图片
|
||||
def oss_get_image(oss_client, path, data_type):
|
||||
# cv2 默认全通道读取
|
||||
bucket = path.split("/", 1)[0]
|
||||
object_name = path.split("/", 1)[1]
|
||||
image_object = None
|
||||
try:
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
if data_type == "cv2":
|
||||
image_bytes = image_data.read()
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
|
||||
if image_object.dtype == np.uint16:
|
||||
image_object = (image_object / 256).astype('uint8')
|
||||
else:
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取图片出现异常 ######: {e}")
|
||||
return image_object
|
||||
|
||||
|
||||
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f"上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
# def upload_json_to_minio_sync(
|
||||
# minio_client: Minio,
|
||||
# bucket_name: str,
|
||||
# object_name: str,
|
||||
# data: list
|
||||
# ) -> str:
|
||||
# """
|
||||
# 将 Python 字典转换为 JSON 字符串,并上传到 MinIO。
|
||||
#
|
||||
# :param minio_client: 已初始化的 MinIO 客户端实例。
|
||||
# :param bucket_name: 目标 Bucket 名称。
|
||||
# :param object_name: 目标文件路径/名称 (e.g., 'data/report.json')。
|
||||
# :param data: 要上传的 Python 字典数据。
|
||||
# :return: 成功返回 True,失败返回 False。
|
||||
# """
|
||||
# try:
|
||||
# # 1. 将 Python 字典序列化为 JSON 字符串
|
||||
# json_string = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
# # 2. 将 JSON 字符串编码为字节流 (bytes)
|
||||
# json_bytes = json_string.encode('utf-8')
|
||||
#
|
||||
# # 3. 创建 BytesIO 对象,用于从内存读取数据
|
||||
# data_stream = io.BytesIO(json_bytes)
|
||||
#
|
||||
# # 4. 使用 put_object 上传数据流
|
||||
# minio_client.put_object(
|
||||
# bucket_name,
|
||||
# object_name,
|
||||
# data_stream,
|
||||
# length=len(json_bytes),
|
||||
# content_type='application/json; charset=utf-8' # 设置正确的 MIME 类型
|
||||
# )
|
||||
# logger.info(f"✅ JSON file uploaded successfully to {bucket_name}/{object_name}")
|
||||
# return f'{bucket_name}/{object_name}'
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"❌ An unexpected error occurred: {e}")
|
||||
# return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||
url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg'
|
||||
read_type = "1"
|
||||
img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save("4.png")
|
||||
6
app/server/utils/minio_config.py
Normal file
6
app/server/utils/minio_config.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
MINIO_LC_DATA_PATH = "lanecarford/lc_image_data"
|
||||
57
app/server/utils/request_post.py
Normal file
57
app/server/utils/request_post.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5):
|
||||
"""
|
||||
发送POST请求的封装函数
|
||||
|
||||
:param url: 接口的URL地址
|
||||
:param data: 要发送的数据(字典形式,用于表单数据等,会自动编码)
|
||||
:param json_data: 要发送的JSON数据(字典形式,会自动转换为JSON字符串)
|
||||
:param headers: 请求头字典
|
||||
:param auth: 认证信息(如 ('username', 'password') 形式用于基本认证)
|
||||
:param timeout: 超时时间,单位为秒
|
||||
:return: 返回接口的响应对象
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=timeout
|
||||
)
|
||||
response.raise_for_status() # 如果请求失败,抛出异常
|
||||
return response
|
||||
except requests.RequestException as e:
|
||||
print(f"POST请求出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
url = 'https://83aa2db8e006.ngrok-free.app/api/style/callback'
|
||||
|
||||
object_data = {
|
||||
'outfit_id': "test",
|
||||
"status": "test",
|
||||
"path": "test",
|
||||
"items": [
|
||||
"test"
|
||||
]
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Accept': "*/*",
|
||||
'Accept-Encoding': "gzip, deflate, br",
|
||||
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
|
||||
'Connection': "keep-alive",
|
||||
'Content-Type': "application/json"
|
||||
}
|
||||
start_time = time.time()
|
||||
X = post_request(url=url, data=json.dumps(object_data), headers=headers)
|
||||
print(time.time() - start_time)
|
||||
print(X)
|
||||
Reference in New Issue
Block a user