openai 替换为 通义千问

This commit is contained in:
2024-07-08 18:50:01 +08:00
parent d772adcd7a
commit 8ad3e8ac0f
8 changed files with 412 additions and 89 deletions

View File

@@ -1,16 +1,19 @@
import logging
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from langchain.agents import Tool
from langchain.utilities import SerpAPIWrapper
from langchain_community.utilities import SerpAPIWrapper
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
# from langchain_community.chat_models import ChatOpenAI
# from langchain_community.llms import OpenAI
from langchain.callbacks import FileCallbackHandler
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.service import CallQWen
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
@@ -26,10 +29,12 @@ logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()]
)
# llm = ChatOpenAI(temperature=0.1,
# openai_api_key=OPENAI_API_KEY,
# # callbacks=[OpenAICallbackHandler()]
# )
llm = ChatTongyi(api_key="sk-7658298c6b99443c98184a5e634fe6ab")
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
@@ -45,14 +50,15 @@ tools = [
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool(
name="tutorial_tool",
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string",
func=CustomTutorialTool(),
return_direct=True
)
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
QuerySQLCheckerTool(db=db, llm = ChatTongyi(temperature=0, api_key="sk-7658298c6b99443c98184a5e634fe6ab")),
# Tool(
# name="tutorial_tool",
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
# "Input is an empty string",
# func=CustomTutorialTool(),
# return_direct=True
# )
]
messages = [
@@ -90,25 +96,47 @@ def chat(post_data):
input_message = post_data.message
gender = post_data.gender
final_outputs = agent_executor(
{"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}",
)
# final_outputs = agent_executor(
# {"input": input_message, "gender": gender},
# callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
# session_key=f"buffer:{user_id}:{session_id}",
# )
final_outputs = CallQWen.call_with_messages(input_message)
# api_response = {
# 'user_id': user_id,
# 'session_id': session_id,
# # 'message_id': message_id,
# # 'create_time': created_time,
# 'input': final_outputs['input'],
# # 'conversion': messages,
# 'output': final_outputs['output'],
# # 'gpt_response_time': gpt_response_time,
# 'total_tokens': final_outputs['total_tokens'],
# 'total_cost': final_outputs['total_cost'],
# 'prompt_tokens': final_outputs['prompt_tokens'],
# 'completion_tokens': final_outputs['completion_tokens'],
# 'response_type': final_outputs['response_type']
# }
# if final_outputs["output"].startswith("["):
# final_str = final_outputs["output"].replace("\\", "")
# else:
# final_str = final_outputs["output"]
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': final_outputs['input'],
'input': input_message,
# 'conversion': messages,
'output': final_outputs['output'],
'output': final_outputs["output"],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
'response_type': final_outputs["response_type"]
}
logging.info(api_response)
return api_response