feat 1.增加 建议词 机制 2.对话生图实现

This commit is contained in:
zcr
2026-02-06 11:55:11 +08:00
parent ec195d17e1
commit 3248c45cd4
12 changed files with 655 additions and 85 deletions

View File

View File

@@ -0,0 +1,43 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, Field
# 定义输出结构,保证稳定性
class SuggestionOutput(BaseModel):
suggestions: list[str] = Field(description="A list of 3 short follow-up questions or actions for the user, max 10 chars each.")
async def generate_chat_suggestions(messages, model) -> list[str]:
"""
根据对话历史生成 3 个推荐追问按钮
"""
# 只需要最近的几次交互即可判断意图
recent_msgs = messages[-4:]
parser = JsonOutputParser(pydantic_object=SuggestionOutput)
prompt = ChatPromptTemplate.from_messages([
("system", """
你是家具设计系统的交互助手。请根据用户的对话历史,预测用户接下来最可能想做的 3 件事。
【判断逻辑】
1. 如果用户已经确定了【类型、材质、风格】但还没有生成过草图 -> 必须推荐 "生成设计草图"
2. 如果刚生成了草图 -> 推荐 "调整材质""查看三维视图""下载报价单" 等。
3. 如果用户还在犹豫 -> 推荐具体的风格或材质询问。
请直接输出 JSON 格式,包含 suggestions 字段。按钮文案要简短中文不超过8个字
"""),
("user", "对话历史:{history}"),
])
chain = prompt | model | parser
try:
# 将消息对象转为字符串喂给模型
history_str = "\n".join([f"{m.type}: {m.content}" for m in recent_msgs])
result = await chain.ainvoke({"history": history_str})
return result.get("suggestions", [])
except Exception as e:
print(f"建议生成失败: {e}")
return []

View File

@@ -0,0 +1,65 @@
import io
import logging
from io import BytesIO
import urllib3
from PIL import Image
from minio import Minio
from src.core.config import settings
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# 自定义 Retry 类
class CustomRetry(urllib3.Retry):
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
# 调用父类的 increment 方法
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
# 打印重试信息
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
return new_retry
logger = logging.getLogger()
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
status_forcelist=[500, 502, 503, 504],
),
)
# 获取图片
def oss_get_image(oss_client, bucket, object_name, data_type):
# cv2 默认全通道读取
image_object = None
try:
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f" | 获取图片出现异常 ######: {e}")
return image_object
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
req = None
try:
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
except Exception as e:
logger.warning(f" | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
url = "aida-users/89/sketch/123-89.png"
read_type = "2"
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img.show()
img.save("result.png")