Files
AiDA_Python/app/service/chat_robot/script/main.py

116 lines
4.3 KiB
Python
Raw Normal View History

import json
2024-05-29 11:12:59 +08:00
import logging
2024-05-29 11:12:59 +08:00
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
2024-05-29 11:12:59 +08:00
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from loguru import logger
from app.core.config import *
2024-05-29 11:12:59 +08:00
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
2024-05-29 11:12:59 +08:00
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
logfile = "logs/chat_debug.log"
logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()]
)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
tools = [
Tool(
name="internet_search",
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool(
name="tutorial_tool",
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string",
func=CustomTutorialTool(),
return_direct=True
)
]
messages = [
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template(
"{input} "
"Question from a {gender}."
),
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
agent = ConversationalFunctionsAgent(
llm=llm,
tools=tools,
prompt=prompt
)
memory = UserConversationBufferWindowMemory.from_redis(
return_messages=True, k=2, input_key='input', output_key='output'
)
agent_executor = CustomAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
final_outputs = agent_executor(
{"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}",
)
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': final_outputs['input'],
# 'conversion': messages,
'output': final_outputs['output'],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
}
logging.info(json.dumps(api_response))
2024-05-29 11:12:59 +08:00
return api_response