重构图像生成和搜索工具;更新主代理来处理输入图像
- 更新了“generate_image.py”以接受输入图像以增强图像生成。 - 修改了`pexels_search.py`和`unsplash_search.py`以将日志记录和上传路径从“explorer”更改为“explore”。 - 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。 - 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。 - 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。 - 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。 - 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
This commit is contained in:
@@ -40,7 +40,7 @@ class PrintPrompt(BaseModel):
|
||||
|
||||
def extract_input_node(state: PrintState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
input_text = state["messages"][-1].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
@@ -49,13 +49,13 @@ def generate_print_prompt_node(state: PrintState) -> dict:
|
||||
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=f"""你是一个专业的印花图案设计师。
|
||||
请根据用户输入,生成用于AI图像生成的印花图案提示词。
|
||||
SystemMessage(content=f"""你是一个专业的印花图案设计师.
|
||||
请根据用户输入,生成用于AI图像生成的印花图案提示词.
|
||||
|
||||
要求:
|
||||
1. 提示词应该详细描述印花图案的样式、元素、颜色、布局
|
||||
要求:
|
||||
1. 提示词应该详细描述印花图案的样式,元素,颜色,布局
|
||||
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
|
||||
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
4. 提示词数量为 {state.get("print_num", 1)}
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
@@ -73,17 +73,19 @@ def generate_print_prompt_node(state: PrintState) -> dict:
|
||||
|
||||
|
||||
async def generate_print_img_node(state: PrintState) -> dict:
|
||||
"""根据生成的提示词,生成印花图案"""
|
||||
# 如果 print_prompts 为空,使用 input_text 作为 prompt
|
||||
"""根据生成的提示词,生成印花图案"""
|
||||
# 如果 print_prompts 为空,使用 input_text 作为 prompt
|
||||
if state.get("print_need_prompt_generation", False):
|
||||
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
|
||||
else:
|
||||
input_text = state.get("input_text", "")
|
||||
prompts = [input_text]
|
||||
|
||||
input_images = state.get("input_images", [])
|
||||
|
||||
print_img_urls = []
|
||||
for prompt in prompts:
|
||||
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
|
||||
image_url = await generate_print_tool.ainvoke({"prompt": prompt, "input_images": input_images})
|
||||
print_img_urls.append(image_url)
|
||||
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
|
||||
|
||||
@@ -94,7 +96,7 @@ async def generate_print_img_node(state: PrintState) -> dict:
|
||||
|
||||
|
||||
def should_generate_prompt(state: PrintState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
|
||||
logger.info(
|
||||
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
|
||||
@@ -106,7 +108,6 @@ def should_generate_prompt(state: PrintState) -> str:
|
||||
|
||||
|
||||
def build_print_graph():
|
||||
|
||||
workflow = StateGraph(PrintState)
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("gen_prompt", generate_print_prompt_node)
|
||||
@@ -145,8 +146,8 @@ async def main(test_input, print_need_prompt_generation=True):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 需要 prompt 生成(默认)
|
||||
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
|
||||
# 测试示例 1: 需要 prompt 生成(默认)
|
||||
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
|
||||
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
|
||||
print("=== 需要 prompt 生成 ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
input_images: list[str] = Field(default=[], description="Input images for the generation.")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_print_tool(prompt: str) -> str:
|
||||
async def generate_print_tool(prompt: str, input_images: list[str]) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "aida-users"
|
||||
object_name = f"agent_generate_print/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
|
||||
return [image_url]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user