feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped

This commit is contained in:
zcr
2025-12-30 16:49:08 +08:00
parent 1be716e414
commit 18024a2d70
167 changed files with 5283 additions and 10464 deletions

View File

@@ -3,27 +3,20 @@ import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_classic.agents import AgentExecutor
from langchain_classic.schema import RUN_KEY
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks, CallbackManager
from langchain_core.load import dumpd
from langchain_core.outputs import RunInfo
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
**kwargs:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
{return_value_key: observation},
"",
)
except:
except Exception as e:
print(e)
pass
# Invalid tools won't be in the map, so we return False.

View File

@@ -1,26 +1,15 @@
import json
import re
from dataclasses import dataclass
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
from langchain_classic.agents import OpenAIFunctionsAgent
from langchain_community.utils.ernie_functions import FunctionDescription
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
from langchain_core.tools import BaseTool
@dataclass
@@ -76,7 +65,6 @@ def _create_function_message(
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
into it.
Args:
callbacks:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string

View File

@@ -2,18 +2,16 @@
from typing import Any, Dict
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
get_openai_token_cost_for_model
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
from langchain_core.outputs import LLMResult
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
@@ -22,7 +20,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
if "token_usage" not in response.llm_output:
return None
if "function_call" in response.generations[0][0].message.additional_kwargs:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema", "tutorial_tool"]:
self.need_record = False
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
self.response_type = "image"
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
return None
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""

View File

@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
Args:
command:
fetch:
**kwargs:
"""
with self._engine.begin() as connection:
if self._schema is not None:

View File

@@ -1,15 +1,15 @@
import json
import logging
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.callbacks import FileCallbackHandler
from langchain_core.messages import SystemMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_core.tools import Tool
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
# # callbacks=[OpenAICallbackHandler()]
# )
llm = ChatTongyi(api_key=QWEN_API_KEY)
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -43,11 +43,11 @@ tools = [
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
QuerySQLDataBaseTool(db=db),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
# Tool(
# name="tutorial_tool",
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
@@ -133,5 +133,5 @@ def chat(post_data):
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs["response_type"]
}
logging.info(json.dumps(api_response))
logging.info(json.dumps(api_response, indent=4))
return api_response

View File

@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
import json
import redis
from langchain_classic.memory.chat_memory import BaseChatMemory
from langchain_classic.memory.utils import get_prompt_input_key
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
from redis import Redis
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
from app.core.config import settings
class UserConversationBufferWindowMemory(BaseChatMemory):
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
host: str = settings.REDIS_HOST,
port: int = settings.REDIS_PORT,
db: int = 3,
**kwargs
):
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""

View File

@@ -5,10 +5,10 @@ from dashscope import Generation
from retry import retry
from urllib3.exceptions import NewConnectionError
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
from app.service.search_image_with_text.service import query
@@ -149,7 +149,7 @@ tools = [
}
]
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
def search_from_internet(message):
response = Generation.call(
model='qwen-turbo',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=message,
prompt='The output must be in English.Keep the final result under 200 words.'
# tools=tools,
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
def get_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
@@ -203,7 +203,7 @@ def get_response(messages):
def get_assistant_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
return response
global tool_info
def call_with_messages(message):
global tool_info
user_input = message
print('\n')
@@ -241,7 +243,7 @@ def call_with_messages(message):
response_type = "chat"
while flag and count <= 3:
first_response = get_response(messages)
first_response = get_response
assistant_output = first_response.output.choices[0].message
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
@@ -260,7 +262,7 @@ def call_with_messages(message):
]
tool_info['content'] = search_from_internet(message)
flag = False
result_content = tool_info['content'].output.text
result_content = tool_info['content']
# 如果模型选择的工具是get_database_table
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}

View File

@@ -2,21 +2,15 @@
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
# from langchain.sql_database import SQLDatabase
from langchain_community.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Extra, Field, root_validator
class BaseSQLDatabaseTool(BaseModel):
@@ -62,7 +56,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
)
def _run(
self,
@@ -95,9 +89,9 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
)
def _run(
self,
@@ -183,11 +177,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
# from langchain.chains.llm import LLMChain
llm = values.get("llm") # type: ignore[arg-type]
llm = values.get("llm") # type: ignore[arg-type]
prompt = PromptTemplate(
template=QUERY_CHECKER, input_variables=["dialect", "query"]
)

View File

@@ -1,6 +1,6 @@
from typing import Any
from langchain.tools.base import BaseTool
from langchain_core.tools import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN