调用llama3.2-vision,自动识别图片,输出prompt
This commit is contained in:
@@ -4,9 +4,10 @@ import time
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
from app.schemas.prompt_generation import PromptGenerationImageModel
|
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en, get_translation_from_llama3
|
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
|
||||||
|
get_prompt_from_image
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -32,3 +33,20 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
|||||||
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=data)
|
return ResponseModel(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/img2prompt")
|
||||||
|
def get_prompt_from_img(img: ImageRequest):
|
||||||
|
"""
|
||||||
|
自动识别图片并输出为prompt
|
||||||
|
|
||||||
|
:param img: 图片的minio地址
|
||||||
|
:return: 图片的文字描述
|
||||||
|
"""
|
||||||
|
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||||
|
"The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||||
|
"given.")
|
||||||
|
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||||
|
description = get_prompt_from_image(img, text)
|
||||||
|
logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||||
|
return description
|
||||||
|
|||||||
@@ -3,3 +3,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class PromptGenerationImageModel(BaseModel):
|
class PromptGenerationImageModel(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRequest(BaseModel):
|
||||||
|
img: str
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from retry import retry
|
|||||||
|
|
||||||
from app.core.config import QWEN_API_KEY
|
from app.core.config import QWEN_API_KEY
|
||||||
from app.service.chat_robot.script.service.CallQWen import get_language
|
from app.service.chat_robot.script.service.CallQWen import get_language
|
||||||
|
from app.service.prompt_generation.util import minio_util
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -143,6 +144,38 @@ def get_translation_from_llama3(text):
|
|||||||
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_from_image(image_path, text):
|
||||||
|
start_time = time.time()
|
||||||
|
# url = "http://localhost:11434/api/generate"
|
||||||
|
url = "http://10.1.1.243:11434/api/generate"
|
||||||
|
|
||||||
|
image_base64 = minio_util.minio_url_to_base64(image_path.img)
|
||||||
|
# image_base64 = minio_url_to_base64(image_path)
|
||||||
|
|
||||||
|
# 创建请求的负载 translator是自定义的翻译模型
|
||||||
|
payload = {
|
||||||
|
"model": "llama3.2-vision",
|
||||||
|
"images": [image_base64],
|
||||||
|
"prompt": f"{text}",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
# 将负载转换为 JSON 格式
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||||
|
# 处理响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
# print("Response from server:")
|
||||||
|
# print(response.json())
|
||||||
|
resp = json.loads(response.content).get("response")
|
||||||
|
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
|
||||||
|
# print("input : {}, sketch re-generate result : {}".format(text, resp))
|
||||||
|
return resp
|
||||||
|
else:
|
||||||
|
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
|
||||||
|
print(f"Request failed with status code {response.status_code}")
|
||||||
|
print(response.text)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function"""
|
"""Main function"""
|
||||||
text = get_translation_from_llama3("[火焰]")
|
text = get_translation_from_llama3("[火焰]")
|
||||||
|
|||||||
21
app/service/prompt_generation/util/minio_util.py
Normal file
21
app/service/prompt_generation/util/minio_util.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
def minio_url_to_base64(minio_url: str) -> str:
|
||||||
|
bucket_name, object_name = minio_url.split("/", 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = minio_client.get_object(bucket_name, object_name)
|
||||||
|
image_data = response.read()
|
||||||
|
return base64.b64encode(image_data).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to get object: {e}")
|
||||||
|
finally:
|
||||||
|
if 'response' in locals():
|
||||||
|
response.close()
|
||||||
Reference in New Issue
Block a user