更新图形生成工具,优化返回格式并添加新功能
This commit is contained in:
@@ -23,12 +23,14 @@ async def explor_tool(
|
|||||||
# 方式 1:从 configurable 获取
|
# 方式 1:从 configurable 获取
|
||||||
user_id = config.get("configurable", {}).get("user_id", "agent")
|
user_id = config.get("configurable", {}).get("user_id", "agent")
|
||||||
|
|
||||||
|
results = []
|
||||||
if method == "unsplash":
|
if method == "unsplash":
|
||||||
return await get_random_photos(query, count=per_page, user_id=user_id)
|
results = await get_random_photos(query, count=per_page, user_id=user_id)
|
||||||
elif method == "pexels":
|
elif method == "pexels":
|
||||||
return await search_photos(query, per_page=per_page, user_id=user_id)
|
results = await search_photos(query, per_page=per_page, user_id=user_id)
|
||||||
else:
|
else:
|
||||||
pass
|
results = []
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import END, START, StateGraph
|
||||||
from langgraph.graph.message import add_messages
|
from langgraph.graph.message import add_messages
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool
|
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
||||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||||
|
|
||||||
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
|
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
|
|||||||
file_name = f"{uuid7()}.png"
|
file_name = f"{uuid7()}.png"
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
|
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
|
||||||
return image_url
|
return [image_url]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -20,7 +20,7 @@ async def generate_print_tool(prompt: str) -> str:
|
|||||||
bucket_name = "aida-users"
|
bucket_name = "aida-users"
|
||||||
object_name = f"agent_generate_print/{uuid7()}.png"
|
object_name = f"agent_generate_print/{uuid7()}.png"
|
||||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||||
return image_url
|
return [image_url]
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
@@ -79,21 +79,19 @@ def generate_sketch_prompt_node(state: SketchState) -> dict:
|
|||||||
async def generate_sketch_img_node(state: SketchState) -> dict:
|
async def generate_sketch_img_node(state: SketchState) -> dict:
|
||||||
"""根据生成的提示词,生成服装草图"""
|
"""根据生成的提示词,生成服装草图"""
|
||||||
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
|
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
|
||||||
# if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
|
if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
|
||||||
|
|
||||||
# input_text = state.get("input_text", "")
|
input_text = state.get("input_text", "")
|
||||||
# prompts = [build_sketch_template_prompt(input_text)]
|
prompts = [build_sketch_template_prompt(input_text)]
|
||||||
# else:
|
else:
|
||||||
# prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
|
prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
|
||||||
|
|
||||||
# sketch_img_urls = []
|
sketch_img_urls = []
|
||||||
# for prompt in prompts:
|
for prompt in prompts:
|
||||||
# image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
|
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
|
||||||
# sketch_img_urls.append(image_url)
|
sketch_img_urls.append(image_url)
|
||||||
|
|
||||||
# result_text = f"服装草图生成完成,共生成 {len(sketch_img_urls)} 张图片:\n" + "\n".join(sketch_img_urls)
|
return {"sketch_img_urls": sketch_img_urls}
|
||||||
# return {"sketch_img_urls": sketch_img_urls, "messages": [AIMessage(content=result_text)]}
|
|
||||||
return {"messages": [AIMessage(content="hello")]}
|
|
||||||
|
|
||||||
|
|
||||||
"""条件分支 判断是否需要生成 prompt"""
|
"""条件分支 判断是否需要生成 prompt"""
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ async def generate_sketch_tool(prompt: str) -> str:
|
|||||||
bucket_name = "fida-public-bucket"
|
bucket_name = "fida-public-bucket"
|
||||||
object_name = f"test/{uuid7()}.png"
|
object_name = f"test/{uuid7()}.png"
|
||||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||||
return image_url
|
return [image_url]
|
||||||
|
|
||||||
|
|
||||||
async def run_test():
|
async def run_test():
|
||||||
|
|||||||
@@ -29,3 +29,5 @@ qwen_plus_llm = ChatQwen(
|
|||||||
top_p=0.8,
|
top_p=0.8,
|
||||||
api_key=QWEN_API_KEY_INTL,
|
api_key=QWEN_API_KEY_INTL,
|
||||||
)
|
)
|
||||||
|
# response = qwen_plus_llm.invoke("你好")
|
||||||
|
# print(response)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.service.fashion_agent.graph_node.design_graph.graph import build_design
|
|||||||
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
|
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
|
||||||
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
||||||
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
|
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
|
||||||
from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool
|
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
||||||
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
|
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
|
||||||
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
|
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
|
||||||
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
|
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
|
|||||||
from app.service.fashion_agent.main_agent import build_main_graph
|
from app.service.fashion_agent.main_agent import build_main_graph
|
||||||
from langgraph.prebuilt import ToolCallTransformer
|
from langgraph.prebuilt import ToolCallTransformer
|
||||||
from typing import AsyncGenerator, TypedDict
|
from typing import AsyncGenerator, TypedDict
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage, ToolMessage
|
||||||
from app.schemas.fashion_agent import FashionAgentRequest
|
from app.schemas.fashion_agent import FashionAgentRequest
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -71,6 +71,8 @@ class FashionAgentService:
|
|||||||
|
|
||||||
data["tool_name"] = tool_name
|
data["tool_name"] = tool_name
|
||||||
|
|
||||||
|
if isinstance(data["output"], ToolMessage):
|
||||||
|
data["output"] = json.loads(data["output"].content)
|
||||||
response_event = {"event_type": "tool", "data": data}
|
response_event = {"event_type": "tool", "data": data}
|
||||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,12 @@ dependencies = [
|
|||||||
"image>=1.5.33",
|
"image>=1.5.33",
|
||||||
"langchain>=1.2.0",
|
"langchain>=1.2.0",
|
||||||
"langchain-community>=0.4.1",
|
"langchain-community>=0.4.1",
|
||||||
|
"langchain-openai>=1.2.2",
|
||||||
"langchain-qwq>=0.3.5",
|
"langchain-qwq>=0.3.5",
|
||||||
"langgraph>=1.0.5",
|
"langgraph>=1.0.5",
|
||||||
"langgraph-api>=0.4.28",
|
"langgraph-api>=0.4.28",
|
||||||
"langgraph-cli[inmem,redis]<=0.4.26",
|
"langgraph-cli[inmem,redis]<=0.4.26",
|
||||||
|
"langsmith>=0.8.11",
|
||||||
"load>=1.0.14",
|
"load>=1.0.14",
|
||||||
"load-dotenv>=0.1.0",
|
"load-dotenv>=0.1.0",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
|
|||||||
4
uv.lock
generated
4
uv.lock
generated
@@ -3739,10 +3739,12 @@ dependencies = [
|
|||||||
{ name = "image" },
|
{ name = "image" },
|
||||||
{ name = "langchain" },
|
{ name = "langchain" },
|
||||||
{ name = "langchain-community" },
|
{ name = "langchain-community" },
|
||||||
|
{ name = "langchain-openai" },
|
||||||
{ name = "langchain-qwq" },
|
{ name = "langchain-qwq" },
|
||||||
{ name = "langgraph" },
|
{ name = "langgraph" },
|
||||||
{ name = "langgraph-api" },
|
{ name = "langgraph-api" },
|
||||||
{ name = "langgraph-cli", extra = ["inmem"] },
|
{ name = "langgraph-cli", extra = ["inmem"] },
|
||||||
|
{ name = "langsmith" },
|
||||||
{ name = "load" },
|
{ name = "load" },
|
||||||
{ name = "load-dotenv" },
|
{ name = "load-dotenv" },
|
||||||
{ name = "loguru" },
|
{ name = "loguru" },
|
||||||
@@ -3796,10 +3798,12 @@ requires-dist = [
|
|||||||
{ name = "image", specifier = ">=1.5.33" },
|
{ name = "image", specifier = ">=1.5.33" },
|
||||||
{ name = "langchain", specifier = ">=1.2.0" },
|
{ name = "langchain", specifier = ">=1.2.0" },
|
||||||
{ name = "langchain-community", specifier = ">=0.4.1" },
|
{ name = "langchain-community", specifier = ">=0.4.1" },
|
||||||
|
{ name = "langchain-openai", specifier = ">=1.2.2" },
|
||||||
{ name = "langchain-qwq", specifier = ">=0.3.5" },
|
{ name = "langchain-qwq", specifier = ">=0.3.5" },
|
||||||
{ name = "langgraph", specifier = ">=1.0.5" },
|
{ name = "langgraph", specifier = ">=1.0.5" },
|
||||||
{ name = "langgraph-api", specifier = ">=0.4.28" },
|
{ name = "langgraph-api", specifier = ">=0.4.28" },
|
||||||
{ name = "langgraph-cli", extras = ["inmem", "redis"], specifier = "<=0.4.26" },
|
{ name = "langgraph-cli", extras = ["inmem", "redis"], specifier = "<=0.4.26" },
|
||||||
|
{ name = "langsmith", specifier = ">=0.8.11" },
|
||||||
{ name = "load", specifier = ">=1.0.14" },
|
{ name = "load", specifier = ">=1.0.14" },
|
||||||
{ name = "load-dotenv", specifier = ">=0.1.0" },
|
{ name = "load-dotenv", specifier = ">=0.1.0" },
|
||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
|
|||||||
Reference in New Issue
Block a user