feat 1.增加 建议词 机制 2.对话生图实现
This commit is contained in:
@@ -7,13 +7,21 @@ class Settings(BaseSettings):
|
||||
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||
"""
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file='.env_local',
|
||||
env_file_encoding='utf-8',
|
||||
extra='ignore' # 忽略环境变量中多余的键
|
||||
)
|
||||
# --- google api 配置信息 ---
|
||||
GOOGLE_GENAI_USE_VERTEXAI: str = Field(default="", description="")
|
||||
GOOGLE_API_KEY: str = Field(default="", description="")
|
||||
GOOGLE_CLOUD_PROJECT: str = Field(default="", description="")
|
||||
GOOGLE_CLOUD_LOCATION: str = Field(default="", description="")
|
||||
|
||||
# --- minio 配置信息 ---
|
||||
MINIO_URL: str = Field(default='', description="")
|
||||
MINIO_ACCESS: str = Field(default='', description="")
|
||||
MINIO_SECRET: str = Field(default='', description="")
|
||||
MINIO_SECURE: bool = Field(default=True, description="")
|
||||
|
||||
# --- mongodb配置信息 ---
|
||||
MONGODB_USERNAME: str = Field(default="", description="")
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from src.schemas.chat import ChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.agent.graph import app # 导入已经 compile 好的 graph
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
|
||||
@@ -57,38 +57,114 @@ async def chat_stream(request: ChatRequest):
|
||||
# 如果是回溯操作,我们生成一个新的 ID,或者由前端传入一个新的 target_thread_id
|
||||
is_branching = source_thread_id and checkpoint_id
|
||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||
# 2. 获取配置参数
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
|
||||
# 2. 如果是分叉请求,我们需要先“搬家”状态
|
||||
# 构建基础 Config
|
||||
current_config = {
|
||||
"configurable": {
|
||||
"thread_id": target_thread_id,
|
||||
"llm_temperature": temp
|
||||
}
|
||||
}
|
||||
# 3. 处理状态初始化与分支
|
||||
initial_messages = []
|
||||
|
||||
# 如果是全新的对话(没有 source_thread_id),或者明确要求分叉
|
||||
if not source_thread_id or is_branching:
|
||||
# 如果用户传了标签,构造 SystemMessage 注入上下文
|
||||
if request.config_params:
|
||||
cp = request.config_params
|
||||
system_prompt = (
|
||||
f"Current furniture design background settings:\n"
|
||||
f"- type: {cp.type}\n"
|
||||
f"- space/region: {cp.region}\n"
|
||||
f"- style tendency: {cp.style}\n"
|
||||
f"Please strictly follow the above settings in subsequent conversations。"
|
||||
)
|
||||
initial_messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 4. 执行分叉逻辑(搬运旧数据)
|
||||
if is_branching:
|
||||
# 获取旧状态
|
||||
source_config = {"configurable": {"thread_id": source_thread_id, "checkpoint_id": checkpoint_id}}
|
||||
source_config = {
|
||||
"configurable": {
|
||||
"thread_id": source_thread_id,
|
||||
"checkpoint_id": checkpoint_id
|
||||
}
|
||||
}
|
||||
older_state = await app.aget_state(source_config)
|
||||
|
||||
# 将旧状态的消息,作为新 thread 的初始值注入
|
||||
# 注意:这里我们手动把旧消息塞给新 thread
|
||||
new_config = {"configurable": {"thread_id": target_thread_id}}
|
||||
await app.aupdate_state(new_config, older_state.values)
|
||||
# 将旧消息和我们新定义的 SystemMessage 合并
|
||||
# update_state 会将这些消息推送到新 thread 的存储中
|
||||
combined_values = older_state.values.copy()
|
||||
if initial_messages:
|
||||
combined_values["messages"] = list(combined_values["messages"]) + initial_messages
|
||||
|
||||
# 现在的 config 指向新 Thread
|
||||
current_config = new_config
|
||||
else:
|
||||
current_config = {"configurable": {"thread_id": target_thread_id}}
|
||||
await app.aupdate_state(current_config, combined_values)
|
||||
|
||||
async def event_generator():
|
||||
# 告诉前端:现在是在哪个 Thread 上工作(如果是分叉,前端需要更新本地存储的 ID)
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching}, ensure_ascii=False)}\n\n"
|
||||
# 初始推送状态信息
|
||||
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构造本次请求的输入
|
||||
# 如果是第一次开始,且有 initial_messages,则连同 user message 一起发送
|
||||
# --- 核心逻辑:构造本次请求的消息列表 ---
|
||||
new_messages = []
|
||||
if not source_thread_id and initial_messages:
|
||||
new_messages.extend(initial_messages)
|
||||
# 添加用户消息
|
||||
new_messages.append(HumanMessage(content=request.message))
|
||||
|
||||
# --- 新增:强制绘图指令注入 ---
|
||||
# if request.force_sketch:
|
||||
# force_instruction = HumanMessage(
|
||||
# content="[SYSTEM_DIRECTIVE]: 用户点击了强制生成按钮。请立即根据当前上下文调用 generate_furniture_sketch 工具生成草图,无需确认。"
|
||||
# )
|
||||
# new_messages.append(force_instruction)
|
||||
|
||||
input_data = {"messages": new_messages}
|
||||
|
||||
async for event in app.astream(
|
||||
{"messages": [HumanMessage(content=request.message)]},
|
||||
input_data,
|
||||
current_config,
|
||||
stream_mode="updates"
|
||||
):
|
||||
# ... 发送流式内容的逻辑保持不变 ...
|
||||
for node_name, output in event.items():
|
||||
if "messages" in output:
|
||||
msg = output["messages"][-1]
|
||||
# 获取最新 state 以获取 checkpoint_id
|
||||
state = await app.aget_state(current_config)
|
||||
yield f"data: {json.dumps({'node': node_name, 'content': msg.content, 'checkpoint_id': state.config['configurable']['checkpoint_id']}, ensure_ascii=False)}\n\n"
|
||||
current_cp_id = state.config["configurable"].get("checkpoint_id")
|
||||
|
||||
# 遍历本次 update 产生的所有消息
|
||||
for msg in output["messages"]:
|
||||
payload = {
|
||||
"node": node_name,
|
||||
"content": "",
|
||||
"image_url": None,
|
||||
"checkpoint_id": current_cp_id,
|
||||
"suggestions": []
|
||||
}
|
||||
|
||||
# --- 核心改动:提取建议按钮 ---
|
||||
# 无论是不是 Suggester 节点,只要消息里带了建议就提取
|
||||
if hasattr(msg, "additional_kwargs") and "suggestions" in msg.additional_kwargs:
|
||||
payload["suggestions"] = msg.additional_kwargs["suggestions"]
|
||||
|
||||
content = msg.content
|
||||
# 逻辑判断:MinIO 图片处理
|
||||
if node_name == "Visualizer" and str(content).endswith("png") and "furniture/sketches" in str(content):
|
||||
payload["image_url"] = content
|
||||
payload["content"] = "已为您生成设计草图"
|
||||
else:
|
||||
payload["content"] = content
|
||||
|
||||
# 如果消息既没有文本、也没有图片、也没有建议(比如中间的 ToolCall 消息),则跳过
|
||||
if not payload["content"] and not payload["image_url"] and not payload["suggestions"]:
|
||||
continue
|
||||
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@@ -147,11 +223,11 @@ async def get_chat_history(thread_id: str):
|
||||
last_msg = msgs[-1]
|
||||
# 获取内容并做摘要截断
|
||||
content = getattr(last_msg, "content", str(last_msg))
|
||||
msg_content = content[:50] + ("..." if len(content) > 50 else "")
|
||||
msg_content = content
|
||||
|
||||
history_data.append(HistoryItem(
|
||||
checkpoint_id=state.config["configurable"]["checkpoint_id"],
|
||||
last_message=msg_content[:50],
|
||||
last_message=msg_content,
|
||||
node=state.metadata.get("source"),
|
||||
timestamp=state.metadata.get("step")
|
||||
))
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
type: str = Field(..., description="家具类型,如:沙发、餐桌")
|
||||
region: str = Field(..., description="地区/空间,如:客厅、卧室、户外")
|
||||
style: str = Field(..., description="设计风格,如:极简、工业风、中式")
|
||||
temperature: confloat(ge=0, le=2.0) = Field(default=0.7, description="模型温度")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str = Field(..., description="用户的输入指令")
|
||||
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||
config_params: Optional[AgentConfig] = None
|
||||
# force_sketch: bool = False # 新增:是否强制绘图
|
||||
|
||||
|
||||
class HistoryItem(BaseModel):
|
||||
|
||||
@@ -1,86 +1,117 @@
|
||||
import os
|
||||
|
||||
from google.oauth2 import service_account
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.tools import generate_2025_report_tool, generate_furniture_sketch
|
||||
from src.server.agent.config_loader import get_agent_prompt
|
||||
from src.core.config import settings
|
||||
from src.server.utils.generate_suggestion import generate_chat_suggestions
|
||||
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
# 初始化 Gemini 模型 (使用 Flash 以保证速度)
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash", temperature=0.5, credentials=creds,
|
||||
project="aida-461108", location='us-central1', vertexai=True, api_key=settings.GOOGLE_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# 辅助函数:根据配置动态获取 LLM
|
||||
def get_model(config: RunnableConfig):
|
||||
# 从 configurable 中获取温度,默认为 0.5 (对应你之前的设置)
|
||||
# 这个 key 必须与你在 chat_stream 路由里定义的 "llm_temperature" 一致
|
||||
temp = config["configurable"].get("llm_temperature", 0.5)
|
||||
|
||||
return ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash",
|
||||
temperature=temp,
|
||||
credentials=creds,
|
||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
||||
vertexai=True,
|
||||
api_key=settings.GOOGLE_API_KEY
|
||||
)
|
||||
|
||||
|
||||
# --- 1. Designer Agent (设计顾问) ---
|
||||
def designer_node(state: AgentState):
|
||||
async def designer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责细化设计需求,提供专业参数"""
|
||||
model = get_model(config) # 获取带动态温度的模型
|
||||
|
||||
messages = state["messages"]
|
||||
system_text = get_agent_prompt("designer") or """
|
||||
你是一位资深的家具设计师。你的职责是:
|
||||
1. 从用户的模糊描述中提取或补充具体的设计参数(尺寸、材质、人体工学数据)。
|
||||
2. 如果用户想画图,不要直接画,而是先描述清楚细节,然后让 Visualizer 去画。
|
||||
请以专业的口吻回复。
|
||||
"""
|
||||
system_text = get_agent_prompt("designer")
|
||||
|
||||
system_prompt = SystemMessage(content=system_text)
|
||||
response = llm.invoke([system_prompt] + messages)
|
||||
# 改为异步调用 ainvoke
|
||||
response = await model.ainvoke([system_prompt] + messages)
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
# --- 2. Researcher Agent (情报专家) ---
|
||||
def researcher_node(state: AgentState):
|
||||
async def researcher_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责调用报告生成工具"""
|
||||
# 绑定工具给 LLM
|
||||
model = get_model(config)
|
||||
tools = [generate_2025_report_tool]
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
llm_with_tools = model.bind_tools(tools)
|
||||
|
||||
messages = state["messages"]
|
||||
system_text = get_agent_prompt("researcher") or "你是情报专家,负责检索与整理参考资料并生成报告。"
|
||||
system_text = get_agent_prompt("researcher")
|
||||
system_prompt = SystemMessage(content=system_text)
|
||||
response = llm_with_tools.invoke([system_prompt] + messages)
|
||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
||||
|
||||
# 如果模型决定调用工具
|
||||
if response.tool_calls:
|
||||
# 这里为了简化,直接在节点内执行工具(LangGraph也可以用 ToolNode)
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "generate_2025_report_tool":
|
||||
result = generate_2025_report_tool.invoke(tool_call["args"])
|
||||
# 这里的工具调用如果也是异步的,建议加 await
|
||||
result = await generate_2025_report_tool.ainvoke(tool_call["args"])
|
||||
return {"messages": [response, HumanMessage(content=str(result))]}
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
# --- 3. Visualizer Agent (视觉专家) ---
|
||||
def visualizer_node(state: AgentState):
|
||||
async def visualizer_node(state: AgentState, config: RunnableConfig):
|
||||
"""负责将自然语言转化为绘图 Prompt 并调用绘图工具"""
|
||||
model = get_model(config)
|
||||
tools = [generate_furniture_sketch]
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
llm_with_tools = model.bind_tools(tools)
|
||||
|
||||
messages = state["messages"]
|
||||
system_text = get_agent_prompt("visualizer") or """
|
||||
你是视觉专家。你的目标是生成高质量的家具草图。
|
||||
步骤:
|
||||
1. 根据上下文,编写一个详细的 Stable Diffusion 风格的英文 Prompt。
|
||||
2. 必须调用 generate_furniture_sketch 工具来生成图片。
|
||||
"""
|
||||
system_text = get_agent_prompt("visualizer")
|
||||
|
||||
# 强制它尝试调用工具
|
||||
system_prompt = SystemMessage(content=system_text)
|
||||
response = llm_with_tools.invoke([system_prompt] + messages)
|
||||
response = await llm_with_tools.ainvoke([system_prompt] + messages)
|
||||
|
||||
if response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "generate_furniture_sketch":
|
||||
result = generate_furniture_sketch.invoke(tool_call["args"])
|
||||
# 返回工具结果给 LLM,让它生成最终回复
|
||||
final_msg = f"已为您生成草图,链接如下:{result}"
|
||||
return {"messages": [response, HumanMessage(content=final_msg)]}
|
||||
img_url = await generate_furniture_sketch.ainvoke(tool_call["args"])
|
||||
return {
|
||||
"messages": [
|
||||
response,
|
||||
ToolMessage(content=img_url, tool_call_id=tool_call["id"]) # 标记这是一个图片结果
|
||||
]
|
||||
}
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
# --- 4. Suggester Agent (推荐对话专家) ---
|
||||
async def suggester_node(state: AgentState, config: RunnableConfig):
|
||||
"""专门生成追问建议的节点,作为流程终点"""
|
||||
model = get_model(config)
|
||||
messages = state["messages"]
|
||||
|
||||
# 只需要分析最近的对话
|
||||
suggestions = await generate_chat_suggestions(messages, model)
|
||||
|
||||
# 返回一个特殊消息,前端通过解析 additional_kwargs 获取按钮内容
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"suggestions": suggestions},
|
||||
name="Suggester"
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
@@ -10,14 +10,15 @@ from pymongo import MongoClient
|
||||
|
||||
from src.core.config import settings, MONGO_URI
|
||||
from src.server.agent.state import AgentState
|
||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node
|
||||
from src.server.agent.agents import designer_node, researcher_node, visualizer_node, suggester_node
|
||||
from langgraph.checkpoint.mongodb import MongoDBSaver
|
||||
|
||||
|
||||
# --- Supervisor (路由逻辑) ---
|
||||
# 定义路由的输出结构,强制 LLM 选择一个
|
||||
class RouteResponse(BaseModel):
|
||||
next: Literal["Designer", "Researcher", "Visualizer", "FINISH"]
|
||||
# 将 FINISH 替换或增加 Suggester
|
||||
next: Literal["Designer", "Researcher", "Visualizer", "Suggester"]
|
||||
|
||||
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
@@ -34,30 +35,23 @@ llm_supervisor = ChatGoogleGenerativeAI(
|
||||
def supervisor_node(state: AgentState):
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return {"next": "FINISH"}
|
||||
return {"next": "Suggester"}
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
# --- 改进的拦截逻辑 ---
|
||||
# 如果最后一条消息是 AI 产生的(且没有调用工具),说明专家已经回复完了用户
|
||||
# 此时我们才拦截并结束,否则会导致专家没机会说话
|
||||
# --- 拦截逻辑修改 ---
|
||||
# 如果专家已经回复完了(AIMessage 且无工具调用),则交给 Suggester 生成按钮
|
||||
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
|
||||
return {"next": "FINISH"}
|
||||
return {"next": "Suggester"}
|
||||
|
||||
# 如果最后一条是 HumanMessage,说明用户刚说完,Supervisor 必须派发任务
|
||||
system_prompt = """
|
||||
你是家具设计团队的主管(Supervisor)。
|
||||
请根据用户的意图,选择最合适的专家:
|
||||
- Designer: 设计建议、参数细化、闲聊、问候。
|
||||
- Visualizer: 绘图、看草图。
|
||||
- Researcher: 市场报告、趋势。
|
||||
|
||||
只需输出专家名称。
|
||||
system_prompt = """你是家具设计主管。分配任务给专家:
|
||||
- Designer: 设计建议、参数细化。
|
||||
- Visualizer: 绘图需求。
|
||||
- Researcher: 市场报告。
|
||||
"""
|
||||
|
||||
chain = llm_supervisor.with_structured_output(RouteResponse)
|
||||
decision = chain.invoke([{"role": "system", "content": system_prompt}] + messages)
|
||||
|
||||
return {"next": decision.next}
|
||||
|
||||
|
||||
@@ -68,10 +62,11 @@ workflow.add_node("Supervisor", supervisor_node)
|
||||
workflow.add_node("Designer", designer_node)
|
||||
workflow.add_node("Researcher", researcher_node)
|
||||
workflow.add_node("Visualizer", visualizer_node)
|
||||
workflow.add_node("Suggester", suggester_node) # 新增节点
|
||||
|
||||
workflow.add_edge(START, "Supervisor")
|
||||
|
||||
# 这里的逻辑是关键:Supervisor 决定去向
|
||||
# 修改条件边映射
|
||||
workflow.add_conditional_edges(
|
||||
"Supervisor",
|
||||
lambda state: state["next"],
|
||||
@@ -79,16 +74,18 @@ workflow.add_conditional_edges(
|
||||
"Designer": "Designer",
|
||||
"Researcher": "Researcher",
|
||||
"Visualizer": "Visualizer",
|
||||
"FINISH": END
|
||||
"Suggester": "Suggester" # 原本的 FINISH 现在指向 Suggester
|
||||
}
|
||||
)
|
||||
|
||||
# 重点修改:专家执行完后,必须回到 Supervisor 进行状态检查
|
||||
# 如果 Supervisor 发现专家刚说完话,它会触发上面的逻辑返回 FINISH
|
||||
# 专家执行完依然回到 Supervisor
|
||||
workflow.add_edge("Designer", "Supervisor")
|
||||
workflow.add_edge("Researcher", "Supervisor")
|
||||
workflow.add_edge("Visualizer", "Supervisor")
|
||||
|
||||
# 重点:Suggester 是整个流程的终点
|
||||
workflow.add_edge("Suggester", END)
|
||||
|
||||
client = MongoClient(MONGO_URI)
|
||||
checkpointer = MongoDBSaver(
|
||||
client=client["furniture_agent_db"],
|
||||
|
||||
@@ -1,4 +1,30 @@
|
||||
import base64
|
||||
import uuid
|
||||
from google.oauth2 import service_account
|
||||
from langchain_core.tools import tool
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig, Modality
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from minio import Minio
|
||||
|
||||
from src.core.config import settings
|
||||
from src.server.utils.new_oss_client import oss_upload_image
|
||||
|
||||
# 初始化全局凭证和客户端
|
||||
creds = service_account.Credentials.from_service_account_file(
|
||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
client = genai.Client(
|
||||
credentials=creds,
|
||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
||||
vertexai=True
|
||||
)
|
||||
|
||||
|
||||
# --- 模拟你已经开发好的报告生成功能 ---
|
||||
@@ -13,13 +39,76 @@ def generate_2025_report_tool(topic: str) -> str:
|
||||
return f"【报告生成成功】已生成关于 {topic} 的 PDF 报告。核心洞察:2025年趋势倾向于生物嗜好设计(Biophilic Design)和可持续软木材质。"
|
||||
|
||||
|
||||
# --- 绘图工具 ---
|
||||
# --- 2. 绘图工具 (接入 Nano Banana 逻辑) ---
|
||||
@tool
|
||||
def generate_furniture_sketch(prompt: str) -> str:
|
||||
"""
|
||||
用于生成家具草图。输入必须是详细的英文绘画提示词(Prompt)。
|
||||
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
||||
"""
|
||||
print(f"\n[系统日志] 正在调用 Gemini/Imagen 绘图 API,Prompt: {prompt}...")
|
||||
# 在真实场景中,这里调用 Google Imagen API 或 Midjourney API
|
||||
# 示例返回一个模拟的图片链接
|
||||
return "https://furniture-design-db.com/generated_sketch_v1.jpg"
|
||||
print(f"\n[系统日志] 正在调用 Nano Banana (Gemini Image Gen) ...")
|
||||
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash-image",
|
||||
contents=(f"Generate a professional furniture design sketch: {prompt}"),
|
||||
config=GenerateContentConfig(
|
||||
response_modalities=[Modality.TEXT, Modality.IMAGE],
|
||||
),
|
||||
)
|
||||
|
||||
image_bytes = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.inline_data:
|
||||
image_bytes = part.inline_data.data
|
||||
break
|
||||
|
||||
if not image_bytes:
|
||||
return "未能生成图像数据。"
|
||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||
bucket = "fida-test" # 替换为你的 bucket 名称
|
||||
# 3. 调用你的上传函数
|
||||
upload_res = oss_upload_image(
|
||||
oss_client=minio_client,
|
||||
bucket=bucket,
|
||||
object_name=object_name,
|
||||
image_bytes=image_bytes
|
||||
)
|
||||
|
||||
if upload_res:
|
||||
# 4. 构造访问链接 (如果是私有 bucket,需使用 presigned_get_object)
|
||||
# 这里简单示例为直接访问地址
|
||||
image_url = f"{bucket}/{object_name}"
|
||||
return image_url
|
||||
else:
|
||||
return "图片生成成功,但上传至存储服务器失败。"
|
||||
except Exception as e:
|
||||
return f"绘图流程异常: {str(e)}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(generate_furniture_sketch("椅子"))
|
||||
# creds = service_account.Credentials.from_service_account_file(
|
||||
# settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
# )
|
||||
# client = genai.Client(
|
||||
# credentials=creds,
|
||||
# project=settings.GOOGLE_CLOUD_PROJECT,
|
||||
# location=settings.GOOGLE_CLOUD_LOCATION,
|
||||
# vertexai=True
|
||||
# )
|
||||
#
|
||||
# response = client.models.generate_content(
|
||||
# model="gemini-2.5-flash-image",
|
||||
# contents=("Generate an image of the Eiffel tower with fireworks in the background."),
|
||||
# config=GenerateContentConfig(
|
||||
# response_modalities=[Modality.TEXT, Modality.IMAGE],
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# for part in response.candidates[0].content.parts:
|
||||
# if part.text:
|
||||
# print(part.text)
|
||||
# elif part.inline_data:
|
||||
# image = Image.open(BytesIO((part.inline_data.data)))
|
||||
# image.save("example-image-eiffel-tower.png")
|
||||
|
||||
0
src/server/utils/__init__.py
Normal file
0
src/server/utils/__init__.py
Normal file
43
src/server/utils/generate_suggestion.py
Normal file
43
src/server/utils/generate_suggestion.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# 定义输出结构,保证稳定性
|
||||
class SuggestionOutput(BaseModel):
|
||||
suggestions: list[str] = Field(description="A list of 3 short follow-up questions or actions for the user, max 10 chars each.")
|
||||
|
||||
|
||||
async def generate_chat_suggestions(messages, model) -> list[str]:
|
||||
"""
|
||||
根据对话历史生成 3 个推荐追问按钮
|
||||
"""
|
||||
# 只需要最近的几次交互即可判断意图
|
||||
recent_msgs = messages[-4:]
|
||||
|
||||
parser = JsonOutputParser(pydantic_object=SuggestionOutput)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", """
|
||||
你是家具设计系统的交互助手。请根据用户的对话历史,预测用户接下来最可能想做的 3 件事。
|
||||
|
||||
【判断逻辑】
|
||||
1. 如果用户已经确定了【类型、材质、风格】但还没有生成过草图 -> 必须推荐 "生成设计草图"。
|
||||
2. 如果刚生成了草图 -> 推荐 "调整材质"、"查看三维视图"、"下载报价单" 等。
|
||||
3. 如果用户还在犹豫 -> 推荐具体的风格或材质询问。
|
||||
|
||||
请直接输出 JSON 格式,包含 suggestions 字段。按钮文案要简短(中文,不超过8个字)。
|
||||
"""),
|
||||
("user", "对话历史:{history}"),
|
||||
])
|
||||
|
||||
chain = prompt | model | parser
|
||||
|
||||
try:
|
||||
# 将消息对象转为字符串喂给模型
|
||||
history_str = "\n".join([f"{m.type}: {m.content}" for m in recent_msgs])
|
||||
result = await chain.ainvoke({"history": history_str})
|
||||
return result.get("suggestions", [])
|
||||
except Exception as e:
|
||||
print(f"建议生成失败: {e}")
|
||||
return []
|
||||
65
src/server/utils/new_oss_client.py
Normal file
65
src/server/utils/new_oss_client.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import io
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
class CustomRetry(urllib3.Retry):
|
||||
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||
# 调用父类的 increment 方法
|
||||
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||
# 打印重试信息
|
||||
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||
return new_retry
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||
http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
status_forcelist=[500, 502, 503, 504],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 获取图片
|
||||
def oss_get_image(oss_client, bucket, object_name, data_type):
|
||||
# cv2 默认全通道读取
|
||||
image_object = None
|
||||
try:
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f" | 获取图片出现异常 ######: {e}")
|
||||
return image_object
|
||||
|
||||
|
||||
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f" | 上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
url = "aida-users/89/sketch/123-89.png"
|
||||
read_type = "2"
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
img.show()
|
||||
img.save("result.png")
|
||||
Reference in New Issue
Block a user