diff --git a/app/api/api_agent_generate_image.py b/app/api/api_agent_generate_image.py new file mode 100644 index 0000000..d8efbbb --- /dev/null +++ b/app/api/api_agent_generate_image.py @@ -0,0 +1,20 @@ +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.response_template import ResponseModel +from app.service.generate_image.agent_generate import GenerateImage + +router = APIRouter() +logger = logging.getLogger() + + +@router.get("/agent_generate_image") +def generate_image(prompt: str): + try: + server = GenerateImage() + data = server.get_result(prompt) + except Exception as e: + logger.warning(f"generate_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index 33e238e..61bd43f 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -12,6 +12,8 @@ from app.api import api_mannequins_edit from app.api import api_prompt_generation from app.api import api_recommendation from app.api import api_super_resolution +from app.api import api_agent_generate_image + from app.api import api_test router = APIRouter() @@ -30,3 +32,4 @@ router.include_router(api_query_image.router, tags=['api_query_image'], prefix=" router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") +router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api") diff --git a/app/service/generate_image/agent_generate.py b/app/service/generate_image/agent_generate.py new file mode 100644 index 0000000..24623dc --- /dev/null +++ b/app/service/generate_image/agent_generate.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import logging +from datetime import timedelta + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.service.utils.oss_client import oss_upload_image + +logger = logging.getLogger() + + +class GenerateImage: + def __init__(self): + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.batch_size = 1 + self.mode = 'txt2img' + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + def get_result(self, prompt): + prompts = [prompt] * self.batch_size + modes = [self.mode] * self.batch_size + images = [self.image.astype(np.float16)] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_mode.set_data_from_numpy(mode_obj) + + inputs = [input_text, input_image, input_mode] + result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs) + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + _, img_byte_array = cv2.imencode('.jpg', image_result) + object_name = f'test.jpg' + req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) + url = self.minio_client.get_presigned_url( + "GET", + "test", + object_name, + expires=timedelta(hours=2), + ) + return url + + +if __name__ == '__main__': + server = GenerateImage() + print(server.get_result("rabbit"))