Files
AiDA_Python/app/service/chat_robot/script/main.py
zhouchengrong 91ed45e978 feat(新功能): qwen api key 修改
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-03-31 17:08:52 +08:00

139 lines
5.2 KiB
Python

import json
import logging
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from app.core.config import *
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.service import CallQWen
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
logfile = "logs/chat_debug.log"
logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
# llm = ChatOpenAI(temperature=0.1,
# openai_api_key=OPENAI_API_KEY,
# # callbacks=[OpenAICallbackHandler()]
# )
llm = ChatTongyi(api_key=QWEN_API_KEY)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
tools = [
Tool(
name="internet_search",
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
# Tool(
# name="tutorial_tool",
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
# "Input is an empty string",
# func=CustomTutorialTool(),
# return_direct=True
# )
]
messages = [
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template(
"{input} "
"Question from a {gender}."
),
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
agent = ConversationalFunctionsAgent(
llm=llm,
tools=tools,
prompt=prompt
)
memory = UserConversationBufferWindowMemory.from_redis(
return_messages=True, k=2, input_key='input', output_key='output'
)
agent_executor = CustomAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
# final_outputs = agent_executor(
# {"input": input_message, "gender": gender},
# callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
# session_key=f"buffer:{user_id}:{session_id}",
# )
final_outputs = CallQWen.call_with_messages(input_message, gender)
# api_response = {
# 'user_id': user_id,
# 'session_id': session_id,
# # 'message_id': message_id,
# # 'create_time': created_time,
# 'input': final_outputs['input'],
# # 'conversion': messages,
# 'output': final_outputs['output'],
# # 'gpt_response_time': gpt_response_time,
# 'total_tokens': final_outputs['total_tokens'],
# 'total_cost': final_outputs['total_cost'],
# 'prompt_tokens': final_outputs['prompt_tokens'],
# 'completion_tokens': final_outputs['completion_tokens'],
# 'response_type': final_outputs['response_type']
# }
# if final_outputs["output"].startswith("["):
# final_str = final_outputs["output"].replace("\\", "")
# else:
# final_str = final_outputs["output"]
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': input_message,
# 'conversion': messages,
'output': final_outputs["output"],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs["response_type"]
}
logging.info(json.dumps(api_response))
return api_response