diff --git a/app/api/api_chat_robot.py b/app/api/api_chat_robot.py new file mode 100644 index 0000000..c394046 --- /dev/null +++ b/app/api/api_chat_robot.py @@ -0,0 +1,27 @@ +import logging +import time +from fastapi import APIRouter + +from app.schemas.chat_robot import ChatRobotModel +from app.service.chat_robot.script.main import chat + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/chat_robot") +def chat_robot(request_data: ChatRobotModel): + try: + logger.info(f"chat_robot request item is : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + data = chat(post_data=request_data) + logger.info(f"chat_robot Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"chat_robot Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py new file mode 100644 index 0000000..5e71eec --- /dev/null +++ b/app/api/api_prompt_generation.py @@ -0,0 +1,28 @@ +import logging +import time + +from fastapi import APIRouter + +from app.schemas.prompt_generation import PromptGenerationImageModel +from app.service.prompt_generation.chatgpt_for_translation import translate_to_en + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/translateToEN") +def prompt_generation(request_data: PromptGenerationImageModel): + try: + logger.info(f"prompt_translate to English request data : @@@@@@:{request_data}") + code = 200 + message = "access" + start_time = time.time() + data = translate_to_en(request_data.text) + logger.info(f"prompt_generation Run time is @@@@@@:{time.time() - start_time}") + except Exception as e: + code = 400 + message = str(e) + data = str(e) + logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") + logger.info({"code": code, "message": message, "data": data}) + return {"code": code, "message": message, "data": data} diff --git a/app/api/api_route.py b/app/api/api_route.py index ff21b34..c1add93 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -5,6 +5,9 @@ from app.api import api_super_resolution from app.api import api_generate_image from app.api import api_attribute_retrieve from app.api import api_design +from app.api import api_chat_robot +from app.api import api_prompt_generation + router = APIRouter() @@ -13,3 +16,5 @@ router.include_router(api_super_resolution.router, tags=["super_resolution"], pr router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api") router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api") router.include_router(api_design.router, tags=['design'], prefix="/api") +router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") +router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index 6e22adc..5744dec 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" @@ -61,6 +61,32 @@ MILVUS_PORT = "19530" MILVUS_TABLE_KEYPOINT = "keypoint_cache" MILVUS_TABLE_SEG = "seg_cache" +# Mysql 配置 +DB_HOST = '18.167.251.121' # 数据库主机地址 +# DB_PORT = int( 33006) +DB_PORT = 33008 # 数据库端口 +DB_USERNAME = 'aida_con_python' # 数据库用户名 +DB_PASSWORD = '123456' # 数据库密码 +DB_NAME = 'aida' # 数据库库名 + +# openai +os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c" +OPENAI_STREAM = True +BUFFER_THRESHOLD = 6 # must be even number +SINGLE_TOKEN_THRESHOLD = 200 +TOKEN_THRESHOLD = 600 +OPENAI_TEMPERATURE = 0 + +# OPENAI_API_KEY = "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" +OPENAI_API_KEY = "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" +OPENAI_MODEL = "gpt-3.5-turbo-0613" +OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", } + # attribute service config ATT_TRITON_URL = "10.1.1.240:10000" diff --git a/app/schemas/chat_robot.py b/app/schemas/chat_robot.py new file mode 100644 index 0000000..cebf74a --- /dev/null +++ b/app/schemas/chat_robot.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class ChatRobotModel(BaseModel): + gender: str + message: str + session_id: str + user_id: int diff --git a/app/schemas/prompt_generation.py b/app/schemas/prompt_generation.py new file mode 100644 index 0000000..195291b --- /dev/null +++ b/app/schemas/prompt_generation.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class PromptGenerationImageModel(BaseModel): + text: str diff --git a/app/service/chat_robot/script/agents/__init__.py b/app/service/chat_robot/script/agents/__init__.py new file mode 100644 index 0000000..30c40f9 --- /dev/null +++ b/app/service/chat_robot/script/agents/__init__.py @@ -0,0 +1,7 @@ +from .agent_executor import CustomAgentExecutor +from .conversational_functions_agent import ConversationalFunctionsAgent + +__all__ = [ + "CustomAgentExecutor", + "ConversationalFunctionsAgent" +] diff --git a/app/service/chat_robot/script/agents/agent_executor.py b/app/service/chat_robot/script/agents/agent_executor.py new file mode 100644 index 0000000..cc69936 --- /dev/null +++ b/app/service/chat_robot/script/agents/agent_executor.py @@ -0,0 +1,132 @@ +import inspect +import json +import logging +from typing import Any, Dict, List, Optional, Union, Tuple + +from langchain.agents import AgentExecutor +from langchain.callbacks.manager import Callbacks, CallbackManager +from langchain.load.dump import dumpd +from langchain.schema import RUN_KEY, RunInfo +from langchain_core.agents import AgentAction, AgentFinish + + +class CustomAgentExecutor(AgentExecutor): + def __call__( + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, + session_key: str = "", + *, + tags: Optional[List[str]] = None, + include_run_info: bool = False, + ) -> Dict[str, Any]: + """Run the logic of this chain and add to output if desired. + + Args: + inputs: Dictionary of inputs, or single input if chain expects + only one param. + return_only_outputs: boolean for whether to return only outputs in the + response. If True, only new keys generated by this chain will be + returned. If False, both input keys and new keys generated by this + chain will be returned. Defaults to False. + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. + """ + inputs = self.prep_inputs(inputs, session_key) + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose, tags, self.tags + ) + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( + dumpd(self), + inputs, + ) + try: + outputs = ( + self._call(inputs, run_manager=run_manager) + if new_arg_supported + else self._call(inputs) + ) + except (KeyboardInterrupt, Exception) as e: + logging.exception(e) + run_manager.on_chain_error(e) + raise e + run_manager.on_chain_end(outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs, session_key + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs + + def prep_outputs( + self, + inputs: Dict[str, str], + outputs: Dict[str, str], + return_only_outputs: bool = False, + session_key: str = "" + ) -> Dict[str, str]: + """Validate and prep outputs.""" + self._validate_outputs(outputs) + if self.memory is not None and outputs['need_record']: + self.memory.save_context(inputs, outputs, session_key) + if return_only_outputs: + return outputs + else: + return {**inputs, **outputs} + + def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]: + """Validate and prep inputs.""" + if not isinstance(inputs, dict): + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + if len(_input_keys) != 1: + raise ValueError( + f"A single string input was passed in, but this chain expects " + f"multiple inputs ({_input_keys}). When a chain expects " + f"multiple inputs, please call it by passing in a dictionary, " + "eg `chain({'foo': 1, 'bar': 2})`" + ) + inputs = {list(_input_keys)[0]: inputs} + if self.memory is not None: + external_context = self.memory.load_memory_variables(inputs, session_key) + inputs = dict(inputs, **external_context) + self._validate_inputs(inputs) + return inputs + + def _get_tool_return( + self, next_step_output: Tuple[AgentAction, str] + ) -> Optional[AgentFinish]: + """Check if the tool is a returning tool.""" + agent_action, observation = next_step_output + name_to_tool_map = {tool.name: tool for tool in self.tools} + return_value_key = "output" + + if len(self.agent.return_values) > 0: + return_value_key = self.agent.return_values[0] + + try: + observation_list = json.loads(observation) + if agent_action.tool == "sql_db_query" and isinstance(observation_list, + list) and observation_list.__len__() != 0: + return AgentFinish( + {return_value_key: observation}, + "", + ) + except: + pass + + # Invalid tools won't be in the map, so we return False. + if agent_action.tool in name_to_tool_map: + if name_to_tool_map[agent_action.tool].return_direct: + return AgentFinish( + {return_value_key: observation}, + "", + ) + return None diff --git a/app/service/chat_robot/script/agents/conversational_functions_agent.py b/app/service/chat_robot/script/agents/conversational_functions_agent.py new file mode 100644 index 0000000..eb362a7 --- /dev/null +++ b/app/service/chat_robot/script/agents/conversational_functions_agent.py @@ -0,0 +1,198 @@ +import json +import re +from json import JSONDecodeError +from typing import List, Tuple, Any, Union +from dataclasses import dataclass + +from langchain.callbacks.manager import Callbacks +from langchain.agents import ( + OpenAIFunctionsAgent, +) +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseMessage, + OutputParserException +) +from langchain.schema.messages import ( + AIMessage, + FunctionMessage +) +from langchain.tools import BaseTool, StructuredTool +# from langchain.tools.convert_to_openai import FunctionDescription +from langchain.utils.openai_functions import FunctionDescription + + +@dataclass +class _FunctionsAgentAction(AgentAction): + """Add message_log to AgentAction class for the _FunctionAgentAction + """ + message_log: List[BaseMessage] + + def __init__( + self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any + ): + """Override init to support instantiation by position for backward compat.""" + super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) + + +def _convert_agent_action_to_messages( + agent_action: AgentAction, observation: str +) -> List[BaseMessage]: + """Convert an agents action to a message. + + This code is used to reconstruct the original AI message from the agents action. + + Args: + agent_action: Agent action to convert. + + Returns: + AIMessage that corresponds to the original tools invocation. + """ + if isinstance(agent_action, _FunctionsAgentAction): + return agent_action.message_log + [ + _create_function_message(agent_action, observation) + ] + else: + return [AIMessage(content=agent_action.log)] + + +def _create_function_message( + agent_action: AgentAction, observation: str +) -> FunctionMessage: + """Convert agents action and observation into a function message. + Args: + agent_action: the tools invocation request from the agents + observation: the result of the tools invocation + Returns: + FunctionMessage that corresponds to the original tools invocation + """ + if not isinstance(observation, str): + try: + content = json.dumps(observation, ensure_ascii=False) + except Exception: + content = str(observation) + else: + content = observation + return FunctionMessage( + name=agent_action.tool, + content=content, + ) + + +def _format_intermediate_steps( + intermediate_steps: List[Tuple[AgentAction, str]], +) -> List[BaseMessage]: + """Format intermediate steps. + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + Returns: + list of messages to send to the LLM for the next prediction + """ + messages = [] + + for intermediate_step in intermediate_steps: + agent_action, observation = intermediate_step + messages.extend(_convert_agent_action_to_messages(agent_action, observation)) + + return messages + + +def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: + """Format tools into the OpenAI function API.""" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": tool.param_description if hasattr(tool, 'param_description') else "", + }, + }, + "required": ["query"], + } + + return { + "name": tool.name, + "description": tool.description, + "parameters": parameters, + } + + +def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: + if not isinstance(message, AIMessage): + raise TypeError(f"Expected an AI message but got {type(message)}") + + function_call = message.additional_kwargs.get("function_call", {}) + + if function_call: + function_call = message.additional_kwargs["function_call"] + function_name = function_call["name"] + try: + _tool_input = json.loads(function_call["arguments"]) + except JSONDecodeError: + raise OutputParserException( + f"Could not parse tools input: {function_call} because" + f"the `arguments` is not valid JSON." + ) + + if "query" in _tool_input: + tool_input = _tool_input["query"] + else: + tool_input = _tool_input + + return _FunctionsAgentAction( + tool=function_name, + tool_input=tool_input, + log=f"\nInvoking: `{function_name}` with `{tool_input}`\n", + message_log=[message] + ) + + # pattern = r'\((.*?)\)' + # matches = re.findall(pattern, message.content) + # result = [] + # + # for match in matches: + # result.append(match) + # + # if result: + # output = result + # else: + # output = message.content + + return AgentFinish(return_values={"output": message.content}, log=message.content) + + +class ConversationalFunctionsAgent(OpenAIFunctionsAgent): + @property + def functions(self) -> List[dict]: + return [dict(_format_tool_to_openai_function(t)) for t in self.tools] + + def plan(self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any + ) -> Union[AgentAction, AgentFinish]: + """Decide how agents should move after receiving an input. The difference between + OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser + into it. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + **kwargs: Including user's input string + + Returns: + Action specifying what tools to use. + """ + agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) + messages: List[BaseMessage] = prompt.to_messages() + predicted_message = self.llm.predict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + agent_decision = _parse_ai_message(predicted_message) + return agent_decision diff --git a/app/service/chat_robot/script/callbacks/__init__.py b/app/service/chat_robot/script/callbacks/__init__.py new file mode 100644 index 0000000..8f644bd --- /dev/null +++ b/app/service/chat_robot/script/callbacks/__init__.py @@ -0,0 +1,6 @@ +from .openai_token_record_callback import OpenAITokenRecordCallbackHandler + + +__all__ = [ + 'OpenAITokenRecordCallbackHandler' +] diff --git a/app/service/chat_robot/script/callbacks/openai_token_record_callback.py b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py new file mode 100644 index 0000000..64ed7f4 --- /dev/null +++ b/app/service/chat_robot/script/callbacks/openai_token_record_callback.py @@ -0,0 +1,46 @@ +"""Callback Handler that add on_chain_end function to record Token usage.""" +from typing import Any, Dict + +from langchain.callbacks import OpenAICallbackHandler +from langchain.schema import LLMResult +from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model + + +class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler): + need_record: bool = True + response_type: str = "string" + """Callback Handler that tracks OpenAI info and write to redis after agent finish""" + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + self.successful_requests += 1 + if "token_usage" not in response.llm_output: + return None + if "function_call" in response.generations[0][0].message.additional_kwargs: + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]: + self.need_record = False + if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query": + self.response_type = "image" + token_usage = response.llm_output["token_usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + model_name = standardize_model_name(response.llm_output.get("model_name", "")) + if model_name in MODEL_COST_PER_1K_TOKENS: + completion_cost = get_openai_token_cost_for_model( + model_name, completion_tokens, is_completion=True + ) + prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens) + self.total_cost += prompt_cost + completion_cost + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None: + """Write token usage to redis.""" + outputs["total_tokens"] = self.total_tokens + outputs["total_cost"] = self.total_cost + outputs["prompt_tokens"] = self.prompt_tokens + outputs["completion_tokens"] = self.completion_tokens + outputs["need_record"] = self.need_record + outputs["response_type"] = self.response_type diff --git a/app/service/chat_robot/script/database.py b/app/service/chat_robot/script/database.py new file mode 100644 index 0000000..8a5dfdb --- /dev/null +++ b/app/service/chat_robot/script/database.py @@ -0,0 +1,79 @@ +from typing import Optional, List +import json + +from sqlalchemy import text +# from langchain import SQLDatabase +from langchain.utilities import SQLDatabase + + +class CustomDatabase(SQLDatabase): + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + # def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + connection = self._engine.connect() + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + # raise ValueError(f"table_names {missing_tables} not found in database") + return f"Table {','.join(missing_tables)} can not be found in the database" + all_table_names = table_names + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + ] + + tables = [] + for table in meta_tables: + table_name = table.name + column_names = table.columns.keys() + table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n" + for column_name in column_names: + if column_name not in ["ID", "img_name"]: + query = text(f"SELECT DISTINCT {column_name} FROM {table_name}") + result = connection.execute(query) + enum_values: List[str] = [row[0] for row in result.fetchall()] + column_info = f"{column_name}: {', '.join(enum_values)}\n" + table_info += column_info + + # table_info = f"Table: {table_name}\n" + # + # if self._sample_rows_in_table_info: + # table_info += f"{self._get_sample_rows(table)}\n" + tables.append(table_info) + final_str = "\n\n".join(tables) + return final_str + + def run(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + """ + with self._engine.begin() as connection: + if self._schema is not None: + if self.dialect == "snowflake": + connection.exec_driver_sql( + f"ALTER SESSION SET search_path='{self._schema}'" + ) + elif self.dialect == "bigquery": + connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") + else: + connection.exec_driver_sql(f"SET search_path TO {self._schema}") + cursor = connection.execute(text(command)) + if cursor.rowcount: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone() # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + + # Convert columns values to string to avoid issues with sqlalchmey + # trunacating text + if isinstance(result, list): + return json.dumps([r[0] for r in result]) + + return json.dumps([result[0]]) + return "" diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py new file mode 100644 index 0000000..2a62664 --- /dev/null +++ b/app/service/chat_robot/script/main.py @@ -0,0 +1,114 @@ +import logging +from loguru import logger +from langchain.agents import Tool +from langchain.utilities import SerpAPIWrapper +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder +from langchain.schema import SystemMessage, AIMessage +from langchain.chat_models import ChatOpenAI +from langchain.llms.openai import OpenAI +from langchain.callbacks import FileCallbackHandler +from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent +from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler +from app.service.chat_robot.script.database import CustomDatabase +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX +from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) +from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory +from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool +from app.core.config import * + +import os + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" +# log callbacks +logfile = "logs/chat_debug.log" +logger.add(logfile, colorize=True, enqueue=True) +log_handler = FileCallbackHandler(logfile) + +# Initiate our LLM 'gpt-3.5-turbo' +llm = ChatOpenAI(temperature=0.1, + openai_api_key=OPENAI_API_KEY, + # callbacks=[OpenAICallbackHandler()] + ) + +search = SerpAPIWrapper() +db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3', + include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress', + 'female_outwear', 'male_bottom', 'male_top', 'male_outwear'], + engine_args={"pool_recycle": 7200}) +tools = [ + Tool( + name="internet_search", + description="Can be used to perform Internet searches", + func=search.run + ), + QuerySQLDataBaseTool(db=db, return_direct=False), + InfoSQLDatabaseTool(db=db), + ListSQLDatabaseTool(db=db), + QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)), + Tool( + name="tutorial_tool", + description="Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string", + func=CustomTutorialTool(), + return_direct=True + ) +] + +messages = [ + SystemMessage(content=FASHION_CHAT_BOT_PREFIX), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template( + "{input} " + "Question from a {gender}." + ), + AIMessage(content=TOOLS_FUNCTIONS_SUFFIX), + MessagesPlaceholder(variable_name="agent_scratchpad"), +] + +prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages) +agent = ConversationalFunctionsAgent( + llm=llm, + tools=tools, + prompt=prompt +) + +memory = UserConversationBufferWindowMemory.from_redis( + return_messages=True, k=2, input_key='input', output_key='output' +) +agent_executor = CustomAgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + verbose=True, + memory=memory, +) + + +def chat(post_data): + user_id = post_data.user_id + session_id = post_data.session_id + input_message = post_data.message + gender = post_data.gender + + final_outputs = agent_executor( + {"input": input_message, "gender": gender}, + callbacks=[OpenAITokenRecordCallbackHandler(), log_handler], + session_key=f"buffer:{user_id}:{session_id}", + ) + api_response = { + 'user_id': user_id, + 'session_id': session_id, + # 'message_id': message_id, + # 'create_time': created_time, + 'input': final_outputs['input'], + # 'conversion': messages, + 'output': final_outputs['output'], + # 'gpt_response_time': gpt_response_time, + 'total_tokens': final_outputs['total_tokens'], + 'total_cost': final_outputs['total_cost'], + 'prompt_tokens': final_outputs['prompt_tokens'], + 'completion_tokens': final_outputs['completion_tokens'], + 'response_type': final_outputs['response_type'] + } + logging.info(api_response) + return api_response diff --git a/app/service/chat_robot/script/memory/__init__.py b/app/service/chat_robot/script/memory/__init__.py new file mode 100644 index 0000000..9586157 --- /dev/null +++ b/app/service/chat_robot/script/memory/__init__.py @@ -0,0 +1,3 @@ +from .user_buffer_window import UserConversationBufferWindowMemory + +__all__ = ['UserConversationBufferWindowMemory'] diff --git a/app/service/chat_robot/script/memory/user_buffer_window.py b/app/service/chat_robot/script/memory/user_buffer_window.py new file mode 100644 index 0000000..9fbc2d6 --- /dev/null +++ b/app/service/chat_robot/script/memory/user_buffer_window.py @@ -0,0 +1,93 @@ +import logging +from typing import Any, Dict, List, Tuple +import json + +import redis +from redis import Redis +from langchain.memory import RedisChatMessageHistory +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage +from langchain.schema.messages import _message_to_dict, messages_from_dict +from langchain.memory.utils import get_prompt_input_key + +from app.core.config import * + + +class UserConversationBufferWindowMemory(BaseChatMemory): + """Buffer for storing conversation memory.""" + + redis_client: Redis + human_prefix: str = "Human" + ai_prefix: str = "AI" + memory_key: str = "history" #: :meta private: + k: int = 5 + + @classmethod + def from_redis( + cls, + host: str = REDIS_HOST, + port: int = REDIS_PORT, + db: int = 3, + **kwargs + ): + redis_client = Redis(host=host, port=port, db=db) + try: + response = redis_client.ping() + if response: + print("Connect to redis server successfully.") + logging.info("Connect to redis server successfully.") + else: + print("Fail to connect to redis server") + logging.info("Fail to connect to redis server") + except redis.RedisError as e: + logging.info(f"Error occurs when connecting to redis server: {str(e)}") + return cls(redis_client=redis_client, **kwargs) + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]: + """Return history buffer.""" + _items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else [] + items = [json.loads(m.decode("utf-8")) for m in _items[::-1]] + buffer = messages_from_dict(items) + if not self.return_messages: + buffer = get_buffer_string( + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + return {self.memory_key: buffer} + + def _get_input_output( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> Tuple[str, str]: + if self.input_key is None: + prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) + else: + prompt_input_key = self.input_key + if self.output_key is None: + if len(outputs) != 1: + raise ValueError(f"One output key expected, got {outputs.keys()}") + output_key = list(outputs.keys())[0] + else: + output_key = self.output_key + return inputs[prompt_input_key], outputs[output_key] + + def add_message(self, key: str, message: BaseMessage) -> None: + self.redis_client.lpush(key, json.dumps(_message_to_dict(message))) + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None: + """Save context from this conversation to buffer.""" + input_str, output_str = self._get_input_output(inputs, outputs) + self.add_message(key, HumanMessage(content=input_str)) + self.add_message(key, AIMessage(content=output_str)) + + # def clear(self, key) -> None: + # """Clear memory contents.""" + # self.redis_client.delete(key) diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py new file mode 100644 index 0000000..a88044d --- /dev/null +++ b/app/service/chat_robot/script/prompt.py @@ -0,0 +1,52 @@ +FASHION_CHAT_BOT_PREFIX = """ +You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +Remember your answer should be very precise and the final output answer should not exceed 20 words. + +You may encounter the following types of questions: +1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +Never query for all the columns from a specific table, only ask for the relevant columns given the question. +You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +If the question does not seem related to the database, just return "I don't know" as the answer. + +2) If the query related to current events, you should use internet_search to seek help from the internet. + +3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. + +Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +""" + +TOOL_SELECT_SUFFIX = """ +Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly. +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +The use of online resources should be limited to inquiry pertaining to current subjects. +""" + +SQL_FUNCTIONS_SUFFIX = """ +For database-related questions, use SQL tools to identify relevant tables and query their schemas. +""" + +INTERNET_SEARCH_SUFFIX = """ +If the question should be answered using internet search tools, I should seek help from the internet. +""" + +ANSWER_FORMAT_SUFFIX = """ +My final answer are limited to 20 words and be as much precise as possible. +""" + +TOOLS_FUNCTIONS_SUFFIX = ( + "If the input involves clothing queries," + "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." + "All SQL statements must use 'ORDER BY RAND()', for example:" + "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" + "If the input does not involve clothing queries, " + "I should engage in conversation as an assistant or search from internet with internet_search tool." + "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +) + +TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." diff --git a/app/service/chat_robot/script/tools/__init__.py b/app/service/chat_robot/script/tools/__init__.py new file mode 100644 index 0000000..4a40a33 --- /dev/null +++ b/app/service/chat_robot/script/tools/__init__.py @@ -0,0 +1,10 @@ +from .sql_tools import ( + QuerySQLDataBaseTool, + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool +) + +__all__ = [ + "QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool" +] diff --git a/app/service/chat_robot/script/tools/sql_tools.py b/app/service/chat_robot/script/tools/sql_tools.py new file mode 100644 index 0000000..92b8003 --- /dev/null +++ b/app/service/chat_robot/script/tools/sql_tools.py @@ -0,0 +1,183 @@ +# flake8: noqa +"""Tools for interacting with a SQL database.""" +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.chains.llm import LLMChain +from langchain.prompts import PromptTemplate +# from langchain.sql_database import SQLDatabase +from langchain.utilities import SQLDatabase +from langchain.tools.base import BaseTool +from langchain.tools.sql_database.prompt import QUERY_CHECKER + + +class BaseSQLDatabaseTool(BaseModel): + """Base tools for interacting with a SQL database.""" + + db: SQLDatabase = Field(exclude=True) + param_description: str = "" + + # Override BaseTool.Config to appease mypy + # See https://github.com/pydantic/pydantic/issues/4173 + class Config(BaseTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.forbid + + +class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for querying a SQL database.""" + + name = "sql_db_query" + # description = """ + # Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables. + # This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database. + # You should always use ‘order by rand()’ to randomly select data. + # If the query is not correct, an error message will be returned. + # If an error is returned, rewrite the query, check the query, and try again. + # Always limit your query to at most 4 results. + # Never query for all the columns from a specific table, only ask for the relevant columns given the question. + # You MUST double check your query before executing it. + # DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + # """ + + description: str = ( + "The input of this tool is a detailed and correct SQL select query statement, " + "and the output is the result of the database, and it can only return up to 4 results." + "If the query is not correct, an error message will be returned." + "If an error is returned, rewrite the query, check the query, and try again." + "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," + "use sql_db_schema to query the correct table fields." + + "Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() " + "LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' " + "order by rand() LIMIT 2'" + ) + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + result = self.db.run_no_throw(query) + return result + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("QuerySqlDbTool does not support async") + + +class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting metadata about a SQL database.""" + + name = "sql_db_schema" + # description = """ + # The database contains information of lots of fashion items, such as item name, their fashion attributes. + # There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear. + # Find the most relevant tables with the query, and output the schema of these tables. + # """ + + description: str = ( + "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables." + "There are eight tables covering eight fashion categories: female_top, female_pants, female_dress," + "female_skirt, female_outwear, male_bottom, male_top, and male_outwear." + + "Example Input: 'female_outwear, male_top'" + ) + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw(table_names.split(", ")) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SchemaSqlDbTool does not support async") + + +class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting tables names.""" + + name = "sql_db_list_tables" + description = "Input is an empty string, output is a comma separated list of tables in the database." + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for a specific table.""" + return ", ".join(self.db.get_usable_table_names()) + + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("ListTablesSqlDbTool does not support async") + + +class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): + """Use an LLM to check if a query is correct. + Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: LLMChain = Field(init=False) + name = "sql_db_query_checker" + description = ( + "Use this tools to double check if your query is correct before executing it." + "Always use this tools before executing a query with sql_db_query!" + ) + + @root_validator(pre=True) + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "llm_chain" not in values: + values["llm_chain"] = LLMChain( + llm=values.get("llm"), + prompt=PromptTemplate( + template=QUERY_CHECKER, + input_variables=["query", "dialect"] + ), + ) + + if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: + # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: + raise ValueError( + "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.predict(query=query, dialect=self.db.dialect) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict(query=query, dialect=self.db.dialect) diff --git a/app/service/chat_robot/script/tools/tutorial_tool.py b/app/service/chat_robot/script/tools/tutorial_tool.py new file mode 100644 index 0000000..c08eb9d --- /dev/null +++ b/app/service/chat_robot/script/tools/tutorial_tool.py @@ -0,0 +1,19 @@ +from typing import Any + +from langchain.tools.base import BaseTool + +from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN + + +# 处理系统引导教程相关的输入 +class CustomTutorialTool(BaseTool): + name = "tutorial_tool" + + description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string") + + def _run(self, tool_input, **kwargs: Any) -> str: + return TUTORIAL_TOOL_RETURN + + async def _arun(self, tool_input, **kwargs: Any) -> str: + raise NotImplementedError("CustomTutorialTool does not support async") diff --git a/app/service/chat_robot/script/utils/__init__.py b/app/service/chat_robot/script/utils/__init__.py new file mode 100644 index 0000000..92a2f16 --- /dev/null +++ b/app/service/chat_robot/script/utils/__init__.py @@ -0,0 +1 @@ +from .logger import Logger diff --git a/app/service/chat_robot/script/utils/logger.py b/app/service/chat_robot/script/utils/logger.py new file mode 100644 index 0000000..cb52c18 --- /dev/null +++ b/app/service/chat_robot/script/utils/logger.py @@ -0,0 +1,26 @@ +import logging +from logging import handlers + + +class Logger(object): + level_relations = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'crit': logging.CRITICAL + } + + def __init__(self, filename, level='info', when='D', backCount=3, + fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'): + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) # set log format + self.logger.setLevel(self.level_relations.get(level)) # set log level + sh = logging.StreamHandler() # output to terminal + sh.setFormatter(format_str) # set format for terminal log + th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount, + encoding='utf-8') # log into file + + th.setFormatter(format_str) # set format for file log + self.logger.addHandler(sh) # output to terminal + self.logger.addHandler(th) # output to file diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py new file mode 100644 index 0000000..b9c2c80 --- /dev/null +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -0,0 +1,70 @@ +import os + +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, \ + PromptTemplate + +from app.core.config import OPENAI_MODEL, OPENAI_API_KEY + +# os.environ["http_proxy"] = "http://127.0.0.1:7890" +# os.environ["https_proxy"] = "http://127.0.0.1:7890" + + +llm = ChatOpenAI(model_name=OPENAI_MODEL, + openai_api_key=OPENAI_API_KEY, + temperature=0) + + +def translate_to_en(text): + template = ( + """You are a translation expert, proficient in various languages. + And can translate various languages into English. + Please translate to grammatically correct English regardless of the input language. + If the input is in English, check for grammatical errors. If there are no errors, simply output the sentence. + If there are grammatical errors, correct them and then output the sentence.""" + ) + system_message_prompt = SystemMessagePromptTemplate.from_template(template) + + # 待翻译文本由 Human 角色输入 + human_template = "User input : {text}" + human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template) + + # 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate + chat_prompt_template = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template) + + template = ( + """ + Input sentence: + {translate} + 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images, + ensuring all key nouns or adjectives related to the image are retained. + 2. Simplify complex sentence structures and clarify ambiguous expressions. + 3. Only Output the adjusted English sentence. + + Output : + """ + ) + # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words." + # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding. + + prompt_template = PromptTemplate(input_variables=["translate"], template=template) + prompt_chain = LLMChain(llm=llm, prompt=prompt_template) + + from langchain.chains import SimpleSequentialChain + overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True) + + response = overall_chain.run(text) + return response + + +def main(): + """Main function""" + translate_to_en("生成一件运动风格的夹克,带有拉链和口袋,适合休闲穿着") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 1529082..e3f2934 100644 Binary files a/requirements.txt and b/requirements.txt differ