111
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from dashscope import Generation
|
||||
from requests import RequestException
|
||||
from retry import retry
|
||||
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.prompt import GET_LANGUAGE_PREFIX
|
||||
from app.service.prompt_generation.util import minio_util
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_assistant_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='false'
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_language(message: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
"content": GET_LANGUAGE_PREFIX, # ai message
|
||||
"role": "system"
|
||||
},
|
||||
{
|
||||
"content": "Tree", # 用户message
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "English", # 用户message
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "玩具", # 用户message
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Chinese", # 用户message
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": message, # 用户message
|
||||
"role": "user"
|
||||
}
|
||||
]
|
||||
|
||||
first_response = get_assistant_response(messages)
|
||||
assistant_output = first_response.output.choices[0].message.content
|
||||
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||||
# print(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||||
return assistant_output
|
||||
|
||||
|
||||
@retry(exceptions=RequestException, tries=3, delay=1)
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='True'
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_translation_from_llama3(text):
|
||||
start_time = time.time()
|
||||
url = f"http://{settings.A6000_SERVICE_HOST}:12434/api/generate"
|
||||
# 先获取用户输入文本的语言
|
||||
language = get_language(text)
|
||||
|
||||
if 'English' in language:
|
||||
return text
|
||||
|
||||
# 创建请求的负载 translator是自定义的翻译模型
|
||||
payload = {
|
||||
"model": "AiDA-translator:latest",
|
||||
"prompt": f"[{text}]",
|
||||
"stream": False
|
||||
}
|
||||
# 将负载转换为 JSON 格式
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
# print("Response from server:")
|
||||
# print(response.json())
|
||||
resp = json.loads(response.content).get("response")
|
||||
logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}")
|
||||
print("input : {}, translate result : {}".format(text, resp))
|
||||
return resp
|
||||
else:
|
||||
logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}")
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
print(response.text)
|
||||
return ""
|
||||
|
||||
|
||||
# 在llama3中创建一个翻译模型
|
||||
# def create_model_with_llama(text):
|
||||
# url = "http://localhost:11434/api/create"
|
||||
# # url = "http://20.1.1.43:1143/api/generate"
|
||||
#
|
||||
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
||||
#
|
||||
# # 创建翻译器的配置文件
|
||||
# payload = {
|
||||
# "model": "translator",
|
||||
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
|
||||
# "Never translate or modify any English input."
|
||||
# "The input must be fully translated into coherent English sentences."
|
||||
# }
|
||||
#
|
||||
# # 将负载转换为 JSON 格式
|
||||
# headers = {'Content-Type': 'application/json'}
|
||||
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
|
||||
|
||||
def get_prompt_from_image(image_path, text):
|
||||
start_time = time.time()
|
||||
# url = "http://localhost:11434/api/generate"
|
||||
url = f"http://{settings.B_4_X_4090_SERVICE_HOST}:11434/api/generate"
|
||||
|
||||
image_base64 = minio_util.minio_url_to_base64(image_path.img)
|
||||
# image_base64 = minio_url_to_base64(image_path)
|
||||
|
||||
# 创建请求的负载 translator是自定义的翻译模型
|
||||
payload = {
|
||||
"model": "llama3.2-vision",
|
||||
"images": [image_base64],
|
||||
"prompt": f"{text}",
|
||||
"stream": False
|
||||
}
|
||||
# 将负载转换为 JSON 格式
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
# 处理响应
|
||||
if response.status_code == 200:
|
||||
# print("Response from server:")
|
||||
# print(response.json())
|
||||
resp = json.loads(response.content).get("response")
|
||||
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
|
||||
# print("input : {}, sketch re-generate result : {}".format(text, resp))
|
||||
return resp
|
||||
else:
|
||||
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
|
||||
print(f"Request failed with status code {response.status_code}")
|
||||
print(response.text)
|
||||
return ""
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
text = get_translation_from_llama3("[火焰]")
|
||||
print(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user