feat chat robot 接口迁移

This commit is contained in:
zhouchengrong
2024-05-29 11:12:59 +08:00
parent a9dcd444c8
commit 13fec64125
23 changed files with 1139 additions and 1 deletions

27
app/api/api_chat_robot.py Normal file
View File

@@ -0,0 +1,27 @@
import logging
import time
from fastapi import APIRouter
from app.schemas.chat_robot import ChatRobotModel
from app.service.chat_robot.script.main import chat
router = APIRouter()
logger = logging.getLogger()
@router.post("/chat_robot")
def chat_robot(request_data: ChatRobotModel):
try:
logger.info(f"chat_robot request item is : @@@@@@:{request_data}")
code = 200
message = "access"
start_time = time.time()
data = chat(post_data=request_data)
logger.info(f"chat_robot Run time is @@@@@@:{time.time() - start_time}")
except Exception as e:
code = 400
message = str(e)
data = str(e)
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
logger.info({"code": code, "message": message, "data": data})
return {"code": code, "message": message, "data": data}

View File

@@ -0,0 +1,28 @@
import logging
import time
from fastapi import APIRouter
from app.schemas.prompt_generation import PromptGenerationImageModel
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en
router = APIRouter()
logger = logging.getLogger()
@router.post("/translateToEN")
def prompt_generation(request_data: PromptGenerationImageModel):
try:
logger.info(f"prompt_translate to English request data : @@@@@@:{request_data}")
code = 200
message = "access"
start_time = time.time()
data = translate_to_en(request_data.text)
logger.info(f"prompt_generation Run time is @@@@@@:{time.time() - start_time}")
except Exception as e:
code = 400
message = str(e)
data = str(e)
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
logger.info({"code": code, "message": message, "data": data})
return {"code": code, "message": message, "data": data}

View File

@@ -5,6 +5,9 @@ from app.api import api_super_resolution
from app.api import api_generate_image
from app.api import api_attribute_retrieve
from app.api import api_design
from app.api import api_chat_robot
from app.api import api_prompt_generation
router = APIRouter()
@@ -13,3 +16,5 @@ router.include_router(api_super_resolution.router, tags=["super_resolution"], pr
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")

View File

@@ -19,7 +19,7 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
DEBUG = False
DEBUG = True
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
@@ -61,6 +61,32 @@ MILVUS_PORT = "19530"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# attribute service config
ATT_TRITON_URL = "10.1.1.240:10000"

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class ChatRobotModel(BaseModel):
gender: str
message: str
session_id: str
user_id: int

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class PromptGenerationImageModel(BaseModel):
text: str

View File

@@ -0,0 +1,7 @@
from .agent_executor import CustomAgentExecutor
from .conversational_functions_agent import ConversationalFunctionsAgent
__all__ = [
"CustomAgentExecutor",
"ConversationalFunctionsAgent"
]

View File

@@ -0,0 +1,132 @@
import inspect
import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_core.agents import AgentAction, AgentFinish
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs, session_key)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
logging.exception(e)
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs, session_key
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
session_key: str = ""
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output"
if len(self.agent.return_values) > 0:
return_value_key = self.agent.return_values[0]
try:
observation_list = json.loads(observation)
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
list) and observation_list.__len__() != 0:
return AgentFinish(
{return_value_key: observation},
"",
)
except:
pass
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:
return AgentFinish(
{return_value_key: observation},
"",
)
return None

View File

@@ -0,0 +1,198 @@
import json
import re
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
@dataclass
class _FunctionsAgentAction(AgentAction):
"""Add message_log to AgentAction class for the _FunctionAgentAction
"""
message_log: List[BaseMessage]
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agents action to a message.
This code is used to reconstruct the original AI message from the agents action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tools invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agents action and observation into a function message.
Args:
agent_action: the tools invocation request from the agents
observation: the result of the tools invocation
Returns:
FunctionMessage that corresponds to the original tools invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tools into the OpenAI function API."""
parameters = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": tool.param_description if hasattr(tool, 'param_description') else "",
},
},
"required": ["query"],
}
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message but got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tools input: {function_call} because"
f"the `arguments` is not valid JSON."
)
if "query" in _tool_input:
tool_input = _tool_input["query"]
else:
tool_input = _tool_input
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
message_log=[message]
)
# pattern = r'\((.*?)\)'
# matches = re.findall(pattern, message.content)
# result = []
#
# for match in matches:
# result.append(match)
#
# if result:
# output = result
# else:
# output = message.content
return AgentFinish(return_values={"output": message.content}, log=message.content)
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
@property
def functions(self) -> List[dict]:
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
def plan(self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Decide how agents should move after receiving an input. The difference between
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
into it.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string
Returns:
Action specifying what tools to use.
"""
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages: List[BaseMessage] = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -0,0 +1,6 @@
from .openai_token_record_callback import OpenAITokenRecordCallbackHandler
__all__ = [
'OpenAITokenRecordCallbackHandler'
]

View File

@@ -0,0 +1,46 @@
"""Callback Handler that add on_chain_end function to record Token usage."""
from typing import Any, Dict
from langchain.callbacks import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1
if "token_usage" not in response.llm_output:
return None
if "function_call" in response.generations[0][0].message.additional_kwargs:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
self.need_record = False
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
self.response_type = "image"
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
if model_name in MODEL_COST_PER_1K_TOKENS:
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""
outputs["total_tokens"] = self.total_tokens
outputs["total_cost"] = self.total_cost
outputs["prompt_tokens"] = self.prompt_tokens
outputs["completion_tokens"] = self.completion_tokens
outputs["need_record"] = self.need_record
outputs["response_type"] = self.response_type

View File

@@ -0,0 +1,79 @@
from typing import Optional, List
import json
from sqlalchemy import text
# from langchain import SQLDatabase
from langchain.utilities import SQLDatabase
class CustomDatabase(SQLDatabase):
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
connection = self._engine.connect()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
# raise ValueError(f"table_names {missing_tables} not found in database")
return f"Table {','.join(missing_tables)} can not be found in the database"
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
]
tables = []
for table in meta_tables:
table_name = table.name
column_names = table.columns.keys()
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
for column_name in column_names:
if column_name not in ["ID", "img_name"]:
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
result = connection.execute(query)
enum_values: List[str] = [row[0] for row in result.fetchall()]
column_info = f"{column_name}: {', '.join(enum_values)}\n"
table_info += column_info
# table_info = f"Table: {table_name}\n"
#
# if self._sample_rows_in_table_info:
# table_info += f"{self._get_sample_rows(table)}\n"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
f"ALTER SESSION SET search_path='{self._schema}'"
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
else:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.execute(text(command))
if cursor.rowcount:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
return json.dumps([r[0] for r in result])
return json.dumps([result[0]])
return ""

View File

@@ -0,0 +1,114 @@
import logging
from loguru import logger
from langchain.agents import Tool
from langchain.utilities import SerpAPIWrapper
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.callbacks import FileCallbackHandler
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
from app.core.config import *
import os
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
logfile = "logs/chat_debug.log"
logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()]
)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
tools = [
Tool(
name="internet_search",
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool(
name="tutorial_tool",
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string",
func=CustomTutorialTool(),
return_direct=True
)
]
messages = [
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template(
"{input} "
"Question from a {gender}."
),
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
agent = ConversationalFunctionsAgent(
llm=llm,
tools=tools,
prompt=prompt
)
memory = UserConversationBufferWindowMemory.from_redis(
return_messages=True, k=2, input_key='input', output_key='output'
)
agent_executor = CustomAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
final_outputs = agent_executor(
{"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}",
)
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': final_outputs['input'],
# 'conversion': messages,
'output': final_outputs['output'],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
}
logging.info(api_response)
return api_response

View File

@@ -0,0 +1,3 @@
from .user_buffer_window import UserConversationBufferWindowMemory
__all__ = ['UserConversationBufferWindowMemory']

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any, Dict, List, Tuple
import json
import redis
from redis import Redis
from langchain.memory import RedisChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
class UserConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
redis_client: Redis
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
db: int = 3,
**kwargs
):
redis_client = Redis(host=host, port=port, db=db)
try:
response = redis_client.ping()
if response:
print("Connect to redis server successfully.")
logging.info("Connect to redis server successfully.")
else:
print("Fail to connect to redis server")
logging.info("Fail to connect to redis server")
except redis.RedisError as e:
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
return cls(redis_client=redis_client, **kwargs)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
"""Return history buffer."""
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
buffer = messages_from_dict(items)
if not self.return_messages:
buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.add_message(key, HumanMessage(content=input_str))
self.add_message(key, AIMessage(content=output_str))
# def clear(self, key) -> None:
# """Clear memory contents."""
# self.redis_client.delete(key)

View File

@@ -0,0 +1,52 @@
FASHION_CHAT_BOT_PREFIX = """
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
Remember your answer should be very precise and the final output answer should not exceed 20 words.
You may encounter the following types of questions:
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
2) If the query related to current events, you should use internet_search to seek help from the internet.
3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
"""
TOOL_SELECT_SUFFIX = """
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
The use of online resources should be limited to inquiry pertaining to current subjects.
"""
SQL_FUNCTIONS_SUFFIX = """
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
"""
INTERNET_SEARCH_SUFFIX = """
If the question should be answered using internet search tools, I should seek help from the internet.
"""
ANSWER_FORMAT_SUFFIX = """
My final answer are limited to 20 words and be as much precise as possible.
"""
TOOLS_FUNCTIONS_SUFFIX = (
"If the input involves clothing queries,"
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
"All SQL statements must use 'ORDER BY RAND()', for example:"
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
"If the input does not involve clothing queries, "
"I should engage in conversation as an assistant or search from internet with internet_search tool."
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
)
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."

View File

@@ -0,0 +1,10 @@
from .sql_tools import (
QuerySQLDataBaseTool,
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool
)
__all__ = [
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
]

View File

@@ -0,0 +1,183 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class BaseSQLDatabaseTool(BaseModel):
"""Base tools for interacting with a SQL database."""
db: SQLDatabase = Field(exclude=True)
param_description: str = ""
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name = "sql_db_query"
# description = """
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
# You should always use order by rand() to randomly select data.
# If the query is not correct, an error message will be returned.
# If an error is returned, rewrite the query, check the query, and try again.
# Always limit your query to at most 4 results.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
# You MUST double check your query before executing it.
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
# """
description: str = (
"The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use sql_db_schema to query the correct table fields."
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
result = self.db.run_no_throw(query)
return result
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name = "sql_db_schema"
# description = """
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
# Find the most relevant tables with the query, and output the schema of these tables.
# """
description: str = (
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))
async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database."
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())
async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = (
"Use this tools to double check if your query is correct before executing it."
"Always use this tools before executing a query with sql_db_query!"
)
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)

View File

@@ -0,0 +1,19 @@
from typing import Any
from langchain.tools.base import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
# 处理系统引导教程相关的输入
class CustomTutorialTool(BaseTool):
name = "tutorial_tool"
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string")
def _run(self, tool_input, **kwargs: Any) -> str:
return TUTORIAL_TOOL_RETURN
async def _arun(self, tool_input, **kwargs: Any) -> str:
raise NotImplementedError("CustomTutorialTool does not support async")

View File

@@ -0,0 +1 @@
from .logger import Logger

View File

@@ -0,0 +1,26 @@
import logging
from logging import handlers
class Logger(object):
level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
}
def __init__(self, filename, level='info', when='D', backCount=3,
fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
self.logger = logging.getLogger(filename)
format_str = logging.Formatter(fmt) # set log format
self.logger.setLevel(self.level_relations.get(level)) # set log level
sh = logging.StreamHandler() # output to terminal
sh.setFormatter(format_str) # set format for terminal log
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
encoding='utf-8') # log into file
th.setFormatter(format_str) # set format for file log
self.logger.addHandler(sh) # output to terminal
self.logger.addHandler(th) # output to file

View File

@@ -0,0 +1,70 @@
import os
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, \
PromptTemplate
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
llm = ChatOpenAI(model_name=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY,
temperature=0)
def translate_to_en(text):
template = (
"""You are a translation expert, proficient in various languages.
And can translate various languages into English.
Please translate to grammatically correct English regardless of the input language.
If the input is in English, check for grammatical errors. If there are no errors, simply output the sentence.
If there are grammatical errors, correct them and then output the sentence."""
)
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
# 待翻译文本由 Human 角色输入
human_template = "User input : {text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
chat_prompt_template = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template)
template = (
"""
Input sentence:
{translate}
1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
ensuring all key nouns or adjectives related to the image are retained.
2. Simplify complex sentence structures and clarify ambiguous expressions.
3. Only Output the adjusted English sentence.
Output :
"""
)
# "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
# 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
prompt_template = PromptTemplate(input_variables=["translate"], template=template)
prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
from langchain.chains import SimpleSequentialChain
overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
response = overall_chain.run(text)
return response
def main():
"""Main function"""
translate_to_en("生成一件运动风格的夹克,带有拉链和口袋,适合休闲穿着")
if __name__ == "__main__":
main()

Binary file not shown.