Files
AiDA_Python/app/service/generate_image/agent_generate.py
zhouchengrong e94d5ac6a3 feat(新功能):
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试): Agent generate test
2025-03-13 12:04:14 +08:00

69 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import logging
from datetime import timedelta
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.service.utils.oss_client import oss_upload_image
logger = logging.getLogger()
class GenerateImage:
def __init__(self):
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def get_result(self, prompt):
prompts = [prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result)
object_name = f'test.jpg'
req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array)
url = self.minio_client.get_presigned_url(
"GET",
"test",
object_name,
expires=timedelta(hours=2),
)
return url
if __name__ == '__main__':
server = GenerateImage()
print(server.get_result("rabbit"))