feat chat robot 接口迁移
This commit is contained in:
7
app/service/chat_robot/script/agents/__init__.py
Normal file
7
app/service/chat_robot/script/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .agent_executor import CustomAgentExecutor
|
||||
from .conversational_functions_agent import ConversationalFunctionsAgent
|
||||
|
||||
__all__ = [
|
||||
"CustomAgentExecutor",
|
||||
"ConversationalFunctionsAgent"
|
||||
]
|
||||
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs, session_key)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
logging.exception(e)
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs, session_key
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
session_key: str = ""
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
) -> Optional[AgentFinish]:
|
||||
"""Check if the tool is a returning tool."""
|
||||
agent_action, observation = next_step_output
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
return_value_key = "output"
|
||||
|
||||
if len(self.agent.return_values) > 0:
|
||||
return_value_key = self.agent.return_values[0]
|
||||
|
||||
try:
|
||||
observation_list = json.loads(observation)
|
||||
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
|
||||
list) and observation_list.__len__() != 0:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,198 @@
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FunctionsAgentAction(AgentAction):
|
||||
"""Add message_log to AgentAction class for the _FunctionAgentAction
|
||||
"""
|
||||
message_log: List[BaseMessage]
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agents action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agents action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tools invocation.
|
||||
"""
|
||||
if isinstance(agent_action, _FunctionsAgentAction):
|
||||
return agent_action.message_log + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agents action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tools invocation request from the agents
|
||||
observation: the result of the tools invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tools invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _format_intermediate_steps(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for intermediate_step in intermediate_steps:
|
||||
agent_action, observation = intermediate_step
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tools into the OpenAI function API."""
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": tool.param_description if hasattr(tool, 'param_description') else "",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message but got {type(message)}")
|
||||
|
||||
function_call = message.additional_kwargs.get("function_call", {})
|
||||
|
||||
if function_call:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
function_name = function_call["name"]
|
||||
try:
|
||||
_tool_input = json.loads(function_call["arguments"])
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tools input: {function_call} because"
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
|
||||
if "query" in _tool_input:
|
||||
tool_input = _tool_input["query"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
return _FunctionsAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
|
||||
message_log=[message]
|
||||
)
|
||||
|
||||
# pattern = r'\((.*?)\)'
|
||||
# matches = re.findall(pattern, message.content)
|
||||
# result = []
|
||||
#
|
||||
# for match in matches:
|
||||
# result.append(match)
|
||||
#
|
||||
# if result:
|
||||
# output = result
|
||||
# else:
|
||||
# output = message.content
|
||||
|
||||
return AgentFinish(return_values={"output": message.content}, log=message.content)
|
||||
|
||||
|
||||
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
|
||||
|
||||
def plan(self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Decide how agents should move after receiving an input. The difference between
|
||||
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
|
||||
into it.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
Returns:
|
||||
Action specifying what tools to use.
|
||||
"""
|
||||
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages: List[BaseMessage] = prompt.to_messages()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
6
app/service/chat_robot/script/callbacks/__init__.py
Normal file
6
app/service/chat_robot/script/callbacks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .openai_token_record_callback import OpenAITokenRecordCallbackHandler
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OpenAITokenRecordCallbackHandler'
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Callback Handler that add on_chain_end function to record Token usage."""
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
need_record: bool = True
|
||||
response_type: str = "string"
|
||||
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
self.successful_requests += 1
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
if "function_call" in response.generations[0][0].message.additional_kwargs:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
|
||||
self.need_record = False
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
|
||||
self.response_type = "image"
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
|
||||
if model_name in MODEL_COST_PER_1K_TOKENS:
|
||||
completion_cost = get_openai_token_cost_for_model(
|
||||
model_name, completion_tokens, is_completion=True
|
||||
)
|
||||
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
|
||||
"""Write token usage to redis."""
|
||||
outputs["total_tokens"] = self.total_tokens
|
||||
outputs["total_cost"] = self.total_cost
|
||||
outputs["prompt_tokens"] = self.prompt_tokens
|
||||
outputs["completion_tokens"] = self.completion_tokens
|
||||
outputs["need_record"] = self.need_record
|
||||
outputs["response_type"] = self.response_type
|
||||
79
app/service/chat_robot/script/database.py
Normal file
79
app/service/chat_robot/script/database.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Optional, List
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
# from langchain import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
|
||||
class CustomDatabase(SQLDatabase):
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
connection = self._engine.connect()
|
||||
all_table_names = self.get_usable_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
# raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
return f"Table {','.join(missing_tables)} can not be found in the database"
|
||||
all_table_names = table_names
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
]
|
||||
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
table_name = table.name
|
||||
column_names = table.columns.keys()
|
||||
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
|
||||
for column_name in column_names:
|
||||
if column_name not in ["ID", "img_name"]:
|
||||
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
|
||||
result = connection.execute(query)
|
||||
enum_values: List[str] = [row[0] for row in result.fetchall()]
|
||||
column_info = f"{column_name}: {', '.join(enum_values)}\n"
|
||||
table_info += column_info
|
||||
|
||||
# table_info = f"Table: {table_name}\n"
|
||||
#
|
||||
# if self._sample_rows_in_table_info:
|
||||
# table_info += f"{self._get_sample_rows(table)}\n"
|
||||
tables.append(table_info)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
if self.dialect == "snowflake":
|
||||
connection.exec_driver_sql(
|
||||
f"ALTER SESSION SET search_path='{self._schema}'"
|
||||
)
|
||||
elif self.dialect == "bigquery":
|
||||
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
||||
else:
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.rowcount:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone() # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
|
||||
# Convert columns values to string to avoid issues with sqlalchmey
|
||||
# trunacating text
|
||||
if isinstance(result, list):
|
||||
return json.dumps([r[0] for r in result])
|
||||
|
||||
return json.dumps([result[0]])
|
||||
return ""
|
||||
114
app/service/chat_robot/script/main.py
Normal file
114
app/service/chat_robot/script/main.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
from loguru import logger
|
||||
from langchain.agents import Tool
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
|
||||
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
|
||||
from app.core.config import *
|
||||
|
||||
import os
|
||||
|
||||
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
||||
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
||||
# log callbacks
|
||||
logfile = "logs/chat_debug.log"
|
||||
logger.add(logfile, colorize=True, enqueue=True)
|
||||
log_handler = FileCallbackHandler(logfile)
|
||||
|
||||
# Initiate our LLM 'gpt-3.5-turbo'
|
||||
llm = ChatOpenAI(temperature=0.1,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
# callbacks=[OpenAICallbackHandler()]
|
||||
)
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
tools = [
|
||||
Tool(
|
||||
name="internet_search",
|
||||
description="Can be used to perform Internet searches",
|
||||
func=search.run
|
||||
),
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
Tool(
|
||||
name="tutorial_tool",
|
||||
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string",
|
||||
func=CustomTutorialTool(),
|
||||
return_direct=True
|
||||
)
|
||||
]
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
HumanMessagePromptTemplate.from_template(
|
||||
"{input} "
|
||||
"Question from a {gender}."
|
||||
),
|
||||
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
|
||||
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
|
||||
agent = ConversationalFunctionsAgent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
memory = UserConversationBufferWindowMemory.from_redis(
|
||||
return_messages=True, k=2, input_key='input', output_key='output'
|
||||
)
|
||||
agent_executor = CustomAgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
def chat(post_data):
|
||||
user_id = post_data.user_id
|
||||
session_id = post_data.session_id
|
||||
input_message = post_data.message
|
||||
gender = post_data.gender
|
||||
|
||||
final_outputs = agent_executor(
|
||||
{"input": input_message, "gender": gender},
|
||||
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
|
||||
session_key=f"buffer:{user_id}:{session_id}",
|
||||
)
|
||||
api_response = {
|
||||
'user_id': user_id,
|
||||
'session_id': session_id,
|
||||
# 'message_id': message_id,
|
||||
# 'create_time': created_time,
|
||||
'input': final_outputs['input'],
|
||||
# 'conversion': messages,
|
||||
'output': final_outputs['output'],
|
||||
# 'gpt_response_time': gpt_response_time,
|
||||
'total_tokens': final_outputs['total_tokens'],
|
||||
'total_cost': final_outputs['total_cost'],
|
||||
'prompt_tokens': final_outputs['prompt_tokens'],
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs['response_type']
|
||||
}
|
||||
logging.info(api_response)
|
||||
return api_response
|
||||
3
app/service/chat_robot/script/memory/__init__.py
Normal file
3
app/service/chat_robot/script/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .user_buffer_window import UserConversationBufferWindowMemory
|
||||
|
||||
__all__ = ['UserConversationBufferWindowMemory']
|
||||
93
app/service/chat_robot/script/memory/user_buffer_window.py
Normal file
93
app/service/chat_robot/script/memory/user_buffer_window.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import json
|
||||
|
||||
import redis
|
||||
from redis import Redis
|
||||
from langchain.memory import RedisChatMessageHistory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
|
||||
class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
redis_client: Redis
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
k: int = 5
|
||||
|
||||
@classmethod
|
||||
def from_redis(
|
||||
cls,
|
||||
host: str = REDIS_HOST,
|
||||
port: int = REDIS_PORT,
|
||||
db: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
redis_client = Redis(host=host, port=port, db=db)
|
||||
try:
|
||||
response = redis_client.ping()
|
||||
if response:
|
||||
print("Connect to redis server successfully.")
|
||||
logging.info("Connect to redis server successfully.")
|
||||
else:
|
||||
print("Fail to connect to redis server")
|
||||
logging.info("Fail to connect to redis server")
|
||||
except redis.RedisError as e:
|
||||
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
|
||||
return cls(redis_client=redis_client, **kwargs)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
|
||||
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
|
||||
buffer = messages_from_dict(items)
|
||||
if not self.return_messages:
|
||||
buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
def _get_input_output(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
output_key = list(outputs.keys())[0]
|
||||
else:
|
||||
output_key = self.output_key
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def add_message(self, key: str, message: BaseMessage) -> None:
|
||||
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.add_message(key, HumanMessage(content=input_str))
|
||||
self.add_message(key, AIMessage(content=output_str))
|
||||
|
||||
# def clear(self, key) -> None:
|
||||
# """Clear memory contents."""
|
||||
# self.redis_client.delete(key)
|
||||
52
app/service/chat_robot/script/prompt.py
Normal file
52
app/service/chat_robot/script/prompt.py
Normal file
@@ -0,0 +1,52 @@
|
||||
FASHION_CHAT_BOT_PREFIX = """
|
||||
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
|
||||
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
|
||||
Remember your answer should be very precise and the final output answer should not exceed 20 words.
|
||||
|
||||
You may encounter the following types of questions:
|
||||
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
|
||||
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
|
||||
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
|
||||
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
If the question does not seem related to the database, just return "I don't know" as the answer.
|
||||
|
||||
2) If the query related to current events, you should use internet_search to seek help from the internet.
|
||||
|
||||
3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
|
||||
|
||||
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
|
||||
"""
|
||||
|
||||
TOOL_SELECT_SUFFIX = """
|
||||
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
|
||||
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
|
||||
The use of online resources should be limited to inquiry pertaining to current subjects.
|
||||
"""
|
||||
|
||||
SQL_FUNCTIONS_SUFFIX = """
|
||||
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
|
||||
"""
|
||||
|
||||
INTERNET_SEARCH_SUFFIX = """
|
||||
If the question should be answered using internet search tools, I should seek help from the internet.
|
||||
"""
|
||||
|
||||
ANSWER_FORMAT_SUFFIX = """
|
||||
My final answer are limited to 20 words and be as much precise as possible.
|
||||
"""
|
||||
|
||||
TOOLS_FUNCTIONS_SUFFIX = (
|
||||
"If the input involves clothing queries,"
|
||||
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
|
||||
"All SQL statements must use 'ORDER BY RAND()', for example:"
|
||||
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
|
||||
"If the input does not involve clothing queries, "
|
||||
"I should engage in conversation as an assistant or search from internet with internet_search tool."
|
||||
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
|
||||
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
|
||||
)
|
||||
|
||||
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
|
||||
10
app/service/chat_robot/script/tools/__init__.py
Normal file
10
app/service/chat_robot/script/tools/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .sql_tools import (
|
||||
QuerySQLDataBaseTool,
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
|
||||
]
|
||||
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
"""Base tools for interacting with a SQL database."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
param_description: str = ""
|
||||
|
||||
# Override BaseTool.Config to appease mypy
|
||||
# See https://github.com/pydantic/pydantic/issues/4173
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for querying a SQL database."""
|
||||
|
||||
name = "sql_db_query"
|
||||
# description = """
|
||||
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
|
||||
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
|
||||
# You should always use ‘order by rand()’ to randomly select data.
|
||||
# If the query is not correct, an error message will be returned.
|
||||
# If an error is returned, rewrite the query, check the query, and try again.
|
||||
# Always limit your query to at most 4 results.
|
||||
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
# You MUST double check your query before executing it.
|
||||
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||||
"and the output is the result of the database, and it can only return up to 4 results."
|
||||
"If the query is not correct, an error message will be returned."
|
||||
"If an error is returned, rewrite the query, check the query, and try again."
|
||||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||||
"use sql_db_schema to query the correct table fields."
|
||||
|
||||
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
result = self.db.run_no_throw(query)
|
||||
return result
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("QuerySqlDbTool does not support async")
|
||||
|
||||
|
||||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting metadata about a SQL database."""
|
||||
|
||||
name = "sql_db_schema"
|
||||
# description = """
|
||||
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
|
||||
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
|
||||
# Find the most relevant tables with the query, and output the schema of these tables.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
table_names: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.db.get_table_info_no_throw(table_names.split(", "))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
table_name: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("SchemaSqlDbTool does not support async")
|
||||
|
||||
|
||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name = "sql_db_list_tables"
|
||||
description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a specific table."""
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||
|
||||
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: LLMChain = Field(init=False)
|
||||
name = "sql_db_query_checker"
|
||||
description = (
|
||||
"Use this tools to double check if your query is correct before executing it."
|
||||
"Always use this tools before executing a query with sql_db_query!"
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER,
|
||||
input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the LLM to check the query."""
|
||||
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)
|
||||
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
# 处理系统引导教程相关的输入
|
||||
class CustomTutorialTool(BaseTool):
|
||||
name = "tutorial_tool"
|
||||
|
||||
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string")
|
||||
|
||||
def _run(self, tool_input, **kwargs: Any) -> str:
|
||||
return TUTORIAL_TOOL_RETURN
|
||||
|
||||
async def _arun(self, tool_input, **kwargs: Any) -> str:
|
||||
raise NotImplementedError("CustomTutorialTool does not support async")
|
||||
1
app/service/chat_robot/script/utils/__init__.py
Normal file
1
app/service/chat_robot/script/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .logger import Logger
|
||||
26
app/service/chat_robot/script/utils/logger.py
Normal file
26
app/service/chat_robot/script/utils/logger.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from logging import handlers
|
||||
|
||||
|
||||
class Logger(object):
|
||||
level_relations = {
|
||||
'debug': logging.DEBUG,
|
||||
'info': logging.INFO,
|
||||
'warning': logging.WARNING,
|
||||
'error': logging.ERROR,
|
||||
'crit': logging.CRITICAL
|
||||
}
|
||||
|
||||
def __init__(self, filename, level='info', when='D', backCount=3,
|
||||
fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
|
||||
self.logger = logging.getLogger(filename)
|
||||
format_str = logging.Formatter(fmt) # set log format
|
||||
self.logger.setLevel(self.level_relations.get(level)) # set log level
|
||||
sh = logging.StreamHandler() # output to terminal
|
||||
sh.setFormatter(format_str) # set format for terminal log
|
||||
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
|
||||
encoding='utf-8') # log into file
|
||||
|
||||
th.setFormatter(format_str) # set format for file log
|
||||
self.logger.addHandler(sh) # output to terminal
|
||||
self.logger.addHandler(th) # output to file
|
||||
Reference in New Issue
Block a user