Merge remote-tracking branch 'origin/local' into local
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
"""Callback Handler that add on_chain_end function to record Token usage."""
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
|
||||
get_openai_token_cost_for_model
|
||||
|
||||
|
||||
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import Dict
|
||||
|
||||
from dashscope.api_entities.dashscope_response import GenerationUsage
|
||||
|
||||
|
||||
class QWenCallbackHandler:
|
||||
|
||||
total_tokens: int = 0
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def on_llm_end(self, response: GenerationUsage) -> None:
|
||||
"""Collect token usage."""
|
||||
|
||||
|
||||
self.input_tokens += response.input_tokens
|
||||
self.output_tokens += response.output_tokens
|
||||
self.total_tokens = self.input_tokens + self.output_tokens
|
||||
self.total_cost = 0.04 * self.input_tokens / 1000 + 0.12 * self.output_tokens / 1000
|
||||
|
||||
def on_chain_end(self, outputs: Dict ) -> None:
|
||||
"""Write token usage to redis."""
|
||||
outputs["total_tokens"] = self.total_tokens
|
||||
outputs["total_cost"] = self.total_cost
|
||||
outputs["prompt_tokens"] = self.input_tokens
|
||||
outputs["completion_tokens"] = self.output_tokens
|
||||
print("input_tokens : {} \noutput_tokens : {}".format(outputs["prompt_tokens"] , outputs["completion_tokens"]))
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
|
||||
from sqlalchemy import text
|
||||
# from langchain import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
|
||||
class CustomDatabase(SQLDatabase):
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from loguru import logger
|
||||
from langchain.agents import Tool
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import *
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
|
||||
from app.service.chat_robot.script.service import CallQWen
|
||||
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
|
||||
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
|
||||
from app.core.config import *
|
||||
|
||||
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
||||
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
||||
@@ -27,10 +27,12 @@ logger.add(logfile, colorize=True, enqueue=True)
|
||||
log_handler = FileCallbackHandler(logfile)
|
||||
|
||||
# Initiate our LLM 'gpt-3.5-turbo'
|
||||
llm = ChatOpenAI(temperature=0.1,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
# callbacks=[OpenAICallbackHandler()]
|
||||
)
|
||||
# llm = ChatOpenAI(temperature=0.1,
|
||||
# openai_api_key=OPENAI_API_KEY,
|
||||
# # callbacks=[OpenAICallbackHandler()]
|
||||
# )
|
||||
|
||||
llm = ChatTongyi(api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
@@ -46,14 +48,15 @@ tools = [
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
Tool(
|
||||
name="tutorial_tool",
|
||||
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string",
|
||||
func=CustomTutorialTool(),
|
||||
return_direct=True
|
||||
)
|
||||
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm = ChatTongyi(temperature=0, api_key="sk-7658298c6b99443c98184a5e634fe6ab")),
|
||||
# Tool(
|
||||
# name="tutorial_tool",
|
||||
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
# "Input is an empty string",
|
||||
# func=CustomTutorialTool(),
|
||||
# return_direct=True
|
||||
# )
|
||||
]
|
||||
|
||||
messages = [
|
||||
@@ -91,25 +94,47 @@ def chat(post_data):
|
||||
input_message = post_data.message
|
||||
gender = post_data.gender
|
||||
|
||||
final_outputs = agent_executor(
|
||||
{"input": input_message, "gender": gender},
|
||||
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
|
||||
session_key=f"buffer:{user_id}:{session_id}",
|
||||
)
|
||||
# final_outputs = agent_executor(
|
||||
# {"input": input_message, "gender": gender},
|
||||
# callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
|
||||
# session_key=f"buffer:{user_id}:{session_id}",
|
||||
# )
|
||||
|
||||
final_outputs = CallQWen.call_with_messages(input_message)
|
||||
# api_response = {
|
||||
# 'user_id': user_id,
|
||||
# 'session_id': session_id,
|
||||
# # 'message_id': message_id,
|
||||
# # 'create_time': created_time,
|
||||
# 'input': final_outputs['input'],
|
||||
# # 'conversion': messages,
|
||||
# 'output': final_outputs['output'],
|
||||
# # 'gpt_response_time': gpt_response_time,
|
||||
# 'total_tokens': final_outputs['total_tokens'],
|
||||
# 'total_cost': final_outputs['total_cost'],
|
||||
# 'prompt_tokens': final_outputs['prompt_tokens'],
|
||||
# 'completion_tokens': final_outputs['completion_tokens'],
|
||||
# 'response_type': final_outputs['response_type']
|
||||
# }
|
||||
# if final_outputs["output"].startswith("["):
|
||||
# final_str = final_outputs["output"].replace("\\", "")
|
||||
# else:
|
||||
# final_str = final_outputs["output"]
|
||||
|
||||
api_response = {
|
||||
'user_id': user_id,
|
||||
'session_id': session_id,
|
||||
# 'message_id': message_id,
|
||||
# 'create_time': created_time,
|
||||
'input': final_outputs['input'],
|
||||
'input': input_message,
|
||||
# 'conversion': messages,
|
||||
'output': final_outputs['output'],
|
||||
'output': final_outputs["output"],
|
||||
# 'gpt_response_time': gpt_response_time,
|
||||
'total_tokens': final_outputs['total_tokens'],
|
||||
'total_cost': final_outputs['total_cost'],
|
||||
'prompt_tokens': final_outputs['prompt_tokens'],
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs['response_type']
|
||||
'response_type': final_outputs["response_type"]
|
||||
}
|
||||
logging.info(json.dumps(api_response))
|
||||
return api_response
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
|
||||
import redis
|
||||
from redis import Redis
|
||||
from langchain.memory import RedisChatMessageHistory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
|
||||
235
app/service/chat_robot/script/service/CallQWen.py
Normal file
235
app/service/chat_robot/script/service/CallQWen.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from dashscope import Generation
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
|
||||
|
||||
|
||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
get_table_info_description = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
|
||||
query_database_description = (
|
||||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||||
"and the output is the result of the database, and it can only return up to 4 results."
|
||||
"If the query is not correct, an error message will be returned."
|
||||
"If an error is returned, rewrite the query, check the query, and try again."
|
||||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||||
"use get_table_info to query the correct table fields."
|
||||
|
||||
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
|
||||
tools = [
|
||||
# 工具一
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "search_from_internet",
|
||||
# "description": "从网络搜索结果。",
|
||||
# "parameters": {
|
||||
# "type" : "object",
|
||||
# "properties" : {
|
||||
# "user_input" : {
|
||||
# "type" : "string",
|
||||
# "description" : "用户输入。比如 : 2025年的时尚潮流趋势是什么?"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# 工具二
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_database_table",
|
||||
"description": get_database_table_description,
|
||||
"parameters": {
|
||||
}
|
||||
}
|
||||
},
|
||||
# 工具三
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_table_info",
|
||||
"description": get_table_info_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_names": {
|
||||
"type": "list",
|
||||
"description": "需要查询表结构的表名"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["table_names"]
|
||||
}
|
||||
},
|
||||
# 工具四
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_database",
|
||||
"description": query_database_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql_string": {
|
||||
"type": "string",
|
||||
"description": "由模型生成的sql语句"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["sql_string"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
qwen = QWenCallbackHandler()
|
||||
|
||||
def search_from_internet(message):
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
|
||||
messages=message,
|
||||
tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='True'
|
||||
)
|
||||
return response
|
||||
|
||||
def get_database_table():
|
||||
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
|
||||
|
||||
def get_table_info(table_names):
|
||||
return CustomDatabase.get_table_info(db, table_names)
|
||||
|
||||
def query_database(sql_string):
|
||||
return CustomDatabase.run(db, sql_string)
|
||||
|
||||
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='True'
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def call_with_messages(message):
|
||||
print('\n')
|
||||
# messages = [
|
||||
# {
|
||||
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
|
||||
# "role": "user"
|
||||
# }
|
||||
# ]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"content": FASHION_CHAT_BOT_PREFIX, # 系统message
|
||||
"role": "system"
|
||||
},
|
||||
{
|
||||
# "content": input('请输入:'), # 用户message
|
||||
"content": message, # 用户message
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": TOOLS_FUNCTIONS_SUFFIX, # ai message
|
||||
"role": "assistant"
|
||||
}
|
||||
]
|
||||
|
||||
# 模型的第一轮调用
|
||||
# first_response = get_response(messages)
|
||||
# assistant_output = first_response.output.choices[0].message
|
||||
# print(f"\n大模型第一轮输出信息:{first_response}\n")
|
||||
# messages.append(assistant_output)
|
||||
flag = True
|
||||
count = 1
|
||||
result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
||||
response_type = "chat"
|
||||
|
||||
while flag and count <= 3:
|
||||
first_response = get_response(messages)
|
||||
assistant_output = first_response.output.choices[0].message
|
||||
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
|
||||
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
|
||||
messages.append(assistant_output)
|
||||
|
||||
if 'tool_calls' not in assistant_output: # 如果模型判断无需调用工具,则将assistant的回复直接打印出来,无需进行模型的第二轮调用
|
||||
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
|
||||
result_content = assistant_output.content
|
||||
break
|
||||
# 如果模型选择的工具是search_from_internet
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet':
|
||||
# tool_info = {"name": "search_from_internet", "role": "tool"}
|
||||
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input']
|
||||
# tool_info['content'] = search_from_internet(user_input)
|
||||
# 如果模型选择的工具是get_database_table
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
# 如果模型选择的工具是get_table_info
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
||||
tool_info = {"name": "get_table_info", "role": "tool"}
|
||||
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
||||
tool_info['content'] = get_table_info(table_names)
|
||||
# 如果模型选择的工具是query_database
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
||||
tool_info = {"name": "query_database", "role": "tool"}
|
||||
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||||
tool_info['content'] = query_database(sql_string)
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
response_type = "image"
|
||||
|
||||
print(f"工具输出信息:{tool_info['content']}\n")
|
||||
messages.append(tool_info)
|
||||
count += 1
|
||||
|
||||
final_output = {"output": result_content}
|
||||
final_output["response_type"] = response_type
|
||||
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
||||
|
||||
|
||||
|
||||
# 模型的第二轮调用,对工具的输出进行总结
|
||||
# if flag :
|
||||
# second_response = get_response(messages)
|
||||
# print(f"大模型第二轮输出信息:{second_response}\n")
|
||||
# print(f"最终答案:{second_response.output.choices[0].message['content']}")
|
||||
# result_content = second_response.output.choices[0].message['content']
|
||||
|
||||
return final_output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
call_with_messages()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@@ -12,9 +12,11 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
@@ -135,32 +137,70 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||
|
||||
|
||||
# class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
# """Use an LLM to check if a query is correct.
|
||||
# Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
#
|
||||
# template: str = QUERY_CHECKER
|
||||
# llm: BaseLanguageModel
|
||||
# llm_chain: LLMChain = Field(init=False)
|
||||
# name = "sql_db_query_checker"
|
||||
# description = (
|
||||
# "Use this tools to double check if your query is correct before executing it."
|
||||
# "Always use this tools before executing a query with sql_db_query!"
|
||||
# )
|
||||
#
|
||||
# @root_validator(pre=True)
|
||||
# def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# if "llm_chain" not in values:
|
||||
# values["llm_chain"] = LLMChain(
|
||||
# llm=values.get("llm"),
|
||||
# prompt=PromptTemplate(
|
||||
# template=QUERY_CHECKER,
|
||||
# input_variables=["query", "dialect"]
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
# raise ValueError(
|
||||
# "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
# )
|
||||
#
|
||||
# return values
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: LLMChain = Field(init=False)
|
||||
name = "sql_db_query_checker"
|
||||
description = (
|
||||
"Use this tools to double check if your query is correct before executing it."
|
||||
"Always use this tools before executing a query with sql_db_query!"
|
||||
)
|
||||
llm_chain: Any = Field(init=False)
|
||||
name: str = "sql_db_query_checker"
|
||||
description: str = """
|
||||
Use this tool to double check if your query is correct before executing it.
|
||||
Always use this tool before executing a query with sql_db_query!
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER,
|
||||
input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
# from langchain.chains.llm import LLMChain
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
prompt = PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
)
|
||||
values["llm_chain"] = prompt | llm
|
||||
# values["llm_chain"] = LLMChain(
|
||||
# llm=values.get("llm"), # type: ignore[arg-type]
|
||||
# prompt=PromptTemplate(
|
||||
# template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
# ),
|
||||
# )
|
||||
|
||||
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
if values["llm_chain"].first.input_variables != ["dialect", "query"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from dashscope import Generation
|
||||
# from langchain.chains import LLMChain
|
||||
from langchain_community.chat_models import QianfanChatEndpoint, ChatTongyi
|
||||
# from langchain.chat_models import ChatOpenAI
|
||||
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableSequence
|
||||
|
||||
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
|
||||
|
||||
@@ -10,9 +13,9 @@ from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
|
||||
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
||||
|
||||
|
||||
llm = ChatOpenAI(model_name=OPENAI_MODEL,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
temperature=0)
|
||||
# llm = ChatOpenAI(model_name=OPENAI_MODEL,
|
||||
# openai_api_key=OPENAI_API_KEY,
|
||||
# temperature=0)
|
||||
|
||||
|
||||
def translate_to_en(text):
|
||||
@@ -24,48 +27,34 @@ def translate_to_en(text):
|
||||
output the input text exactly as it is without any modifications or additions.
|
||||
If there are grammatical errors, correct them and then output the sentence."""
|
||||
)
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
||||
messages = [
|
||||
{
|
||||
"content": template, # 系统message
|
||||
"role": "system"
|
||||
},
|
||||
{
|
||||
# "content": input('请输入:'), # 用户message
|
||||
"content": text, # 用户message
|
||||
"role": "user"
|
||||
}
|
||||
]
|
||||
first_response = get_response(messages)
|
||||
assistant_output = first_response.output.choices[0].message
|
||||
print("translate result : {}".format(assistant_output))
|
||||
return assistant_output.content
|
||||
|
||||
# 待翻译文本由 Human 角色输入
|
||||
human_template = "User input : {text}"
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
|
||||
|
||||
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
|
||||
chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||
[system_message_prompt, human_message_prompt]
|
||||
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='True'
|
||||
)
|
||||
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template)
|
||||
|
||||
result = translate_chain.invoke(text)
|
||||
|
||||
logging.info("translate result : " + result.get('text'))
|
||||
# print("translate result : " + result.get('text'))
|
||||
return result.get('text')
|
||||
|
||||
# template = (
|
||||
# """
|
||||
# Input sentence:
|
||||
# {translate}
|
||||
# 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
|
||||
# ensuring all key nouns or adjectives related to the image are retained.
|
||||
# 2. Simplify complex sentence structures and clarify ambiguous expressions.
|
||||
# 3. Only Output the adjusted English sentence.
|
||||
#
|
||||
# Output :
|
||||
# """
|
||||
# )
|
||||
# # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
|
||||
# # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
|
||||
#
|
||||
# prompt_template = PromptTemplate(input_variables=["translate"], template=template)
|
||||
# prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
|
||||
#
|
||||
# from langchain.chains import SimpleSequentialChain
|
||||
# overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
|
||||
#
|
||||
# response = overall_chain.run(text)
|
||||
# return response
|
||||
|
||||
return response
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
BIN
requirements_2.txt
Normal file
BIN
requirements_2.txt
Normal file
Binary file not shown.
Reference in New Issue
Block a user