feat product image 新增product type 参数 ,解决single item 无法检测头部的问题

fix
This commit is contained in:
zhouchengrong
2024-07-04 14:14:57 +08:00
parent 24142a01cc
commit eede159507
13 changed files with 163 additions and 101 deletions

View File

@@ -39,6 +39,7 @@ class GenerateProductImage:
self.category = "product_image"
self.image_strength = request_data.image_strength
self.batch_size = 1
self.product_type = request_data.product_type
self.prompt = request_data.prompt
self.image, self.image_size = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
@@ -54,7 +55,10 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
if self.product_type == "single":
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
@@ -73,9 +77,16 @@ class GenerateProductImage:
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
else:
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
@@ -86,7 +97,11 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength]
input_image_strength.set_data_from_numpy(image_strength_obj)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
@@ -151,6 +166,7 @@ if __name__ == '__main__':
image_strength=0.9,
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
product_type="single"
)
server = GenerateProductImage(rd)
print(server.get_result())