feat 新增接口描述 docs页面 ,新增S3 图片get upload 操作,整理代码

fix
This commit is contained in:
zhouchengrong
2024-06-25 16:58:17 +08:00
parent 558d86b312
commit db3d86204f
17 changed files with 1087 additions and 996 deletions

View File

@@ -10,15 +10,17 @@
import json
import logging
import time
import cv2
import minio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -120,13 +122,6 @@ class GenerateImage:
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
@@ -146,7 +141,7 @@ class GenerateImage:
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
@@ -186,10 +181,10 @@ if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test",
gender="male"
)
server = GenerateImage(rd)
print(server.get_result())
print(server.get_result())