aida agent (基础版)搭建完成
This commit is contained in:
159
app/service/fashion_agent/graph_node/design_graph/graph.py
Normal file
159
app/service/fashion_agent/graph_node/design_graph/graph.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import logging
|
||||
from typing import Annotated, Required, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool # noqa: E402
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class DesignState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
|
||||
design_request_data: dict = {}
|
||||
request_objects: list[dict] = []
|
||||
results: list[dict] = []
|
||||
|
||||
|
||||
"""节点"""
|
||||
|
||||
|
||||
def run_design_node(state: DesignState) -> dict:
|
||||
"""调用 design_tool 执行设计任务,逐个推送结果"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
|
||||
request_data = state.get("design_request_data")
|
||||
request_objects = request_data.get("objects")
|
||||
|
||||
results = []
|
||||
for item in design_tool.invoke({"objects": request_objects}):
|
||||
logger.info(f"design result: {item}")
|
||||
results.append(item)
|
||||
writer({"event": "tool-output-delta", "tool_name": "design_tool", "type": "design_result", "data": item})
|
||||
|
||||
writer({"event": "tool-finished", "tool_name": "design_tool", "type": "design_result", "data": results})
|
||||
result_text = f"设计完成,共处理 {len(results)} 个对象"
|
||||
return {"results": results, "messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
"""构建 Graph"""
|
||||
|
||||
|
||||
def build_design_graph():
|
||||
"""构建 design graph"""
|
||||
workflow = StateGraph(DesignState)
|
||||
|
||||
workflow.add_node("run_design", run_design_node)
|
||||
|
||||
workflow.add_edge(START, "run_design")
|
||||
workflow.add_edge("run_design", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
design_graph = build_design_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
request_data = {
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [203, 249],
|
||||
"hand_point_right": [229, 343],
|
||||
"waistband_left": [119, 248],
|
||||
"hand_point_left": [97, 343],
|
||||
"shoulder_left": [108, 107],
|
||||
"relation_type": "System",
|
||||
"shoulder_right": [212, 107],
|
||||
"relation_id": 1020356,
|
||||
},
|
||||
"layer_order": False,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": False,
|
||||
"single_overall": "overall",
|
||||
"switch_category": "",
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "209 196 171",
|
||||
"image_id": 84093,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Outwear",
|
||||
},
|
||||
{
|
||||
"color": "63 71 73",
|
||||
"image_id": 100496,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Blouse",
|
||||
},
|
||||
{
|
||||
"color": "111 78 63",
|
||||
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
|
||||
"image_id": 92193,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Trousers",
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"offset": [1, 1],
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Body",
|
||||
},
|
||||
],
|
||||
"objectSign": "65830966",
|
||||
}
|
||||
],
|
||||
"process_id": "4802946666428422",
|
||||
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
|
||||
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults",
|
||||
}
|
||||
result = design_graph.invoke({"design_request_data": request_data})
|
||||
print(result)
|
||||
206
app/service/fashion_agent/graph_node/design_graph/tools.py
Normal file
206
app/service/fashion_agent/graph_node/design_graph/tools.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.design_fast.design_generate import process_item, process_layer
|
||||
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class DesignModel(BaseModel):
|
||||
objects: list[dict] = Field(description="")
|
||||
|
||||
|
||||
@tool(args_schema=DesignModel, description="design tool")
|
||||
def design_tool(objects: list[dict]):
|
||||
"""design tool"""
|
||||
|
||||
result_queue = queue.Queue()
|
||||
|
||||
def process_object(obj):
|
||||
basic = obj["basic"]
|
||||
design_type = basic.get("design_type", "default")
|
||||
items_response = {
|
||||
"layers": [],
|
||||
"objectSign": obj["objectSign"] if "objectSign" in obj.keys() else "",
|
||||
}
|
||||
if basic["single_overall"] == "overall":
|
||||
item_results = []
|
||||
for item in obj["items"]:
|
||||
item_results.append(process_item(item, basic, design_type))
|
||||
layers = []
|
||||
for item in item_results:
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float("inf")))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
|
||||
for lay in layers:
|
||||
items_response["layers"].append(
|
||||
{
|
||||
"image_category": "body" if lay["name"] == "mannequin" else lay["name"],
|
||||
"position": lay["position"],
|
||||
"priority": lay.get("priority", None),
|
||||
"resize_scale": lay["resize_scale"] if "resize_scale" in lay.keys() else None,
|
||||
"image_size": lay["image"] if lay["image"] is None else lay["image"].size,
|
||||
"gradient_string": lay["gradient_string"] if "gradient_string" in lay.keys() else "",
|
||||
"mask_url": lay["mask_url"],
|
||||
"image_url": lay["image_url"] if "image_url" in lay.keys() else None,
|
||||
"pattern_overall_image_url": (
|
||||
lay["pattern_overall_image_url"] if "pattern_overall_image_url" in lay.keys() else None
|
||||
),
|
||||
"pattern_print_image_url": lay["pattern_print_image_url"] if "pattern_print_image_url" in lay.keys() else None,
|
||||
}
|
||||
)
|
||||
items_response["synthesis_url"] = synthesis(layers, new_size, basic)
|
||||
else:
|
||||
item_result = process_item(obj["items"][0], basic, design_type)
|
||||
items_response["layers"].append(
|
||||
{
|
||||
"image_category": f"{item_result['name']}_front",
|
||||
"image_size": item_result["back_image"].size if item_result["back_image"] else None,
|
||||
"position": None,
|
||||
"priority": 0,
|
||||
"image_url": item_result["front_image_url"],
|
||||
"mask_url": item_result["mask_url"],
|
||||
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
|
||||
"pattern_overall_image_url": (
|
||||
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
|
||||
),
|
||||
"pattern_print_image_url": (
|
||||
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
|
||||
),
|
||||
}
|
||||
)
|
||||
items_response["layers"].append(
|
||||
{
|
||||
"image_category": f"{item_result['name']}_back",
|
||||
"image_size": item_result["front_image"].size if item_result["front_image"] else None,
|
||||
"position": None,
|
||||
"priority": 0,
|
||||
"image_url": item_result["back_image_url"],
|
||||
"mask_url": item_result["mask_url"],
|
||||
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
|
||||
"pattern_overall_image_url": (
|
||||
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
|
||||
),
|
||||
"pattern_print_image_url": (
|
||||
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
|
||||
),
|
||||
}
|
||||
)
|
||||
items_response["synthesis_url"] = synthesis_single(item_result["front_image"], item_result["back_image"])
|
||||
logger.info(items_response)
|
||||
result_queue.put(items_response)
|
||||
|
||||
# 启动所有线程
|
||||
threads = []
|
||||
for obj in objects:
|
||||
t = threading.Thread(target=process_object, args=(obj,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# 主线程逐个取出结果 yield
|
||||
finished = 0
|
||||
total = len(objects)
|
||||
while finished < total:
|
||||
result = result_queue.get()
|
||||
yield result
|
||||
finished += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
request_objects = [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [203, 249],
|
||||
"hand_point_right": [229, 343],
|
||||
"waistband_left": [119, 248],
|
||||
"hand_point_left": [97, 343],
|
||||
"shoulder_left": [108, 107],
|
||||
"relation_type": "System",
|
||||
"shoulder_right": [212, 107],
|
||||
"relation_id": 1020356,
|
||||
},
|
||||
"layer_order": False,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": False,
|
||||
"single_overall": "overall",
|
||||
"switch_category": "",
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "209 196 171",
|
||||
"image_id": 84093,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Outwear",
|
||||
},
|
||||
{
|
||||
"color": "63 71 73",
|
||||
"image_id": 100496,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Blouse",
|
||||
},
|
||||
{
|
||||
"color": "111 78 63",
|
||||
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
|
||||
"image_id": 92193,
|
||||
"offset": [1, 1],
|
||||
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
|
||||
"print": {
|
||||
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
|
||||
"overall": {
|
||||
"location": [[0.0, 0.0]],
|
||||
"print_angle_list": [0.0, 0.0],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": [[0.0, 0.0]],
|
||||
},
|
||||
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
|
||||
},
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Trousers",
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"offset": [1, 1],
|
||||
"resize_scale": [1.0, 1.0],
|
||||
"type": "Body",
|
||||
},
|
||||
],
|
||||
"objectSign": "65830966",
|
||||
}
|
||||
]
|
||||
|
||||
result = design_tool.invoke({"objects": request_objects})
|
||||
for item in result:
|
||||
print(item)
|
||||
138
app/service/fashion_agent/graph_node/explorer_graph/graph.py
Normal file
138
app/service/fashion_agent/graph_node/explorer_graph/graph.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Required, TypedDict
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
|
||||
from langchain_qwq import ChatQwen
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
||||
from app.service.fashion_agent.init_llm import build_llm
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class ExplorerState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
search_query: str
|
||||
image_results: list[dict] # 每项包含 image_url 和 minio_path
|
||||
provider: str = "unsplash" # 图片源: "pexels" 或 "unsplash"
|
||||
|
||||
|
||||
"""节点"""
|
||||
|
||||
|
||||
def extract_input_node(state: ExplorerState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
"""搜索关键词"""
|
||||
|
||||
query: str = Field(description="用于搜索灵感图片的英文关键词,简洁有力")
|
||||
|
||||
|
||||
# TODO 要考虑搜索图片失败或者图片不存在的情况, 搜索不到 需要调整搜索词或者拆分搜索词,最终失败的话调用mood board生成工具生成, 保证绝对有图片
|
||||
async def generate_query_node(state: ExplorerState) -> dict:
|
||||
"""使用 LLM 分析用户输入,生成搜索关键词"""
|
||||
input_text = state["input_text"]
|
||||
logger.info(f"[Explorer] 用户输入: {input_text}")
|
||||
llm = build_llm()
|
||||
|
||||
structured_llm = llm.with_structured_output(SearchQuery)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="""你是一个专业的服装设计师助手。
|
||||
根据用户输入,生成一个英文搜索关键词,用于在图片库中搜索服装设计灵感图片(moodboard)。
|
||||
|
||||
要求:
|
||||
1. 使用英文,简洁有力
|
||||
2. 适合搜索高质量的设计灵感图片
|
||||
|
||||
例如:
|
||||
用户输入:"夏季连衣裙,清新风格"
|
||||
输出:summer dress fresh style"""),
|
||||
HumanMessage(content=input_text),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
logger.info(f"[Explorer] LLM 生成的搜索关键词: {result.query}")
|
||||
return {"search_query": result.query}
|
||||
|
||||
|
||||
async def search_and_upload_node(state: ExplorerState, config: RunnableConfig) -> dict:
|
||||
"""使用搜索关键词获取图片并上传到 minio"""
|
||||
query = state.get("search_query", "")
|
||||
user_id = state.get("user_id", "agent")
|
||||
provider = state.get("provider", "unsplash")
|
||||
|
||||
try:
|
||||
results = await explor_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 搜索失败 '{query}': {e}")
|
||||
results = []
|
||||
|
||||
return {"image_results": results}
|
||||
|
||||
|
||||
def summarize_node(state: ExplorerState) -> dict:
|
||||
"""汇总结果"""
|
||||
input_text = state.get("input_text", "")
|
||||
query = state.get("search_query", "")
|
||||
results = state.get("image_results", [])
|
||||
|
||||
result_text = f"【灵感探索 Moodboard】\n\n"
|
||||
result_text += f"基于您的需求:「{input_text}」\n"
|
||||
result_text += f"搜索关键词:{query}\n\n"
|
||||
result_text += f"已为您找到 {len(results)} 张灵感图片:\n"
|
||||
|
||||
for i, item in enumerate(results, 1):
|
||||
result_text += f" {i}. 原图: {item.get('image_url', '')}\n"
|
||||
result_text += f" Minio: {item.get('minio_path', '')}\n"
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
"""构建图"""
|
||||
|
||||
|
||||
def build_explorer_graph():
|
||||
"""构建灵感探索图"""
|
||||
workflow = StateGraph(ExplorerState)
|
||||
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("generate_query", generate_query_node)
|
||||
workflow.add_node("search_and_upload", search_and_upload_node)
|
||||
workflow.add_node("summarize", summarize_node)
|
||||
|
||||
workflow.add_edge(START, "extract_input")
|
||||
workflow.add_edge("extract_input", "generate_query")
|
||||
workflow.add_edge("generate_query", "search_and_upload")
|
||||
workflow.add_edge("search_and_upload", "summarize")
|
||||
workflow.add_edge("summarize", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test():
|
||||
graph = build_explorer_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="夏季连衣裙,清新自然风格")],
|
||||
"provider": "unsplash",
|
||||
}
|
||||
)
|
||||
print(result["messages"][-1].content)
|
||||
|
||||
asyncio.run(test())
|
||||
54
app/service/fashion_agent/graph_node/explorer_graph/tools.py
Normal file
54
app/service/fashion_agent/graph_node/explorer_graph/tools.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from app.service.fashion_agent.graph_node.node_tools.pexels_search import search_photos
|
||||
from app.service.fashion_agent.graph_node.node_tools.unsplash_search import get_random_photos
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
"""Input schema for Pexels Search Tool."""
|
||||
|
||||
query: str = Field(description="Search query for Pexels, e.g., 'minimalist fashion moodboard', 'summer dress inspiration'")
|
||||
per_page: int = Field(description="Number of images to return (1-80)", default=4)
|
||||
user_id: str = Field(description="User ID for image storage", default="agent")
|
||||
method: str = Field(description="", default="unsplash")
|
||||
|
||||
|
||||
@tool(args_schema=SearchInput)
|
||||
async def explor_tool(
|
||||
query: str, per_page: int = 4, user_id: str = "agent", method: str = "unsplash", config: RunnableConfig = None
|
||||
) -> list[dict]:
|
||||
"""Search for fashion inspiration images on Unsplash and upload to minio. Returns a list of dicts with image_url and minio_path."""
|
||||
if config:
|
||||
# 方式 1:从 configurable 获取
|
||||
user_id = config.get("configurable", {}).get("user_id", "agent")
|
||||
|
||||
if method == "unsplash":
|
||||
return await get_random_photos(query, count=per_page, user_id=user_id)
|
||||
elif method == "pexels":
|
||||
return await search_photos(query, per_page=per_page, user_id=user_id)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
urls = await get_random_photos("summer dress fresh natural style", count=4)
|
||||
print(f"Uploaded {len(urls)} images to minio:")
|
||||
for url in urls:
|
||||
print(f" {url}")
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
urls = await search_photos("minimalist fashion moodboard", per_page=4)
|
||||
print(f"Uploaded {len(urls)} images to minio:")
|
||||
for url in urls:
|
||||
print(f" {url}")
|
||||
|
||||
asyncio.run(test())
|
||||
152
app/service/fashion_agent/graph_node/logo_graph/graph.py
Normal file
152
app/service/fashion_agent/graph_node/logo_graph/graph.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import asyncio
|
||||
from typing import Annotated, Required, TypedDict
|
||||
from langchain_core.messages import AIMessage, AnyMessage
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
|
||||
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
|
||||
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class LogoState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
user_id: str = "agent"
|
||||
role: str = ""
|
||||
gender: str = ""
|
||||
style: str = ""
|
||||
need_prompt_generation: bool = True # 是否需要使用 prompt 生成节点
|
||||
|
||||
logo_num: int = 1
|
||||
|
||||
logo_prompts: list[str] = []
|
||||
logo_img_urls: list[str] = []
|
||||
|
||||
|
||||
"""生成 Logo 的提示词节点"""
|
||||
|
||||
|
||||
# 定义输出结构
|
||||
class LogoPrompt(BaseModel):
|
||||
"""生成的 Logo 图像提示词"""
|
||||
|
||||
prompts: list[str] = Field(description="用于生成 Logo 的详细提示词")
|
||||
|
||||
|
||||
def extract_input_node(state: LogoState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
def generate_logo_prompt_node(state: LogoState) -> dict:
|
||||
"""根据用户输入生成 Logo 的图像生成提示词"""
|
||||
structured_llm = qwen_plus_llm.with_structured_output(LogoPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="""从用户输入中提取核心主题词,只输出一个简单的英文单词。
|
||||
例如:
|
||||
- "我想要一个猫咪图案" -> "cat"
|
||||
- "设计一个花朵" -> "flower"
|
||||
- "可爱的狗" -> "dog"
|
||||
只输出单词,不要其他内容。"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
prompts = result.prompts
|
||||
|
||||
return {
|
||||
"logo_prompts": prompts,
|
||||
}
|
||||
|
||||
|
||||
"""生成 Logo 图案节点"""
|
||||
|
||||
|
||||
async def generate_logo_img_node(state: LogoState) -> dict:
|
||||
"""根据生成的提示词,生成 Logo 图案"""
|
||||
# 如果 logo_prompts 为空,使用 input_text 作为 prompt
|
||||
prompts = state["logo_prompts"] if state["logo_prompts"] else [state["input_text"]]
|
||||
|
||||
logo_img_urls = []
|
||||
for i in range(state.get("logo_num", 1)):
|
||||
image_url = await generate_logo_tool.ainvoke({"prompt": prompts[i], "user_id": state.get("user_id", "agent")})
|
||||
logo_img_urls.append(image_url)
|
||||
|
||||
result_text = f"Logo 生成完成,共生成 {len(logo_img_urls)} 张图片:\n"
|
||||
return {"logo_img_urls": logo_img_urls, "messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
"""条件分支 判断是否需要生成 prompt"""
|
||||
|
||||
|
||||
def should_generate_prompt(state: LogoState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
if state.get("need_prompt_generation", True):
|
||||
return "gen_prompt"
|
||||
else:
|
||||
return "gen_logo"
|
||||
|
||||
|
||||
def build_logo_graph():
|
||||
"""构建独立的画像收集 Graph"""
|
||||
|
||||
workflow = StateGraph(LogoState)
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("gen_prompt", generate_logo_prompt_node)
|
||||
workflow.add_node("gen_logo", generate_logo_img_node)
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "extract_input")
|
||||
workflow.add_conditional_edges(
|
||||
"extract_input",
|
||||
should_generate_prompt,
|
||||
{
|
||||
"gen_prompt": "gen_prompt",
|
||||
"gen_logo": "gen_logo",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("gen_prompt", "gen_logo")
|
||||
workflow.add_edge("gen_logo", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def main(test_input, user_id="agent", need_prompt_generation=True):
|
||||
graph = build_logo_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"input_text": test_input,
|
||||
"user_id": user_id,
|
||||
"logo_prompts": [] if need_prompt_generation else [test_input],
|
||||
"need_prompt_generation": need_prompt_generation,
|
||||
"role": "",
|
||||
"gender": "",
|
||||
"style": "",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 需要 prompt 生成(默认)- 简单关键词输入
|
||||
test_input = "我想要一个金毛图案"
|
||||
result = asyncio.run(main(test_input, need_prompt_generation=True))
|
||||
print("=== 需要 prompt 生成 ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# 测试示例 2: 直接使用用户提供的 prompt
|
||||
user_prompt = "golden retriever"
|
||||
result = asyncio.run(main(user_prompt, need_prompt_generation=False))
|
||||
print("\n=== 直接使用 prompt ===")
|
||||
print(f"Result: {result}")
|
||||
@@ -0,0 +1,27 @@
|
||||
import httpx
|
||||
|
||||
|
||||
async def generate_image(
|
||||
bucket_name="fida-public-bucket",
|
||||
object_name=f"furniture/sketches/123456.png",
|
||||
prompt="Generate a modern minimalist dining chair made of light "
|
||||
"oak wood and white leather, with slim metal legs, photographed "
|
||||
"in a bright Scandinavian living room with natural sunlight, high detail, "
|
||||
"8k resolution.",
|
||||
):
|
||||
request_data = {
|
||||
"input_image_paths": [],
|
||||
"prompt": prompt,
|
||||
"bucket_name": bucket_name,
|
||||
"object_name": object_name,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = await client.post(
|
||||
f"http://20.1.1.33:14202/predict",
|
||||
json=request_data,
|
||||
)
|
||||
result = resp.json()
|
||||
image_url = result.get("output_path", None)
|
||||
return image_url
|
||||
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.tools import tool
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from uuid_utils import uuid7
|
||||
from app.core.config import settings
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
# 模型配置
|
||||
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
|
||||
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
|
||||
|
||||
# 线程池用于执行同步推理
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _generate_logo_sync(prompt: str) -> Image.Image:
|
||||
"""同步生成 Logo 的内部函数"""
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
|
||||
# 准备输入
|
||||
prompts = [prompt]
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
|
||||
negative_prompts = "bad, ugly"
|
||||
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
|
||||
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
||||
|
||||
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
|
||||
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
|
||||
input_seed.set_data_from_numpy(seed_input)
|
||||
|
||||
inputs = [input_text, input_text_neg, input_seed]
|
||||
|
||||
# 同步推理
|
||||
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
|
||||
async def generate_logo(prompt: str) -> Image.Image:
|
||||
"""异步生成透明背景的 Logo 图片"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
|
||||
|
||||
|
||||
class GenerateLogoToolInput(BaseModel):
|
||||
"""Input schema for the Generate Logo Tool."""
|
||||
|
||||
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
|
||||
user_id: str = Field(description="User ID for image storage", default="agent")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateLogoToolInput)
|
||||
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
|
||||
"""Generate a transparent background logo image based on a simple keyword."""
|
||||
|
||||
image = await generate_logo(prompt=prompt)
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.png"
|
||||
loop = asyncio.get_event_loop()
|
||||
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
|
||||
return image_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
|
||||
print(f"Logo saved to: {result}")
|
||||
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from PIL import Image
|
||||
from uuid_utils import uuid7
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger()
|
||||
PEXELS_API_KEY = os.environ.get("PEXELS_API_KEY", "")
|
||||
PEXELS_BASE_URL = os.environ.get("PEXELS_BASE_URL", "")
|
||||
|
||||
# 线程池用于执行同步上传
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -> list[dict]:
|
||||
"""从 Pexels 搜索图片并上传到 minio
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
per_page: 返回图片数量 (1-80)
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
图片信息列表,每项包含 image_url 和 minio_path
|
||||
"""
|
||||
# 搜索图片
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{PEXELS_BASE_URL}/search",
|
||||
headers={"Authorization": PEXELS_API_KEY},
|
||||
params={
|
||||
"query": query,
|
||||
"per_page": per_page,
|
||||
"orientation": "square",
|
||||
"size": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Pexels API error: {response.status_code} - {response.text}")
|
||||
|
||||
data = response.json()
|
||||
photos = data.get("photos", [])
|
||||
|
||||
# 下载并上传到 minio
|
||||
results = []
|
||||
for photo in photos:
|
||||
try:
|
||||
# 下载图片(使用 large 尺寸)
|
||||
image_url = photo["src"]["original"]
|
||||
async with httpx.AsyncClient(timeout=60) as dl_client:
|
||||
dl_response = await dl_client.get(image_url)
|
||||
image = Image.open(io.BytesIO(dl_response.content))
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import httpx
|
||||
|
||||
from PIL import Image
|
||||
from uuid_utils import uuid7
|
||||
from dotenv import load_dotenv
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Unsplash API 配置
|
||||
UNSPLASH_ACCESS_KEY = os.environ.get("UNSPLASH_ACCESS_KEY", "")
|
||||
UNSPLASH_BASE_URL = os.environ.get("UNSPLASH_BASE_URL", "")
|
||||
logger = logging.getLogger()
|
||||
# 线程池用于执行同步上传
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
async def get_random_photos(query: str, count: int = 4, user_id: str = "agent") -> list[dict]:
|
||||
"""从 Unsplash 获取随机图片并上传到 minio
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
count: 返回图片数量 (1-30)
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
图片信息列表,每项包含 image_url 和 minio_path
|
||||
"""
|
||||
# 获取随机图片
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{UNSPLASH_BASE_URL}/search/photos",
|
||||
headers={"Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"},
|
||||
params={
|
||||
"query": query,
|
||||
"per_page": count,
|
||||
"page": 1,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Unsplash API error: {response.status_code} - {response.text}")
|
||||
|
||||
data = response.json()
|
||||
# /search/photos 返回 {"results": [...], "total": ...}
|
||||
photos = data.get("results", [])
|
||||
|
||||
# 下载并上传到 minio
|
||||
results = []
|
||||
for photo in photos:
|
||||
try:
|
||||
# 下载图片
|
||||
image_url = photo["urls"]["raw"]
|
||||
async with httpx.AsyncClient(timeout=60) as dl_client:
|
||||
dl_response = await dl_client.get(image_url)
|
||||
image = Image.open(io.BytesIO(dl_response.content))
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
"""测试 Unsplash 搜索"""
|
||||
query = "summer dress fresh natural style"
|
||||
print(f"搜索关键词: {query}")
|
||||
print("=" * 50)
|
||||
|
||||
results = await get_random_photos(query, count=4, user_id="test")
|
||||
print(f"\n找到 {len(results)} 张图片:")
|
||||
for i, item in enumerate(results, 1):
|
||||
print(f" {i}. 原图: {item.get('image_url', '')}")
|
||||
print(f" Minio: {item.get('minio_path', '')}")
|
||||
|
||||
asyncio.run(test())
|
||||
158
app/service/fashion_agent/graph_node/print_graph/graph.py
Normal file
158
app/service/fashion_agent/graph_node/print_graph/graph.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Required, TypedDict
|
||||
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool, test
|
||||
|
||||
logger = logging.getLogger()
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class PrintState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
role: str = ""
|
||||
gender: str = ""
|
||||
style: str = ""
|
||||
print_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
|
||||
|
||||
print_num: int = 1
|
||||
|
||||
print_prompts: list[str] = []
|
||||
print_img_urls: list[str] = []
|
||||
|
||||
|
||||
"""生成印花图案的提示词节点"""
|
||||
|
||||
|
||||
# 定义输出结构
|
||||
class PrintPrompt(BaseModel):
|
||||
"""生成的印花图像提示词"""
|
||||
|
||||
prompts: list[str] = Field(description="用于生成印花图案的详细提示词")
|
||||
|
||||
|
||||
def extract_input_node(state: PrintState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
def generate_print_prompt_node(state: PrintState) -> dict:
|
||||
"""根据用户输入生成印花图案的图像生成提示词"""
|
||||
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=f"""你是一个专业的印花图案设计师。
|
||||
请根据用户输入,生成用于AI图像生成的印花图案提示词。
|
||||
|
||||
要求:
|
||||
1. 提示词应该详细描述印花图案的样式、元素、颜色、布局
|
||||
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
|
||||
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
4. 提示词数量为 {state.get("print_num", 1)}
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
prompts = result.prompts
|
||||
logger.info(f"[Print Graph] Generated print prompts: {prompts}")
|
||||
return {
|
||||
"print_prompts": prompts,
|
||||
}
|
||||
|
||||
|
||||
"""生成印花图案节点"""
|
||||
|
||||
|
||||
async def generate_print_img_node(state: PrintState) -> dict:
|
||||
"""根据生成的提示词,生成印花图案"""
|
||||
# 如果 print_prompts 为空,使用 input_text 作为 prompt
|
||||
if state.get("print_need_prompt_generation", False):
|
||||
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
|
||||
else:
|
||||
input_text = state.get("input_text", "")
|
||||
prompts = [input_text]
|
||||
|
||||
print_img_urls = []
|
||||
for prompt in prompts:
|
||||
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
|
||||
print_img_urls.append(image_url)
|
||||
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
|
||||
|
||||
return {"print_img_urls": print_img_urls}
|
||||
|
||||
|
||||
"""条件分支 判断是否需要生成 prompt"""
|
||||
|
||||
|
||||
def should_generate_prompt(state: PrintState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
|
||||
logger.info(
|
||||
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
|
||||
)
|
||||
if state.get("print_need_prompt_generation", True):
|
||||
return "gen_prompt"
|
||||
else:
|
||||
return "gen_print"
|
||||
|
||||
|
||||
def build_print_graph():
|
||||
|
||||
workflow = StateGraph(PrintState)
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("gen_prompt", generate_print_prompt_node)
|
||||
workflow.add_node("gen_print", generate_print_img_node)
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "extract_input")
|
||||
workflow.add_conditional_edges(
|
||||
"extract_input",
|
||||
should_generate_prompt,
|
||||
{
|
||||
"gen_prompt": "gen_prompt",
|
||||
"gen_print": "gen_print",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("gen_prompt", "gen_print")
|
||||
workflow.add_edge("gen_print", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
|
||||
async def main(test_input, print_need_prompt_generation=True):
|
||||
graph = build_print_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"input_text": test_input,
|
||||
"print_prompts": [] if print_need_prompt_generation else [test_input],
|
||||
"print_need_prompt_generation": print_need_prompt_generation,
|
||||
"role": "",
|
||||
"gender": "",
|
||||
"style": "",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 需要 prompt 生成(默认)
|
||||
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
|
||||
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
|
||||
print("=== 需要 prompt 生成 ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# 测试示例 2: 直接使用用户提供的 prompt
|
||||
user_prompt = "Elegant floral print pattern, pink and white colors, suitable for dress fabric, seamless tileable design"
|
||||
result = asyncio.run(main(user_prompt, print_need_prompt_generation=False))
|
||||
print("\n=== 直接使用 prompt ===")
|
||||
print(f"Result: {result}")
|
||||
39
app/service/fashion_agent/graph_node/print_graph/tools.py
Normal file
39
app/service/fashion_agent/graph_node/print_graph/tools.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import asyncio
|
||||
|
||||
from langchain.tools import tool
|
||||
from langsmith import uuid7
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
|
||||
|
||||
|
||||
class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_print_tool(prompt: str) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "aida-users"
|
||||
object_name = f"agent_generate_print/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
return image_url
|
||||
|
||||
|
||||
@tool
|
||||
async def test(text: str):
|
||||
"""测试工具函数,返回固定字符串"""
|
||||
return text
|
||||
|
||||
|
||||
async def run_test():
|
||||
result = await generate_print_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(run_test())
|
||||
print(result)
|
||||
178
app/service/fashion_agent/graph_node/sketch_graph/graph.py
Normal file
178
app/service/fashion_agent/graph_node/sketch_graph/graph.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Required, TypedDict
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class SketchState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
role: str = ""
|
||||
gender: str = ""
|
||||
style: str = ""
|
||||
sketch_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
|
||||
|
||||
sketch_num: int = 1
|
||||
|
||||
sketch_prompts: list[str] = []
|
||||
sketch_img_urls: list[str] = []
|
||||
|
||||
|
||||
"""生成服装草图的提示词节点"""
|
||||
|
||||
|
||||
# 定义输出结构
|
||||
class SketchPrompt(BaseModel):
|
||||
"""生成的印花图像提示词"""
|
||||
|
||||
prompts: list[str] = Field(description="用于生成服装草图的详细提示词")
|
||||
|
||||
|
||||
def extract_input_node(state: SketchState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
def generate_sketch_prompt_node(state: SketchState) -> dict:
|
||||
"""根据用户输入生成服装草图的图像生成提示词"""
|
||||
structured_llm = qwen_plus_llm.with_structured_output(SketchPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=f"""你是一个专业的服装设计师。
|
||||
请根据用户输入,生成用于AI图像生成的服装草图提示词。
|
||||
|
||||
要求:
|
||||
1. 提示词必须包含:clean black and white line drawing only, pure white background, centered composition
|
||||
2. 提示词应该详细描述服装的廓形、结构、细节
|
||||
3. 提示词应该适合用于 Stable Diffusion 图像生成模型
|
||||
4. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
5. 草图风格必须是黑白线稿,不要添加颜色
|
||||
6. 提示词数量为 {state.get("sketch_num", 1)}
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
prompts = result.prompts
|
||||
|
||||
return {
|
||||
"sketch_prompts": prompts,
|
||||
}
|
||||
|
||||
|
||||
"""生成服装草图节点"""
|
||||
|
||||
|
||||
async def generate_sketch_img_node(state: SketchState) -> dict:
|
||||
"""根据生成的提示词,生成服装草图"""
|
||||
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
|
||||
# if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
|
||||
|
||||
# input_text = state.get("input_text", "")
|
||||
# prompts = [build_sketch_template_prompt(input_text)]
|
||||
# else:
|
||||
# prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
|
||||
|
||||
# sketch_img_urls = []
|
||||
# for prompt in prompts:
|
||||
# image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
|
||||
# sketch_img_urls.append(image_url)
|
||||
|
||||
# result_text = f"服装草图生成完成,共生成 {len(sketch_img_urls)} 张图片:\n" + "\n".join(sketch_img_urls)
|
||||
# return {"sketch_img_urls": sketch_img_urls, "messages": [AIMessage(content=result_text)]}
|
||||
return {"messages": [AIMessage(content="hello")]}
|
||||
|
||||
|
||||
"""条件分支 判断是否需要生成 prompt"""
|
||||
|
||||
|
||||
def should_generate_prompt(state: SketchState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
if state.get("sketch_need_prompt_generation", False):
|
||||
return "gen_prompt"
|
||||
else:
|
||||
return "gen_sketch"
|
||||
|
||||
|
||||
def build_sketch_graph():
|
||||
workflow = StateGraph(SketchState)
|
||||
workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
workflow.add_edge(START, "gen_sketch")
|
||||
workflow.add_edge("gen_sketch", END)
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
# workflow = StateGraph(SketchState)
|
||||
# workflow.add_node("extract_input", extract_input_node)
|
||||
# workflow.add_node("gen_prompt", generate_sketch_prompt_node)
|
||||
# workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
|
||||
# # 添加边
|
||||
# workflow.add_edge(START, "extract_input")
|
||||
# workflow.add_conditional_edges(
|
||||
# "extract_input",
|
||||
# should_generate_prompt,
|
||||
# {
|
||||
# "gen_prompt": "gen_prompt",
|
||||
# "gen_sketch": "gen_sketch",
|
||||
# },
|
||||
# )
|
||||
# workflow.add_edge("gen_prompt", "gen_sketch")
|
||||
# workflow.add_edge("gen_sketch", END)
|
||||
|
||||
# graph = workflow.compile()
|
||||
# return graph
|
||||
|
||||
|
||||
def build_sketch_template_prompt(input_text: str) -> str:
|
||||
"""构建 sketch prompt 模板"""
|
||||
return f"{input_text}, clean black and white line drawing only, pure white background, centered composition, fashion sketch style"
|
||||
|
||||
|
||||
async def main(test_input, sketch_need_prompt_generation=False):
|
||||
graph = build_sketch_graph()
|
||||
|
||||
# 如果不需要 LLM 生成 prompt,使用模板
|
||||
if not sketch_need_prompt_generation:
|
||||
sketch_prompts = [build_sketch_template_prompt(test_input)]
|
||||
else:
|
||||
sketch_prompts = []
|
||||
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"input_text": test_input,
|
||||
"sketch_prompts": sketch_prompts,
|
||||
"sketch_need_prompt_generation": sketch_need_prompt_generation,
|
||||
"role": "",
|
||||
"gender": "",
|
||||
"style": "",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 直接使用模板 prompt(默认)
|
||||
test_input = "dress"
|
||||
result = asyncio.run(main(test_input, sketch_need_prompt_generation=False))
|
||||
print("=== 使用模板 prompt ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# # 测试示例 2: 使用 LLM 生成 prompt
|
||||
# test_input = "设计一条优雅的A字廓形连衣裙,V领设计,收腰,裙摆到膝盖,适合日常穿着"
|
||||
# result = asyncio.run(main(test_input, sketch_need_prompt_generation=True))
|
||||
# print("\n=== 使用 LLM 生成 prompt ===")
|
||||
# print(f"Result: {result}")
|
||||
33
app/service/fashion_agent/graph_node/sketch_graph/tools.py
Normal file
33
app/service/fashion_agent/graph_node/sketch_graph/tools.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
|
||||
from langchain.tools import tool
|
||||
from langsmith import uuid7
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
|
||||
|
||||
|
||||
class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_sketch_tool(prompt: str) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "fida-public-bucket"
|
||||
object_name = f"test/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
return image_url
|
||||
|
||||
|
||||
async def run_test():
|
||||
result = await generate_sketch_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(run_test())
|
||||
print(result)
|
||||
@@ -0,0 +1,69 @@
|
||||
import asyncio
|
||||
from typing import Annotated, Required, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class TrendingState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
|
||||
|
||||
"""节点"""
|
||||
|
||||
|
||||
def extract_input_node(state: TrendingState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
async def trending_node(state: TrendingState) -> dict:
|
||||
"""趋势分析节点(占位)"""
|
||||
input_text = state.get("input_text", "")
|
||||
|
||||
# TODO: 接入真实的趋势分析逻辑
|
||||
result_text = (
|
||||
f"【趋势分析】\n基于您的输入「{input_text}」,以下是当前服装设计趋势:\n\n"
|
||||
"1. 极简主义持续流行,黑白灰为主色调\n"
|
||||
"2. 可持续时尚成为主流,环保面料受青睐\n"
|
||||
"3. 复古风格回潮,90年代元素重新流行\n"
|
||||
"4. 功能性与美学结合,运动休闲风持续升温"
|
||||
)
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
"""构建图"""
|
||||
|
||||
|
||||
def build_trending_graph():
|
||||
"""构建趋势分析图"""
|
||||
workflow = StateGraph(TrendingState)
|
||||
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("trending", trending_node)
|
||||
|
||||
workflow.add_edge(START, "extract_input")
|
||||
workflow.add_edge("extract_input", "trending")
|
||||
workflow.add_edge("trending", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test():
|
||||
graph = build_trending_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="女装连衣裙")],
|
||||
}
|
||||
)
|
||||
print(result["messages"][-1].content)
|
||||
|
||||
asyncio.run(test())
|
||||
Reference in New Issue
Block a user