重构图像生成和搜索工具;更新主代理来处理输入图像

- 更新了“generate_image.py”以接受输入图像以增强图像生成。
- 修改了`pexels_search.py​​`和`unsplash_search.py​​`以将日志记录和上传路径从“explorer”更改为“explore”。
- 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。
- 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。
- 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。
- 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。
- 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
This commit is contained in:
zcr
2026-06-17 11:56:53 +08:00
parent 35e791b4e2
commit b9163f0b46
16 changed files with 296 additions and 212 deletions

View File

@@ -40,7 +40,7 @@ class PrintPrompt(BaseModel):
def extract_input_node(state: PrintState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
input_text = state["messages"][-1].content if state.get("messages") else ""
return {"input_text": input_text}
@@ -49,13 +49,13 @@ def generate_print_prompt_node(state: PrintState) -> dict:
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
messages = [
SystemMessage(content=f"""你是一个专业的印花图案设计师
请根据用户输入生成用于AI图像生成的印花图案提示词
SystemMessage(content=f"""你是一个专业的印花图案设计师.
请根据用户输入,生成用于AI图像生成的印花图案提示词.
要求
1. 提示词应该详细描述印花图案的样式元素颜色布局
要求:
1. 提示词应该详细描述印花图案的样式,元素,颜色,布局
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
3. 提示词应该使用英文因为图像生成模型对英文理解更好
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
4. 提示词数量为 {state.get("print_num", 1)}
"""),
HumanMessage(content=state["input_text"]),
@@ -73,17 +73,19 @@ def generate_print_prompt_node(state: PrintState) -> dict:
async def generate_print_img_node(state: PrintState) -> dict:
"""根据生成的提示词生成印花图案"""
# 如果 print_prompts 为空使用 input_text 作为 prompt
"""根据生成的提示词,生成印花图案"""
# 如果 print_prompts 为空,使用 input_text 作为 prompt
if state.get("print_need_prompt_generation", False):
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
else:
input_text = state.get("input_text", "")
prompts = [input_text]
input_images = state.get("input_images", [])
print_img_urls = []
for prompt in prompts:
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
image_url = await generate_print_tool.ainvoke({"prompt": prompt, "input_images": input_images})
print_img_urls.append(image_url)
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
@@ -94,7 +96,7 @@ async def generate_print_img_node(state: PrintState) -> dict:
def should_generate_prompt(state: PrintState) -> str:
"""条件分支判断是否需要生成 prompt"""
"""条件分支:判断是否需要生成 prompt"""
logger.info(
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
@@ -106,7 +108,6 @@ def should_generate_prompt(state: PrintState) -> str:
def build_print_graph():
workflow = StateGraph(PrintState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_print_prompt_node)
@@ -145,8 +146,8 @@ async def main(test_input, print_need_prompt_generation=True):
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成默认
test_input = "我想要一个优雅的花卉印花适合用于连衣裙颜色以粉色和白色为主"
# 测试示例 1: 需要 prompt 生成(默认)
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")

View File

@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
input_images: list[str] = Field(default=[], description="Input images for the generation.")
@tool(args_schema=GenerateImageToolInput)
async def generate_print_tool(prompt: str) -> str:
async def generate_print_tool(prompt: str, input_images: list[str]) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "aida-users"
object_name = f"agent_generate_print/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
return [image_url]