Merge remote-tracking branch 'origin/local' into local

This commit is contained in:
zhouchengrong
2024-07-09 10:18:46 +08:00
10 changed files with 414 additions and 94 deletions

View File

@@ -1,9 +1,13 @@
"""Callback Handler that add on_chain_end function to record Token usage.""" """Callback Handler that add on_chain_end function to record Token usage."""
from typing import Any, Dict from typing import Any, Dict
from langchain.callbacks import OpenAICallbackHandler from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
get_openai_token_cost_for_model
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler): class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):

View File

@@ -0,0 +1,28 @@
from typing import Dict
from dashscope.api_entities.dashscope_response import GenerationUsage
class QWenCallbackHandler:
total_tokens: int = 0
input_tokens: int = 0
output_tokens: int = 0
total_cost: float = 0.0
def on_llm_end(self, response: GenerationUsage) -> None:
"""Collect token usage."""
self.input_tokens += response.input_tokens
self.output_tokens += response.output_tokens
self.total_tokens = self.input_tokens + self.output_tokens
self.total_cost = 0.04 * self.input_tokens / 1000 + 0.12 * self.output_tokens / 1000
def on_chain_end(self, outputs: Dict ) -> None:
"""Write token usage to redis."""
outputs["total_tokens"] = self.total_tokens
outputs["total_cost"] = self.total_cost
outputs["prompt_tokens"] = self.input_tokens
outputs["completion_tokens"] = self.output_tokens
print("input_tokens : {} \noutput_tokens : {}".format(outputs["prompt_tokens"] , outputs["completion_tokens"]))

View File

@@ -3,7 +3,7 @@ import json
from sqlalchemy import text from sqlalchemy import text
# from langchain import SQLDatabase # from langchain import SQLDatabase
from langchain.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
class CustomDatabase(SQLDatabase): class CustomDatabase(SQLDatabase):

View File

@@ -1,23 +1,23 @@
import json import json
import logging import logging
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from langchain.agents import Tool from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler from langchain.callbacks import FileCallbackHandler
from langchain.chat_models import ChatOpenAI from langchain.utilities import SerpAPIWrapper
from langchain.llms.openai import OpenAI
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper from langchain.chat_models import ChatOpenAI
from loguru import logger from langchain.llms.openai import OpenAI
from langchain.callbacks import FileCallbackHandler
from app.core.config import *
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.service import CallQWen
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool) from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool from app.core.config import *
# os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890"
@@ -27,10 +27,12 @@ logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile) log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo' # Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1, # llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY, # openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()] # # callbacks=[OpenAICallbackHandler()]
) # )
llm = ChatTongyi(api_key="sk-7658298c6b99443c98184a5e634fe6ab")
search = SerpAPIWrapper() search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3', db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
@@ -46,14 +48,15 @@ tools = [
QuerySQLDataBaseTool(db=db, return_direct=False), QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db), InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db), ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)), # QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool( QuerySQLCheckerTool(db=db, llm = ChatTongyi(temperature=0, api_key="sk-7658298c6b99443c98184a5e634fe6ab")),
name="tutorial_tool", # Tool(
description="Utilize this tool to retrieve specific statements related to user guidance tutorials." # name="tutorial_tool",
"Input is an empty string", # description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
func=CustomTutorialTool(), # "Input is an empty string",
return_direct=True # func=CustomTutorialTool(),
) # return_direct=True
# )
] ]
messages = [ messages = [
@@ -91,25 +94,47 @@ def chat(post_data):
input_message = post_data.message input_message = post_data.message
gender = post_data.gender gender = post_data.gender
final_outputs = agent_executor( # final_outputs = agent_executor(
{"input": input_message, "gender": gender}, # {"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler], # callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}", # session_key=f"buffer:{user_id}:{session_id}",
) # )
final_outputs = CallQWen.call_with_messages(input_message)
# api_response = {
# 'user_id': user_id,
# 'session_id': session_id,
# # 'message_id': message_id,
# # 'create_time': created_time,
# 'input': final_outputs['input'],
# # 'conversion': messages,
# 'output': final_outputs['output'],
# # 'gpt_response_time': gpt_response_time,
# 'total_tokens': final_outputs['total_tokens'],
# 'total_cost': final_outputs['total_cost'],
# 'prompt_tokens': final_outputs['prompt_tokens'],
# 'completion_tokens': final_outputs['completion_tokens'],
# 'response_type': final_outputs['response_type']
# }
# if final_outputs["output"].startswith("["):
# final_str = final_outputs["output"].replace("\\", "")
# else:
# final_str = final_outputs["output"]
api_response = { api_response = {
'user_id': user_id, 'user_id': user_id,
'session_id': session_id, 'session_id': session_id,
# 'message_id': message_id, # 'message_id': message_id,
# 'create_time': created_time, # 'create_time': created_time,
'input': final_outputs['input'], 'input': input_message,
# 'conversion': messages, # 'conversion': messages,
'output': final_outputs['output'], 'output': final_outputs["output"],
# 'gpt_response_time': gpt_response_time, # 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'], 'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'], 'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'], 'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'], 'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type'] 'response_type': final_outputs["response_type"]
} }
logging.info(json.dumps(api_response)) logging.info(json.dumps(api_response))
return api_response return api_response

View File

@@ -4,7 +4,6 @@ import json
import redis import redis
from redis import Redis from redis import Redis
from langchain.memory import RedisChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict from langchain.schema.messages import _message_to_dict, messages_from_dict

View File

@@ -0,0 +1,235 @@
import json
from typing import Dict, Any
from dashscope import Generation
from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
get_table_info_description = (
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
query_database_description = (
"The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use get_table_info to query the correct table fields."
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
tools = [
# 工具一
# {
# "type": "function",
# "function": {
# "name": "search_from_internet",
# "description": "从网络搜索结果。",
# "parameters": {
# "type" : "object",
# "properties" : {
# "user_input" : {
# "type" : "string",
# "description" : "用户输入。比如 2025年的时尚潮流趋势是什么"
# }
# }
# }
# }
# },
# 工具二
{
"type": "function",
"function": {
"name": "get_database_table",
"description": get_database_table_description,
"parameters": {
}
}
},
# 工具三
{
"type": "function",
"function": {
"name": "get_table_info",
"description": get_table_info_description,
"parameters": {
"type": "object",
"properties": {
"table_names": {
"type": "list",
"description": "需要查询表结构的表名"
}
}
},
"required": ["table_names"]
}
},
# 工具四
{
"type": "function",
"function": {
"name": "query_database",
"description": query_database_description,
"parameters": {
"type": "object",
"properties": {
"sql_string": {
"type": "string",
"description": "由模型生成的sql语句"
}
}
},
"required": ["sql_string"]
}
}
]
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
qwen = QWenCallbackHandler()
def search_from_internet(message):
response = Generation.call(
model='qwen-turbo',
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
messages=message,
tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='True'
)
return response
def get_database_table():
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
def get_table_info(table_names):
return CustomDatabase.get_table_info(db, table_names)
def query_database(sql_string):
return CustomDatabase.run(db, sql_string)
def get_response(messages):
response = Generation.call(
model='qwen-max',
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
messages=messages,
tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='True'
)
return response
def call_with_messages(message):
print('\n')
# messages = [
# {
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
# "role": "user"
# }
# ]
messages = [
{
"content": FASHION_CHAT_BOT_PREFIX, # 系统message
"role": "system"
},
{
# "content": input('请输入:'), # 用户message
"content": message, # 用户message
"role": "user"
},
{
"content": TOOLS_FUNCTIONS_SUFFIX, # ai message
"role": "assistant"
}
]
# 模型的第一轮调用
# first_response = get_response(messages)
# assistant_output = first_response.output.choices[0].message
# print(f"\n大模型第一轮输出信息{first_response}\n")
# messages.append(assistant_output)
flag = True
count = 1
result_content = "我是一个时尚AI助手请问有什么可以帮您"
response_type = "chat"
while flag and count <= 3:
first_response = get_response(messages)
assistant_output = first_response.output.choices[0].message
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
messages.append(assistant_output)
if 'tool_calls' not in assistant_output: # 如果模型判断无需调用工具则将assistant的回复直接打印出来无需进行模型的第二轮调用
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
result_content = assistant_output.content
break
# 如果模型选择的工具是search_from_internet
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet':
# tool_info = {"name": "search_from_internet", "role": "tool"}
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input']
# tool_info['content'] = search_from_internet(user_input)
# 如果模型选择的工具是get_database_table
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
# 如果模型选择的工具是get_table_info
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
tool_info = {"name": "get_table_info", "role": "tool"}
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
tool_info['content'] = get_table_info(table_names)
# 如果模型选择的工具是query_database
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
tool_info = {"name": "query_database", "role": "tool"}
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
tool_info['content'] = query_database(sql_string)
flag = False
result_content = tool_info['content']
response_type = "image"
print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info)
count += 1
final_output = {"output": result_content}
final_output["response_type"] = response_type
QWenCallbackHandler.on_chain_end(qwen, final_output)
# 模型的第二轮调用,对工具的输出进行总结
# if flag :
# second_response = get_response(messages)
# print(f"大模型第二轮输出信息:{second_response}\n")
# print(f"最终答案:{second_response.output.choices[0].message['content']}")
# result_content = second_response.output.choices[0].message['content']
return final_output
if __name__ == '__main__':
call_with_messages()

View File

@@ -1,6 +1,6 @@
# flake8: noqa # flake8: noqa
"""Tools for interacting with a SQL database.""" """Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@@ -12,9 +12,11 @@ from langchain.callbacks.manager import (
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase # from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
class BaseSQLDatabaseTool(BaseModel): class BaseSQLDatabaseTool(BaseModel):
@@ -135,32 +137,70 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
raise NotImplementedError("ListTablesSqlDbTool does not support async") raise NotImplementedError("ListTablesSqlDbTool does not support async")
# class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
# """Use an LLM to check if a query is correct.
# Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
#
# template: str = QUERY_CHECKER
# llm: BaseLanguageModel
# llm_chain: LLMChain = Field(init=False)
# name = "sql_db_query_checker"
# description = (
# "Use this tools to double check if your query is correct before executing it."
# "Always use this tools before executing a query with sql_db_query!"
# )
#
# @root_validator(pre=True)
# def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# if "llm_chain" not in values:
# values["llm_chain"] = LLMChain(
# llm=values.get("llm"),
# prompt=PromptTemplate(
# template=QUERY_CHECKER,
# input_variables=["query", "dialect"]
# ),
# )
#
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
# raise ValueError(
# "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
# )
#
# return values
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct. """Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLanguageModel llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False) llm_chain: Any = Field(init=False)
name = "sql_db_query_checker" name: str = "sql_db_query_checker"
description = ( description: str = """
"Use this tools to double check if your query is correct before executing it." Use this tool to double check if your query is correct before executing it.
"Always use this tools before executing a query with sql_db_query!" Always use this tool before executing a query with sql_db_query!
) """
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True) @root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values: if "llm_chain" not in values:
values["llm_chain"] = LLMChain( # from langchain.chains.llm import LLMChain
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: llm = values.get("llm") # type: ignore[arg-type]
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: prompt = PromptTemplate(
template=QUERY_CHECKER, input_variables=["dialect", "query"]
)
values["llm_chain"] = prompt | llm
# values["llm_chain"] = LLMChain(
# llm=values.get("llm"), # type: ignore[arg-type]
# prompt=PromptTemplate(
# template=QUERY_CHECKER, input_variables=["dialect", "query"]
# ),
# )
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
if values["llm_chain"].first.input_variables != ["dialect", "query"]:
raise ValueError( raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
) )

View File

@@ -1,8 +1,11 @@
import logging import logging
from langchain.chains import LLMChain from dashscope import Generation
from langchain.chat_models import ChatOpenAI # from langchain.chains import LLMChain
from langchain_community.chat_models import QianfanChatEndpoint, ChatTongyi
# from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import RunnableSequence
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
@@ -10,9 +13,9 @@ from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
# os.environ["https_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890"
llm = ChatOpenAI(model_name=OPENAI_MODEL, # llm = ChatOpenAI(model_name=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY, # openai_api_key=OPENAI_API_KEY,
temperature=0) # temperature=0)
def translate_to_en(text): def translate_to_en(text):
@@ -24,48 +27,34 @@ def translate_to_en(text):
output the input text exactly as it is without any modifications or additions. output the input text exactly as it is without any modifications or additions.
If there are grammatical errors, correct them and then output the sentence.""" If there are grammatical errors, correct them and then output the sentence."""
) )
system_message_prompt = SystemMessagePromptTemplate.from_template(template) messages = [
{
"content": template, # 系统message
"role": "system"
},
{
# "content": input('请输入:'), # 用户message
"content": text, # 用户message
"role": "user"
}
]
first_response = get_response(messages)
assistant_output = first_response.output.choices[0].message
print("translate result : {}".format(assistant_output))
return assistant_output.content
# 待翻译文本由 Human 角色输入
human_template = "User input : {text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
chat_prompt_template = ChatPromptTemplate.from_messages( def get_response(messages):
[system_message_prompt, human_message_prompt] response = Generation.call(
model='qwen-max',
api_key='sk-7658298c6b99443c98184a5e634fe6ab',
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='True'
) )
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template) return response
result = translate_chain.invoke(text)
logging.info("translate result : " + result.get('text'))
# print("translate result : " + result.get('text'))
return result.get('text')
# template = (
# """
# Input sentence:
# {translate}
# 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
# ensuring all key nouns or adjectives related to the image are retained.
# 2. Simplify complex sentence structures and clarify ambiguous expressions.
# 3. Only Output the adjusted English sentence.
#
# Output :
# """
# )
# # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
# # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
#
# prompt_template = PromptTemplate(input_variables=["translate"], template=template)
# prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
#
# from langchain.chains import SimpleSequentialChain
# overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
#
# response = overall_chain.run(text)
# return response
def main(): def main():
"""Main function""" """Main function"""

Binary file not shown.

BIN
requirements_2.txt Normal file

Binary file not shown.