feat(新功能):
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试): Agent generate test
This commit is contained in:
20
app/api/api_agent_generate_image.py
Normal file
20
app/api/api_agent_generate_image.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.generate_image.agent_generate import GenerateImage
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/agent_generate_image")
|
||||||
|
def generate_image(prompt: str):
|
||||||
|
try:
|
||||||
|
server = GenerateImage()
|
||||||
|
data = server.get_result(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"generate_image Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data)
|
||||||
@@ -12,6 +12,8 @@ from app.api import api_mannequins_edit
|
|||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
from app.api import api_recommendation
|
from app.api import api_recommendation
|
||||||
from app.api import api_super_resolution
|
from app.api import api_super_resolution
|
||||||
|
from app.api import api_agent_generate_image
|
||||||
|
|
||||||
from app.api import api_test
|
from app.api import api_test
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -30,3 +32,4 @@ router.include_router(api_query_image.router, tags=['api_query_image'], prefix="
|
|||||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||||
|
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
||||||
|
|||||||
68
app/service/generate_image/agent_generate.py
Normal file
68
app/service/generate_image/agent_generate.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :trinity_client
|
||||||
|
@File :service_att_recognition.py
|
||||||
|
@Author :周成融
|
||||||
|
@Date :2023/7/26 12:01:05
|
||||||
|
@detail :
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
from minio import Minio
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
from app.service.utils.oss_client import oss_upload_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateImage:
|
||||||
|
def __init__(self):
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||||
|
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||||
|
self.batch_size = 1
|
||||||
|
self.mode = 'txt2img'
|
||||||
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
def get_result(self, prompt):
|
||||||
|
prompts = [prompt] * self.batch_size
|
||||||
|
modes = [self.mode] * self.batch_size
|
||||||
|
images = [self.image.astype(np.float16)] * self.batch_size
|
||||||
|
|
||||||
|
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
|
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||||
|
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||||
|
|
||||||
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||||
|
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||||
|
|
||||||
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
input_mode.set_data_from_numpy(mode_obj)
|
||||||
|
|
||||||
|
inputs = [input_text, input_image, input_mode]
|
||||||
|
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
||||||
|
image = result.as_numpy("generated_image")
|
||||||
|
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||||
|
_, img_byte_array = cv2.imencode('.jpg', image_result)
|
||||||
|
object_name = f'test.jpg'
|
||||||
|
req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array)
|
||||||
|
url = self.minio_client.get_presigned_url(
|
||||||
|
"GET",
|
||||||
|
"test",
|
||||||
|
object_name,
|
||||||
|
expires=timedelta(hours=2),
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = GenerateImage()
|
||||||
|
print(server.get_result("rabbit"))
|
||||||
Reference in New Issue
Block a user