feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
This commit is contained in:
@@ -3,27 +3,20 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_classic.agents import AgentExecutor
|
||||
from langchain_classic.schema import RUN_KEY
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks, CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import RunInfo
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
from langchain_classic.agents import OpenAIFunctionsAgent
|
||||
from langchain_community.utils.ernie_functions import FunctionDescription
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,7 +65,6 @@ def _create_function_message(
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
into it.
|
||||
|
||||
Args:
|
||||
callbacks:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
@@ -2,18 +2,16 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
|
||||
get_openai_token_cost_for_model
|
||||
|
||||
|
||||
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
need_record: bool = True
|
||||
response_type: str = "string"
|
||||
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
@@ -22,7 +20,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
if "function_call" in response.generations[0][0].message.additional_kwargs:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema", "tutorial_tool"]:
|
||||
self.need_record = False
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
|
||||
self.response_type = "image"
|
||||
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
return None
|
||||
|
||||
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
|
||||
"""Write token usage to redis."""
|
||||
|
||||
@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
Args:
|
||||
command:
|
||||
fetch:
|
||||
**kwargs:
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.agents import Tool
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
from langchain_core.callbacks import FileCallbackHandler
|
||||
from langchain_core.messages import SystemMessage, AIMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
|
||||
# # callbacks=[OpenAICallbackHandler()]
|
||||
# )
|
||||
|
||||
llm = ChatTongyi(api_key=QWEN_API_KEY)
|
||||
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -43,11 +43,11 @@ tools = [
|
||||
description="Can be used to perform Internet searches",
|
||||
func=search.run
|
||||
),
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
QuerySQLDataBaseTool(db=db),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
|
||||
# Tool(
|
||||
# name="tutorial_tool",
|
||||
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
@@ -133,5 +133,5 @@ def chat(post_data):
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs["response_type"]
|
||||
}
|
||||
logging.info(json.dumps(api_response))
|
||||
logging.info(json.dumps(api_response, indent=4))
|
||||
return api_response
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
|
||||
import json
|
||||
|
||||
import redis
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
|
||||
from redis import Redis
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@classmethod
|
||||
def from_redis(
|
||||
cls,
|
||||
host: str = REDIS_HOST,
|
||||
port: int = REDIS_PORT,
|
||||
host: str = settings.REDIS_HOST,
|
||||
port: int = settings.REDIS_PORT,
|
||||
db: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def add_message(self, key: str, message: BaseMessage) -> None:
|
||||
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
|
||||
@@ -5,10 +5,10 @@ from dashscope import Generation
|
||||
from retry import retry
|
||||
from urllib3.exceptions import NewConnectionError
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
@@ -149,7 +149,7 @@ tools = [
|
||||
}
|
||||
]
|
||||
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
|
||||
def search_from_internet(message):
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=message,
|
||||
prompt='The output must be in English.Keep the final result under 200 words.'
|
||||
# tools=tools,
|
||||
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
@@ -203,7 +203,7 @@ def get_response(messages):
|
||||
def get_assistant_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
|
||||
return response
|
||||
|
||||
|
||||
global tool_info
|
||||
|
||||
|
||||
def call_with_messages(message):
|
||||
global tool_info
|
||||
user_input = message
|
||||
print('\n')
|
||||
|
||||
@@ -241,7 +243,7 @@ def call_with_messages(message):
|
||||
response_type = "chat"
|
||||
|
||||
while flag and count <= 3:
|
||||
first_response = get_response(messages)
|
||||
first_response = get_response
|
||||
assistant_output = first_response.output.choices[0].message
|
||||
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
|
||||
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
|
||||
@@ -260,7 +262,7 @@ def call_with_messages(message):
|
||||
]
|
||||
tool_info['content'] = search_from_internet(message)
|
||||
flag = False
|
||||
result_content = tool_info['content'].output.text
|
||||
result_content = tool_info['content']
|
||||
# 如果模型选择的工具是get_database_table
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
|
||||
@@ -2,21 +2,15 @@
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
@@ -62,7 +56,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -95,9 +89,9 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -183,11 +177,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
# from langchain.chains.llm import LLMChain
|
||||
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
prompt = PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
Reference in New Issue
Block a user